mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Rename storage classes (#12913)
This commit is contained in:
parent
e541bb9eed
commit
1e453053cb
1
changelog.d/12913.misc
Normal file
1
changelog.d/12913.misc
Normal file
@ -0,0 +1 @@
|
||||
Rename storage classes.
|
@ -22,7 +22,7 @@ from synapse.events import EventBase
|
||||
from synapse.types import JsonDict, StateMap
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.storage import Storage
|
||||
from synapse.storage.controllers import StorageControllers
|
||||
from synapse.storage.databases.main import DataStore
|
||||
from synapse.storage.state import StateFilter
|
||||
|
||||
@ -84,7 +84,7 @@ class EventContext:
|
||||
incomplete state.
|
||||
"""
|
||||
|
||||
_storage: "Storage"
|
||||
_storage: "StorageControllers"
|
||||
rejected: Union[Literal[False], str] = False
|
||||
_state_group: Optional[int] = None
|
||||
state_group_before_event: Optional[int] = None
|
||||
@ -97,7 +97,7 @@ class EventContext:
|
||||
|
||||
@staticmethod
|
||||
def with_state(
|
||||
storage: "Storage",
|
||||
storage: "StorageControllers",
|
||||
state_group: Optional[int],
|
||||
state_group_before_event: Optional[int],
|
||||
state_delta_due_to_event: Optional[StateMap[str]],
|
||||
@ -117,7 +117,7 @@ class EventContext:
|
||||
|
||||
@staticmethod
|
||||
def for_outlier(
|
||||
storage: "Storage",
|
||||
storage: "StorageControllers",
|
||||
) -> "EventContext":
|
||||
"""Return an EventContext instance suitable for persisting an outlier event"""
|
||||
return EventContext(storage=storage)
|
||||
@ -147,7 +147,7 @@ class EventContext:
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def deserialize(storage: "Storage", input: JsonDict) -> "EventContext":
|
||||
def deserialize(storage: "StorageControllers", input: JsonDict) -> "EventContext":
|
||||
"""Converts a dict that was produced by `serialize` back into a
|
||||
EventContext.
|
||||
|
||||
|
@ -109,7 +109,6 @@ class FederationServer(FederationBase):
|
||||
super().__init__(hs)
|
||||
|
||||
self.handler = hs.get_federation_handler()
|
||||
self.storage = hs.get_storage()
|
||||
self._spam_checker = hs.get_spam_checker()
|
||||
self._federation_event_handler = hs.get_federation_event_handler()
|
||||
self.state = hs.get_state_handler()
|
||||
|
@ -30,8 +30,8 @@ logger = logging.getLogger(__name__)
|
||||
class AdminHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastores().main
|
||||
self.storage = hs.get_storage()
|
||||
self.state_storage = self.storage.state
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self._state_storage_controller = self._storage_controllers.state
|
||||
|
||||
async def get_whois(self, user: UserID) -> JsonDict:
|
||||
connections = []
|
||||
@ -197,7 +197,9 @@ class AdminHandler:
|
||||
|
||||
from_key = events[-1].internal_metadata.after
|
||||
|
||||
events = await filter_events_for_client(self.storage, user_id, events)
|
||||
events = await filter_events_for_client(
|
||||
self._storage_controllers, user_id, events
|
||||
)
|
||||
|
||||
writer.write_events(room_id, events)
|
||||
|
||||
@ -233,7 +235,9 @@ class AdminHandler:
|
||||
for event_id in extremities:
|
||||
if not event_to_unseen_prevs[event_id]:
|
||||
continue
|
||||
state = await self.state_storage.get_state_for_event(event_id)
|
||||
state = await self._state_storage_controller.get_state_for_event(
|
||||
event_id
|
||||
)
|
||||
writer.write_state(room_id, event_id, state)
|
||||
|
||||
return writer.finished()
|
||||
|
@ -71,7 +71,7 @@ class DeviceWorkerHandler:
|
||||
self.store = hs.get_datastores().main
|
||||
self.notifier = hs.get_notifier()
|
||||
self.state = hs.get_state_handler()
|
||||
self.state_storage = hs.get_storage().state
|
||||
self._state_storage = hs.get_storage_controllers().state
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self.server_name = hs.hostname
|
||||
|
||||
@ -204,7 +204,7 @@ class DeviceWorkerHandler:
|
||||
continue
|
||||
|
||||
# mapping from event_id -> state_dict
|
||||
prev_state_ids = await self.state_storage.get_state_ids_for_events(
|
||||
prev_state_ids = await self._state_storage.get_state_ids_for_events(
|
||||
event_ids
|
||||
)
|
||||
|
||||
|
@ -139,7 +139,7 @@ class EventStreamHandler:
|
||||
class EventHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastores().main
|
||||
self.storage = hs.get_storage()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
|
||||
async def get_event(
|
||||
self,
|
||||
@ -177,7 +177,7 @@ class EventHandler:
|
||||
is_peeking = user.to_string() not in users
|
||||
|
||||
filtered = await filter_events_for_client(
|
||||
self.storage, user.to_string(), [event], is_peeking=is_peeking
|
||||
self._storage_controllers, user.to_string(), [event], is_peeking=is_peeking
|
||||
)
|
||||
|
||||
if not filtered:
|
||||
|
@ -125,8 +125,8 @@ class FederationHandler:
|
||||
self.hs = hs
|
||||
|
||||
self.store = hs.get_datastores().main
|
||||
self.storage = hs.get_storage()
|
||||
self.state_storage = self.storage.state
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self._state_storage_controller = self._storage_controllers.state
|
||||
self.federation_client = hs.get_federation_client()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
self.server_name = hs.hostname
|
||||
@ -324,7 +324,7 @@ class FederationHandler:
|
||||
# We set `check_history_visibility_only` as we might otherwise get false
|
||||
# positives from users having been erased.
|
||||
filtered_extremities = await filter_events_for_server(
|
||||
self.storage,
|
||||
self._storage_controllers,
|
||||
self.server_name,
|
||||
events_to_check,
|
||||
redact=False,
|
||||
@ -660,7 +660,7 @@ class FederationHandler:
|
||||
# in the invitee's sync stream. It is stripped out for all other local users.
|
||||
event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]
|
||||
|
||||
context = EventContext.for_outlier(self.storage)
|
||||
context = EventContext.for_outlier(self._storage_controllers)
|
||||
stream_id = await self._federation_event_handler.persist_events_and_notify(
|
||||
event.room_id, [(event, context)]
|
||||
)
|
||||
@ -849,7 +849,7 @@ class FederationHandler:
|
||||
)
|
||||
)
|
||||
|
||||
context = EventContext.for_outlier(self.storage)
|
||||
context = EventContext.for_outlier(self._storage_controllers)
|
||||
await self._federation_event_handler.persist_events_and_notify(
|
||||
event.room_id, [(event, context)]
|
||||
)
|
||||
@ -878,7 +878,7 @@ class FederationHandler:
|
||||
|
||||
await self.federation_client.send_leave(host_list, event)
|
||||
|
||||
context = EventContext.for_outlier(self.storage)
|
||||
context = EventContext.for_outlier(self._storage_controllers)
|
||||
stream_id = await self._federation_event_handler.persist_events_and_notify(
|
||||
event.room_id, [(event, context)]
|
||||
)
|
||||
@ -1027,7 +1027,7 @@ class FederationHandler:
|
||||
if event.internal_metadata.outlier:
|
||||
raise NotFoundError("State not known at event %s" % (event_id,))
|
||||
|
||||
state_groups = await self.state_storage.get_state_groups_ids(
|
||||
state_groups = await self._state_storage_controller.get_state_groups_ids(
|
||||
room_id, [event_id]
|
||||
)
|
||||
|
||||
@ -1078,7 +1078,9 @@ class FederationHandler:
|
||||
],
|
||||
)
|
||||
|
||||
events = await filter_events_for_server(self.storage, origin, events)
|
||||
events = await filter_events_for_server(
|
||||
self._storage_controllers, origin, events
|
||||
)
|
||||
|
||||
return events
|
||||
|
||||
@ -1109,7 +1111,9 @@ class FederationHandler:
|
||||
if not in_room:
|
||||
raise AuthError(403, "Host not in room.")
|
||||
|
||||
events = await filter_events_for_server(self.storage, origin, [event])
|
||||
events = await filter_events_for_server(
|
||||
self._storage_controllers, origin, [event]
|
||||
)
|
||||
event = events[0]
|
||||
return event
|
||||
else:
|
||||
@ -1138,7 +1142,7 @@ class FederationHandler:
|
||||
)
|
||||
|
||||
missing_events = await filter_events_for_server(
|
||||
self.storage, origin, missing_events
|
||||
self._storage_controllers, origin, missing_events
|
||||
)
|
||||
|
||||
return missing_events
|
||||
@ -1480,9 +1484,11 @@ class FederationHandler:
|
||||
# clear the lazy-loading flag.
|
||||
logger.info("Updating current state for %s", room_id)
|
||||
assert (
|
||||
self.storage.persistence is not None
|
||||
self._storage_controllers.persistence is not None
|
||||
), "TODO(faster_joins): support for workers"
|
||||
await self.storage.persistence.update_current_state(room_id)
|
||||
await self._storage_controllers.persistence.update_current_state(
|
||||
room_id
|
||||
)
|
||||
|
||||
logger.info("Clearing partial-state flag for %s", room_id)
|
||||
success = await self.store.clear_partial_state_room(room_id)
|
||||
|
@ -98,8 +98,8 @@ class FederationEventHandler:
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._store = hs.get_datastores().main
|
||||
self._storage = hs.get_storage()
|
||||
self._state_storage = self._storage.state
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self._state_storage_controller = self._storage_controllers.state
|
||||
|
||||
self._state_handler = hs.get_state_handler()
|
||||
self._event_creation_handler = hs.get_event_creation_handler()
|
||||
@ -535,7 +535,9 @@ class FederationEventHandler:
|
||||
)
|
||||
return
|
||||
await self._store.update_state_for_partial_state_event(event, context)
|
||||
self._state_storage.notify_event_un_partial_stated(event.event_id)
|
||||
self._state_storage_controller.notify_event_un_partial_stated(
|
||||
event.event_id
|
||||
)
|
||||
|
||||
async def backfill(
|
||||
self, dest: str, room_id: str, limit: int, extremities: Collection[str]
|
||||
@ -835,7 +837,9 @@ class FederationEventHandler:
|
||||
|
||||
try:
|
||||
# Get the state of the events we know about
|
||||
ours = await self._state_storage.get_state_groups_ids(room_id, seen)
|
||||
ours = await self._state_storage_controller.get_state_groups_ids(
|
||||
room_id, seen
|
||||
)
|
||||
|
||||
# state_maps is a list of mappings from (type, state_key) to event_id
|
||||
state_maps: List[StateMap[str]] = list(ours.values())
|
||||
@ -1436,7 +1440,7 @@ class FederationEventHandler:
|
||||
# we're not bothering about room state, so flag the event as an outlier.
|
||||
event.internal_metadata.outlier = True
|
||||
|
||||
context = EventContext.for_outlier(self._storage)
|
||||
context = EventContext.for_outlier(self._storage_controllers)
|
||||
try:
|
||||
validate_event_for_room_version(room_version_obj, event)
|
||||
check_auth_rules_for_event(room_version_obj, event, auth)
|
||||
@ -1613,7 +1617,7 @@ class FederationEventHandler:
|
||||
# given state at the event. This should correctly handle cases
|
||||
# like bans, especially with state res v2.
|
||||
|
||||
state_sets_d = await self._state_storage.get_state_groups_ids(
|
||||
state_sets_d = await self._state_storage_controller.get_state_groups_ids(
|
||||
event.room_id, extrem_ids
|
||||
)
|
||||
state_sets: List[StateMap[str]] = list(state_sets_d.values())
|
||||
@ -1885,7 +1889,7 @@ class FederationEventHandler:
|
||||
|
||||
# create a new state group as a delta from the existing one.
|
||||
prev_group = context.state_group
|
||||
state_group = await self._state_storage.store_state_group(
|
||||
state_group = await self._state_storage_controller.store_state_group(
|
||||
event.event_id,
|
||||
event.room_id,
|
||||
prev_group=prev_group,
|
||||
@ -1894,7 +1898,7 @@ class FederationEventHandler:
|
||||
)
|
||||
|
||||
return EventContext.with_state(
|
||||
storage=self._storage,
|
||||
storage=self._storage_controllers,
|
||||
state_group=state_group,
|
||||
state_group_before_event=context.state_group_before_event,
|
||||
state_delta_due_to_event=state_updates,
|
||||
@ -1984,11 +1988,14 @@ class FederationEventHandler:
|
||||
)
|
||||
return result["max_stream_id"]
|
||||
else:
|
||||
assert self._storage.persistence
|
||||
assert self._storage_controllers.persistence
|
||||
|
||||
# Note that this returns the events that were persisted, which may not be
|
||||
# the same as were passed in if some were deduplicated due to transaction IDs.
|
||||
events, max_stream_token = await self._storage.persistence.persist_events(
|
||||
(
|
||||
events,
|
||||
max_stream_token,
|
||||
) = await self._storage_controllers.persistence.persist_events(
|
||||
event_and_contexts, backfilled=backfilled
|
||||
)
|
||||
|
||||
|
@ -67,8 +67,8 @@ class InitialSyncHandler:
|
||||
]
|
||||
] = ResponseCache(hs.get_clock(), "initial_sync_cache")
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
self.storage = hs.get_storage()
|
||||
self.state_storage = self.storage.state
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self._state_storage_controller = self._storage_controllers.state
|
||||
|
||||
async def snapshot_all_rooms(
|
||||
self,
|
||||
@ -198,7 +198,8 @@ class InitialSyncHandler:
|
||||
event.stream_ordering,
|
||||
)
|
||||
deferred_room_state = run_in_background(
|
||||
self.state_storage.get_state_for_events, [event.event_id]
|
||||
self._state_storage_controller.get_state_for_events,
|
||||
[event.event_id],
|
||||
).addCallback(
|
||||
lambda states: cast(StateMap[EventBase], states[event.event_id])
|
||||
)
|
||||
@ -218,7 +219,7 @@ class InitialSyncHandler:
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
messages = await filter_events_for_client(
|
||||
self.storage, user_id, messages
|
||||
self._storage_controllers, user_id, messages
|
||||
)
|
||||
|
||||
start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)
|
||||
@ -355,7 +356,9 @@ class InitialSyncHandler:
|
||||
member_event_id: str,
|
||||
is_peeking: bool,
|
||||
) -> JsonDict:
|
||||
room_state = await self.state_storage.get_state_for_event(member_event_id)
|
||||
room_state = await self._state_storage_controller.get_state_for_event(
|
||||
member_event_id
|
||||
)
|
||||
|
||||
limit = pagin_config.limit if pagin_config else None
|
||||
if limit is None:
|
||||
@ -369,7 +372,7 @@ class InitialSyncHandler:
|
||||
)
|
||||
|
||||
messages = await filter_events_for_client(
|
||||
self.storage, user_id, messages, is_peeking=is_peeking
|
||||
self._storage_controllers, user_id, messages, is_peeking=is_peeking
|
||||
)
|
||||
|
||||
start_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, token)
|
||||
@ -474,7 +477,7 @@ class InitialSyncHandler:
|
||||
)
|
||||
|
||||
messages = await filter_events_for_client(
|
||||
self.storage, user_id, messages, is_peeking=is_peeking
|
||||
self._storage_controllers, user_id, messages, is_peeking=is_peeking
|
||||
)
|
||||
|
||||
start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)
|
||||
|
@ -84,8 +84,8 @@ class MessageHandler:
|
||||
self.clock = hs.get_clock()
|
||||
self.state = hs.get_state_handler()
|
||||
self.store = hs.get_datastores().main
|
||||
self.storage = hs.get_storage()
|
||||
self.state_storage = self.storage.state
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self._state_storage_controller = self._storage_controllers.state
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
self._ephemeral_events_enabled = hs.config.server.enable_ephemeral_messages
|
||||
|
||||
@ -132,7 +132,7 @@ class MessageHandler:
|
||||
assert (
|
||||
membership_event_id is not None
|
||||
), "check_user_in_room_or_world_readable returned invalid data"
|
||||
room_state = await self.state_storage.get_state_for_events(
|
||||
room_state = await self._state_storage_controller.get_state_for_events(
|
||||
[membership_event_id], StateFilter.from_types([key])
|
||||
)
|
||||
data = room_state[membership_event_id].get(key)
|
||||
@ -193,7 +193,7 @@ class MessageHandler:
|
||||
|
||||
# check whether the user is in the room at that time to determine
|
||||
# whether they should be treated as peeking.
|
||||
state_map = await self.state_storage.get_state_for_event(
|
||||
state_map = await self._state_storage_controller.get_state_for_event(
|
||||
last_event.event_id,
|
||||
StateFilter.from_types([(EventTypes.Member, user_id)]),
|
||||
)
|
||||
@ -206,7 +206,7 @@ class MessageHandler:
|
||||
is_peeking = not joined
|
||||
|
||||
visible_events = await filter_events_for_client(
|
||||
self.storage,
|
||||
self._storage_controllers,
|
||||
user_id,
|
||||
[last_event],
|
||||
filter_send_to_client=False,
|
||||
@ -214,9 +214,11 @@ class MessageHandler:
|
||||
)
|
||||
|
||||
if visible_events:
|
||||
room_state_events = await self.state_storage.get_state_for_events(
|
||||
room_state_events = (
|
||||
await self._state_storage_controller.get_state_for_events(
|
||||
[last_event.event_id], state_filter=state_filter
|
||||
)
|
||||
)
|
||||
room_state: Mapping[Any, EventBase] = room_state_events[
|
||||
last_event.event_id
|
||||
]
|
||||
@ -244,9 +246,11 @@ class MessageHandler:
|
||||
assert (
|
||||
membership_event_id is not None
|
||||
), "check_user_in_room_or_world_readable returned invalid data"
|
||||
room_state_events = await self.state_storage.get_state_for_events(
|
||||
room_state_events = (
|
||||
await self._state_storage_controller.get_state_for_events(
|
||||
[membership_event_id], state_filter=state_filter
|
||||
)
|
||||
)
|
||||
room_state = room_state_events[membership_event_id]
|
||||
|
||||
now = self.clock.time_msec()
|
||||
@ -402,7 +406,7 @@ class EventCreationHandler:
|
||||
self.auth = hs.get_auth()
|
||||
self._event_auth_handler = hs.get_event_auth_handler()
|
||||
self.store = hs.get_datastores().main
|
||||
self.storage = hs.get_storage()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self.state = hs.get_state_handler()
|
||||
self.clock = hs.get_clock()
|
||||
self.validator = EventValidator()
|
||||
@ -1032,7 +1036,7 @@ class EventCreationHandler:
|
||||
# after it is created
|
||||
if builder.internal_metadata.outlier:
|
||||
event.internal_metadata.outlier = True
|
||||
context = EventContext.for_outlier(self.storage)
|
||||
context = EventContext.for_outlier(self._storage_controllers)
|
||||
elif (
|
||||
event.type == EventTypes.MSC2716_INSERTION
|
||||
and state_event_ids
|
||||
@ -1445,7 +1449,7 @@ class EventCreationHandler:
|
||||
"""
|
||||
extra_users = extra_users or []
|
||||
|
||||
assert self.storage.persistence is not None
|
||||
assert self._storage_controllers.persistence is not None
|
||||
assert self._events_shard_config.should_handle(
|
||||
self._instance_name, event.room_id
|
||||
)
|
||||
@ -1679,7 +1683,7 @@ class EventCreationHandler:
|
||||
event,
|
||||
event_pos,
|
||||
max_stream_token,
|
||||
) = await self.storage.persistence.persist_event(
|
||||
) = await self._storage_controllers.persistence.persist_event(
|
||||
event, context=context, backfilled=backfilled
|
||||
)
|
||||
|
||||
|
@ -129,8 +129,8 @@ class PaginationHandler:
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastores().main
|
||||
self.storage = hs.get_storage()
|
||||
self.state_storage = self.storage.state
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self._state_storage_controller = self._storage_controllers.state
|
||||
self.clock = hs.get_clock()
|
||||
self._server_name = hs.hostname
|
||||
self._room_shutdown_handler = hs.get_room_shutdown_handler()
|
||||
@ -352,7 +352,7 @@ class PaginationHandler:
|
||||
self._purges_in_progress_by_room.add(room_id)
|
||||
try:
|
||||
async with self.pagination_lock.write(room_id):
|
||||
await self.storage.purge_events.purge_history(
|
||||
await self._storage_controllers.purge_events.purge_history(
|
||||
room_id, token, delete_local_events
|
||||
)
|
||||
logger.info("[purge] complete")
|
||||
@ -414,7 +414,7 @@ class PaginationHandler:
|
||||
if joined:
|
||||
raise SynapseError(400, "Users are still joined to this room")
|
||||
|
||||
await self.storage.purge_events.purge_room(room_id)
|
||||
await self._storage_controllers.purge_events.purge_room(room_id)
|
||||
|
||||
async def get_messages(
|
||||
self,
|
||||
@ -529,7 +529,10 @@ class PaginationHandler:
|
||||
events = await event_filter.filter(events)
|
||||
|
||||
events = await filter_events_for_client(
|
||||
self.storage, user_id, events, is_peeking=(member_event_id is None)
|
||||
self._storage_controllers,
|
||||
user_id,
|
||||
events,
|
||||
is_peeking=(member_event_id is None),
|
||||
)
|
||||
|
||||
# if after the filter applied there are no more events
|
||||
@ -550,7 +553,7 @@ class PaginationHandler:
|
||||
(EventTypes.Member, event.sender) for event in events
|
||||
)
|
||||
|
||||
state_ids = await self.state_storage.get_state_ids_for_event(
|
||||
state_ids = await self._state_storage_controller.get_state_ids_for_event(
|
||||
events[0].event_id, state_filter=state_filter
|
||||
)
|
||||
|
||||
@ -664,7 +667,7 @@ class PaginationHandler:
|
||||
400, "Users are still joined to this room"
|
||||
)
|
||||
|
||||
await self.storage.purge_events.purge_room(room_id)
|
||||
await self._storage_controllers.purge_events.purge_room(room_id)
|
||||
|
||||
logger.info("complete")
|
||||
self._delete_by_id[delete_id].status = DeleteStatus.STATUS_COMPLETE
|
||||
|
@ -69,7 +69,7 @@ class BundledAggregations:
|
||||
class RelationsHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._main_store = hs.get_datastores().main
|
||||
self._storage = hs.get_storage()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self._auth = hs.get_auth()
|
||||
self._clock = hs.get_clock()
|
||||
self._event_handler = hs.get_event_handler()
|
||||
@ -143,7 +143,10 @@ class RelationsHandler:
|
||||
)
|
||||
|
||||
events = await filter_events_for_client(
|
||||
self._storage, user_id, events, is_peeking=(member_event_id is None)
|
||||
self._storage_controllers,
|
||||
user_id,
|
||||
events,
|
||||
is_peeking=(member_event_id is None),
|
||||
)
|
||||
|
||||
now = self._clock.time_msec()
|
||||
|
@ -1192,8 +1192,8 @@ class RoomContextHandler:
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastores().main
|
||||
self.storage = hs.get_storage()
|
||||
self.state_storage = self.storage.state
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self._state_storage_controller = self._storage_controllers.state
|
||||
self._relations_handler = hs.get_relations_handler()
|
||||
|
||||
async def get_event_context(
|
||||
@ -1236,7 +1236,10 @@ class RoomContextHandler:
|
||||
if use_admin_priviledge:
|
||||
return events
|
||||
return await filter_events_for_client(
|
||||
self.storage, user.to_string(), events, is_peeking=is_peeking
|
||||
self._storage_controllers,
|
||||
user.to_string(),
|
||||
events,
|
||||
is_peeking=is_peeking,
|
||||
)
|
||||
|
||||
event = await self.store.get_event(
|
||||
@ -1293,7 +1296,7 @@ class RoomContextHandler:
|
||||
# first? Shouldn't we be consistent with /sync?
|
||||
# https://github.com/matrix-org/matrix-doc/issues/687
|
||||
|
||||
state = await self.state_storage.get_state_for_events(
|
||||
state = await self._state_storage_controller.get_state_for_events(
|
||||
[last_event_id], state_filter=state_filter
|
||||
)
|
||||
|
||||
|
@ -17,7 +17,7 @@ class RoomBatchHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastores().main
|
||||
self.state_storage = hs.get_storage().state
|
||||
self._state_storage_controller = hs.get_storage_controllers().state
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
self.room_member_handler = hs.get_room_member_handler()
|
||||
self.auth = hs.get_auth()
|
||||
@ -141,7 +141,7 @@ class RoomBatchHandler:
|
||||
) = await self.store.get_max_depth_of(event_ids)
|
||||
# mapping from (type, state_key) -> state_event_id
|
||||
assert most_recent_event_id is not None
|
||||
prev_state_map = await self.state_storage.get_state_ids_for_event(
|
||||
prev_state_map = await self._state_storage_controller.get_state_ids_for_event(
|
||||
most_recent_event_id
|
||||
)
|
||||
# List of state event ID's
|
||||
|
@ -55,8 +55,8 @@ class SearchHandler:
|
||||
self.hs = hs
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
self._relations_handler = hs.get_relations_handler()
|
||||
self.storage = hs.get_storage()
|
||||
self.state_storage = self.storage.state
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self._state_storage_controller = self._storage_controllers.state
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]:
|
||||
@ -460,7 +460,7 @@ class SearchHandler:
|
||||
filtered_events = await search_filter.filter([r["event"] for r in results])
|
||||
|
||||
events = await filter_events_for_client(
|
||||
self.storage, user.to_string(), filtered_events
|
||||
self._storage_controllers, user.to_string(), filtered_events
|
||||
)
|
||||
|
||||
events.sort(key=lambda e: -rank_map[e.event_id])
|
||||
@ -559,7 +559,7 @@ class SearchHandler:
|
||||
filtered_events = await search_filter.filter([r["event"] for r in results])
|
||||
|
||||
events = await filter_events_for_client(
|
||||
self.storage, user.to_string(), filtered_events
|
||||
self._storage_controllers, user.to_string(), filtered_events
|
||||
)
|
||||
|
||||
room_events.extend(events)
|
||||
@ -644,11 +644,11 @@ class SearchHandler:
|
||||
)
|
||||
|
||||
events_before = await filter_events_for_client(
|
||||
self.storage, user.to_string(), res.events_before
|
||||
self._storage_controllers, user.to_string(), res.events_before
|
||||
)
|
||||
|
||||
events_after = await filter_events_for_client(
|
||||
self.storage, user.to_string(), res.events_after
|
||||
self._storage_controllers, user.to_string(), res.events_after
|
||||
)
|
||||
|
||||
context: JsonDict = {
|
||||
@ -677,7 +677,7 @@ class SearchHandler:
|
||||
[(EventTypes.Member, sender) for sender in senders]
|
||||
)
|
||||
|
||||
state = await self.state_storage.get_state_for_event(
|
||||
state = await self._state_storage_controller.get_state_for_event(
|
||||
last_event_id, state_filter
|
||||
)
|
||||
|
||||
|
@ -238,8 +238,8 @@ class SyncHandler:
|
||||
self.clock = hs.get_clock()
|
||||
self.state = hs.get_state_handler()
|
||||
self.auth = hs.get_auth()
|
||||
self.storage = hs.get_storage()
|
||||
self.state_storage = self.storage.state
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self._state_storage_controller = self._storage_controllers.state
|
||||
|
||||
# TODO: flush cache entries on subsequent sync request.
|
||||
# Once we get the next /sync request (ie, one with the same access token
|
||||
@ -512,7 +512,7 @@ class SyncHandler:
|
||||
current_state_ids = frozenset(current_state_ids_map.values())
|
||||
|
||||
recents = await filter_events_for_client(
|
||||
self.storage,
|
||||
self._storage_controllers,
|
||||
sync_config.user.to_string(),
|
||||
recents,
|
||||
always_include_ids=current_state_ids,
|
||||
@ -580,7 +580,7 @@ class SyncHandler:
|
||||
current_state_ids = frozenset(current_state_ids_map.values())
|
||||
|
||||
loaded_recents = await filter_events_for_client(
|
||||
self.storage,
|
||||
self._storage_controllers,
|
||||
sync_config.user.to_string(),
|
||||
loaded_recents,
|
||||
always_include_ids=current_state_ids,
|
||||
@ -630,7 +630,7 @@ class SyncHandler:
|
||||
event: event of interest
|
||||
state_filter: The state filter used to fetch state from the database.
|
||||
"""
|
||||
state_ids = await self.state_storage.get_state_ids_for_event(
|
||||
state_ids = await self._state_storage_controller.get_state_ids_for_event(
|
||||
event.event_id, state_filter=state_filter or StateFilter.all()
|
||||
)
|
||||
if event.is_state():
|
||||
@ -710,7 +710,7 @@ class SyncHandler:
|
||||
return None
|
||||
|
||||
last_event = last_events[-1]
|
||||
state_ids = await self.state_storage.get_state_ids_for_event(
|
||||
state_ids = await self._state_storage_controller.get_state_ids_for_event(
|
||||
last_event.event_id,
|
||||
state_filter=StateFilter.from_types(
|
||||
[(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")]
|
||||
@ -889,14 +889,16 @@ class SyncHandler:
|
||||
if full_state:
|
||||
if batch:
|
||||
current_state_ids = (
|
||||
await self.state_storage.get_state_ids_for_event(
|
||||
await self._state_storage_controller.get_state_ids_for_event(
|
||||
batch.events[-1].event_id, state_filter=state_filter
|
||||
)
|
||||
)
|
||||
|
||||
state_ids = await self.state_storage.get_state_ids_for_event(
|
||||
state_ids = (
|
||||
await self._state_storage_controller.get_state_ids_for_event(
|
||||
batch.events[0].event_id, state_filter=state_filter
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
current_state_ids = await self.get_state_at(
|
||||
@ -915,7 +917,7 @@ class SyncHandler:
|
||||
elif batch.limited:
|
||||
if batch:
|
||||
state_at_timeline_start = (
|
||||
await self.state_storage.get_state_ids_for_event(
|
||||
await self._state_storage_controller.get_state_ids_for_event(
|
||||
batch.events[0].event_id, state_filter=state_filter
|
||||
)
|
||||
)
|
||||
@ -950,7 +952,7 @@ class SyncHandler:
|
||||
|
||||
if batch:
|
||||
current_state_ids = (
|
||||
await self.state_storage.get_state_ids_for_event(
|
||||
await self._state_storage_controller.get_state_ids_for_event(
|
||||
batch.events[-1].event_id, state_filter=state_filter
|
||||
)
|
||||
)
|
||||
@ -982,7 +984,7 @@ class SyncHandler:
|
||||
# So we fish out all the member events corresponding to the
|
||||
# timeline here, and then dedupe any redundant ones below.
|
||||
|
||||
state_ids = await self.state_storage.get_state_ids_for_event(
|
||||
state_ids = await self._state_storage_controller.get_state_ids_for_event(
|
||||
batch.events[0].event_id,
|
||||
# we only want members!
|
||||
state_filter=StateFilter.from_types(
|
||||
|
@ -221,7 +221,7 @@ class Notifier:
|
||||
self.room_to_user_streams: Dict[str, Set[_NotifierUserStream]] = {}
|
||||
|
||||
self.hs = hs
|
||||
self.storage = hs.get_storage()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self.event_sources = hs.get_event_sources()
|
||||
self.store = hs.get_datastores().main
|
||||
self.pending_new_room_events: List[_PendingRoomEventEntry] = []
|
||||
@ -623,7 +623,7 @@ class Notifier:
|
||||
|
||||
if name == "room":
|
||||
new_events = await filter_events_for_client(
|
||||
self.storage,
|
||||
self._storage_controllers,
|
||||
user.to_string(),
|
||||
new_events,
|
||||
is_peeking=is_peeking,
|
||||
|
@ -65,7 +65,7 @@ class HttpPusher(Pusher):
|
||||
|
||||
def __init__(self, hs: "HomeServer", pusher_config: PusherConfig):
|
||||
super().__init__(hs, pusher_config)
|
||||
self.storage = self.hs.get_storage()
|
||||
self._storage_controllers = self.hs.get_storage_controllers()
|
||||
self.app_display_name = pusher_config.app_display_name
|
||||
self.device_display_name = pusher_config.device_display_name
|
||||
self.pushkey_ts = pusher_config.ts
|
||||
@ -343,7 +343,9 @@ class HttpPusher(Pusher):
|
||||
}
|
||||
return d
|
||||
|
||||
ctx = await push_tools.get_context_for_event(self.storage, event, self.user_id)
|
||||
ctx = await push_tools.get_context_for_event(
|
||||
self._storage_controllers, event, self.user_id
|
||||
)
|
||||
|
||||
d = {
|
||||
"notification": {
|
||||
|
@ -114,10 +114,10 @@ class Mailer:
|
||||
|
||||
self.send_email_handler = hs.get_send_email_handler()
|
||||
self.store = self.hs.get_datastores().main
|
||||
self.state_storage = self.hs.get_storage().state
|
||||
self._state_storage_controller = self.hs.get_storage_controllers().state
|
||||
self.macaroon_gen = self.hs.get_macaroon_generator()
|
||||
self.state_handler = self.hs.get_state_handler()
|
||||
self.storage = hs.get_storage()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self.app_name = app_name
|
||||
self.email_subjects: EmailSubjectConfig = hs.config.email.email_subjects
|
||||
|
||||
@ -456,7 +456,7 @@ class Mailer:
|
||||
}
|
||||
|
||||
the_events = await filter_events_for_client(
|
||||
self.storage, user_id, results.events_before
|
||||
self._storage_controllers, user_id, results.events_before
|
||||
)
|
||||
the_events.append(notif_event)
|
||||
|
||||
@ -494,7 +494,7 @@ class Mailer:
|
||||
)
|
||||
else:
|
||||
# Attempt to check the historical state for the room.
|
||||
historical_state = await self.state_storage.get_state_for_event(
|
||||
historical_state = await self._state_storage_controller.get_state_for_event(
|
||||
event.event_id, StateFilter.from_types((type_state_key,))
|
||||
)
|
||||
sender_state_event = historical_state.get(type_state_key)
|
||||
@ -767,9 +767,11 @@ class Mailer:
|
||||
member_event_ids.append(sender_state_event_id)
|
||||
else:
|
||||
# Attempt to check the historical state for the room.
|
||||
historical_state = await self.state_storage.get_state_for_event(
|
||||
historical_state = (
|
||||
await self._state_storage_controller.get_state_for_event(
|
||||
event_id, StateFilter.from_types((type_state_key,))
|
||||
)
|
||||
)
|
||||
sender_state_event = historical_state.get(type_state_key)
|
||||
if sender_state_event:
|
||||
member_events[event_id] = sender_state_event
|
||||
|
@ -16,7 +16,7 @@ from typing import Dict
|
||||
from synapse.api.constants import ReceiptTypes
|
||||
from synapse.events import EventBase
|
||||
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
|
||||
from synapse.storage import Storage
|
||||
from synapse.storage.controllers import StorageControllers
|
||||
from synapse.storage.databases.main import DataStore
|
||||
|
||||
|
||||
@ -52,7 +52,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
|
||||
|
||||
|
||||
async def get_context_for_event(
|
||||
storage: Storage, ev: EventBase, user_id: str
|
||||
storage: StorageControllers, ev: EventBase, user_id: str
|
||||
) -> Dict[str, str]:
|
||||
ctx = {}
|
||||
|
||||
|
@ -69,7 +69,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
||||
super().__init__(hs)
|
||||
|
||||
self.store = hs.get_datastores().main
|
||||
self.storage = hs.get_storage()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self.clock = hs.get_clock()
|
||||
self.federation_event_handler = hs.get_federation_event_handler()
|
||||
|
||||
@ -133,7 +133,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
||||
event.internal_metadata.outlier = event_payload["outlier"]
|
||||
|
||||
context = EventContext.deserialize(
|
||||
self.storage, event_payload["context"]
|
||||
self._storage_controllers, event_payload["context"]
|
||||
)
|
||||
|
||||
event_and_contexts.append((event, context))
|
||||
|
@ -70,7 +70,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
|
||||
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
self.store = hs.get_datastores().main
|
||||
self.storage = hs.get_storage()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
@staticmethod
|
||||
@ -127,7 +127,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
|
||||
event.internal_metadata.outlier = content["outlier"]
|
||||
|
||||
requester = Requester.deserialize(self.store, content["requester"])
|
||||
context = EventContext.deserialize(self.storage, content["context"])
|
||||
context = EventContext.deserialize(
|
||||
self._storage_controllers, content["context"]
|
||||
)
|
||||
|
||||
ratelimit = content["ratelimit"]
|
||||
extra_users = [UserID.from_string(u) for u in content["extra_users"]]
|
||||
|
@ -123,7 +123,8 @@ from synapse.server_notices.worker_server_notices_sender import (
|
||||
WorkerServerNoticesSender,
|
||||
)
|
||||
from synapse.state import StateHandler, StateResolutionHandler
|
||||
from synapse.storage import Databases, Storage
|
||||
from synapse.storage import Databases
|
||||
from synapse.storage.controllers import StorageControllers
|
||||
from synapse.streams.events import EventSources
|
||||
from synapse.types import DomainSpecificString, ISynapseReactor
|
||||
from synapse.util import Clock
|
||||
@ -729,8 +730,8 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||
return PasswordPolicyHandler(self)
|
||||
|
||||
@cache_in_self
|
||||
def get_storage(self) -> Storage:
|
||||
return Storage(self, self.get_datastores())
|
||||
def get_storage_controllers(self) -> StorageControllers:
|
||||
return StorageControllers(self, self.get_datastores())
|
||||
|
||||
@cache_in_self
|
||||
def get_replication_streamer(self) -> ReplicationStreamer:
|
||||
|
@ -127,10 +127,10 @@ class StateHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastores().main
|
||||
self.state_storage = hs.get_storage().state
|
||||
self._state_storage_controller = hs.get_storage_controllers().state
|
||||
self.hs = hs
|
||||
self._state_resolution_handler = hs.get_state_resolution_handler()
|
||||
self._storage = hs.get_storage()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
|
||||
@overload
|
||||
async def get_current_state(
|
||||
@ -337,13 +337,15 @@ class StateHandler:
|
||||
#
|
||||
|
||||
if not state_group_before_event:
|
||||
state_group_before_event = await self.state_storage.store_state_group(
|
||||
state_group_before_event = (
|
||||
await self._state_storage_controller.store_state_group(
|
||||
event.event_id,
|
||||
event.room_id,
|
||||
prev_group=state_group_before_event_prev_group,
|
||||
delta_ids=deltas_to_state_group_before_event,
|
||||
current_state_ids=state_ids_before_event,
|
||||
)
|
||||
)
|
||||
|
||||
# Assign the new state group to the cached state entry.
|
||||
#
|
||||
@ -359,7 +361,7 @@ class StateHandler:
|
||||
|
||||
if not event.is_state():
|
||||
return EventContext.with_state(
|
||||
storage=self._storage,
|
||||
storage=self._storage_controllers,
|
||||
state_group_before_event=state_group_before_event,
|
||||
state_group=state_group_before_event,
|
||||
state_delta_due_to_event={},
|
||||
@ -382,16 +384,18 @@ class StateHandler:
|
||||
state_ids_after_event[key] = event.event_id
|
||||
delta_ids = {key: event.event_id}
|
||||
|
||||
state_group_after_event = await self.state_storage.store_state_group(
|
||||
state_group_after_event = (
|
||||
await self._state_storage_controller.store_state_group(
|
||||
event.event_id,
|
||||
event.room_id,
|
||||
prev_group=state_group_before_event,
|
||||
delta_ids=delta_ids,
|
||||
current_state_ids=state_ids_after_event,
|
||||
)
|
||||
)
|
||||
|
||||
return EventContext.with_state(
|
||||
storage=self._storage,
|
||||
storage=self._storage_controllers,
|
||||
state_group=state_group_after_event,
|
||||
state_group_before_event=state_group_before_event,
|
||||
state_delta_due_to_event=delta_ids,
|
||||
@ -416,7 +420,9 @@ class StateHandler:
|
||||
"""
|
||||
logger.debug("resolve_state_groups event_ids %s", event_ids)
|
||||
|
||||
state_groups = await self.state_storage.get_state_group_for_events(event_ids)
|
||||
state_groups = await self._state_storage_controller.get_state_group_for_events(
|
||||
event_ids
|
||||
)
|
||||
|
||||
state_group_ids = state_groups.values()
|
||||
|
||||
@ -424,8 +430,13 @@ class StateHandler:
|
||||
state_group_ids_set = set(state_group_ids)
|
||||
if len(state_group_ids_set) == 1:
|
||||
(state_group_id,) = state_group_ids_set
|
||||
state = await self.state_storage.get_state_for_groups(state_group_ids_set)
|
||||
prev_group, delta_ids = await self.state_storage.get_state_group_delta(
|
||||
state = await self._state_storage_controller.get_state_for_groups(
|
||||
state_group_ids_set
|
||||
)
|
||||
(
|
||||
prev_group,
|
||||
delta_ids,
|
||||
) = await self._state_storage_controller.get_state_group_delta(
|
||||
state_group_id
|
||||
)
|
||||
return _StateCacheEntry(
|
||||
@ -439,7 +450,7 @@ class StateHandler:
|
||||
|
||||
room_version = await self.store.get_room_version_id(room_id)
|
||||
|
||||
state_to_resolve = await self.state_storage.get_state_for_groups(
|
||||
state_to_resolve = await self._state_storage_controller.get_state_for_groups(
|
||||
state_group_ids_set
|
||||
)
|
||||
|
||||
|
@ -18,41 +18,20 @@ The storage layer is split up into multiple parts to allow Synapse to run
|
||||
against different configurations of databases (e.g. single or multiple
|
||||
databases). The `DatabasePool` class represents connections to a single physical
|
||||
database. The `databases` are classes that talk directly to a `DatabasePool`
|
||||
instance and have associated schemas, background updates, etc. On top of those
|
||||
there are classes that provide high level interfaces that combine calls to
|
||||
multiple `databases`.
|
||||
instance and have associated schemas, background updates, etc.
|
||||
|
||||
On top of the databases are the StorageControllers, located in the
|
||||
`synapse.storage.controllers` module. These classes provide high level
|
||||
interfaces that combine calls to multiple `databases`. They are bundled into the
|
||||
`StorageControllers` singleton for ease of use, and exposed via
|
||||
`HomeServer.get_storage_controllers()`.
|
||||
|
||||
There are also schemas that get applied to every database, regardless of the
|
||||
data stores associated with them (e.g. the schema version tables), which are
|
||||
stored in `synapse.storage.schema`.
|
||||
"""
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from synapse.storage.databases import Databases
|
||||
from synapse.storage.databases.main import DataStore
|
||||
from synapse.storage.persist_events import EventsPersistenceStorage
|
||||
from synapse.storage.purge_events import PurgeEventsStorage
|
||||
from synapse.storage.state import StateGroupStorage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
__all__ = ["Databases", "DataStore"]
|
||||
|
||||
|
||||
class Storage:
|
||||
"""The high level interfaces for talking to various storage layers."""
|
||||
|
||||
def __init__(self, hs: "HomeServer", stores: Databases):
|
||||
# We include the main data store here mainly so that we don't have to
|
||||
# rewrite all the existing code to split it into high vs low level
|
||||
# interfaces.
|
||||
self.main = stores.main
|
||||
|
||||
self.purge_events = PurgeEventsStorage(hs, stores)
|
||||
self.state = StateGroupStorage(hs, stores)
|
||||
|
||||
self.persistence = None
|
||||
if stores.persist_events:
|
||||
self.persistence = EventsPersistenceStorage(hs, stores)
|
||||
|
46
synapse/storage/controllers/__init__.py
Normal file
46
synapse/storage/controllers/__init__.py
Normal file
@ -0,0 +1,46 @@
|
||||
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from synapse.storage.controllers.persist_events import (
|
||||
EventsPersistenceStorageController,
|
||||
)
|
||||
from synapse.storage.controllers.purge_events import PurgeEventsStorageController
|
||||
from synapse.storage.controllers.state import StateGroupStorageController
|
||||
from synapse.storage.databases import Databases
|
||||
from synapse.storage.databases.main import DataStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
__all__ = ["Databases", "DataStore"]
|
||||
|
||||
|
||||
class StorageControllers:
|
||||
"""The high level interfaces for talking to various storage controller layers."""
|
||||
|
||||
def __init__(self, hs: "HomeServer", stores: Databases):
|
||||
# We include the main data store here mainly so that we don't have to
|
||||
# rewrite all the existing code to split it into high vs low level
|
||||
# interfaces.
|
||||
self.main = stores.main
|
||||
|
||||
self.purge_events = PurgeEventsStorageController(hs, stores)
|
||||
self.state = StateGroupStorageController(hs, stores)
|
||||
|
||||
self.persistence = None
|
||||
if stores.persist_events:
|
||||
self.persistence = EventsPersistenceStorageController(hs, stores)
|
@ -272,7 +272,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
|
||||
pass
|
||||
|
||||
|
||||
class EventsPersistenceStorage:
|
||||
class EventsPersistenceStorageController:
|
||||
"""High level interface for handling persisting newly received events.
|
||||
|
||||
Takes care of batching up events by room, and calculating the necessary
|
@ -24,7 +24,7 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PurgeEventsStorage:
|
||||
class PurgeEventsStorageController:
|
||||
"""High level interface for purging rooms and event history."""
|
||||
|
||||
def __init__(self, hs: "HomeServer", stores: Databases):
|
351
synapse/storage/controllers/state.py
Normal file
351
synapse/storage/controllers/state.py
Normal file
@ -0,0 +1,351 @@
|
||||
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Awaitable,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
from synapse.events import EventBase
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
|
||||
from synapse.types import MutableStateMap, StateMap
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.databases import Databases
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StateGroupStorageController:
|
||||
"""High level interface to fetching state for event."""
|
||||
|
||||
def __init__(self, hs: "HomeServer", stores: "Databases"):
|
||||
self._is_mine_id = hs.is_mine_id
|
||||
self.stores = stores
|
||||
self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
|
||||
|
||||
def notify_event_un_partial_stated(self, event_id: str) -> None:
|
||||
self._partial_state_events_tracker.notify_un_partial_stated(event_id)
|
||||
|
||||
async def get_state_group_delta(
|
||||
self, state_group: int
|
||||
) -> Tuple[Optional[int], Optional[StateMap[str]]]:
|
||||
"""Given a state group try to return a previous group and a delta between
|
||||
the old and the new.
|
||||
|
||||
Args:
|
||||
state_group: The state group used to retrieve state deltas.
|
||||
|
||||
Returns:
|
||||
A tuple of the previous group and a state map of the event IDs which
|
||||
make up the delta between the old and new state groups.
|
||||
"""
|
||||
|
||||
state_group_delta = await self.stores.state.get_state_group_delta(state_group)
|
||||
return state_group_delta.prev_group, state_group_delta.delta_ids
|
||||
|
||||
async def get_state_groups_ids(
|
||||
self, _room_id: str, event_ids: Collection[str]
|
||||
) -> Dict[int, MutableStateMap[str]]:
|
||||
"""Get the event IDs of all the state for the state groups for the given events
|
||||
|
||||
Args:
|
||||
_room_id: id of the room for these events
|
||||
event_ids: ids of the events
|
||||
|
||||
Returns:
|
||||
dict of state_group_id -> (dict of (type, state_key) -> event id)
|
||||
|
||||
Raises:
|
||||
RuntimeError if we don't have a state group for one or more of the events
|
||||
(ie they are outliers or unknown)
|
||||
"""
|
||||
if not event_ids:
|
||||
return {}
|
||||
|
||||
event_to_groups = await self.get_state_group_for_events(event_ids)
|
||||
|
||||
groups = set(event_to_groups.values())
|
||||
group_to_state = await self.stores.state._get_state_for_groups(groups)
|
||||
|
||||
return group_to_state
|
||||
|
||||
async def get_state_ids_for_group(
|
||||
self, state_group: int, state_filter: Optional[StateFilter] = None
|
||||
) -> StateMap[str]:
|
||||
"""Get the event IDs of all the state in the given state group
|
||||
|
||||
Args:
|
||||
state_group: A state group for which we want to get the state IDs.
|
||||
state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
|
||||
|
||||
Returns:
|
||||
Resolves to a map of (type, state_key) -> event_id
|
||||
"""
|
||||
group_to_state = await self.get_state_for_groups((state_group,), state_filter)
|
||||
|
||||
return group_to_state[state_group]
|
||||
|
||||
async def get_state_groups(
|
||||
self, room_id: str, event_ids: Collection[str]
|
||||
) -> Dict[int, List[EventBase]]:
|
||||
"""Get the state groups for the given list of event_ids
|
||||
|
||||
Args:
|
||||
room_id: ID of the room for these events.
|
||||
event_ids: The event IDs to retrieve state for.
|
||||
|
||||
Returns:
|
||||
dict of state_group_id -> list of state events.
|
||||
"""
|
||||
if not event_ids:
|
||||
return {}
|
||||
|
||||
group_to_ids = await self.get_state_groups_ids(room_id, event_ids)
|
||||
|
||||
state_event_map = await self.stores.main.get_events(
|
||||
[
|
||||
ev_id
|
||||
for group_ids in group_to_ids.values()
|
||||
for ev_id in group_ids.values()
|
||||
],
|
||||
get_prev_content=False,
|
||||
)
|
||||
|
||||
return {
|
||||
group: [
|
||||
state_event_map[v]
|
||||
for v in event_id_map.values()
|
||||
if v in state_event_map
|
||||
]
|
||||
for group, event_id_map in group_to_ids.items()
|
||||
}
|
||||
|
||||
def _get_state_groups_from_groups(
|
||||
self, groups: List[int], state_filter: StateFilter
|
||||
) -> Awaitable[Dict[int, StateMap[str]]]:
|
||||
"""Returns the state groups for a given set of groups, filtering on
|
||||
types of state events.
|
||||
|
||||
Args:
|
||||
groups: list of state group IDs to query
|
||||
state_filter: The state filter used to fetch state
|
||||
from the database.
|
||||
|
||||
Returns:
|
||||
Dict of state group to state map.
|
||||
"""
|
||||
|
||||
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
|
||||
|
||||
async def get_state_for_events(
|
||||
self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
|
||||
) -> Dict[str, StateMap[EventBase]]:
|
||||
"""Given a list of event_ids and type tuples, return a list of state
|
||||
dicts for each event.
|
||||
|
||||
Args:
|
||||
event_ids: The events to fetch the state of.
|
||||
state_filter: The state filter used to fetch state.
|
||||
|
||||
Returns:
|
||||
A dict of (event_id) -> (type, state_key) -> [state_events]
|
||||
|
||||
Raises:
|
||||
RuntimeError if we don't have a state group for one or more of the events
|
||||
(ie they are outliers or unknown)
|
||||
"""
|
||||
await_full_state = True
|
||||
if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
|
||||
await_full_state = False
|
||||
|
||||
event_to_groups = await self.get_state_group_for_events(
|
||||
event_ids, await_full_state=await_full_state
|
||||
)
|
||||
|
||||
groups = set(event_to_groups.values())
|
||||
group_to_state = await self.stores.state._get_state_for_groups(
|
||||
groups, state_filter or StateFilter.all()
|
||||
)
|
||||
|
||||
state_event_map = await self.stores.main.get_events(
|
||||
[ev_id for sd in group_to_state.values() for ev_id in sd.values()],
|
||||
get_prev_content=False,
|
||||
)
|
||||
|
||||
event_to_state = {
|
||||
event_id: {
|
||||
k: state_event_map[v]
|
||||
for k, v in group_to_state[group].items()
|
||||
if v in state_event_map
|
||||
}
|
||||
for event_id, group in event_to_groups.items()
|
||||
}
|
||||
|
||||
return {event: event_to_state[event] for event in event_ids}
|
||||
|
||||
async def get_state_ids_for_events(
|
||||
self,
|
||||
event_ids: Collection[str],
|
||||
state_filter: Optional[StateFilter] = None,
|
||||
) -> Dict[str, StateMap[str]]:
|
||||
"""
|
||||
Get the state dicts corresponding to a list of events, containing the event_ids
|
||||
of the state events (as opposed to the events themselves)
|
||||
|
||||
Args:
|
||||
event_ids: events whose state should be returned
|
||||
state_filter: The state filter used to fetch state from the database.
|
||||
|
||||
Returns:
|
||||
A dict from event_id -> (type, state_key) -> event_id
|
||||
|
||||
Raises:
|
||||
RuntimeError if we don't have a state group for one or more of the events
|
||||
(ie they are outliers or unknown)
|
||||
"""
|
||||
await_full_state = True
|
||||
if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
|
||||
await_full_state = False
|
||||
|
||||
event_to_groups = await self.get_state_group_for_events(
|
||||
event_ids, await_full_state=await_full_state
|
||||
)
|
||||
|
||||
groups = set(event_to_groups.values())
|
||||
group_to_state = await self.stores.state._get_state_for_groups(
|
||||
groups, state_filter or StateFilter.all()
|
||||
)
|
||||
|
||||
event_to_state = {
|
||||
event_id: group_to_state[group]
|
||||
for event_id, group in event_to_groups.items()
|
||||
}
|
||||
|
||||
return {event: event_to_state[event] for event in event_ids}
|
||||
|
||||
async def get_state_for_event(
|
||||
self, event_id: str, state_filter: Optional[StateFilter] = None
|
||||
) -> StateMap[EventBase]:
|
||||
"""
|
||||
Get the state dict corresponding to a particular event
|
||||
|
||||
Args:
|
||||
event_id: event whose state should be returned
|
||||
state_filter: The state filter used to fetch state from the database.
|
||||
|
||||
Returns:
|
||||
A dict from (type, state_key) -> state_event
|
||||
|
||||
Raises:
|
||||
RuntimeError if we don't have a state group for the event (ie it is an
|
||||
outlier or is unknown)
|
||||
"""
|
||||
state_map = await self.get_state_for_events(
|
||||
[event_id], state_filter or StateFilter.all()
|
||||
)
|
||||
return state_map[event_id]
|
||||
|
||||
async def get_state_ids_for_event(
|
||||
self, event_id: str, state_filter: Optional[StateFilter] = None
|
||||
) -> StateMap[str]:
|
||||
"""
|
||||
Get the state dict corresponding to a particular event
|
||||
|
||||
Args:
|
||||
event_id: event whose state should be returned
|
||||
state_filter: The state filter used to fetch state from the database.
|
||||
|
||||
Returns:
|
||||
A dict from (type, state_key) -> state_event_id
|
||||
|
||||
Raises:
|
||||
RuntimeError if we don't have a state group for the event (ie it is an
|
||||
outlier or is unknown)
|
||||
"""
|
||||
state_map = await self.get_state_ids_for_events(
|
||||
[event_id], state_filter or StateFilter.all()
|
||||
)
|
||||
return state_map[event_id]
|
||||
|
||||
def get_state_for_groups(
|
||||
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
|
||||
) -> Awaitable[Dict[int, MutableStateMap[str]]]:
|
||||
"""Gets the state at each of a list of state groups, optionally
|
||||
filtering by type/state_key
|
||||
|
||||
Args:
|
||||
groups: list of state groups for which we want to get the state.
|
||||
state_filter: The state filter used to fetch state.
|
||||
from the database.
|
||||
|
||||
Returns:
|
||||
Dict of state group to state map.
|
||||
"""
|
||||
return self.stores.state._get_state_for_groups(
|
||||
groups, state_filter or StateFilter.all()
|
||||
)
|
||||
|
||||
async def get_state_group_for_events(
|
||||
self,
|
||||
event_ids: Collection[str],
|
||||
await_full_state: bool = True,
|
||||
) -> Mapping[str, int]:
|
||||
"""Returns mapping event_id -> state_group
|
||||
|
||||
Args:
|
||||
event_ids: events to get state groups for
|
||||
await_full_state: if true, will block if we do not yet have complete
|
||||
state at these events.
|
||||
"""
|
||||
if await_full_state:
|
||||
await self._partial_state_events_tracker.await_full_state(event_ids)
|
||||
|
||||
return await self.stores.main._get_state_group_for_events(event_ids)
|
||||
|
||||
async def store_state_group(
|
||||
self,
|
||||
event_id: str,
|
||||
room_id: str,
|
||||
prev_group: Optional[int],
|
||||
delta_ids: Optional[StateMap[str]],
|
||||
current_state_ids: StateMap[str],
|
||||
) -> int:
|
||||
"""Store a new set of state, returning a newly assigned state group.
|
||||
|
||||
Args:
|
||||
event_id: The event ID for which the state was calculated.
|
||||
room_id: ID of the room for which the state was calculated.
|
||||
prev_group: A previous state group for the room, optional.
|
||||
delta_ids: The delta between state at `prev_group` and
|
||||
`current_state_ids`, if `prev_group` was given. Same format as
|
||||
`current_state_ids`.
|
||||
current_state_ids: The state to store. Map of (type, state_key)
|
||||
to event_id.
|
||||
|
||||
Returns:
|
||||
The state group ID
|
||||
"""
|
||||
return await self.stores.state.store_state_group(
|
||||
event_id, room_id, prev_group, delta_ids, current_state_ids
|
||||
)
|
@ -15,7 +15,6 @@
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
@ -32,15 +31,11 @@ import attr
|
||||
from frozendict import frozendict
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.events import EventBase
|
||||
from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
|
||||
from synapse.types import MutableStateMap, StateKey, StateMap
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad
|
||||
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.databases import Databases
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -578,318 +573,3 @@ _ALL_NON_MEMBER_STATE_FILTER = StateFilter(
|
||||
types=frozendict({EventTypes.Member: frozenset()}), include_others=True
|
||||
)
|
||||
_NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False)
|
||||
|
||||
|
||||
class StateGroupStorage:
|
||||
"""High level interface to fetching state for event."""
|
||||
|
||||
def __init__(self, hs: "HomeServer", stores: "Databases"):
|
||||
self._is_mine_id = hs.is_mine_id
|
||||
self.stores = stores
|
||||
self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
|
||||
|
||||
def notify_event_un_partial_stated(self, event_id: str) -> None:
|
||||
self._partial_state_events_tracker.notify_un_partial_stated(event_id)
|
||||
|
||||
async def get_state_group_delta(
|
||||
self, state_group: int
|
||||
) -> Tuple[Optional[int], Optional[StateMap[str]]]:
|
||||
"""Given a state group try to return a previous group and a delta between
|
||||
the old and the new.
|
||||
|
||||
Args:
|
||||
state_group: The state group used to retrieve state deltas.
|
||||
|
||||
Returns:
|
||||
A tuple of the previous group and a state map of the event IDs which
|
||||
make up the delta between the old and new state groups.
|
||||
"""
|
||||
|
||||
state_group_delta = await self.stores.state.get_state_group_delta(state_group)
|
||||
return state_group_delta.prev_group, state_group_delta.delta_ids
|
||||
|
||||
async def get_state_groups_ids(
|
||||
self, _room_id: str, event_ids: Collection[str]
|
||||
) -> Dict[int, MutableStateMap[str]]:
|
||||
"""Get the event IDs of all the state for the state groups for the given events
|
||||
|
||||
Args:
|
||||
_room_id: id of the room for these events
|
||||
event_ids: ids of the events
|
||||
|
||||
Returns:
|
||||
dict of state_group_id -> (dict of (type, state_key) -> event id)
|
||||
|
||||
Raises:
|
||||
RuntimeError if we don't have a state group for one or more of the events
|
||||
(ie they are outliers or unknown)
|
||||
"""
|
||||
if not event_ids:
|
||||
return {}
|
||||
|
||||
event_to_groups = await self.get_state_group_for_events(event_ids)
|
||||
|
||||
groups = set(event_to_groups.values())
|
||||
group_to_state = await self.stores.state._get_state_for_groups(groups)
|
||||
|
||||
return group_to_state
|
||||
|
||||
async def get_state_ids_for_group(
|
||||
self, state_group: int, state_filter: Optional[StateFilter] = None
|
||||
) -> StateMap[str]:
|
||||
"""Get the event IDs of all the state in the given state group
|
||||
|
||||
Args:
|
||||
state_group: A state group for which we want to get the state IDs.
|
||||
state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
|
||||
|
||||
Returns:
|
||||
Resolves to a map of (type, state_key) -> event_id
|
||||
"""
|
||||
group_to_state = await self.get_state_for_groups((state_group,), state_filter)
|
||||
|
||||
return group_to_state[state_group]
|
||||
|
||||
async def get_state_groups(
|
||||
self, room_id: str, event_ids: Collection[str]
|
||||
) -> Dict[int, List[EventBase]]:
|
||||
"""Get the state groups for the given list of event_ids
|
||||
|
||||
Args:
|
||||
room_id: ID of the room for these events.
|
||||
event_ids: The event IDs to retrieve state for.
|
||||
|
||||
Returns:
|
||||
dict of state_group_id -> list of state events.
|
||||
"""
|
||||
if not event_ids:
|
||||
return {}
|
||||
|
||||
group_to_ids = await self.get_state_groups_ids(room_id, event_ids)
|
||||
|
||||
state_event_map = await self.stores.main.get_events(
|
||||
[
|
||||
ev_id
|
||||
for group_ids in group_to_ids.values()
|
||||
for ev_id in group_ids.values()
|
||||
],
|
||||
get_prev_content=False,
|
||||
)
|
||||
|
||||
return {
|
||||
group: [
|
||||
state_event_map[v]
|
||||
for v in event_id_map.values()
|
||||
if v in state_event_map
|
||||
]
|
||||
for group, event_id_map in group_to_ids.items()
|
||||
}
|
||||
|
||||
def _get_state_groups_from_groups(
|
||||
self, groups: List[int], state_filter: StateFilter
|
||||
) -> Awaitable[Dict[int, StateMap[str]]]:
|
||||
"""Returns the state groups for a given set of groups, filtering on
|
||||
types of state events.
|
||||
|
||||
Args:
|
||||
groups: list of state group IDs to query
|
||||
state_filter: The state filter used to fetch state
|
||||
from the database.
|
||||
|
||||
Returns:
|
||||
Dict of state group to state map.
|
||||
"""
|
||||
|
||||
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
|
||||
|
||||
async def get_state_for_events(
|
||||
self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
|
||||
) -> Dict[str, StateMap[EventBase]]:
|
||||
"""Given a list of event_ids and type tuples, return a list of state
|
||||
dicts for each event.
|
||||
|
||||
Args:
|
||||
event_ids: The events to fetch the state of.
|
||||
state_filter: The state filter used to fetch state.
|
||||
|
||||
Returns:
|
||||
A dict of (event_id) -> (type, state_key) -> [state_events]
|
||||
|
||||
Raises:
|
||||
RuntimeError if we don't have a state group for one or more of the events
|
||||
(ie they are outliers or unknown)
|
||||
"""
|
||||
await_full_state = True
|
||||
if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
|
||||
await_full_state = False
|
||||
|
||||
event_to_groups = await self.get_state_group_for_events(
|
||||
event_ids, await_full_state=await_full_state
|
||||
)
|
||||
|
||||
groups = set(event_to_groups.values())
|
||||
group_to_state = await self.stores.state._get_state_for_groups(
|
||||
groups, state_filter or StateFilter.all()
|
||||
)
|
||||
|
||||
state_event_map = await self.stores.main.get_events(
|
||||
[ev_id for sd in group_to_state.values() for ev_id in sd.values()],
|
||||
get_prev_content=False,
|
||||
)
|
||||
|
||||
event_to_state = {
|
||||
event_id: {
|
||||
k: state_event_map[v]
|
||||
for k, v in group_to_state[group].items()
|
||||
if v in state_event_map
|
||||
}
|
||||
for event_id, group in event_to_groups.items()
|
||||
}
|
||||
|
||||
return {event: event_to_state[event] for event in event_ids}
|
||||
|
||||
async def get_state_ids_for_events(
|
||||
self,
|
||||
event_ids: Collection[str],
|
||||
state_filter: Optional[StateFilter] = None,
|
||||
) -> Dict[str, StateMap[str]]:
|
||||
"""
|
||||
Get the state dicts corresponding to a list of events, containing the event_ids
|
||||
of the state events (as opposed to the events themselves)
|
||||
|
||||
Args:
|
||||
event_ids: events whose state should be returned
|
||||
state_filter: The state filter used to fetch state from the database.
|
||||
|
||||
Returns:
|
||||
A dict from event_id -> (type, state_key) -> event_id
|
||||
|
||||
Raises:
|
||||
RuntimeError if we don't have a state group for one or more of the events
|
||||
(ie they are outliers or unknown)
|
||||
"""
|
||||
await_full_state = True
|
||||
if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
|
||||
await_full_state = False
|
||||
|
||||
event_to_groups = await self.get_state_group_for_events(
|
||||
event_ids, await_full_state=await_full_state
|
||||
)
|
||||
|
||||
groups = set(event_to_groups.values())
|
||||
group_to_state = await self.stores.state._get_state_for_groups(
|
||||
groups, state_filter or StateFilter.all()
|
||||
)
|
||||
|
||||
event_to_state = {
|
||||
event_id: group_to_state[group]
|
||||
for event_id, group in event_to_groups.items()
|
||||
}
|
||||
|
||||
return {event: event_to_state[event] for event in event_ids}
|
||||
|
||||
async def get_state_for_event(
|
||||
self, event_id: str, state_filter: Optional[StateFilter] = None
|
||||
) -> StateMap[EventBase]:
|
||||
"""
|
||||
Get the state dict corresponding to a particular event
|
||||
|
||||
Args:
|
||||
event_id: event whose state should be returned
|
||||
state_filter: The state filter used to fetch state from the database.
|
||||
|
||||
Returns:
|
||||
A dict from (type, state_key) -> state_event
|
||||
|
||||
Raises:
|
||||
RuntimeError if we don't have a state group for the event (ie it is an
|
||||
outlier or is unknown)
|
||||
"""
|
||||
state_map = await self.get_state_for_events(
|
||||
[event_id], state_filter or StateFilter.all()
|
||||
)
|
||||
return state_map[event_id]
|
||||
|
||||
async def get_state_ids_for_event(
|
||||
self, event_id: str, state_filter: Optional[StateFilter] = None
|
||||
) -> StateMap[str]:
|
||||
"""
|
||||
Get the state dict corresponding to a particular event
|
||||
|
||||
Args:
|
||||
event_id: event whose state should be returned
|
||||
state_filter: The state filter used to fetch state from the database.
|
||||
|
||||
Returns:
|
||||
A dict from (type, state_key) -> state_event_id
|
||||
|
||||
Raises:
|
||||
RuntimeError if we don't have a state group for the event (ie it is an
|
||||
outlier or is unknown)
|
||||
"""
|
||||
state_map = await self.get_state_ids_for_events(
|
||||
[event_id], state_filter or StateFilter.all()
|
||||
)
|
||||
return state_map[event_id]
|
||||
|
||||
def get_state_for_groups(
|
||||
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
|
||||
) -> Awaitable[Dict[int, MutableStateMap[str]]]:
|
||||
"""Gets the state at each of a list of state groups, optionally
|
||||
filtering by type/state_key
|
||||
|
||||
Args:
|
||||
groups: list of state groups for which we want to get the state.
|
||||
state_filter: The state filter used to fetch state.
|
||||
from the database.
|
||||
|
||||
Returns:
|
||||
Dict of state group to state map.
|
||||
"""
|
||||
return self.stores.state._get_state_for_groups(
|
||||
groups, state_filter or StateFilter.all()
|
||||
)
|
||||
|
||||
async def get_state_group_for_events(
|
||||
self,
|
||||
event_ids: Collection[str],
|
||||
await_full_state: bool = True,
|
||||
) -> Mapping[str, int]:
|
||||
"""Returns mapping event_id -> state_group
|
||||
|
||||
Args:
|
||||
event_ids: events to get state groups for
|
||||
await_full_state: if true, will block if we do not yet have complete
|
||||
state at these events.
|
||||
"""
|
||||
if await_full_state:
|
||||
await self._partial_state_events_tracker.await_full_state(event_ids)
|
||||
|
||||
return await self.stores.main._get_state_group_for_events(event_ids)
|
||||
|
||||
async def store_state_group(
|
||||
self,
|
||||
event_id: str,
|
||||
room_id: str,
|
||||
prev_group: Optional[int],
|
||||
delta_ids: Optional[StateMap[str]],
|
||||
current_state_ids: StateMap[str],
|
||||
) -> int:
|
||||
"""Store a new set of state, returning a newly assigned state group.
|
||||
|
||||
Args:
|
||||
event_id: The event ID for which the state was calculated.
|
||||
room_id: ID of the room for which the state was calculated.
|
||||
prev_group: A previous state group for the room, optional.
|
||||
delta_ids: The delta between state at `prev_group` and
|
||||
`current_state_ids`, if `prev_group` was given. Same format as
|
||||
`current_state_ids`.
|
||||
current_state_ids: The state to store. Map of (type, state_key)
|
||||
to event_id.
|
||||
|
||||
Returns:
|
||||
The state group ID
|
||||
"""
|
||||
return await self.stores.state.store_state_group(
|
||||
event_id, room_id, prev_group, delta_ids, current_state_ids
|
||||
)
|
||||
|
@ -20,7 +20,7 @@ from typing_extensions import Final
|
||||
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.utils import prune_event
|
||||
from synapse.storage import Storage
|
||||
from synapse.storage.controllers import StorageControllers
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import RetentionPolicy, StateMap, get_domain_from_id
|
||||
|
||||
@ -47,7 +47,7 @@ _HISTORY_VIS_KEY: Final[Tuple[str, str]] = (EventTypes.RoomHistoryVisibility, ""
|
||||
|
||||
|
||||
async def filter_events_for_client(
|
||||
storage: Storage,
|
||||
storage: StorageControllers,
|
||||
user_id: str,
|
||||
events: List[EventBase],
|
||||
is_peeking: bool = False,
|
||||
@ -268,7 +268,7 @@ async def filter_events_for_client(
|
||||
|
||||
|
||||
async def filter_events_for_server(
|
||||
storage: Storage,
|
||||
storage: StorageControllers,
|
||||
server_name: str,
|
||||
events: List[EventBase],
|
||||
redact: bool = True,
|
||||
@ -360,7 +360,7 @@ async def filter_events_for_server(
|
||||
|
||||
|
||||
async def _event_to_history_vis(
|
||||
storage: Storage, events: Collection[EventBase]
|
||||
storage: StorageControllers, events: Collection[EventBase]
|
||||
) -> Dict[str, str]:
|
||||
"""Get the history visibility at each of the given events
|
||||
|
||||
@ -407,7 +407,7 @@ async def _event_to_history_vis(
|
||||
|
||||
|
||||
async def _event_to_memberships(
|
||||
storage: Storage, events: Collection[EventBase], server_name: str
|
||||
storage: StorageControllers, events: Collection[EventBase], server_name: str
|
||||
) -> Dict[str, StateMap[EventBase]]:
|
||||
"""Get the remote membership list at each of the given events
|
||||
|
||||
|
@ -29,7 +29,7 @@ class TestEventContext(unittest.HomeserverTestCase):
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.store = hs.get_datastores().main
|
||||
self.storage = hs.get_storage()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
|
||||
self.user_id = self.register_user("u1", "pass")
|
||||
self.user_tok = self.login("u1", "pass")
|
||||
@ -87,7 +87,7 @@ class TestEventContext(unittest.HomeserverTestCase):
|
||||
def _check_serialize_deserialize(self, event, context):
|
||||
serialized = self.get_success(context.serialize(event, self.store))
|
||||
|
||||
d_context = EventContext.deserialize(self.storage, serialized)
|
||||
d_context = EventContext.deserialize(self._storage_controllers, serialized)
|
||||
|
||||
self.assertEqual(context.state_group, d_context.state_group)
|
||||
self.assertEqual(context.rejected, d_context.rejected)
|
||||
|
@ -50,7 +50,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
|
||||
hs = self.setup_test_homeserver(federation_http_client=None)
|
||||
self.handler = hs.get_federation_handler()
|
||||
self.store = hs.get_datastores().main
|
||||
self.state_storage = hs.get_storage().state
|
||||
self.state_storage_controller = hs.get_storage_controllers().state
|
||||
self._event_auth_handler = hs.get_event_auth_handler()
|
||||
return hs
|
||||
|
||||
@ -338,7 +338,9 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
|
||||
# mapping from (type, state_key) -> state_event_id
|
||||
assert most_recent_prev_event_id is not None
|
||||
prev_state_map = self.get_success(
|
||||
self.state_storage.get_state_ids_for_event(most_recent_prev_event_id)
|
||||
self.state_storage_controller.get_state_ids_for_event(
|
||||
most_recent_prev_event_id
|
||||
)
|
||||
)
|
||||
# List of state event ID's
|
||||
prev_state_ids = list(prev_state_map.values())
|
||||
|
@ -70,7 +70,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
|
||||
) -> None:
|
||||
OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
|
||||
main_store = self.hs.get_datastores().main
|
||||
state_storage = self.hs.get_storage().state
|
||||
state_storage_controller = self.hs.get_storage_controllers().state
|
||||
|
||||
# create the room
|
||||
user_id = self.register_user("kermit", "test")
|
||||
@ -146,10 +146,11 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
|
||||
)
|
||||
if prev_exists_as_outlier:
|
||||
prev_event.internal_metadata.outlier = True
|
||||
persistence = self.hs.get_storage().persistence
|
||||
persistence = self.hs.get_storage_controllers().persistence
|
||||
self.get_success(
|
||||
persistence.persist_event(
|
||||
prev_event, EventContext.for_outlier(self.hs.get_storage())
|
||||
prev_event,
|
||||
EventContext.for_outlier(self.hs.get_storage_controllers()),
|
||||
)
|
||||
)
|
||||
else:
|
||||
@ -216,7 +217,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
|
||||
|
||||
# check that the state at that event is as expected
|
||||
state = self.get_success(
|
||||
state_storage.get_state_ids_for_event(pulled_event.event_id)
|
||||
state_storage_controller.get_state_ids_for_event(pulled_event.event_id)
|
||||
)
|
||||
expected_state = {
|
||||
(e.type, e.state_key): e.event_id for e in state_at_prev_event
|
||||
|
@ -37,7 +37,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.handler = self.hs.get_event_creation_handler()
|
||||
self.persist_event_storage = self.hs.get_storage().persistence
|
||||
self._persist_event_storage_controller = (
|
||||
self.hs.get_storage_controllers().persistence
|
||||
)
|
||||
|
||||
self.user_id = self.register_user("tester", "foobar")
|
||||
self.access_token = self.login("tester", "foobar")
|
||||
@ -65,7 +67,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
)
|
||||
self.get_success(
|
||||
self.persist_event_storage.persist_event(memberEvent, memberEventContext)
|
||||
self._persist_event_storage_controller.persist_event(
|
||||
memberEvent, memberEventContext
|
||||
)
|
||||
)
|
||||
|
||||
return memberEvent, memberEventContext
|
||||
@ -129,7 +133,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
||||
self.assertNotEqual(event1.event_id, event3.event_id)
|
||||
|
||||
ret_event3, event_pos3, _ = self.get_success(
|
||||
self.persist_event_storage.persist_event(event3, context)
|
||||
self._persist_event_storage_controller.persist_event(event3, context)
|
||||
)
|
||||
|
||||
# Assert that the returned values match those from the initial event
|
||||
@ -143,7 +147,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
||||
self.assertNotEqual(event1.event_id, event3.event_id)
|
||||
|
||||
events, _ = self.get_success(
|
||||
self.persist_event_storage.persist_events([(event3, context)])
|
||||
self._persist_event_storage_controller.persist_events([(event3, context)])
|
||||
)
|
||||
ret_event4 = events[0]
|
||||
|
||||
@ -166,7 +170,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
||||
self.assertNotEqual(event1.event_id, event2.event_id)
|
||||
|
||||
events, _ = self.get_success(
|
||||
self.persist_event_storage.persist_events(
|
||||
self._persist_event_storage_controller.persist_events(
|
||||
[(event1, context1), (event2, context2)]
|
||||
)
|
||||
)
|
||||
|
@ -954,7 +954,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_storage().persistence.persist_event(event, context)
|
||||
self.hs.get_storage_controllers().persistence.persist_event(event, context)
|
||||
)
|
||||
|
||||
def test_local_user_leaving_room_remains_in_user_directory(self) -> None:
|
||||
|
@ -32,7 +32,7 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase):
|
||||
|
||||
self.master_store = hs.get_datastores().main
|
||||
self.slaved_store = self.worker_hs.get_datastores().main
|
||||
self.storage = hs.get_storage()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
|
||||
def replicate(self):
|
||||
"""Tell the master side of replication that something has happened, and then
|
||||
|
@ -262,7 +262,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||
)
|
||||
msg, msgctx = self.build_event()
|
||||
self.get_success(
|
||||
self.storage.persistence.persist_events([(j2, j2ctx), (msg, msgctx)])
|
||||
self._storage_controllers.persistence.persist_events(
|
||||
[(j2, j2ctx), (msg, msgctx)]
|
||||
)
|
||||
)
|
||||
self.replicate()
|
||||
|
||||
@ -323,12 +325,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||
|
||||
if backfill:
|
||||
self.get_success(
|
||||
self.storage.persistence.persist_events(
|
||||
self._storage_controllers.persistence.persist_events(
|
||||
[(event, context)], backfilled=True
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.get_success(self.storage.persistence.persist_event(event, context))
|
||||
self.get_success(
|
||||
self._storage_controllers.persistence.persist_event(event, context)
|
||||
)
|
||||
|
||||
return event
|
||||
|
||||
|
@ -31,7 +31,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
super().prepare(reactor, clock, homeserver)
|
||||
self.room_creator = homeserver.get_room_creation_handler()
|
||||
self.persist_event_storage = self.hs.get_storage().persistence
|
||||
self.persist_event_storage_controller = (
|
||||
self.hs.get_storage_controllers().persistence
|
||||
)
|
||||
|
||||
# Create a test user
|
||||
self.ourUser = UserID.from_string(OUR_USER_ID)
|
||||
@ -61,7 +63,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
|
||||
)
|
||||
)
|
||||
self.get_success(
|
||||
self.persist_event_storage.persist_event(memberEvent, memberEventContext)
|
||||
self.persist_event_storage_controller.persist_event(
|
||||
memberEvent, memberEventContext
|
||||
)
|
||||
)
|
||||
|
||||
# Join the second user to the second room
|
||||
@ -76,7 +80,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
|
||||
)
|
||||
)
|
||||
self.get_success(
|
||||
self.persist_event_storage.persist_event(memberEvent, memberEventContext)
|
||||
self.persist_event_storage_controller.persist_event(
|
||||
memberEvent, memberEventContext
|
||||
)
|
||||
)
|
||||
|
||||
def test_return_empty_with_no_data(self):
|
||||
|
@ -2579,7 +2579,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
|
||||
other_user_tok = self.login("user", "pass")
|
||||
event_builder_factory = self.hs.get_event_builder_factory()
|
||||
event_creation_handler = self.hs.get_event_creation_handler()
|
||||
storage = self.hs.get_storage()
|
||||
storage_controllers = self.hs.get_storage_controllers()
|
||||
|
||||
# Create two rooms, one with a local user only and one with both a local
|
||||
# and remote user.
|
||||
@ -2604,7 +2604,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
|
||||
event_creation_handler.create_new_client_event(builder)
|
||||
)
|
||||
|
||||
self.get_success(storage.persistence.persist_event(event, context))
|
||||
self.get_success(storage_controllers.persistence.persist_event(event, context))
|
||||
|
||||
# Now get rooms
|
||||
url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms"
|
||||
|
@ -130,7 +130,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
|
||||
We do this by setting a very long time between purge jobs.
|
||||
"""
|
||||
store = self.hs.get_datastores().main
|
||||
storage = self.hs.get_storage()
|
||||
storage_controllers = self.hs.get_storage_controllers()
|
||||
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
|
||||
|
||||
# Send a first event, which should be filtered out at the end of the test.
|
||||
@ -155,7 +155,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(2, len(events), "events retrieved from database")
|
||||
filtered_events = self.get_success(
|
||||
filter_events_for_client(storage, self.user_id, events)
|
||||
filter_events_for_client(storage_controllers, self.user_id, events)
|
||||
)
|
||||
|
||||
# We should only get one event back.
|
||||
|
@ -88,7 +88,7 @@ class RoomBatchTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.clock = clock
|
||||
self.storage = hs.get_storage()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
|
||||
self.virtual_user_id, _ = self.register_appservice_user(
|
||||
"as_user_potato", self.appservice.token
|
||||
@ -168,7 +168,9 @@ class RoomBatchTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# Fetch the state_groups
|
||||
state_group_map = self.get_success(
|
||||
self.storage.state.get_state_groups_ids(room_id, historical_event_ids)
|
||||
self._storage_controllers.state.get_state_groups_ids(
|
||||
room_id, historical_event_ids
|
||||
)
|
||||
)
|
||||
|
||||
# We expect all of the historical events to be using the same state_group
|
||||
|
@ -393,7 +393,8 @@ class EventChainStoreTestCase(HomeserverTestCase):
|
||||
# We need to persist the events to the events and state_events
|
||||
# tables.
|
||||
persist_events_store._store_event_txn(
|
||||
txn, [(e, EventContext(self.hs.get_storage())) for e in events]
|
||||
txn,
|
||||
[(e, EventContext(self.hs.get_storage_controllers())) for e in events],
|
||||
)
|
||||
|
||||
# Actually call the function that calculates the auth chain stuff.
|
||||
|
@ -31,7 +31,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
self.state = self.hs.get_state_handler()
|
||||
self.persistence = self.hs.get_storage().persistence
|
||||
self._persistence = self.hs.get_storage_controllers().persistence
|
||||
self.store = self.hs.get_datastores().main
|
||||
|
||||
self.register_user("user", "pass")
|
||||
@ -71,7 +71,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||
context = self.get_success(
|
||||
self.state.compute_event_context(event, state_ids_before_event=state)
|
||||
)
|
||||
self.get_success(self.persistence.persist_event(event, context))
|
||||
self.get_success(self._persistence.persist_event(event, context))
|
||||
|
||||
def assert_extremities(self, expected_extremities):
|
||||
"""Assert the current extremities for the room"""
|
||||
@ -148,7 +148,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
self.get_success(self.persistence.persist_event(remote_event_2, context))
|
||||
self.get_success(self._persistence.persist_event(remote_event_2, context))
|
||||
|
||||
# Check that we haven't dropped the old extremity.
|
||||
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
|
||||
@ -353,7 +353,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
self.state = self.hs.get_state_handler()
|
||||
self.persistence = self.hs.get_storage().persistence
|
||||
self._persistence = self.hs.get_storage_controllers().persistence
|
||||
self.store = self.hs.get_datastores().main
|
||||
|
||||
def test_remote_user_rooms_cache_invalidated(self):
|
||||
@ -390,7 +390,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
|
||||
)
|
||||
|
||||
context = self.get_success(self.state.compute_event_context(remote_event_1))
|
||||
self.get_success(self.persistence.persist_event(remote_event_1, context))
|
||||
self.get_success(self._persistence.persist_event(remote_event_1, context))
|
||||
|
||||
# Call `get_rooms_for_user` to add the remote user to the cache
|
||||
rooms = self.get_success(self.store.get_rooms_for_user(remote_user))
|
||||
@ -437,7 +437,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
|
||||
)
|
||||
|
||||
context = self.get_success(self.state.compute_event_context(remote_event_1))
|
||||
self.get_success(self.persistence.persist_event(remote_event_1, context))
|
||||
self.get_success(self._persistence.persist_event(remote_event_1, context))
|
||||
|
||||
# Call `get_users_in_room` to add the remote user to the cache
|
||||
users = self.get_success(self.store.get_users_in_room(room_id))
|
||||
|
@ -31,7 +31,7 @@ class PurgeTests(HomeserverTestCase):
|
||||
self.room_id = self.helper.create_room_as(self.user_id)
|
||||
|
||||
self.store = hs.get_datastores().main
|
||||
self.storage = self.hs.get_storage()
|
||||
self._storage_controllers = self.hs.get_storage_controllers()
|
||||
|
||||
def test_purge_history(self):
|
||||
"""
|
||||
@ -51,7 +51,9 @@ class PurgeTests(HomeserverTestCase):
|
||||
|
||||
# Purge everything before this topological token
|
||||
self.get_success(
|
||||
self.storage.purge_events.purge_history(self.room_id, token_str, True)
|
||||
self._storage_controllers.purge_events.purge_history(
|
||||
self.room_id, token_str, True
|
||||
)
|
||||
)
|
||||
|
||||
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
|
||||
@ -79,7 +81,9 @@ class PurgeTests(HomeserverTestCase):
|
||||
|
||||
# Purge everything before this topological token
|
||||
f = self.get_failure(
|
||||
self.storage.purge_events.purge_history(self.room_id, event, True),
|
||||
self._storage_controllers.purge_events.purge_history(
|
||||
self.room_id, event, True
|
||||
),
|
||||
SynapseError,
|
||||
)
|
||||
self.assertIn("greater than forward", f.value.args[0])
|
||||
@ -105,7 +109,9 @@ class PurgeTests(HomeserverTestCase):
|
||||
self.assertIsNotNone(create_event)
|
||||
|
||||
# Purge everything before this topological token
|
||||
self.get_success(self.storage.purge_events.purge_room(self.room_id))
|
||||
self.get_success(
|
||||
self._storage_controllers.purge_events.purge_room(self.room_id)
|
||||
)
|
||||
|
||||
# The events aren't found.
|
||||
self.store._invalidate_get_event_cache(create_event.event_id)
|
||||
|
@ -31,7 +31,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.store = hs.get_datastores().main
|
||||
self.storage = hs.get_storage()
|
||||
self._storage = hs.get_storage_controllers()
|
||||
self.event_builder_factory = hs.get_event_builder_factory()
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
|
||||
@ -71,7 +71,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||
self.event_creation_handler.create_new_client_event(builder)
|
||||
)
|
||||
|
||||
self.get_success(self.storage.persistence.persist_event(event, context))
|
||||
self.get_success(self._storage.persistence.persist_event(event, context))
|
||||
|
||||
return event
|
||||
|
||||
@ -93,7 +93,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||
self.event_creation_handler.create_new_client_event(builder)
|
||||
)
|
||||
|
||||
self.get_success(self.storage.persistence.persist_event(event, context))
|
||||
self.get_success(self._storage.persistence.persist_event(event, context))
|
||||
|
||||
return event
|
||||
|
||||
@ -114,7 +114,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||
self.event_creation_handler.create_new_client_event(builder)
|
||||
)
|
||||
|
||||
self.get_success(self.storage.persistence.persist_event(event, context))
|
||||
self.get_success(self._storage.persistence.persist_event(event, context))
|
||||
|
||||
return event
|
||||
|
||||
@ -268,7 +268,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
self.get_success(self.storage.persistence.persist_event(event_1, context_1))
|
||||
self.get_success(self._storage.persistence.persist_event(event_1, context_1))
|
||||
|
||||
event_2, context_2 = self.get_success(
|
||||
self.event_creation_handler.create_new_client_event(
|
||||
@ -287,7 +287,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
)
|
||||
)
|
||||
self.get_success(self.storage.persistence.persist_event(event_2, context_2))
|
||||
self.get_success(self._storage.persistence.persist_event(event_2, context_2))
|
||||
|
||||
# fetch one of the redactions
|
||||
fetched = self.get_success(self.store.get_event(redaction_event_id1))
|
||||
@ -411,7 +411,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.get_success(
|
||||
self.storage.persistence.persist_event(redaction_event, context)
|
||||
self._storage.persistence.persist_event(redaction_event, context)
|
||||
)
|
||||
|
||||
# Now lets jump to the future where we have censored the redaction event
|
||||
|
@ -72,7 +72,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
|
||||
# Room events need the full datastore, for persist_event() and
|
||||
# get_room_state()
|
||||
self.store = hs.get_datastores().main
|
||||
self.storage = hs.get_storage()
|
||||
self._storage = hs.get_storage_controllers()
|
||||
self.event_factory = hs.get_event_factory()
|
||||
|
||||
self.room = RoomID.from_string("!abcde:test")
|
||||
@ -88,7 +88,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
|
||||
|
||||
def inject_room_event(self, **kwargs):
|
||||
self.get_success(
|
||||
self.storage.persistence.persist_event(
|
||||
self._storage.persistence.persist_event(
|
||||
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
|
||||
)
|
||||
)
|
||||
|
@ -99,7 +99,9 @@ class EventSearchInsertionTest(HomeserverTestCase):
|
||||
prev_event_ids = self.get_success(store.get_prev_events_for_room(room_id))
|
||||
prev_event = self.get_success(store.get_event(prev_event_ids[0]))
|
||||
prev_state_map = self.get_success(
|
||||
self.hs.get_storage().state.get_state_ids_for_event(prev_event_ids[0])
|
||||
self.hs.get_storage_controllers().state.get_state_ids_for_event(
|
||||
prev_event_ids[0]
|
||||
)
|
||||
)
|
||||
|
||||
event_dict = {
|
||||
|
@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
|
||||
class StateStoreTestCase(HomeserverTestCase):
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.store = hs.get_datastores().main
|
||||
self.storage = hs.get_storage()
|
||||
self.storage = hs.get_storage_controllers()
|
||||
self.state_datastore = self.storage.state.stores.state
|
||||
self.event_builder_factory = hs.get_event_builder_factory()
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
|
@ -179,12 +179,12 @@ class Graph:
|
||||
class StateTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.dummy_store = _DummyStore()
|
||||
storage = Mock(main=self.dummy_store, state=self.dummy_store)
|
||||
storage_controllers = Mock(main=self.dummy_store, state=self.dummy_store)
|
||||
hs = Mock(
|
||||
spec_set=[
|
||||
"config",
|
||||
"get_datastores",
|
||||
"get_storage",
|
||||
"get_storage_controllers",
|
||||
"get_auth",
|
||||
"get_state_handler",
|
||||
"get_clock",
|
||||
@ -199,7 +199,7 @@ class StateTestCase(unittest.TestCase):
|
||||
hs.get_clock.return_value = MockClock()
|
||||
hs.get_auth.return_value = Auth(hs)
|
||||
hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
|
||||
hs.get_storage.return_value = storage
|
||||
hs.get_storage_controllers.return_value = storage_controllers
|
||||
|
||||
self.state = StateHandler(hs)
|
||||
self.event_id = 0
|
||||
|
@ -70,7 +70,7 @@ async def inject_event(
|
||||
"""
|
||||
event, context = await create_event(hs, room_version, prev_event_ids, **kwargs)
|
||||
|
||||
persistence = hs.get_storage().persistence
|
||||
persistence = hs.get_storage_controllers().persistence
|
||||
assert persistence is not None
|
||||
|
||||
await persistence.persist_event(event, context)
|
||||
|
@ -34,7 +34,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||
super(FilterEventsForServerTestCase, self).setUp()
|
||||
self.event_creation_handler = self.hs.get_event_creation_handler()
|
||||
self.event_builder_factory = self.hs.get_event_builder_factory()
|
||||
self.storage = self.hs.get_storage()
|
||||
self._storage_controllers = self.hs.get_storage_controllers()
|
||||
|
||||
self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
|
||||
|
||||
@ -60,7 +60,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||
events_to_filter.append(evt)
|
||||
|
||||
filtered = self.get_success(
|
||||
filter_events_for_server(self.storage, "test_server", events_to_filter)
|
||||
filter_events_for_server(
|
||||
self._storage_controllers, "test_server", events_to_filter
|
||||
)
|
||||
)
|
||||
|
||||
# the result should be 5 redacted events, and 5 unredacted events.
|
||||
@ -80,7 +82,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||
outlier = self._inject_outlier()
|
||||
self.assertEqual(
|
||||
self.get_success(
|
||||
filter_events_for_server(self.storage, "remote_hs", [outlier])
|
||||
filter_events_for_server(
|
||||
self._storage_controllers, "remote_hs", [outlier]
|
||||
)
|
||||
),
|
||||
[outlier],
|
||||
)
|
||||
@ -89,7 +93,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||
evt = self._inject_message("@unerased:local_hs")
|
||||
|
||||
filtered = self.get_success(
|
||||
filter_events_for_server(self.storage, "remote_hs", [outlier, evt])
|
||||
filter_events_for_server(
|
||||
self._storage_controllers, "remote_hs", [outlier, evt]
|
||||
)
|
||||
)
|
||||
self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}")
|
||||
self.assertEqual(filtered[0], outlier)
|
||||
@ -99,7 +105,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||
# ... but other servers should only be able to see the outlier (the other should
|
||||
# be redacted)
|
||||
filtered = self.get_success(
|
||||
filter_events_for_server(self.storage, "other_server", [outlier, evt])
|
||||
filter_events_for_server(
|
||||
self._storage_controllers, "other_server", [outlier, evt]
|
||||
)
|
||||
)
|
||||
self.assertEqual(filtered[0], outlier)
|
||||
self.assertEqual(filtered[1].event_id, evt.event_id)
|
||||
@ -132,7 +140,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# ... and the filtering happens.
|
||||
filtered = self.get_success(
|
||||
filter_events_for_server(self.storage, "test_server", events_to_filter)
|
||||
filter_events_for_server(
|
||||
self._storage_controllers, "test_server", events_to_filter
|
||||
)
|
||||
)
|
||||
|
||||
for i in range(0, len(events_to_filter)):
|
||||
@ -168,7 +178,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||
event, context = self.get_success(
|
||||
self.event_creation_handler.create_new_client_event(builder)
|
||||
)
|
||||
self.get_success(self.storage.persistence.persist_event(event, context))
|
||||
self.get_success(
|
||||
self._storage_controllers.persistence.persist_event(event, context)
|
||||
)
|
||||
return event
|
||||
|
||||
def _inject_room_member(
|
||||
@ -194,7 +206,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||
self.event_creation_handler.create_new_client_event(builder)
|
||||
)
|
||||
|
||||
self.get_success(self.storage.persistence.persist_event(event, context))
|
||||
self.get_success(
|
||||
self._storage_controllers.persistence.persist_event(event, context)
|
||||
)
|
||||
return event
|
||||
|
||||
def _inject_message(
|
||||
@ -216,7 +230,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||
self.event_creation_handler.create_new_client_event(builder)
|
||||
)
|
||||
|
||||
self.get_success(self.storage.persistence.persist_event(event, context))
|
||||
self.get_success(
|
||||
self._storage_controllers.persistence.persist_event(event, context)
|
||||
)
|
||||
return event
|
||||
|
||||
def _inject_outlier(self) -> EventBase:
|
||||
@ -234,8 +250,8 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||
event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[]))
|
||||
event.internal_metadata.outlier = True
|
||||
self.get_success(
|
||||
self.storage.persistence.persist_event(
|
||||
event, EventContext.for_outlier(self.storage)
|
||||
self._storage_controllers.persistence.persist_event(
|
||||
event, EventContext.for_outlier(self._storage_controllers)
|
||||
)
|
||||
)
|
||||
return event
|
||||
@ -293,7 +309,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
|
||||
self.assertEqual(
|
||||
self.get_success(
|
||||
filter_events_for_client(
|
||||
self.hs.get_storage(), "@user:test", [invite_event, reject_event]
|
||||
self.hs.get_storage_controllers(),
|
||||
"@user:test",
|
||||
[invite_event, reject_event],
|
||||
)
|
||||
),
|
||||
[invite_event, reject_event],
|
||||
@ -303,7 +321,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
|
||||
self.assertEqual(
|
||||
self.get_success(
|
||||
filter_events_for_client(
|
||||
self.hs.get_storage(), "@other:test", [invite_event, reject_event]
|
||||
self.hs.get_storage_controllers(),
|
||||
"@other:test",
|
||||
[invite_event, reject_event],
|
||||
)
|
||||
),
|
||||
[],
|
||||
|
@ -264,7 +264,7 @@ class MockClock:
|
||||
async def create_room(hs, room_id: str, creator_id: str):
|
||||
"""Creates and persist a creation event for the given room"""
|
||||
|
||||
persistence_store = hs.get_storage().persistence
|
||||
persistence_store = hs.get_storage_controllers().persistence
|
||||
store = hs.get_datastores().main
|
||||
event_builder_factory = hs.get_event_builder_factory()
|
||||
event_creation_handler = hs.get_event_creation_handler()
|
||||
|
Loading…
Reference in New Issue
Block a user