Port to use state storage

This commit is contained in:
Erik Johnston 2019-10-23 17:25:54 +01:00
parent 5db03535d5
commit 69f0054ce6
19 changed files with 216 additions and 115 deletions

View file

@ -30,6 +30,9 @@ class AdminHandler(BaseHandler):
def __init__(self, hs):
super(AdminHandler, self).__init__(hs)
self.storage = hs.get_storage()
self.state_store = self.storage.state
@defer.inlineCallbacks
def get_whois(self, user):
connections = []
@ -205,7 +208,7 @@ class AdminHandler(BaseHandler):
from_key = events[-1].internal_metadata.after
events = yield filter_events_for_client(self.store, user_id, events)
events = yield filter_events_for_client(self.storage, user_id, events)
writer.write_events(room_id, events)
@ -241,7 +244,7 @@ class AdminHandler(BaseHandler):
for event_id in extremities:
if not event_to_unseen_prevs[event_id]:
continue
state = yield self.store.get_state_for_event(event_id)
state = yield self.state_store.get_state_for_event(event_id)
writer.write_state(room_id, event_id, state)
return writer.finished()

View file

@ -46,6 +46,7 @@ class DeviceWorkerHandler(BaseHandler):
self.hs = hs
self.state = hs.get_state_handler()
self.state_store = hs.get_storage().state
self._auth_handler = hs.get_auth_handler()
@trace
@ -178,7 +179,7 @@ class DeviceWorkerHandler(BaseHandler):
continue
# mapping from event_id -> state_dict
prev_state_ids = yield self.store.get_state_ids_for_events(event_ids)
prev_state_ids = yield self.state_store.get_state_ids_for_events(event_ids)
# Check if we've joined the room? If so we just blindly add all the users to
# the "possibly changed" users.

View file

