mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Add StateGroupStorage interface
This commit is contained in:
parent
b7fe62b766
commit
5db03535d5
@ -30,6 +30,7 @@ stored in `synapse.storage.schema`.
|
||||
from synapse.storage.data_stores import DataStores
|
||||
from synapse.storage.data_stores.main import DataStore
|
||||
from synapse.storage.persist_events import EventsPersistenceStorage
|
||||
from synapse.storage.state import StateGroupStorage
|
||||
|
||||
__all__ = ["DataStores", "DataStore"]
|
||||
|
||||
@ -45,6 +46,7 @@ class Storage(object):
|
||||
self.main = stores.main
|
||||
|
||||
self.persistence = EventsPersistenceStorage(hs, stores)
|
||||
self.state = StateGroupStorage(hs, stores)
|
||||
|
||||
|
||||
def are_all_users_on_domain(txn, database_engine, domain):
|
||||
|
@ -550,7 +550,7 @@ class EventsPersistenceStorage(object):
|
||||
|
||||
if missing_event_ids:
|
||||
# Now pull out the state groups for any missing events from DB
|
||||
event_to_groups = yield self.state_store._get_state_group_for_events(
|
||||
event_to_groups = yield self.main_store._get_state_group_for_events(
|
||||
missing_event_ids
|
||||
)
|
||||
event_id_to_state_group.update(event_to_groups)
|
||||
|
@ -19,6 +19,8 @@ from six import iteritems, itervalues
|
||||
|
||||
import attr
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -322,3 +324,233 @@ class StateFilter(object):
|
||||
)
|
||||
|
||||
return member_filter, non_member_filter
|
||||
|
||||
|
||||
class StateGroupStorage(object):
|
||||
"""High level interface to fetching state for event.
|
||||
"""
|
||||
|
||||
def __init__(self, hs, stores):
|
||||
self.stores = stores
|
||||
|
||||
def get_state_group_delta(self, state_group):
|
||||
"""Given a state group try to return a previous group and a delta between
|
||||
the old and the new.
|
||||
|
||||
Returns:
|
||||
(prev_group, delta_ids), where both may be None.
|
||||
"""
|
||||
|
||||
return self.stores.main.get_state_group_delta(state_group)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_groups_ids(self, _room_id, event_ids):
|
||||
"""Get the event IDs of all the state for the state groups for the given events
|
||||
|
||||
Args:
|
||||
_room_id (str): id of the room for these events
|
||||
event_ids (iterable[str]): ids of the events
|
||||
|
||||
Returns:
|
||||
Deferred[dict[int, dict[tuple[str, str], str]]]:
|
||||
dict of state_group_id -> (dict of (type, state_key) -> event id)
|
||||
"""
|
||||
if not event_ids:
|
||||
return {}
|
||||
|
||||
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
|
||||
|
||||
groups = set(itervalues(event_to_groups))
|
||||
group_to_state = yield self.stores.main._get_state_for_groups(groups)
|
||||
|
||||
return group_to_state
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_ids_for_group(self, state_group):
|
||||
"""Get the event IDs of all the state in the given state group
|
||||
|
||||
Args:
|
||||
state_group (int)
|
||||
|
||||
Returns:
|
||||
Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
|
||||
"""
|
||||
group_to_state = yield self._get_state_for_groups((state_group,))
|
||||
|
||||
return group_to_state[state_group]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_groups(self, room_id, event_ids):
|
||||
""" Get the state groups for the given list of event_ids
|
||||
Returns:
|
||||
Deferred[dict[int, list[EventBase]]]:
|
||||
dict of state_group_id -> list of state events.
|
||||
"""
|
||||
if not event_ids:
|
||||
return {}
|
||||
|
||||
group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
|
||||
|
||||
state_event_map = yield self.stores.main.get_events(
|
||||
[
|
||||
ev_id
|
||||
for group_ids in itervalues(group_to_ids)
|
||||
for ev_id in itervalues(group_ids)
|
||||
],
|
||||
get_prev_content=False,
|
||||
)
|
||||
|
||||
return {
|
||||
group: [
|
||||
state_event_map[v]
|
||||
for v in itervalues(event_id_map)
|
||||
if v in state_event_map
|
||||
]
|
||||
for group, event_id_map in iteritems(group_to_ids)
|
||||
}
|
||||
|
||||
def _get_state_groups_from_groups(self, groups, state_filter):
|
||||
"""Returns the state groups for a given set of groups, filtering on
|
||||
types of state events.
|
||||
|
||||
Args:
|
||||
groups(list[int]): list of state group IDs to query
|
||||
state_filter (StateFilter): The state filter used to fetch state
|
||||
from the database.
|
||||
Returns:
|
||||
Deferred[dict[int, dict[tuple[str, str], str]]]:
|
||||
dict of state_group_id -> (dict of (type, state_key) -> event id)
|
||||
"""
|
||||
|
||||
return self.stores.main._get_state_groups_from_groups(groups, state_filter)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
|
||||
"""Given a list of event_ids and type tuples, return a list of state
|
||||
dicts for each event.
|
||||
Args:
|
||||
event_ids (list[string])
|
||||
state_filter (StateFilter): The state filter used to fetch state
|
||||
from the database.
|
||||
Returns:
|
||||
deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
|
||||
"""
|
||||
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
|
||||
|
||||
groups = set(itervalues(event_to_groups))
|
||||
group_to_state = yield self.stores.main._get_state_for_groups(
|
||||
groups, state_filter
|
||||
)
|
||||
|
||||
state_event_map = yield self.stores.main.get_events(
|
||||
[ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
|
||||
get_prev_content=False,
|
||||
)
|
||||
|
||||
event_to_state = {
|
||||
event_id: {
|
||||
k: state_event_map[v]
|
||||
for k, v in iteritems(group_to_state[group])
|
||||
if v in state_event_map
|
||||
}
|
||||
for event_id, group in iteritems(event_to_groups)
|
||||
}
|
||||
|
||||
return {event: event_to_state[event] for event in event_ids}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
|
||||
"""
|
||||
Get the state dicts corresponding to a list of events, containing the event_ids
|
||||
of the state events (as opposed to the events themselves)
|
||||
|
||||
Args:
|
||||
event_ids(list(str)): events whose state should be returned
|
||||
state_filter (StateFilter): The state filter used to fetch state
|
||||
from the database.
|
||||
|
||||
Returns:
|
||||
A deferred dict from event_id -> (type, state_key) -> event_id
|
||||
"""
|
||||
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
|
||||
|
||||
groups = set(itervalues(event_to_groups))
|
||||
group_to_state = yield self.stores.main._get_state_for_groups(
|
||||
groups, state_filter
|
||||
)
|
||||
|
||||
event_to_state = {
|
||||
event_id: group_to_state[group]
|
||||
for event_id, group in iteritems(event_to_groups)
|
||||
}
|
||||
|
||||
return {event: event_to_state[event] for event in event_ids}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
|
||||
"""
|
||||
Get the state dict corresponding to a particular event
|
||||
|
||||
Args:
|
||||
event_id(str): event whose state should be returned
|
||||
state_filter (StateFilter): The state filter used to fetch state
|
||||
from the database.
|
||||
|
||||
Returns:
|
||||
A deferred dict from (type, state_key) -> state_event
|
||||
"""
|
||||
state_map = yield self.get_state_for_events([event_id], state_filter)
|
||||
return state_map[event_id]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
|
||||
"""
|
||||
Get the state dict corresponding to a particular event
|
||||
|
||||
Args:
|
||||
event_id(str): event whose state should be returned
|
||||
state_filter (StateFilter): The state filter used to fetch state
|
||||
from the database.
|
||||
|
||||
Returns:
|
||||
A deferred dict from (type, state_key) -> state_event
|
||||
"""
|
||||
state_map = yield self.get_state_ids_for_events([event_id], state_filter)
|
||||
return state_map[event_id]
|
||||
|
||||
def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
|
||||
"""Gets the state at each of a list of state groups, optionally
|
||||
filtering by type/state_key
|
||||
|
||||
Args:
|
||||
groups (iterable[int]): list of state groups for which we want
|
||||
to get the state.
|
||||
state_filter (StateFilter): The state filter used to fetch state
|
||||
from the database.
|
||||
Returns:
|
||||
Deferred[dict[int, dict[tuple[str, str], str]]]:
|
||||
dict of state_group_id -> (dict of (type, state_key) -> event id)
|
||||
"""
|
||||
return self.stores.main._get_state_for_groups(groups, state_filter)
|
||||
|
||||
def store_state_group(
|
||||
self, event_id, room_id, prev_group, delta_ids, current_state_ids
|
||||
):
|
||||
"""Store a new set of state, returning a newly assigned state group.
|
||||
|
||||
Args:
|
||||
event_id (str): The event ID for which the state was calculated
|
||||
room_id (str)
|
||||
prev_group (int|None): A previous state group for the room, optional.
|
||||
delta_ids (dict|None): The delta between state at `prev_group` and
|
||||
`current_state_ids`, if `prev_group` was given. Same format as
|
||||
`current_state_ids`.
|
||||
current_state_ids (dict): The state to store. Map of (type, state_key)
|
||||
to event_id.
|
||||
|
||||
Returns:
|
||||
Deferred[int]: The state group ID
|
||||
"""
|
||||
return self.stores.main.store_state_group(
|
||||
event_id, room_id, prev_group, delta_ids, current_state_ids
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user