Don't call SQLBaseStore methods from outside stores

This commit is contained in:
Erik Johnston 2019-12-04 10:16:44 +00:00
parent 3eb15c01d9
commit c2f525a525
4 changed files with 20 additions and 19 deletions

View File

@ -542,8 +542,8 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
# Database version # Database version
# #
stats["database_engine"] = hs.get_datastore().database_engine_name stats["database_engine"] = hs.database_engine.module.__name__
stats["database_server_version"] = hs.get_datastore().get_server_version() stats["database_server_version"] = hs.database_engine.server_version
logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)) logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
try: try:
yield hs.get_proxied_http_client().put_json( yield hs.get_proxied_http_client().put_json(

View File

@ -386,15 +386,7 @@ class RulesForRoom(object):
""" """
sequence = self.sequence sequence = self.sequence
rows = yield self.store._simple_select_many_batch( rows = yield self.store.get_membership_from_event_ids(member_event_ids.values())
table="room_memberships",
column="event_id",
iterable=member_event_ids.values(),
retcols=("user_id", "membership", "event_id"),
keyvalues={},
batch_size=500,
desc="_get_rules_for_member_event_ids",
)
members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows} members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows}

View File

@ -1496,14 +1496,6 @@ class SQLBaseStore(object):
return cls.cursor_to_dict(txn) return cls.cursor_to_dict(txn)
@property
def database_engine_name(self):
return self.database_engine.module.__name__
def get_server_version(self):
"""Returns a string describing the server version number"""
return self.database_engine.server_version
class _RollbackButIsFineException(Exception): class _RollbackButIsFineException(Exception):
""" This exception is used to rollback a transaction without implying """ This exception is used to rollback a transaction without implying

View File

@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Iterable, List
from six import iteritems, itervalues from six import iteritems, itervalues
@ -813,6 +814,22 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return set(room_ids) return set(room_ids)
def get_membership_from_event_ids(
self, member_event_ids: Iterable[str]
) -> List[dict]:
"""Get user_id and membership of a set of event IDs.
"""
return self._simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids,
retcols=("user_id", "membership", "event_id"),
keyvalues={},
batch_size=500,
desc="get_membership_from_event_ids",
)
class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):