@ -147,6 +147,10 @@ class EventStreamHandler(BaseHandler):
class EventHandler(BaseHandler):
def __init__(self, hs):
super(EventHandler, self).__init__(hs)
self.storage = hs.get_storage()
@defer.inlineCallbacks
def get_event(self, user, room_id, event_id):
"""Retrieve a single specified event.
@ -172,7 +176,7 @@ class EventHandler(BaseHandler):
is_peeking = user.to_string() not in users
filtered = yield filter_events_for_client(
self.store, user.to_string(), [event], is_peeking=is_peeking
self.storage, user.to_string(), [event], is_peeking=is_peeking
)
if not filtered:

View file

@ -110,6 +110,7 @@ class FederationHandler(BaseHandler):
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.state_store = self.storage.state
self.federation_client = hs.get_federation_client()
self.state_handler = hs.get_state_handler()
self.server_name = hs.hostname
@ -325,7 +326,7 @@ class FederationHandler(BaseHandler):
event_map = {event_id: pdu}
try:
# Get the state of the events we know about
ours = yield self.store.get_state_groups_ids(room_id, seen)
ours = yield self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id
state_maps = list(
@ -889,7 +890,7 @@ class FederationHandler(BaseHandler):
# We set `check_history_visibility_only` as we might otherwise get false
# positives from users having been erased.
filtered_extremities = yield filter_events_for_server(
self.store,
self.storage,
self.server_name,
list(extremities_events.values()),
redact=False,
@ -1550,7 +1551,7 @@ class FederationHandler(BaseHandler):
event_id, allow_none=False, check_room_id=room_id
)
state_groups = yield self.store.get_state_groups(room_id, [event_id])
state_groups = yield self.state_store.get_state_groups(room_id, [event_id])
if state_groups:
_, state = list(iteritems(state_groups)).pop()
@ -1579,7 +1580,7 @@ class FederationHandler(BaseHandler):
event_id, allow_none=False, check_room_id=room_id
)
state_groups = yield self.store.get_state_groups_ids(room_id, [event_id])
state_groups = yield self.state_store.get_state_groups_ids(room_id, [event_id])
if state_groups:
_, state = list(state_groups.items()).pop()
@ -1607,7 +1608,7 @@ class FederationHandler(BaseHandler):
events = yield self.store.get_backfill_events(room_id, pdu_list, limit)
events = yield filter_events_for_server(self.store, origin, events)
events = yield filter_events_for_server(self.storage, origin, events)
return events
@ -1637,7 +1638,7 @@ class FederationHandler(BaseHandler):
if not in_room:
raise AuthError(403, "Host not in room.")
events = yield filter_events_for_server(self.store, origin, [event])
events = yield filter_events_for_server(self.storage, origin, [event])
event = events[0]
return event
else:
@ -1903,7 +1904,7 @@ class FederationHandler(BaseHandler):
# given state at the event. This should correctly handle cases
# like bans, especially with state res v2.
state_sets = yield self.store.get_state_groups(
state_sets = yield self.state_store.get_state_groups(
event.room_id, extrem_ids
)
state_sets = list(state_sets.values())
@ -1994,7 +1995,7 @@ class FederationHandler(BaseHandler):
)
missing_events = yield filter_events_for_server(
self.store, origin, missing_events
self.storage, origin, missing_events
)
return missing_events
@ -2235,7 +2236,7 @@ class FederationHandler(BaseHandler):
# create a new state group as a delta from the existing one.
prev_group = context.state_group
state_group = yield self.store.store_state_group(
state_group = yield self.state_store.store_state_group(
event.event_id,
event.room_id,
prev_group=prev_group,

View file

@ -43,6 +43,8 @@ class InitialSyncHandler(BaseHandler):
self.validator = EventValidator()
self.snapshot_cache = SnapshotCache()
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
self.state_store = self.storage.state
def snapshot_all_rooms(
self,
@ -169,7 +171,7 @@ class InitialSyncHandler(BaseHandler):
elif event.membership == Membership.LEAVE:
room_end_token = "s%d" % (event.stream_ordering,)
deferred_room_state = run_in_background(
self.store.get_state_for_events, [event.event_id]
self.state_store.get_state_for_events, [event.event_id]
)
deferred_room_state.addCallback(
lambda states: states[event.event_id]
@ -189,7 +191,9 @@ class InitialSyncHandler(BaseHandler):
)
).addErrback(unwrapFirstError)
messages = yield filter_events_for_client(self.store, user_id, messages)
messages = yield filter_events_for_client(
self.storage, user_id, messages
)
start_token = now_token.copy_and_replace("room_key", token)
end_token = now_token.copy_and_replace("room_key", room_end_token)
@ -307,7 +311,7 @@ class InitialSyncHandler(BaseHandler):
def _room_initial_sync_parted(
self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking
):
room_state = yield self.store.get_state_for_events([member_event_id])
room_state = yield self.state_store.get_state_for_events([member_event_id])
room_state = room_state[member_event_id]
@ -322,7 +326,7 @@ class InitialSyncHandler(BaseHandler):
)
messages = yield filter_events_for_client(
self.store, user_id, messages, is_peeking=is_peeking
self.storage, user_id, messages, is_peeking=is_peeking
)
start_token = StreamToken.START.copy_and_replace("room_key", token)
@ -414,7 +418,7 @@ class InitialSyncHandler(BaseHandler):
)
messages = yield filter_events_for_client(
self.store, user_id, messages, is_peeking=is_peeking
self.storage, user_id, messages, is_peeking=is_peeking
)
start_token = now_token.copy_and_replace("room_key", token)

View file

@ -59,6 +59,8 @@ class MessageHandler(object):
self.clock = hs.get_clock()
self.state = hs.get_state_handler()
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.state_store = self.storage.state
self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks
@ -82,7 +84,7 @@ class MessageHandler(object):
data = yield self.state.get_current_state(room_id, event_type, state_key)
elif membership == Membership.LEAVE:
key = (event_type, state_key)
room_state = yield self.store.get_state_for_events(
room_state = yield self.state_store.get_state_for_events(
[membership_event_id], StateFilter.from_types([key])
)
data = room_state[membership_event_id].get(key)
@ -135,12 +137,12 @@ class MessageHandler(object):
raise NotFoundError("Can't find event for token %s" % (at_token,))
visible_events = yield filter_events_for_client(
self.store, user_id, last_events
self.storage, user_id, last_events
)
event = last_events[0]
if visible_events:
room_state = yield self.store.get_state_for_events(
room_state = yield self.state_store.get_state_for_events(
[event.event_id], state_filter=state_filter
)
room_state = room_state[event.event_id]
@ -161,7 +163,7 @@ class MessageHandler(object):
)
room_state = yield self.store.get_events(state_ids.values())
elif membership == Membership.LEAVE:
room_state = yield self.store.get_state_for_events(
room_state = yield self.state_store.get_state_for_events(
[membership_event_id], state_filter=state_filter
)
room_state = room_state[membership_event_id]

View file

@ -69,6 +69,8 @@ class PaginationHandler(object):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.state_store = self.storage.state
self.clock = hs.get_clock()
self._server_name = hs.hostname
@ -255,7 +257,7 @@ class PaginationHandler(object):
events = event_filter.filter(events)
events = yield filter_events_for_client(
self.store, user_id, events, is_peeking=(member_event_id is None)
self.storage, user_id, events, is_peeking=(member_event_id is None)
)
if not events:
@ -274,7 +276,7 @@ class PaginationHandler(object):
(EventTypes.Member, event.sender) for event in events
)
state_ids = yield self.store.get_state_ids_for_event(
state_ids = yield self.state_store.get_state_ids_for_event(
events[0].event_id, state_filter=state_filter
)

View file

@ -822,6 +822,8 @@ class RoomContextHandler(object):
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.state_store = self.storage.state
@defer.inlineCallbacks
def get_event_context(self, user, room_id, event_id, limit, event_filter):
@ -848,7 +850,7 @@ class RoomContextHandler(object):
def filter_evts(events):
return filter_events_for_client(
self.store, user.to_string(), events, is_peeking=is_peeking
self.storage, user.to_string(), events, is_peeking=is_peeking
)
event = yield self.store.get_event(
@ -890,7 +892,7 @@ class RoomContextHandler(object):
# first? Shouldn't we be consistent with /sync?
# https://github.com/matrix-org/matrix-doc/issues/687
state = yield self.store.get_state_for_events(
state = yield self.state_store.get_state_for_events(
[last_event_id], state_filter=state_filter
)
results["state"] = list(state[last_event_id].values())

View file

@ -35,6 +35,8 @@ class SearchHandler(BaseHandler):
def __init__(self, hs):
super(SearchHandler, self).__init__(hs)
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
self.state_store = self.storage.state
@defer.inlineCallbacks
def get_old_rooms_from_upgraded_room(self, room_id):
@ -221,7 +223,7 @@ class SearchHandler(BaseHandler):
filtered_events = search_filter.filter([r["event"] for r in results])
events = yield filter_events_for_client(
self.store, user.to_string(), filtered_events
self.storage, user.to_string(), filtered_events
)
events.sort(key=lambda e: -rank_map[e.event_id])
@ -271,7 +273,7 @@ class SearchHandler(BaseHandler):
filtered_events = search_filter.filter([r["event"] for r in results])
events = yield filter_events_for_client(
self.store, user.to_string(), filtered_events
self.storage, user.to_string(), filtered_events
)
room_events.extend(events)
@ -340,11 +342,11 @@ class SearchHandler(BaseHandler):
)
res["events_before"] = yield filter_events_for_client(
self.store, user.to_string(), res["events_before"]
self.storage, user.to_string(), res["events_before"]
)
res["events_after"] = yield filter_events_for_client(
self.store, user.to_string(), res["events_after"]
self.storage, user.to_string(), res["events_after"]
)
res["start"] = now_token.copy_and_replace(
@ -372,7 +374,7 @@ class SearchHandler(BaseHandler):
[(EventTypes.Member, sender) for sender in senders]
)
state = yield self.store.get_state_for_event(
state = yield self.state_store.get_state_for_event(
last_event_id, state_filter
)

View file

@ -230,6 +230,8 @@ class SyncHandler(object):
self.response_cache = ResponseCache(hs, "sync")
self.state = hs.get_state_handler()
self.auth = hs.get_auth()
self.storage = hs.get_storage()
self.state_store = self.storage.state
# ExpiringCache((User, Device)) -> LruCache(state_key => event_id)
self.lazy_loaded_members_cache = ExpiringCache(
@ -417,7 +419,7 @@ class SyncHandler(object):
current_state_ids = frozenset(itervalues(current_state_ids))
recents = yield filter_events_for_client(
self.store,
self.storage,
sync_config.user.to_string(),
recents,
always_include_ids=current_state_ids,
@ -470,7 +472,7 @@ class SyncHandler(object):
current_state_ids = frozenset(itervalues(current_state_ids))
loaded_recents = yield filter_events_for_client(
self.store,
self.storage,
sync_config.user.to_string(),
loaded_recents,
always_include_ids=current_state_ids,
@ -509,7 +511,7 @@ class SyncHandler(object):
Returns:
A Deferred map from ((type, state_key)->Event)
"""
state_ids = yield self.store.get_state_ids_for_event(
state_ids = yield self.state_store.get_state_ids_for_event(
event.event_id, state_filter=state_filter
)
if event.is_state():
@ -580,7 +582,7 @@ class SyncHandler(object):
return None
last_event = last_events[-1]
state_ids = yield self.store.get_state_ids_for_event(
state_ids = yield self.state_store.get_state_ids_for_event(
last_event.event_id,
state_filter=StateFilter.from_types(
[(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")]
@ -757,11 +759,11 @@ class SyncHandler(object):
if full_state:
if batch:
current_state_ids = yield self.store.get_state_ids_for_event(
current_state_ids = yield self.state_store.get_state_ids_for_event(
batch.events[-1].event_id, state_filter=state_filter
)
state_ids = yield self.store.get_state_ids_for_event(
state_ids = yield self.state_store.get_state_ids_for_event(
batch.events[0].event_id, state_filter=state_filter
)
@ -781,7 +783,7 @@ class SyncHandler(object):
)
elif batch.limited:
if batch:
state_at_timeline_start = yield self.store.get_state_ids_for_event(
state_at_timeline_start = yield self.state_store.get_state_ids_for_event(
batch.events[0].event_id, state_filter=state_filter
)
else:
@ -810,7 +812,7 @@ class SyncHandler(object):
)
if batch:
current_state_ids = yield self.store.get_state_ids_for_event(
current_state_ids = yield self.state_store.get_state_ids_for_event(
batch.events[-1].event_id, state_filter=state_filter
)
else:
@ -841,7 +843,7 @@ class SyncHandler(object):
# So we fish out all the member events corresponding to the
# timeline here, and then dedupe any redundant ones below.
state_ids = yield self.store.get_state_ids_for_event(
state_ids = yield self.state_store.get_state_ids_for_event(
batch.events[0].event_id,
# we only want members!
state_filter=StateFilter.from_types(