From e24ff8ebe3d4119d377355402245947f7de61c00 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 23 Feb 2022 11:04:02 +0000 Subject: [PATCH] Remove `HomeServer.get_datastore()` (#12031) The presence of this method was confusing, and mostly present for backwards compatibility. Let's get rid of it. Part of #11733 --- changelog.d/12031.misc | 1 + docs/manhole.md | 2 +- scripts/update_synapse_database | 2 +- synapse/api/auth.py | 2 +- synapse/api/auth_blocking.py | 2 +- synapse/api/filtering.py | 4 +-- synapse/app/_base.py | 2 +- synapse/app/generic_worker.py | 2 +- synapse/app/homeserver.py | 2 +- synapse/app/phone_stats_home.py | 14 +++++--- synapse/appservice/scheduler.py | 2 +- synapse/crypto/keyring.py | 4 +-- synapse/events/builder.py | 2 +- synapse/events/third_party_rules.py | 2 +- synapse/federation/federation_base.py | 2 +- synapse/federation/sender/__init__.py | 2 +- .../sender/per_destination_queue.py | 9 ++--- .../federation/sender/transaction_manager.py | 2 +- synapse/federation/transport/server/_base.py | 2 +- .../federation/transport/server/federation.py | 2 +- synapse/groups/attestations.py | 2 +- synapse/groups/groups_server.py | 2 +- synapse/handlers/account_data.py | 4 +-- synapse/handlers/account_validity.py | 2 +- synapse/handlers/admin.py | 2 +- synapse/handlers/appservice.py | 2 +- synapse/handlers/auth.py | 4 +-- synapse/handlers/cas.py | 2 +- synapse/handlers/deactivate_account.py | 2 +- synapse/handlers/device.py | 4 +-- synapse/handlers/devicemessage.py | 2 +- synapse/handlers/directory.py | 2 +- synapse/handlers/e2e_keys.py | 4 +-- synapse/handlers/e2e_room_keys.py | 2 +- synapse/handlers/event_auth.py | 2 +- synapse/handlers/events.py | 4 +-- synapse/handlers/federation.py | 2 +- synapse/handlers/federation_event.py | 2 +- synapse/handlers/groups_local.py | 2 +- synapse/handlers/identity.py | 2 +- synapse/handlers/initial_sync.py | 2 +- synapse/handlers/message.py | 4 +-- synapse/handlers/oidc.py | 2 +- synapse/handlers/pagination.py | 2 +- synapse/handlers/presence.py | 4 +-- synapse/handlers/profile.py | 2 +- synapse/handlers/read_marker.py | 2 +- synapse/handlers/receipts.py | 4 +-- synapse/handlers/register.py | 2 +- synapse/handlers/room.py | 10 +++--- synapse/handlers/room_batch.py | 2 +- synapse/handlers/room_list.py | 2 +- synapse/handlers/room_member.py | 2 +- synapse/handlers/room_summary.py | 2 +- synapse/handlers/saml.py | 2 +- synapse/handlers/search.py | 2 +- synapse/handlers/set_password.py | 2 +- synapse/handlers/sso.py | 2 +- synapse/handlers/state_deltas.py | 2 +- synapse/handlers/stats.py | 2 +- synapse/handlers/sync.py | 2 +- synapse/handlers/typing.py | 4 +-- synapse/handlers/ui_auth/checkers.py | 4 +-- synapse/handlers/user_directory.py | 2 +- synapse/http/matrixfederationclient.py | 2 +- synapse/module_api/__init__.py | 8 +++-- synapse/notifier.py | 2 +- synapse/push/__init__.py | 2 +- synapse/push/bulk_push_rule_evaluator.py | 4 +-- synapse/push/emailpusher.py | 2 +- synapse/push/httppusher.py | 4 +-- synapse/push/mailer.py | 2 +- synapse/push/pusherpool.py | 2 +- synapse/replication/http/devices.py | 2 +- synapse/replication/http/federation.py | 10 +++--- synapse/replication/http/membership.py | 10 +++--- synapse/replication/http/register.py | 4 +-- synapse/replication/http/send_event.py | 2 +- synapse/replication/tcp/client.py | 4 +-- synapse/replication/tcp/handler.py | 2 +- synapse/replication/tcp/resource.py | 2 +- synapse/replication/tcp/streams/_base.py | 24 ++++++------- synapse/replication/tcp/streams/events.py | 2 +- synapse/rest/admin/__init__.py | 2 +- synapse/rest/admin/background_updates.py | 2 +- synapse/rest/admin/devices.py | 6 ++-- synapse/rest/admin/event_reports.py | 4 +-- synapse/rest/admin/federation.py | 8 ++--- synapse/rest/admin/media.py | 20 +++++------ synapse/rest/admin/registration_tokens.py | 6 ++-- synapse/rest/admin/rooms.py | 16 ++++----- synapse/rest/admin/statistics.py | 2 +- synapse/rest/admin/users.py | 24 ++++++------- synapse/rest/client/account.py | 18 +++++----- synapse/rest/client/account_data.py | 4 +-- synapse/rest/client/directory.py | 6 ++-- synapse/rest/client/events.py | 2 +- synapse/rest/client/groups.py | 8 ++--- synapse/rest/client/initial_sync.py | 2 +- synapse/rest/client/keys.py | 2 +- synapse/rest/client/login.py | 4 +-- synapse/rest/client/notifications.py | 2 +- synapse/rest/client/openid.py | 2 +- synapse/rest/client/push_rule.py | 2 +- synapse/rest/client/pusher.py | 4 ++- synapse/rest/client/register.py | 10 +++--- synapse/rest/client/relations.py | 6 ++-- synapse/rest/client/report_event.py | 2 +- synapse/rest/client/room.py | 12 +++---- synapse/rest/client/room_batch.py | 2 +- synapse/rest/client/shared_rooms.py | 2 +- synapse/rest/client/sync.py | 2 +- synapse/rest/client/tags.py | 2 +- synapse/rest/consent/consent_resource.py | 2 +- synapse/rest/key/v2/remote_key_resource.py | 2 +- synapse/rest/media/v1/media_repository.py | 2 +- synapse/rest/media/v1/preview_url_resource.py | 2 +- synapse/rest/media/v1/thumbnail_resource.py | 2 +- synapse/rest/media/v1/upload_resource.py | 2 +- synapse/rest/synapse/client/password_reset.py | 2 +- synapse/server.py | 16 +++------ .../server_notices/consent_server_notices.py | 2 +- .../resource_limits_server_notices.py | 2 +- .../server_notices/server_notices_manager.py | 2 +- synapse/state/__init__.py | 2 +- .../databases/main/monthly_active_users.py | 2 +- synapse/streams/events.py | 2 +- tests/api/test_auth.py | 8 +++-- tests/api/test_filtering.py | 2 +- tests/api/test_ratelimiting.py | 18 +++++----- tests/app/test_phone_stats_home.py | 34 ++++++++++--------- tests/crypto/test_keyring.py | 10 +++--- tests/events/test_snapshot.py | 2 +- tests/federation/test_complexity.py | 4 +-- tests/federation/test_federation_catch_up.py | 26 +++++++------- tests/federation/test_federation_sender.py | 6 ++-- tests/federation/transport/test_knocking.py | 2 +- tests/handlers/test_appservice.py | 10 +++--- tests/handlers/test_auth.py | 12 +++---- tests/handlers/test_cas.py | 2 +- tests/handlers/test_deactivate_account.py | 2 +- tests/handlers/test_device.py | 4 +-- tests/handlers/test_directory.py | 6 ++-- tests/handlers/test_e2e_keys.py | 2 +- tests/handlers/test_federation.py | 2 +- tests/handlers/test_message.py | 2 +- tests/handlers/test_oidc.py | 6 ++-- tests/handlers/test_presence.py | 4 +-- tests/handlers/test_profile.py | 4 +-- tests/handlers/test_register.py | 6 ++-- tests/handlers/test_saml.py | 4 +-- tests/handlers/test_stats.py | 2 +- tests/handlers/test_sync.py | 4 +-- tests/handlers/test_typing.py | 2 +- tests/handlers/test_user_directory.py | 2 +- tests/module_api/test_api.py | 2 +- tests/push/test_email.py | 22 ++++++------ tests/push/test_http.py | 22 ++++++------ tests/replication/_base.py | 6 ++-- tests/replication/slave/storage/_base.py | 4 +-- .../tcp/streams/test_account_data.py | 4 +-- tests/replication/tcp/streams/test_events.py | 4 +-- .../replication/tcp/streams/test_receipts.py | 4 +-- .../test_federation_sender_shard.py | 2 +- tests/replication/test_pusher_shard.py | 2 +- .../test_sharded_event_persister.py | 8 ++--- tests/rest/admin/test_background_updates.py | 2 +- tests/rest/admin/test_federation.py | 4 +-- tests/rest/admin/test_media.py | 4 +-- tests/rest/admin/test_registration_tokens.py | 2 +- tests/rest/admin/test_room.py | 6 ++-- tests/rest/admin/test_server_notice.py | 2 +- tests/rest/admin/test_user.py | 20 +++++------ tests/rest/client/test_account.py | 10 +++--- tests/rest/client/test_filter.py | 2 +- tests/rest/client/test_login.py | 4 +-- tests/rest/client/test_profile.py | 2 +- tests/rest/client/test_register.py | 28 ++++++++------- tests/rest/client/test_relations.py | 4 +-- tests/rest/client/test_retention.py | 4 +-- tests/rest/client/test_rooms.py | 10 +++--- tests/rest/client/test_shadow_banned.py | 2 +- tests/rest/client/test_shared_rooms.py | 2 +- tests/rest/client/test_sync.py | 2 +- tests/rest/client/test_typing.py | 2 +- tests/rest/client/test_upgrade_room.py | 2 +- tests/rest/media/v1/test_media_storage.py | 2 +- .../test_resource_limits_server_notices.py | 2 +- .../databases/main/test_deviceinbox.py | 2 +- .../databases/main/test_events_worker.py | 6 ++-- tests/storage/databases/main/test_lock.py | 2 +- tests/storage/databases/main/test_room.py | 2 +- tests/storage/test__base.py | 2 +- tests/storage/test_account_data.py | 2 +- tests/storage/test_appservice.py | 2 +- tests/storage/test_background_update.py | 8 ++--- tests/storage/test_cleanup_extrems.py | 4 +-- tests/storage/test_client_ips.py | 4 +-- tests/storage/test_devices.py | 2 +- tests/storage/test_directory.py | 2 +- tests/storage/test_e2e_room_keys.py | 2 +- tests/storage/test_end_to_end_keys.py | 2 +- tests/storage/test_event_chain.py | 4 +-- tests/storage/test_event_federation.py | 2 +- tests/storage/test_event_push_actions.py | 2 +- tests/storage/test_events.py | 4 +-- tests/storage/test_id_generators.py | 6 ++-- tests/storage/test_keys.py | 4 +-- tests/storage/test_main.py | 2 +- tests/storage/test_monthly_active_users.py | 2 +- tests/storage/test_profile.py | 2 +- tests/storage/test_purge.py | 4 +-- tests/storage/test_redaction.py | 2 +- tests/storage/test_registration.py | 2 +- tests/storage/test_rollback_worker.py | 6 ++-- tests/storage/test_room.py | 4 +-- tests/storage/test_room_search.py | 2 +- tests/storage/test_roommember.py | 4 +-- tests/storage/test_state.py | 2 +- tests/storage/test_stream.py | 2 +- tests/storage/test_transactions.py | 2 +- tests/storage/test_user_directory.py | 4 +-- tests/test_federation.py | 8 +++-- tests/test_mau.py | 2 +- tests/test_state.py | 4 +-- tests/test_utils/event_injection.py | 4 ++- tests/test_visibility.py | 4 ++- tests/unittest.py | 14 ++++---- tests/util/test_retryutils.py | 4 +-- tests/utils.py | 2 +- 230 files changed, 526 insertions(+), 500 deletions(-) create mode 100644 changelog.d/12031.misc diff --git a/changelog.d/12031.misc b/changelog.d/12031.misc new file mode 100644 index 000000000..d4bedc6b9 --- /dev/null +++ b/changelog.d/12031.misc @@ -0,0 +1 @@ +Remove legacy `HomeServer.get_datastore()`. diff --git a/docs/manhole.md b/docs/manhole.md index 715ed840f..a82fad0f0 100644 --- a/docs/manhole.md +++ b/docs/manhole.md @@ -94,6 +94,6 @@ As a simple example, retrieving an event from the database: ```pycon >>> from twisted.internet import defer ->>> defer.ensureDeferred(hs.get_datastore().get_event('$1416420717069yeQaw:matrix.org')) +>>> defer.ensureDeferred(hs.get_datastores().main.get_event('$1416420717069yeQaw:matrix.org')) > ``` diff --git a/scripts/update_synapse_database b/scripts/update_synapse_database index 5c6453d77..f43676afa 100755 --- a/scripts/update_synapse_database +++ b/scripts/update_synapse_database @@ -44,7 +44,7 @@ class MockHomeserver(HomeServer): def run_background_updates(hs): - store = hs.get_datastore() + store = hs.get_datastores().main async def run_background_updates(): await store.db_pool.updates.run_background_updates(sleep=False) diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 683241201..01c32417d 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -60,7 +60,7 @@ class Auth: def __init__(self, hs: "HomeServer"): self.hs = hs self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state = hs.get_state_handler() self._account_validity_handler = hs.get_account_validity_handler() diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py index 08fe160c9..22348d2d8 100644 --- a/synapse/api/auth_blocking.py +++ b/synapse/api/auth_blocking.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) class AuthBlocking: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._server_notices_mxid = hs.config.servernotices.server_notices_mxid self._hs_disabled = hs.config.server.hs_disabled diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index d087c816d..fe4cc2e8e 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -150,7 +150,7 @@ def matrix_user_id_validator(user_id_str: str) -> UserID: class Filtering: def __init__(self, hs: "HomeServer"): self._hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {}) @@ -294,7 +294,7 @@ class FilterCollection: class Filter: def __init__(self, hs: "HomeServer", filter_json: JsonDict): self._hs = hs - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self.filter_json = filter_json self.limit = filter_json.get("limit", 10) diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 452c0c09d..3e59805ba 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -448,7 +448,7 @@ async def start(hs: "HomeServer") -> None: # It is now safe to start your Synapse. hs.start_listening() - hs.get_datastore().db_pool.start_profiling() + hs.get_datastores().main.db_pool.start_profiling() hs.get_pusherpool().start() # Log when we start the shut down process. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index aadc882bf..1536a4272 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -142,7 +142,7 @@ class KeyUploadServlet(RestServlet): """ super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.http_client = hs.get_simple_http_client() self.main_uri = hs.config.worker.worker_main_http_uri diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index bfb30003c..b9931001c 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -372,7 +372,7 @@ def setup(config_options: List[str]) -> SynapseHomeServer: await _base.start(hs) - hs.get_datastore().db_pool.updates.start_doing_background_updates() + hs.get_datastores().main.db_pool.updates.start_doing_background_updates() register_start(start) diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index 899dba5c3..40dbdace8 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -82,7 +82,7 @@ async def phone_stats_home( # General statistics # - store = hs.get_datastore() + store = hs.get_datastores().main stats["homeserver"] = hs.config.server.server_name stats["server_context"] = hs.config.server.server_context @@ -170,18 +170,22 @@ def start_phone_stats_home(hs: "HomeServer") -> None: # Rather than update on per session basis, batch up the requests. # If you increase the loop period, the accuracy of user_daily_visits # table will decrease - clock.looping_call(hs.get_datastore().generate_user_daily_visits, 5 * 60 * 1000) + clock.looping_call( + hs.get_datastores().main.generate_user_daily_visits, 5 * 60 * 1000 + ) # monthly active user limiting functionality - clock.looping_call(hs.get_datastore().reap_monthly_active_users, 1000 * 60 * 60) - hs.get_datastore().reap_monthly_active_users() + clock.looping_call( + hs.get_datastores().main.reap_monthly_active_users, 1000 * 60 * 60 + ) + hs.get_datastores().main.reap_monthly_active_users() @wrap_as_background_process("generate_monthly_active_users") async def generate_monthly_active_users() -> None: current_mau_count = 0 current_mau_count_by_service = {} reserved_users: Sized = () - store = hs.get_datastore() + store = hs.get_datastores().main if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only: current_mau_count = await store.get_monthly_active_count() current_mau_count_by_service = ( diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index c42fa32ff..b4e602e88 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -92,7 +92,7 @@ class ApplicationServiceScheduler: def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.as_api = hs.get_application_service_api() self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 72d4a69aa..93d56c077 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -476,7 +476,7 @@ class StoreKeyFetcher(KeyFetcher): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def _fetch_keys( self, keys_to_fetch: List[_FetchKeyRequest] @@ -498,7 +498,7 @@ class BaseV2KeyFetcher(KeyFetcher): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.config = hs.config async def process_v2_response( diff --git a/synapse/events/builder.py b/synapse/events/builder.py index eb39e0ae3..1ea1bb7d3 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -189,7 +189,7 @@ class EventBuilderFactory: self.hostname = hs.hostname self.signing_key = hs.signing_key - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state = hs.get_state_handler() self._event_auth_handler = hs.get_event_auth_handler() diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 1bb8ca714..71ec100a7 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -143,7 +143,7 @@ class ThirdPartyEventRules: def __init__(self, hs: "HomeServer"): self.third_party_rules = None - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = [] self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = [] diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index fab6da3c0..41ac49fdc 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -39,7 +39,7 @@ class FederationBase: self.server_name = hs.hostname self.keyring = hs.get_keyring() self.spam_checker = hs.get_spam_checker() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._clock = hs.get_clock() async def _check_sigs_and_hash( diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 720d7bd74..6106a486d 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -228,7 +228,7 @@ class FederationSender(AbstractFederationSender): self.hs = hs self.server_name = hs.hostname - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state = hs.get_state_handler() self.clock = hs.get_clock() diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index c3132f731..c8768f22b 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -76,7 +76,7 @@ class PerDestinationQueue: ): self._server_name = hs.hostname self._clock = hs.get_clock() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._transaction_manager = transaction_manager self._instance_name = hs.get_instance_name() self._federation_shard_config = hs.config.worker.federation_shard_config @@ -381,9 +381,8 @@ class PerDestinationQueue: ) ) - last_successful_stream_ordering = self._last_successful_stream_ordering - - if last_successful_stream_ordering is None: + _tmp_last_successful_stream_ordering = self._last_successful_stream_ordering + if _tmp_last_successful_stream_ordering is None: # if it's still None, then this means we don't have the information # in our database ­ we haven't successfully sent a PDU to this server # (at least since the introduction of the feature tracking @@ -393,6 +392,8 @@ class PerDestinationQueue: self._catching_up = False return + last_successful_stream_ordering: int = _tmp_last_successful_stream_ordering + # get at most 50 catchup room/PDUs while True: event_ids = await self._store.get_catch_up_room_event_ids( diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 742ee5725..0c1cad86a 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -53,7 +53,7 @@ class TransactionManager: def __init__(self, hs: "synapse.server.HomeServer"): self._server_name = hs.hostname self.clock = hs.get_clock() # nb must be called this for @measure_func - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._transaction_actions = TransactionActions(self._store) self._transport_layer = hs.get_federation_transport_client() diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py index dff2b6835..87e99c7dd 100644 --- a/synapse/federation/transport/server/_base.py +++ b/synapse/federation/transport/server/_base.py @@ -55,7 +55,7 @@ class Authenticator: self._clock = hs.get_clock() self.keyring = hs.get_keyring() self.server_name = hs.hostname - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.federation_domain_whitelist = ( hs.config.federation.federation_domain_whitelist ) diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 4d75e58bf..9cc9a7339 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -746,7 +746,7 @@ class RoomComplexityServlet(BaseFederationServlet): server_name: str, ): super().__init__(hs, authenticator, ratelimiter, server_name) - self._store = self.hs.get_datastore() + self._store = self.hs.get_datastores().main async def on_GET( self, diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index a87896e53..ed26d6a6c 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -140,7 +140,7 @@ class GroupAttestionRenewer: def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.assestations = hs.get_groups_attestation_signing() self.transport_client = hs.get_federation_transport_client() self.is_mine_id = hs.is_mine_id diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 449bbc700..4c3a5a6e2 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -45,7 +45,7 @@ MAX_LONG_DESC_LEN = 10000 class GroupsServerWorkerHandler: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room_list_handler = hs.get_room_list_handler() self.auth = hs.get_auth() self.clock = hs.get_clock() diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index bad48713b..177b4f899 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -30,7 +30,7 @@ if TYPE_CHECKING: class AccountDataHandler: def __init__(self, hs: "HomeServer"): - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._instance_name = hs.get_instance_name() self._notifier = hs.get_notifier() @@ -166,7 +166,7 @@ class AccountDataHandler: class AccountDataEventSource(EventSource[int, JsonDict]): def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def get_current_key(self, direction: str = "f") -> int: return self.store.get_max_account_data_stream_id() diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 87e415df7..9d0975f63 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -43,7 +43,7 @@ class AccountValidityHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.config = hs.config - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.send_email_handler = self.hs.get_send_email_handler() self.clock = self.hs.get_clock() diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 00ab5e79b..96376963f 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class AdminHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_store = self.storage.state diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index a42c3558e..e6461cc3c 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -47,7 +47,7 @@ events_processed_counter = Counter("synapse_handlers_appservice_events_processed class ApplicationServicesHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.is_mine_id = hs.is_mine_id self.appservice_api = hs.get_application_service_api() self.scheduler = hs.get_application_service_scheduler() diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 572f54b1e..3e29c96a4 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -194,7 +194,7 @@ class AuthHandler: SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.clock = hs.get_clock() self.checkers: Dict[str, UserInteractiveAuthChecker] = {} @@ -1183,7 +1183,7 @@ class AuthHandler: # No password providers were able to handle this 3pid # Check local store - user_id = await self.hs.get_datastore().get_user_id_by_threepid( + user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( medium, address ) if not user_id: diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py index 5d8f6c50a..7163af800 100644 --- a/synapse/handlers/cas.py +++ b/synapse/handlers/cas.py @@ -61,7 +61,7 @@ class CasHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self._hostname = hs.hostname - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._auth_handler = hs.get_auth_handler() self._registration_handler = hs.get_registration_handler() diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 7a13d76a6..e4eae0305 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -29,7 +29,7 @@ class DeactivateAccountHandler: """Handler which deals with deactivating user accounts.""" def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.hs = hs self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 36c05f836..934b5bd73 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -63,7 +63,7 @@ class DeviceWorkerHandler: def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self.state = hs.get_state_handler() self.state_store = hs.get_storage().state @@ -628,7 +628,7 @@ class DeviceListUpdater: "Handles incoming device list updates from federation and updates the DB" def __init__(self, hs: "HomeServer", device_handler: DeviceHandler): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.federation = hs.get_federation_client() self.clock = hs.get_clock() self.device_handler = device_handler diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index b582266af..4cb725d02 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -43,7 +43,7 @@ class DeviceMessageHandler: Args: hs: server """ - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self.is_mine = hs.is_mine diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 082f52179..b7064c662 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -44,7 +44,7 @@ class DirectoryHandler: self.state = hs.get_state_handler() self.appservice_handler = hs.get_application_service_handler() self.event_creation_handler = hs.get_event_creation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.config = hs.config self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.require_membership = hs.config.server.require_membership_for_aliases diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index d4dfddf63..d96456cd4 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -47,7 +47,7 @@ logger = logging.getLogger(__name__) class E2eKeysHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.federation = hs.get_federation_client() self.device_handler = hs.get_device_handler() self.is_mine = hs.is_mine @@ -1335,7 +1335,7 @@ class SigningKeyEduUpdater: """Handles incoming signing key updates from federation and updates the DB""" def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.federation = hs.get_federation_client() self.clock = hs.get_clock() self.e2e_keys_handler = e2e_keys_handler diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 12614b2c5..52e44a2d4 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -45,7 +45,7 @@ class E2eRoomKeysHandler: """ def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main # Used to lock whenever a client is uploading key data. This prevents collisions # between clients trying to upload the details of a new session, given all diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index 365063ebd..d441ebb0a 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -43,7 +43,7 @@ class EventAuthHandler: def __init__(self, hs: "HomeServer"): self._clock = hs.get_clock() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._server_name = hs.hostname async def check_auth_rules_from_context( diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index bac5de052..97e75e60c 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -33,7 +33,7 @@ logger = logging.getLogger(__name__) class EventStreamHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.hs = hs @@ -134,7 +134,7 @@ class EventStreamHandler: class EventHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() async def get_event( diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index e9ac920bc..c055c26ec 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -107,7 +107,7 @@ class FederationHandler: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_store = self.storage.state self.federation_client = hs.get_federation_client() diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 7683246be..09d0de1ea 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -95,7 +95,7 @@ class FederationEventHandler: """ def __init__(self, hs: "HomeServer"): - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._storage = hs.get_storage() self._state_store = self._storage.state diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index 9e270d461..e7a399787 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -63,7 +63,7 @@ def _create_rerouter(func_name: str) -> Callable[..., Awaitable[JsonDict]]: class GroupsLocalWorkerHandler: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room_list_handler = hs.get_room_list_handler() self.groups_server_handler = hs.get_groups_server_handler() self.transport_client = hs.get_federation_transport_client() diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index c83eaea35..57c9fdfe6 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -49,7 +49,7 @@ id_server_scheme = "https://" class IdentityHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main # An HTTP client for contacting trusted URLs. self.http_client = SimpleHttpClient(hs) # An HTTP client for contacting identity servers specified by clients. diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 346a06ff4..344f20f37 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -46,7 +46,7 @@ logger = logging.getLogger(__name__) class InitialSyncHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.state_handler = hs.get_state_handler() self.hs = hs diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 4d0da8428..a9c964cd7 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -75,7 +75,7 @@ class MessageHandler: self.auth = hs.get_auth() self.clock = hs.get_clock() self.state = hs.get_state_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_store = self.storage.state self._event_serializer = hs.get_event_client_serializer() @@ -397,7 +397,7 @@ class EventCreationHandler: self.hs = hs self.auth = hs.get_auth() self._event_auth_handler = hs.get_event_auth_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state = hs.get_state_handler() self.clock = hs.get_clock() diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 8f71d975e..593a2aac6 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -273,7 +273,7 @@ class OidcProvider: token_generator: "OidcSessionTokenGenerator", provider: OidcProviderConfig, ): - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._token_generator = token_generator diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 973f26296..5c01a426f 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -127,7 +127,7 @@ class PaginationHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_store = self.storage.state self.clock = hs.get_clock() diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index b223b7262..c155098be 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -133,7 +133,7 @@ class BasePresenceHandler(abc.ABC): def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.presence_router = hs.get_presence_router() self.state = hs.get_state_handler() self.is_mine_id = hs.is_mine_id @@ -1541,7 +1541,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): self.get_presence_handler = hs.get_presence_handler self.get_presence_router = hs.get_presence_router self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def get_new_events( self, diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 36e3ad2ba..dd27f0acc 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -54,7 +54,7 @@ class ProfileHandler: PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.hs = hs diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index 58593e570..bad1acc63 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) class ReadMarkerHandler: def __init__(self, hs: "HomeServer"): self.server_name = hs.config.server.server_name - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.account_data_handler = hs.get_account_data_handler() self.read_marker_linearizer = Linearizer(name="read_marker") diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 5cb1ff749..b4132c353 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -29,7 +29,7 @@ class ReceiptsHandler: def __init__(self, hs: "HomeServer"): self.notifier = hs.get_notifier() self.server_name = hs.config.server.server_name - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.event_auth_handler = hs.get_event_auth_handler() self.hs = hs @@ -163,7 +163,7 @@ class ReceiptsHandler: class ReceiptEventSource(EventSource[int, JsonDict]): def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.config = hs.config @staticmethod diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 80320d2c0..05bb1e022 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -86,7 +86,7 @@ class LoginDict(TypedDict): class RegistrationHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.hs = hs self.auth = hs.get_auth() diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index a990727fc..7b965b4b9 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -105,7 +105,7 @@ class EventContext: class RoomCreationHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.clock = hs.get_clock() self.hs = hs @@ -1115,7 +1115,7 @@ class RoomContextHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_store = self.storage.state @@ -1246,7 +1246,7 @@ class RoomContextHandler: class TimestampLookupHandler: def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state_handler = hs.get_state_handler() self.federation_client = hs.get_federation_client() @@ -1386,7 +1386,7 @@ class TimestampLookupHandler: class RoomEventSource(EventSource[RoomStreamToken, EventBase]): def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def get_new_events( self, @@ -1476,7 +1476,7 @@ class RoomShutdownHandler: self._room_creation_handler = hs.get_room_creation_handler() self._replication = hs.get_replication_data_handler() self.event_creation_handler = hs.get_event_creation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def shutdown_room( self, diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index f8137ec04..abbf7b7b2 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) class RoomBatchHandler: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state_store = hs.get_storage().state self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 1a33211a1..f3577b5d5 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -49,7 +49,7 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None) class RoomListHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.hs = hs self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.response_cache: ResponseCache[ diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index b2adc0f48..a582837cf 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -66,7 +66,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.state_handler = hs.get_state_handler() self.config = hs.config diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index 4844b69a0..2e61d1cbe 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -90,7 +90,7 @@ class RoomSummaryHandler: def __init__(self, hs: "HomeServer"): self._event_auth_handler = hs.get_event_auth_handler() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._event_serializer = hs.get_event_client_serializer() self._server_name = hs.hostname self._federation_client = hs.get_federation_client() diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py index 727d75a50..9602f0d0b 100644 --- a/synapse/handlers/saml.py +++ b/synapse/handlers/saml.py @@ -52,7 +52,7 @@ class Saml2SessionData: class SamlHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.server_name = hs.hostname self._saml_client = Saml2Client(hs.config.saml2.saml2_sp_config) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 0e0e58de0..aa16e417e 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -49,7 +49,7 @@ class _SearchResult: class SearchHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state_handler = hs.get_state_handler() self.clock = hs.get_clock() self.hs = hs diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 706ad7276..73861bbd4 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -27,7 +27,7 @@ class SetPasswordHandler: """Handler which deals with changing user account passwords""" def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 0bb8b0929..ff5b5169c 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -180,7 +180,7 @@ class SsoHandler: def __init__(self, hs: "HomeServer"): self._clock = hs.get_clock() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._server_name = hs.hostname self._registration_handler = hs.get_registration_handler() self._auth_handler = hs.get_auth_handler() diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py index d30ba2b72..2d197282e 100644 --- a/synapse/handlers/state_deltas.py +++ b/synapse/handlers/state_deltas.py @@ -30,7 +30,7 @@ class MatchChange(Enum): class StateDeltasHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def _get_key_change( self, diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 29e41a4c7..436cd971c 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -39,7 +39,7 @@ class StatsHandler: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state = hs.get_state_handler() self.server_name = hs.hostname self.clock = hs.get_clock() diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index e6050cbce..98eaad331 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -266,7 +266,7 @@ class SyncResult: class SyncHandler: def __init__(self, hs: "HomeServer"): self.hs_config = hs.config - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self.presence_handler = hs.get_presence_handler() self.event_sources = hs.get_event_sources() diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index e4bed1c93..843c68eb0 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -57,7 +57,7 @@ class FollowerTypingHandler: """ def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.server_name = hs.config.server.server_name self.clock = hs.get_clock() self.is_mine_id = hs.is_mine_id @@ -446,7 +446,7 @@ class TypingWriterHandler(FollowerTypingHandler): class TypingNotificationEventSource(EventSource[int, JsonDict]): def __init__(self, hs: "HomeServer"): - self._main_store = hs.get_datastore() + self._main_store = hs.get_datastores().main self.clock = hs.get_clock() # We can't call get_typing_handler here because there's a cycle: # diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 184730ebe..014754a63 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -139,7 +139,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker): class _BaseThreepidAuthChecker: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def _check_threepid(self, medium: str, authdict: dict) -> dict: if "threepid_creds" not in authdict: @@ -255,7 +255,7 @@ class RegistrationTokenAuthChecker(UserInteractiveAuthChecker): super().__init__(hs) self.hs = hs self._enabled = bool(hs.config.registration.registration_requires_token) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def is_enabled(self) -> bool: return self._enabled diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 1565e034c..d27ed2be6 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -55,7 +55,7 @@ class UserDirectoryHandler(StateDeltasHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.server_name = hs.hostname self.clock = hs.get_clock() self.notifier = hs.get_notifier() diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index e7656fbb9..40bf1e06d 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -351,7 +351,7 @@ class MatrixFederationHttpClient: ) self.clock = hs.get_clock() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self.version_string_bytes = hs.version_string.encode("ascii") self.default_timeout = 60 diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 07020bfb8..902916d80 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -172,7 +172,9 @@ class ModuleApi: # TODO: Fix this type hint once the types for the data stores have been ironed # out. - self._store: Union[DataStore, "GenericWorkerSlavedStore"] = hs.get_datastore() + self._store: Union[ + DataStore, "GenericWorkerSlavedStore" + ] = hs.get_datastores().main self._auth = hs.get_auth() self._auth_handler = auth_handler self._server_name = hs.hostname @@ -926,7 +928,7 @@ class ModuleApi: ) # Try to retrieve the resulting event. - event = await self._hs.get_datastore().get_event(event_id) + event = await self._hs.get_datastores().main.get_event(event_id) # update_membership is supposed to always return after the event has been # successfully persisted. @@ -1270,7 +1272,7 @@ class PublicRoomListManager: """ def __init__(self, hs: "HomeServer"): - self._store = hs.get_datastore() + self._store = hs.get_datastores().main async def room_is_in_public_room_list(self, room_id: str) -> bool: """Checks whether a room is in the public room list. diff --git a/synapse/notifier.py b/synapse/notifier.py index 753dd6b6a..16d15a1f3 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -222,7 +222,7 @@ class Notifier: self.hs = hs self.storage = hs.get_storage() self.event_sources = hs.get_event_sources() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.pending_new_room_events: List[_PendingRoomEventEntry] = [] # Called when there are new things to stream over replication diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index 5176a1c18..a1b771109 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -68,7 +68,7 @@ class ThrottleParams: class Pusher(metaclass=abc.ABCMeta): def __init__(self, hs: "HomeServer", pusher_config: PusherConfig): self.hs = hs - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.clock = self.hs.get_clock() self.pusher_id = pusher_config.id diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index bee660893..fecf86034 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -103,7 +103,7 @@ class BulkPushRuleEvaluator: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._event_auth_handler = hs.get_event_auth_handler() # Used by `RulesForRoom` to ensure only one thing mutates the cache at a @@ -366,7 +366,7 @@ class RulesForRoom: """ self.room_id = room_id self.is_mine_id = hs.is_mine_id - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room_push_rule_cache_metrics = room_push_rule_cache_metrics # Used to ensure only one thing mutates the cache at a time. Keyed off diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 39bb2acae..1710dd51b 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -66,7 +66,7 @@ class EmailPusher(Pusher): super().__init__(hs, pusher_config) self.mailer = mailer - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.email = pusher_config.pushkey self.timed_call: Optional[IDelayedCall] = None self.throttle_params: Dict[str, ThrottleParams] = {} diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 52c7ff357..581834452 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -133,7 +133,7 @@ class HttpPusher(Pusher): # XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems # to be largely redundant. perhaps we can remove it. badge = await push_tools.get_badge_count( - self.hs.get_datastore(), + self.hs.get_datastores().main, self.user_id, group_by_room=self._group_unread_count_by_room, ) @@ -283,7 +283,7 @@ class HttpPusher(Pusher): tweaks = push_rule_evaluator.tweaks_for_actions(push_action.actions) badge = await push_tools.get_badge_count( - self.hs.get_datastore(), + self.hs.get_datastores().main, self.user_id, group_by_room=self._group_unread_count_by_room, ) diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 3df8452ee..649a4f49d 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -112,7 +112,7 @@ class Mailer: self.template_text = template_text self.send_email_handler = hs.get_send_email_handler() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.state_store = self.hs.get_storage().state self.macaroon_gen = self.hs.get_macaroon_generator() self.state_handler = self.hs.get_state_handler() diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 7912311d2..d0cc657b4 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -59,7 +59,7 @@ class PusherPool: def __init__(self, hs: "HomeServer"): self.hs = hs self.pusher_factory = PusherFactory(hs) - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.clock = self.hs.get_clock() # We shard the handling of push notifications by user ID. diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index f2f40129f..3d6364572 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -63,7 +63,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): super().__init__(hs) self.device_list_updater = hs.get_device_handler().device_list_updater - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() @staticmethod diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index d529c8a19..3e7300b4a 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -68,7 +68,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.clock = hs.get_clock() self.federation_event_handler = hs.get_federation_event_handler() @@ -167,7 +167,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.registry = hs.get_federation_registry() @@ -214,7 +214,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.registry = hs.get_federation_registry() @@ -260,7 +260,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main @staticmethod async def _serialize_payload(room_id: str) -> JsonDict: # type: ignore[override] @@ -297,7 +297,7 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main @staticmethod async def _serialize_payload(room_id: str, room_version: RoomVersion) -> JsonDict: # type: ignore[override] diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 0145858e4..663bff573 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -50,7 +50,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): super().__init__(hs) self.federation_handler = hs.get_federation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() @staticmethod @@ -119,7 +119,7 @@ class ReplicationRemoteKnockRestServlet(ReplicationEndpoint): super().__init__(hs) self.federation_handler = hs.get_federation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() @staticmethod @@ -188,7 +188,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.member_handler = hs.get_room_member_handler() @@ -258,7 +258,7 @@ class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.member_handler = hs.get_room_member_handler() @@ -325,7 +325,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): super().__init__(hs) self.registeration_handler = hs.get_registration_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.distributor = hs.get_distributor() diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index c7f751b70..6c8f8388f 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -36,7 +36,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.registration_handler = hs.get_registration_handler() @staticmethod @@ -112,7 +112,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.registration_handler = hs.get_registration_handler() @staticmethod diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index 33e98daf8..ce7817683 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -69,7 +69,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.clock = hs.get_clock() diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index d59ce7ccf..1b8479b0b 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -111,7 +111,7 @@ class ReplicationDataHandler: """ def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self._reactor = hs.get_reactor() self._clock = hs.get_clock() @@ -340,7 +340,7 @@ class FederationSenderHandler: def __init__(self, hs: "HomeServer"): assert hs.should_send_federation() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._is_mine_id = hs.is_mine_id self._hs = hs diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 17e157239..0d2013a3c 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -95,7 +95,7 @@ class ReplicationCommandHandler: def __init__(self, hs: "HomeServer"): self._replication_data_handler = hs.get_replication_data_handler() self._presence_handler = hs.get_presence_handler() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._notifier = hs.get_notifier() self._clock = hs.get_clock() self._instance_id = hs.get_instance_id() diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index ecd6190f5..494e42a2b 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -72,7 +72,7 @@ class ReplicationStreamer: """ def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.notifier = hs.get_notifier() self._instance_name = hs.get_instance_name() diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 914b9eae8..23d631a76 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -239,7 +239,7 @@ class BackfillStream(Stream): ROW_TYPE = BackfillStreamRow def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), self._current_token, @@ -267,7 +267,7 @@ class PresenceStream(Stream): ROW_TYPE = PresenceStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main if hs.get_instance_name() in hs.config.worker.writers.presence: # on the presence writer, query the presence handler @@ -355,7 +355,7 @@ class ReceiptsStream(Stream): ROW_TYPE = ReceiptsStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_max_receipt_stream_id), @@ -374,7 +374,7 @@ class PushRulesStream(Stream): ROW_TYPE = PushRulesStreamRow def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), @@ -401,7 +401,7 @@ class PushersStream(Stream): ROW_TYPE = PushersStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), @@ -434,7 +434,7 @@ class CachesStream(Stream): ROW_TYPE = CachesStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), store.get_cache_stream_token_for_writer, @@ -455,7 +455,7 @@ class DeviceListsStream(Stream): ROW_TYPE = DeviceListsStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_device_stream_token), @@ -474,7 +474,7 @@ class ToDeviceStream(Stream): ROW_TYPE = ToDeviceStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_to_device_stream_token), @@ -495,7 +495,7 @@ class TagAccountDataStream(Stream): ROW_TYPE = TagAccountDataStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_max_account_data_stream_id), @@ -516,7 +516,7 @@ class AccountDataStream(Stream): ROW_TYPE = AccountDataStreamRow def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(self.store.get_max_account_data_stream_id), @@ -585,7 +585,7 @@ class GroupServerStream(Stream): ROW_TYPE = GroupsStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_group_stream_token), @@ -604,7 +604,7 @@ class UserSignatureStream(Stream): ROW_TYPE = UserSignatureStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_device_stream_token), diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 50c4a5ba0..26f4fa7cf 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -124,7 +124,7 @@ class EventsStream(Stream): NAME = "events" def __init__(self, hs: "HomeServer"): - self._store = hs.get_datastore() + self._store = hs.get_datastores().main super().__init__( hs.get_instance_name(), self._store._stream_id_gen.get_current_token_for_writer, diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index ba0d989d8..6de302f81 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -116,7 +116,7 @@ class PurgeHistoryRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.pagination_handler = hs.get_pagination_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( diff --git a/synapse/rest/admin/background_updates.py b/synapse/rest/admin/background_updates.py index e9bce22a3..93a78db81 100644 --- a/synapse/rest/admin/background_updates.py +++ b/synapse/rest/admin/background_updates.py @@ -112,7 +112,7 @@ class BackgroundUpdateStartJobRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self._auth, request) diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index d9905ff56..cef46ba0d 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -44,7 +44,7 @@ class DeviceRestServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.is_mine = hs.is_mine async def on_GET( @@ -113,7 +113,7 @@ class DevicesRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.is_mine = hs.is_mine async def on_GET( @@ -144,7 +144,7 @@ class DeleteDevicesRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.is_mine = hs.is_mine async def on_POST( diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py index 38477f8ea..6d634eef7 100644 --- a/synapse/rest/admin/event_reports.py +++ b/synapse/rest/admin/event_reports.py @@ -53,7 +53,7 @@ class EventReportsRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) @@ -115,7 +115,7 @@ class EventReportDetailRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, report_id: str diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py index d162e0081..023ed9214 100644 --- a/synapse/rest/admin/federation.py +++ b/synapse/rest/admin/federation.py @@ -48,7 +48,7 @@ class ListDestinationsRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self._auth, request) @@ -105,7 +105,7 @@ class DestinationRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, destination: str @@ -165,7 +165,7 @@ class DestinationMembershipRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, destination: str @@ -221,7 +221,7 @@ class DestinationResetConnectionRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._authenticator = Authenticator(hs) async def on_POST( diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 299f5c9eb..8ca57bdb2 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -47,7 +47,7 @@ class QuarantineMediaInRoom(RestServlet): ] def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( @@ -74,7 +74,7 @@ class QuarantineMediaByUser(RestServlet): PATTERNS = admin_patterns("/user/(?P[^/]*)/media/quarantine$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( @@ -103,7 +103,7 @@ class QuarantineMediaByID(RestServlet): ) def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( @@ -132,7 +132,7 @@ class UnquarantineMediaByID(RestServlet): ) def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( @@ -156,7 +156,7 @@ class ProtectMediaByID(RestServlet): PATTERNS = admin_patterns("/media/protect/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( @@ -178,7 +178,7 @@ class UnprotectMediaByID(RestServlet): PATTERNS = admin_patterns("/media/unprotect/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( @@ -200,7 +200,7 @@ class ListMediaInRoom(RestServlet): PATTERNS = admin_patterns("/room/(?P[^/]*)/media$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_GET( @@ -251,7 +251,7 @@ class DeleteMediaByID(RestServlet): PATTERNS = admin_patterns("/media/(?P[^/]*)/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.server_name = hs.hostname self.media_repository = hs.get_media_repository() @@ -283,7 +283,7 @@ class DeleteMediaByDateSize(RestServlet): PATTERNS = admin_patterns("/media/(?P[^/]*)/delete$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.server_name = hs.hostname self.media_repository = hs.get_media_repository() @@ -352,7 +352,7 @@ class UserMediaRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.media_repository = hs.get_media_repository() async def on_GET( diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py index 04948b640..af606e925 100644 --- a/synapse/rest/admin/registration_tokens.py +++ b/synapse/rest/admin/registration_tokens.py @@ -71,7 +71,7 @@ class ListRegistrationTokensRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) @@ -109,7 +109,7 @@ class NewRegistrationTokenRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() # A string of all the characters allowed to be in a registration_token self.allowed_chars = string.ascii_letters + string.digits + "._~-" @@ -260,7 +260,7 @@ class RegistrationTokenRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]: """Retrieve a registration token.""" diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 5b706efbc..f4736a3da 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -65,7 +65,7 @@ class RoomRestV2Servlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._pagination_handler = hs.get_pagination_handler() async def on_DELETE( @@ -188,7 +188,7 @@ class ListRoomRestServlet(RestServlet): PATTERNS = admin_patterns("/rooms$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() @@ -278,7 +278,7 @@ class RoomRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room_shutdown_handler = hs.get_room_shutdown_handler() self.pagination_handler = hs.get_pagination_handler() @@ -382,7 +382,7 @@ class RoomMembersRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, room_id: str @@ -408,7 +408,7 @@ class RoomStateRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() @@ -525,7 +525,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): def __init__(self, hs: "HomeServer"): super().__init__(hs) self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.event_creation_handler = hs.get_event_creation_handler() self.state_handler = hs.get_state_handler() self.is_mine_id = hs.is_mine_id @@ -670,7 +670,7 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet): def __init__(self, hs: "HomeServer"): super().__init__(hs) self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_DELETE( self, request: SynapseRequest, room_identifier: str @@ -781,7 +781,7 @@ class BlockRoomRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, room_id: str diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py index 7a6546372..3b142b840 100644 --- a/synapse/rest/admin/statistics.py +++ b/synapse/rest/admin/statistics.py @@ -38,7 +38,7 @@ class UserMediaStatisticsRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index c2617ee30..8e29ada8a 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -66,7 +66,7 @@ class UsersRestServletV2(RestServlet): """ def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() @@ -156,7 +156,7 @@ class UserRestServletV2(RestServlet): self.hs = hs self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth_handler = hs.get_auth_handler() self.profile_handler = hs.get_profile_handler() self.set_password_handler = hs.get_set_password_handler() @@ -588,7 +588,7 @@ class DeactivateAccountRestServlet(RestServlet): self._deactivate_account_handler = hs.get_deactivate_account_handler() self.auth = hs.get_auth() self.is_mine = hs.is_mine - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_POST( self, request: SynapseRequest, target_user_id: str @@ -674,7 +674,7 @@ class ResetPasswordRestServlet(RestServlet): PATTERNS = admin_patterns("/reset_password/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self._set_password_handler = hs.get_set_password_handler() @@ -717,7 +717,7 @@ class SearchUsersRestServlet(RestServlet): PATTERNS = admin_patterns("/search_users/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.is_mine = hs.is_mine @@ -775,7 +775,7 @@ class UserAdminServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/admin$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.is_mine = hs.is_mine @@ -835,7 +835,7 @@ class UserMembershipRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, user_id: str @@ -864,7 +864,7 @@ class PushersRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_GET( @@ -905,7 +905,7 @@ class UserTokenRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/login$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self.is_mine_id = hs.is_mine_id @@ -974,7 +974,7 @@ class ShadowBanRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/shadow_ban$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.is_mine_id = hs.is_mine_id @@ -1026,7 +1026,7 @@ class RateLimitRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/override_ratelimit$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.is_mine_id = hs.is_mine_id @@ -1129,7 +1129,7 @@ class AccountDataRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._is_mine_id = hs.is_mine_id async def on_GET( diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index 5802de5b7..4b217882e 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -60,7 +60,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs - self.datastore = hs.get_datastore() + self.datastore = hs.get_datastores().main self.config = hs.config self.identity_handler = hs.get_identity_handler() @@ -114,7 +114,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # This avoids a potential account hijack by requesting a password reset to # an email address which is controlled by the attacker but which, after # canonicalisation, matches the one in our database. - existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( + existing_user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( "email", email ) @@ -168,7 +168,7 @@ class PasswordRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() - self.datastore = self.hs.get_datastore() + self.datastore = self.hs.get_datastores().main self.password_policy_handler = hs.get_password_policy_handler() self._set_password_handler = hs.get_set_password_handler() @@ -347,7 +347,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): self.hs = hs self.config = hs.config self.identity_handler = hs.get_identity_handler() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: self.mailer = Mailer( @@ -450,7 +450,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.hs = hs super().__init__() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.identity_handler = hs.get_identity_handler() async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: @@ -533,7 +533,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): super().__init__() self.config = hs.config self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: self._failure_email_template = ( self.config.email.email_add_threepid_template_failure_html @@ -600,7 +600,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): super().__init__() self.config = hs.config self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.identity_handler = hs.get_identity_handler() async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: @@ -634,7 +634,7 @@ class ThreepidRestServlet(RestServlet): self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() - self.datastore = self.hs.get_datastore() + self.datastore = self.hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) @@ -768,7 +768,7 @@ class ThreepidUnbindRestServlet(RestServlet): self.hs = hs self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() - self.datastore = self.hs.get_datastore() + self.datastore = self.hs.get_datastores().main async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: """Unbind the given 3pid from a specific identity server, or identity servers that are diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py index 58b8adbd3..bfe985939 100644 --- a/synapse/rest/client/account_data.py +++ b/synapse/rest/client/account_data.py @@ -42,7 +42,7 @@ class AccountDataServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = hs.get_account_data_handler() async def on_PUT( @@ -90,7 +90,7 @@ class RoomAccountDataServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = hs.get_account_data_handler() async def on_PUT( diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py index ee247e3d1..e181a0dde 100644 --- a/synapse/rest/client/directory.py +++ b/synapse/rest/client/directory.py @@ -47,7 +47,7 @@ class ClientDirectoryServer(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.directory_handler = hs.get_directory_handler() self.auth = hs.get_auth() @@ -129,7 +129,7 @@ class ClientDirectoryListServer(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.directory_handler = hs.get_directory_handler() self.auth = hs.get_auth() @@ -173,7 +173,7 @@ class ClientAppserviceDirectoryListServer(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.directory_handler = hs.get_directory_handler() self.auth = hs.get_auth() diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py index 672c82106..916f5230f 100644 --- a/synapse/rest/client/events.py +++ b/synapse/rest/client/events.py @@ -39,7 +39,7 @@ class EventStreamRestServlet(RestServlet): super().__init__() self.event_stream_handler = hs.get_event_stream_handler() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) diff --git a/synapse/rest/client/groups.py b/synapse/rest/client/groups.py index a7e9aa3e9..7e1149c7f 100644 --- a/synapse/rest/client/groups.py +++ b/synapse/rest/client/groups.py @@ -705,7 +705,7 @@ class GroupAdminUsersInviteServlet(RestServlet): self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.is_mine_id = hs.is_mine_id @_validate_group_id @@ -854,7 +854,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main @_validate_group_id async def on_PUT( @@ -879,7 +879,7 @@ class PublicisedGroupsForUserServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.groups_handler = hs.get_groups_local_handler() async def on_GET( @@ -901,7 +901,7 @@ class PublicisedGroupsForUsersServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.groups_handler = hs.get_groups_local_handler() async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: diff --git a/synapse/rest/client/initial_sync.py b/synapse/rest/client/initial_sync.py index 49b1037b2..cfadcb8e5 100644 --- a/synapse/rest/client/initial_sync.py +++ b/synapse/rest/client/initial_sync.py @@ -33,7 +33,7 @@ class InitialSyncRestServlet(RestServlet): super().__init__() self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index 730c18f08..ce806e3c1 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -198,7 +198,7 @@ class KeyChangesServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index f9994658c..c9d44c596 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -104,13 +104,13 @@ class LoginRestServlet(RestServlet): self._well_known_builder = WellKnownBuilder(hs) self._address_ratelimiter = Ratelimiter( - store=hs.get_datastore(), + store=hs.get_datastores().main, clock=hs.get_clock(), rate_hz=self.hs.config.ratelimiting.rc_login_address.per_second, burst_count=self.hs.config.ratelimiting.rc_login_address.burst_count, ) self._account_ratelimiter = Ratelimiter( - store=hs.get_datastore(), + store=hs.get_datastores().main, clock=hs.get_clock(), rate_hz=self.hs.config.ratelimiting.rc_login_account.per_second, burst_count=self.hs.config.ratelimiting.rc_login_account.burst_count, diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index 8e427a96a..20377a9ac 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -35,7 +35,7 @@ class NotificationsServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() diff --git a/synapse/rest/client/openid.py b/synapse/rest/client/openid.py index add56d699..820682ec4 100644 --- a/synapse/rest/client/openid.py +++ b/synapse/rest/client/openid.py @@ -67,7 +67,7 @@ class IdTokenServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.server_name = hs.config.server.server_name diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py index 8fe75bd75..a93f6fd5e 100644 --- a/synapse/rest/client/push_rule.py +++ b/synapse/rest/client/push_rule.py @@ -57,7 +57,7 @@ class PushRuleRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self._is_worker = hs.config.worker.worker_app is not None diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py index 98604a938..d6487c31d 100644 --- a/synapse/rest/client/pusher.py +++ b/synapse/rest/client/pusher.py @@ -46,7 +46,9 @@ class PushersRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) user = requester.user - pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) + pushers = await self.hs.get_datastores().main.get_pushers_by_user_id( + user.to_string() + ) filtered_pushers = [p.as_dict() for p in pushers] diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index b8a5135e0..70baf50fa 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -123,7 +123,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): request, "email", email ) - existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( + existing_user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( "email", email ) @@ -203,7 +203,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): request, "msisdn", msisdn ) - existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( + existing_user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( "msisdn", msisdn ) @@ -258,7 +258,7 @@ class RegistrationSubmitTokenServlet(RestServlet): self.auth = hs.get_auth() self.config = hs.config self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: self._failure_email_template = ( @@ -385,7 +385,7 @@ class RegistrationTokenValidityRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.ratelimiter = Ratelimiter( store=self.store, clock=hs.get_clock(), @@ -415,7 +415,7 @@ class RegisterRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self.identity_handler = hs.get_identity_handler() diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 2cab83c4e..487ea38b5 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -85,7 +85,7 @@ class RelationPaginationServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() self.event_handler = hs.get_event_handler() @@ -190,7 +190,7 @@ class RelationAggregationPaginationServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.event_handler = hs.get_event_handler() async def on_GET( @@ -282,7 +282,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() self.event_handler = hs.get_event_handler() diff --git a/synapse/rest/client/report_event.py b/synapse/rest/client/report_event.py index d4a4adb50..6e962a453 100644 --- a/synapse/rest/client/report_event.py +++ b/synapse/rest/client/report_event.py @@ -38,7 +38,7 @@ class ReportEventRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_POST( self, request: SynapseRequest, room_id: str, event_id: str diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 90355e44b..5ccfe5a92 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -477,7 +477,7 @@ class RoomMemberListRestServlet(RestServlet): super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, room_id: str @@ -553,7 +553,7 @@ class RoomMessageListRestServlet(RestServlet): self._hs = hs self.pagination_handler = hs.get_pagination_handler() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, room_id: str @@ -621,7 +621,7 @@ class RoomInitialSyncRestServlet(RestServlet): super().__init__() self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, room_id: str @@ -642,7 +642,7 @@ class RoomEventServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.clock = hs.get_clock() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self.event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() self.auth = hs.get_auth() @@ -1027,7 +1027,7 @@ class JoinedRoomsRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: @@ -1116,7 +1116,7 @@ class TimestampLookupRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self.timestamp_lookup_handler = hs.get_timestamp_lookup_handler() async def on_GET( diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py index 4b6be3832..0048973e5 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py @@ -75,7 +75,7 @@ class RoomBatchSendEventRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() self.room_batch_handler = hs.get_room_batch_handler() diff --git a/synapse/rest/client/shared_rooms.py b/synapse/rest/client/shared_rooms.py index 09a46737d..e669fa789 100644 --- a/synapse/rest/client/shared_rooms.py +++ b/synapse/rest/client/shared_rooms.py @@ -41,7 +41,7 @@ class UserSharedRoomsServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_directory_active = hs.config.server.update_user_directory async def on_GET( diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index f9615da52..f3018ff69 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -103,7 +103,7 @@ class SyncRestServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.sync_handler = hs.get_sync_handler() self.clock = hs.get_clock() self.filtering = hs.get_filtering() diff --git a/synapse/rest/client/tags.py b/synapse/rest/client/tags.py index c88cb9367..ca638755c 100644 --- a/synapse/rest/client/tags.py +++ b/synapse/rest/client/tags.py @@ -39,7 +39,7 @@ class TagListServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, user_id: str, room_id: str diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index 3d2afacc5..25f9ea285 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -78,7 +78,7 @@ class ConsentResource(DirectServeHtmlResource): super().__init__() self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.registration_handler = hs.get_registration_handler() # this is required by the request_handler wrapper diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 3923ba843..3525d6ae5 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -94,7 +94,7 @@ class RemoteKey(DirectServeJsonResource): super().__init__() self.fetcher = ServerKeyFetcher(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.federation_domain_whitelist = ( hs.config.federation.federation_domain_whitelist diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 71b9a34b1..6c414402b 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -75,7 +75,7 @@ class MediaRepository: self.client = hs.get_federation_http_client() self.clock = hs.get_clock() self.server_name = hs.hostname - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.max_upload_size = hs.config.media.max_upload_size self.max_image_pixels = hs.config.media.max_image_pixels diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index c08b60d10..14ea88b24 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -134,7 +134,7 @@ class PreviewUrlResource(DirectServeJsonResource): self.filepaths = media_repo.filepaths self.max_spider_size = hs.config.media.max_spider_size self.server_name = hs.hostname - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.client = SimpleHttpClient( hs, treq_args={"browser_like_redirects": True}, diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index ed91ef5a4..53b156524 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -50,7 +50,7 @@ class ThumbnailResource(DirectServeJsonResource): ): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.media_repo = media_repo self.media_storage = media_storage self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index fde28d08c..e73e431dc 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -37,7 +37,7 @@ class UploadResource(DirectServeJsonResource): self.media_repo = media_repo self.filepaths = media_repo.filepaths - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.server_name = hs.hostname self.auth = hs.get_auth() diff --git a/synapse/rest/synapse/client/password_reset.py b/synapse/rest/synapse/client/password_reset.py index 28a67f04e..6ac9dbc7c 100644 --- a/synapse/rest/synapse/client/password_reset.py +++ b/synapse/rest/synapse/client/password_reset.py @@ -44,7 +44,7 @@ class PasswordResetSubmitTokenResource(DirectServeHtmlResource): super().__init__() self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._local_threepid_handling_disabled_due_to_email_config = ( hs.config.email.local_threepid_handling_disabled_due_to_email_config diff --git a/synapse/server.py b/synapse/server.py index 4c07f2101..b5e2a319b 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -17,7 +17,7 @@ # homeservers; either as a full homeserver as a real application, or a small # partial one for unit test mocking. -# Imports required for the default HomeServer() implementation + import abc import functools import logging @@ -134,7 +134,7 @@ from synapse.server_notices.worker_server_notices_sender import ( WorkerServerNoticesSender, ) from synapse.state import StateHandler, StateResolutionHandler -from synapse.storage import Databases, DataStore, Storage +from synapse.storage import Databases, Storage from synapse.streams.events import EventSources from synapse.types import DomainSpecificString, ISynapseReactor from synapse.util import Clock @@ -225,7 +225,7 @@ class HomeServer(metaclass=abc.ABCMeta): # This is overridden in derived application classes # (such as synapse.app.homeserver.SynapseHomeServer) and gives the class to be - # instantiated during setup() for future return by get_datastore() + # instantiated during setup() for future return by get_datastores() DATASTORE_CLASS = abc.abstractproperty() tls_server_context_factory: Optional[IOpenSSLContextFactory] @@ -355,12 +355,6 @@ class HomeServer(metaclass=abc.ABCMeta): def get_clock(self) -> Clock: return Clock(self._reactor) - def get_datastore(self) -> DataStore: - if not self.datastores: - raise Exception("HomeServer.setup must be called before getting datastores") - - return self.datastores.main - def get_datastores(self) -> Databases: if not self.datastores: raise Exception("HomeServer.setup must be called before getting datastores") @@ -374,7 +368,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_registration_ratelimiter(self) -> Ratelimiter: return Ratelimiter( - store=self.get_datastore(), + store=self.get_datastores().main, clock=self.get_clock(), rate_hz=self.config.ratelimiting.rc_registration.per_second, burst_count=self.config.ratelimiting.rc_registration.burst_count, @@ -847,7 +841,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_request_ratelimiter(self) -> RequestRatelimiter: return RequestRatelimiter( - self.get_datastore(), + self.get_datastores().main, self.get_clock(), self.config.ratelimiting.rc_message, self.config.ratelimiting.rc_admin_redaction, diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index e09a25591..698ca742e 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -32,7 +32,7 @@ class ConsentServerNotices: def __init__(self, hs: "HomeServer"): self._server_notices_manager = hs.get_server_notices_manager() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._users_in_progress: Set[str] = set() diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index 8522930b5..015dd08f0 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -36,7 +36,7 @@ class ResourceLimitsServerNotices: def __init__(self, hs: "HomeServer"): self._server_notices_manager = hs.get_server_notices_manager() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._auth = hs.get_auth() self._config = hs.config self._resouce_limited = False diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index 0cf60236f..7b4814e04 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -29,7 +29,7 @@ SERVER_NOTICE_ROOM_TAG = "m.server_notice" class ServerNoticesManager: def __init__(self, hs: "HomeServer"): - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._config = hs.config self._account_data_handler = hs.get_account_data_handler() self._room_creation_handler = hs.get_room_creation_handler() diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 67e8bc6ec..fcc24ad12 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -126,7 +126,7 @@ class StateHandler: def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state_store = hs.get_storage().state self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index 8f09dd8e8..e9a0cdc6b 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -112,7 +112,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): for tp in self.hs.config.server.mau_limits_reserved_threepids[ : self.hs.config.server.max_mau_value ]: - user_id = await self.hs.get_datastore().get_user_id_by_threepid( + user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( tp["medium"], canonicalise_email(tp["address"]) ) if user_id: diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 4ec2a713c..fb8fe1729 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -48,7 +48,7 @@ class EventSources: # all the attributes of `_EventSourcesInner` are annotated. *(attribute.type(hs) for attribute in attr.fields(_EventSourcesInner)) # type: ignore[misc] ) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def get_current_token(self) -> StreamToken: push_rules_key = self.store.get_max_push_rules_stream_id() diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 4b53b6d40..686d17c0d 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -16,6 +16,8 @@ from unittest.mock import Mock import pymacaroons +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.auth import Auth from synapse.api.constants import UserTypes from synapse.api.errors import ( @@ -26,8 +28,10 @@ from synapse.api.errors import ( ResourceLimitError, ) from synapse.appservice import ApplicationService +from synapse.server import HomeServer from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import Requester +from synapse.util import Clock from tests import unittest from tests.test_utils import simple_async_mock @@ -36,10 +40,10 @@ from tests.utils import mock_getRawHeaders class AuthTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): self.store = Mock() - hs.get_datastore = Mock(return_value=self.store) + hs.datastores.main = self.store hs.get_auth_handler().store = self.store self.auth = Auth(hs) diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index b7fc33dc9..973f0f7fa 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -40,7 +40,7 @@ def MockEvent(**kwargs): class FilteringTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.filtering = hs.get_filtering() - self.datastore = hs.get_datastore() + self.datastore = hs.get_datastores().main def test_errors_on_invalid_filters(self): invalid_filters = [ diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index dcf0110c1..4ef754a18 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -8,7 +8,7 @@ from tests import unittest class TestRatelimiter(unittest.HomeserverTestCase): def test_allowed_via_can_do_action(self): limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", _time_now_s=0) @@ -39,7 +39,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): as_requester = create_requester("@user:example.com", app_service=appservice) limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=0) @@ -70,7 +70,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): as_requester = create_requester("@user:example.com", app_service=appservice) limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=0) @@ -92,7 +92,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): def test_allowed_via_ratelimit(self): limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) # Shouldn't raise @@ -116,7 +116,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): """ # Create a Ratelimiter with a very low allowed rate_hz and burst_count limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) # First attempt should be allowed @@ -162,7 +162,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): """ # Create a Ratelimiter with a very low allowed rate_hz and burst_count limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) # First attempt should be allowed @@ -190,7 +190,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): def test_pruning(self): limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) self.get_success_or_raise( limiter.can_do_action(None, key="test_id_1", _time_now_s=0) @@ -208,7 +208,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): """Test that users that have ratelimiting disabled in the DB aren't ratelimited. """ - store = self.hs.get_datastore() + store = self.hs.get_datastores().main user_id = "@user:test" requester = create_requester(user_id) @@ -233,7 +233,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): def test_multiple_actions(self): limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=3 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3 ) # Test that 4 actions aren't allowed with a maximum burst of 3. allowed, time_allowed = self.get_success_or_raise( diff --git a/tests/app/test_phone_stats_home.py b/tests/app/test_phone_stats_home.py index 19eb4c79d..df731eb59 100644 --- a/tests/app/test_phone_stats_home.py +++ b/tests/app/test_phone_stats_home.py @@ -32,7 +32,7 @@ class PhoneHomeTestCase(HomeserverTestCase): self.helper.send(room_id, "message", tok=access_token) # Check the R30 results do not count that user. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) # Advance 30 days (+ 1 second, because strict inequality causes issues if we are @@ -40,7 +40,7 @@ class PhoneHomeTestCase(HomeserverTestCase): self.reactor.advance(30 * ONE_DAY_IN_SECONDS + 1) # (Make sure the user isn't somehow counted by this point.) - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) # Send a message (this counts as activity) @@ -51,21 +51,21 @@ class PhoneHomeTestCase(HomeserverTestCase): self.reactor.advance(2 * 60 * 60) # *Now* the user is counted. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 1, "unknown": 1}) # Advance 29 days. The user has now not posted for 29 days. self.reactor.advance(29 * ONE_DAY_IN_SECONDS) # The user is still counted. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 1, "unknown": 1}) # Advance another day. The user has now not posted for 30 days. self.reactor.advance(ONE_DAY_IN_SECONDS) # The user is now no longer counted in R30. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) def test_r30_minimum_usage_using_default_config(self): @@ -84,7 +84,7 @@ class PhoneHomeTestCase(HomeserverTestCase): self.helper.send(room_id, "message", tok=access_token) # Check the R30 results do not count that user. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) # Advance 30 days (+ 1 second, because strict inequality causes issues if we are @@ -92,7 +92,7 @@ class PhoneHomeTestCase(HomeserverTestCase): self.reactor.advance(30 * ONE_DAY_IN_SECONDS + 1) # (Make sure the user isn't somehow counted by this point.) - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) # Send a message (this counts as activity) @@ -103,14 +103,14 @@ class PhoneHomeTestCase(HomeserverTestCase): self.reactor.advance(2 * 60 * 60) # *Now* the user is counted. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 1, "unknown": 1}) # Advance 27 days. The user has now not posted for 27 days. self.reactor.advance(27 * ONE_DAY_IN_SECONDS) # The user is still counted. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 1, "unknown": 1}) # Advance another day. The user has now not posted for 28 days. @@ -119,7 +119,7 @@ class PhoneHomeTestCase(HomeserverTestCase): # The user is now no longer counted in R30. # (This is because the user_ips table has been pruned, which by default # only preserves the last 28 days of entries.) - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) def test_r30_user_must_be_retained_for_at_least_a_month(self): @@ -135,7 +135,7 @@ class PhoneHomeTestCase(HomeserverTestCase): self.helper.send(room_id, "message", tok=access_token) # Check the user does not contribute to R30 yet. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) for _ in range(30): @@ -144,14 +144,16 @@ class PhoneHomeTestCase(HomeserverTestCase): self.helper.send(room_id, "I'm still here", tok=access_token) # Notice that the user *still* does not contribute to R30! - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success( + self.hs.get_datastores().main.count_r30_users() + ) self.assertEqual(r30_results, {"all": 0}) self.reactor.advance(ONE_DAY_IN_SECONDS) self.helper.send(room_id, "Still here!", tok=access_token) # *Now* the user appears in R30. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 1, "unknown": 1}) @@ -196,7 +198,7 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase): # (user_daily_visits is updated every 5 minutes using a looping call.) self.reactor.advance(FIVE_MINUTES_IN_SECONDS) - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Check the R30 results do not count that user. r30_results = self.get_success(store.count_r30v2_users()) @@ -275,7 +277,7 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase): # (user_daily_visits is updated every 5 minutes using a looping call.) self.reactor.advance(FIVE_MINUTES_IN_SECONDS) - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Check the user does not contribute to R30 yet. r30_results = self.get_success(store.count_r30v2_users()) @@ -347,7 +349,7 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase): # (user_daily_visits is updated every 5 minutes using a looping call.) self.reactor.advance(FIVE_MINUTES_IN_SECONDS) - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Check that the user does not contribute to R30v2, even though it's been # more than 30 days since registration. diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 17a9fb63a..3a4d50271 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -179,7 +179,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): kr = keyring.Keyring(self.hs) key1 = signedjson.key.generate_signing_key(1) - r = self.hs.get_datastore().store_server_verify_keys( + r = self.hs.get_datastores().main.store_server_verify_keys( "server9", time.time() * 1000, [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))], @@ -272,7 +272,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): ) key1 = signedjson.key.generate_signing_key(1) - r = self.hs.get_datastore().store_server_verify_keys( + r = self.hs.get_datastores().main.store_server_verify_keys( "server9", time.time() * 1000, [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), None))], @@ -448,7 +448,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): # check that the perspectives store is correctly updated lookup_triplet = (SERVER_NAME, testverifykey_id, None) key_json = self.get_success( - self.hs.get_datastore().get_server_keys_json([lookup_triplet]) + self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) ) res = key_json[lookup_triplet] self.assertEqual(len(res), 1) @@ -564,7 +564,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): # check that the perspectives store is correctly updated lookup_triplet = (SERVER_NAME, testverifykey_id, None) key_json = self.get_success( - self.hs.get_datastore().get_server_keys_json([lookup_triplet]) + self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) ) res = key_json[lookup_triplet] self.assertEqual(len(res), 1) @@ -683,7 +683,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): # check that the perspectives store is correctly updated lookup_triplet = (SERVER_NAME, testverifykey_id, None) key_json = self.get_success( - self.hs.get_datastore().get_server_keys_json([lookup_triplet]) + self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) ) res = key_json[lookup_triplet] self.assertEqual(len(res), 1) diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py index ca27388ae..defbc68c1 100644 --- a/tests/events/test_snapshot.py +++ b/tests/events/test_snapshot.py @@ -28,7 +28,7 @@ class TestEventContext(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.user_id = self.register_user("u1", "pass") diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index e40ef9587..9336181c9 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -55,7 +55,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): self.assertTrue(complexity > 0, complexity) # Artificially raise the complexity - store = self.hs.get_datastore() + store = self.hs.get_datastores().main store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23) # Get the room complexity again -- make sure it's our artificial value @@ -149,7 +149,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): ) # Artificially raise the complexity - self.hs.get_datastore().get_current_state_event_counts = ( + self.hs.get_datastores().main.get_current_state_event_counts = ( lambda x: make_awaitable(600) ) diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index f0aa8ed9d..2873b4d43 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -64,7 +64,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): Dictionary of { event_id: str, stream_ordering: int } """ event_id, stream_ordering = self.get_success( - self.hs.get_datastore().db_pool.execute( + self.hs.get_datastores().main.db_pool.execute( "test:get_destination_rooms", None, """ @@ -125,7 +125,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): self.pump() lsso_1 = self.get_success( - self.hs.get_datastore().get_destination_last_successful_stream_ordering( + self.hs.get_datastores().main.get_destination_last_successful_stream_ordering( "host2" ) ) @@ -141,7 +141,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): event_id_2 = self.helper.send(room, "rabbits!", tok=u1_token)["event_id"] lsso_2 = self.get_success( - self.hs.get_datastore().get_destination_last_successful_stream_ordering( + self.hs.get_datastores().main.get_destination_last_successful_stream_ordering( "host2" ) ) @@ -216,7 +216,9 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): # let's also clear any backoffs self.get_success( - self.hs.get_datastore().set_destination_retry_timings("host2", None, 0, 0) + self.hs.get_datastores().main.set_destination_retry_timings( + "host2", None, 0, 0 + ) ) # bring the remote online and clear the received pdu list @@ -296,13 +298,13 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): # destination_rooms should already be populated, but let us pretend that we already # sent (successfully) up to and including event id 2 - event_2 = self.get_success(self.hs.get_datastore().get_event(event_id_2)) + event_2 = self.get_success(self.hs.get_datastores().main.get_event(event_id_2)) # also fetch event 5 so we know its last_successful_stream_ordering later - event_5 = self.get_success(self.hs.get_datastore().get_event(event_id_5)) + event_5 = self.get_success(self.hs.get_datastores().main.get_event(event_id_5)) self.get_success( - self.hs.get_datastore().set_destination_last_successful_stream_ordering( + self.hs.get_datastores().main.set_destination_last_successful_stream_ordering( "host2", event_2.internal_metadata.stream_ordering ) ) @@ -359,7 +361,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): # ASSERT: # - All servers are up to date so none should have outstanding catch-up outstanding_when_successful = self.get_success( - self.hs.get_datastore().get_catch_up_outstanding_destinations(None) + self.hs.get_datastores().main.get_catch_up_outstanding_destinations(None) ) self.assertEqual(outstanding_when_successful, []) @@ -370,7 +372,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): # - Mark zzzerver as being backed-off from now = self.clock.time_msec() self.get_success( - self.hs.get_datastore().set_destination_retry_timings( + self.hs.get_datastores().main.set_destination_retry_timings( "zzzerver", now, now, 24 * 60 * 60 * 1000 # retry in 1 day ) ) @@ -382,14 +384,14 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): # - all remotes are outstanding # - they are returned in batches of 25, in order outstanding_1 = self.get_success( - self.hs.get_datastore().get_catch_up_outstanding_destinations(None) + self.hs.get_datastores().main.get_catch_up_outstanding_destinations(None) ) self.assertEqual(len(outstanding_1), 25) self.assertEqual(outstanding_1, server_names[0:25]) outstanding_2 = self.get_success( - self.hs.get_datastore().get_catch_up_outstanding_destinations( + self.hs.get_datastores().main.get_catch_up_outstanding_destinations( outstanding_1[-1] ) ) @@ -457,7 +459,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): ) self.get_success( - self.hs.get_datastore().set_destination_last_successful_stream_ordering( + self.hs.get_datastores().main.set_destination_last_successful_stream_ordering( "host2", event_1.internal_metadata.stream_ordering ) ) diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index b2376e2db..60e0c31f4 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -176,7 +176,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): def get_users_who_share_room_with_user(user_id): return defer.succeed({"@user2:host2"}) - hs.get_datastore().get_users_who_share_room_with_user = ( + hs.get_datastores().main.get_users_who_share_room_with_user = ( get_users_who_share_room_with_user ) @@ -395,7 +395,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # run the prune job self.reactor.advance(10) self.get_success( - self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1) + self.hs.get_datastores().main._prune_old_outbound_device_pokes(prune_age=1) ) # recover the server @@ -445,7 +445,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # run the prune job self.reactor.advance(10) self.get_success( - self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1) + self.hs.get_datastores().main._prune_old_outbound_device_pokes(prune_age=1) ) # recover the server diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index 686f42ab4..adf0535d9 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -198,7 +198,7 @@ class FederationKnockingTestCase( ] def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main # We're not going to be properly signing events as our remote homeserver is fake, # therefore disable event signature checks. diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index fe57ff267..9918ff680 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -38,7 +38,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_as_api = Mock() self.mock_scheduler = Mock() hs = Mock() - hs.get_datastore.return_value = self.mock_store + hs.get_datastores.return_value = Mock(main=self.mock_store) self.mock_store.get_received_ts.return_value = make_awaitable(0) self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None) self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable( @@ -355,7 +355,9 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): # Mock out application services, and allow defining our own in tests self._services: List[ApplicationService] = [] - self.hs.get_datastore().get_app_services = Mock(return_value=self._services) + self.hs.get_datastores().main.get_app_services = Mock( + return_value=self._services + ) # A user on the homeserver. self.local_user_device_id = "local_device" @@ -494,7 +496,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): # Create a fake device per message. We can't send to-device messages to # a device that doesn't exist. self.get_success( - self.hs.get_datastore().db_pool.simple_insert_many( + self.hs.get_datastores().main.db_pool.simple_insert_many( desc="test_application_services_receive_burst_of_to_device", table="devices", keys=("user_id", "device_id"), @@ -510,7 +512,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): # Seed the device_inbox table with our fake messages self.get_success( - self.hs.get_datastore().add_messages_to_device_inbox(messages, {}) + self.hs.get_datastores().main.add_messages_to_device_inbox(messages, {}) ) # Now have local_user send a final to-device message to exclusive_as_user. All unsent diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 03b8b8615..0c6e55e72 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -129,7 +129,7 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_mau_limits_exceeded_large(self): self.auth_blocking._limit_usage_by_mau = True - self.hs.get_datastore().get_monthly_active_count = Mock( + self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.large_number_of_users) ) @@ -140,7 +140,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ResourceLimitError, ) - self.hs.get_datastore().get_monthly_active_count = Mock( + self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.large_number_of_users) ) self.get_failure( @@ -156,7 +156,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.auth_blocking._limit_usage_by_mau = True # Set the server to be at the edge of too many users. - self.hs.get_datastore().get_monthly_active_count = Mock( + self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.auth_blocking._max_mau_value) ) @@ -175,7 +175,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ) # If in monthly active cohort - self.hs.get_datastore().user_last_seen_monthly_active = Mock( + self.hs.get_datastores().main.user_last_seen_monthly_active = Mock( return_value=make_awaitable(self.clock.time_msec()) ) self.get_success( @@ -192,7 +192,7 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_mau_limits_not_exceeded(self): self.auth_blocking._limit_usage_by_mau = True - self.hs.get_datastore().get_monthly_active_count = Mock( + self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.small_number_of_users) ) # Ensure does not raise exception @@ -202,7 +202,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ) ) - self.hs.get_datastore().get_monthly_active_count = Mock( + self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.small_number_of_users) ) self.get_success( diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index 8705ff894..a26722884 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -77,7 +77,7 @@ class CasHandlerTestCase(HomeserverTestCase): def test_map_cas_user_to_existing_user(self): """Existing users can log in with CAS account.""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) ) diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py index 01096a158..ddda36c5a 100644 --- a/tests/handlers/test_deactivate_account.py +++ b/tests/handlers/test_deactivate_account.py @@ -34,7 +34,7 @@ class DeactivateAccountTestCase(HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self.user = self.register_user("user", "pass") self.token = self.login("user", "pass") diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 43031e07e..683677fd0 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -28,7 +28,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver("server", federation_http_client=None) self.handler = hs.get_device_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main return hs def prepare(self, reactor, clock, hs): @@ -263,7 +263,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase): self.handler = hs.get_device_handler() self.registration = hs.get_registration_handler() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main return hs def test_dehydrate_and_rehydrate_device(self): diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 0ea4e753e..65ab107d0 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -46,7 +46,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.handler = hs.get_directory_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.my_room = RoomAlias.from_string("#my-room:test") self.your_room = RoomAlias.from_string("#your-room:test") @@ -174,7 +174,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = hs.get_directory_handler() self.state_handler = hs.get_state_handler() @@ -289,7 +289,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = hs.get_directory_handler() self.state_handler = hs.get_state_handler() diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 734ed84d7..9338ab92e 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -34,7 +34,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.handler = hs.get_e2e_keys_handler() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main def test_query_local_devices_no_devices(self): """If the user has no devices, we expect an empty list.""" diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 496b58172..e8b4e39d1 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -45,7 +45,7 @@ class FederationTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver(federation_http_client=None) self.handler = hs.get_federation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state_store = hs.get_storage().state self._event_auth_handler = hs.get_event_auth_handler() return hs diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index 5816295d8..f4f7ab484 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -44,7 +44,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token) self.info = self.get_success( - self.hs.get_datastore().get_user_by_access_token( + self.hs.get_datastores().main.get_user_by_access_token( self.access_token, ) ) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index a552d8182..e8418b663 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -856,7 +856,7 @@ class OidcHandlerTestCase(HomeserverTestCase): auth_handler.complete_sso_login.reset_mock() # Test if the mxid is already taken - store = self.hs.get_datastore() + store = self.hs.get_datastores().main user3 = UserID.from_string("@test_user_3:test") self.get_success( store.register_user(user_id=user3.to_string(), password_hash=None) @@ -872,7 +872,7 @@ class OidcHandlerTestCase(HomeserverTestCase): @override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}}) def test_map_userinfo_to_existing_user(self): """Existing users can log in with OpenID Connect when allow_existing_users is True.""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main user = UserID.from_string("@test_user:test") self.get_success( store.register_user(user_id=user.to_string(), password_hash=None) @@ -996,7 +996,7 @@ class OidcHandlerTestCase(HomeserverTestCase): auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) ) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 671dc7d08..61d28603a 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -43,7 +43,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): servlets = [admin.register_servlets] def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main def test_offline_to_online(self): wheel_timer = Mock() @@ -891,7 +891,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): # self.event_builder_for_2 = EventBuilderFactory(hs) # self.event_builder_for_2.hostname = "test2" - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state = hs.get_state_handler() self._event_auth_handler = hs.get_event_auth_handler() diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 60235e569..69e299fc1 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -48,7 +48,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): return hs def prepare(self, reactor, clock, hs: HomeServer): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.frank = UserID.from_string("@1234abcd:test") self.bob = UserID.from_string("@4567:test") @@ -325,7 +325,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): properties are "mimetype" (for the file's type) and "size" (for the file's size). """ - store = self.hs.get_datastore() + store = self.hs.get_datastores().main for name, props in names_and_props.items(): self.get_success( diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index cd6f2c77a..51ee667ab 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -154,7 +154,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.handler = self.hs.get_registration_handler() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.lots_of_users = 100 self.small_number_of_users = 1 @@ -172,7 +172,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertGreater(len(result_token), 20) def test_if_user_exists(self): - store = self.hs.get_datastore() + store = self.hs.get_datastores().main frank = UserID.from_string("@frank:test") self.get_success( store.register_user(user_id=frank.to_string(), password_hash=None) @@ -760,7 +760,7 @@ class RemoteAutoJoinTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.handler = self.hs.get_registration_handler() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main @override_config({"auto_join_rooms": ["#room:remotetest"]}) def test_auto_create_auto_join_remote_room(self): diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 50551aa6e..23941abed 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -142,7 +142,7 @@ class SamlHandlerTestCase(HomeserverTestCase): @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) def test_map_saml_response_to_existing_user(self): """Existing users can log in with SAML account.""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) ) @@ -217,7 +217,7 @@ class SamlHandlerTestCase(HomeserverTestCase): sso_handler.render_error = Mock(return_value=None) # register a user to occupy the first-choice MXID - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) ) diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 56207f4db..ecd78fa36 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -33,7 +33,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = self.hs.get_stats_handler() def _add_background_updates(self): diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 07a760e91..66b0bd4d1 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -41,7 +41,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs: HomeServer): self.sync_handler = self.hs.get_sync_handler() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main # AuthBlocking reads from the hs' config on initialization. We need to # modify its config instead of the hs' @@ -248,7 +248,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): # the prev_events used when creating the join event, such that the ban does not # precede the join. mocked_get_prev_events = patch.object( - self.hs.get_datastore(), + self.hs.get_datastores().main, "get_prev_events_for_room", new_callable=MagicMock, return_value=make_awaitable([last_room_creation_event_id]), diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 000f9b9fd..e461e0359 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -91,7 +91,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.event_source = hs.get_event_sources().sources.typing - self.datastore = hs.get_datastore() + self.datastore = hs.get_datastores().main self.datastore.get_destination_retry_timings = Mock( return_value=defer.succeed(None) ) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 482c90ef6..e159169e2 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -77,7 +77,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = hs.get_user_directory_handler() self.event_builder_factory = self.hs.get_event_builder_factory() self.event_creation_handler = self.hs.get_event_creation_handler() diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index d16cd141a..c3f20f969 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -41,7 +41,7 @@ class ModuleApiTestCase(HomeserverTestCase): ] def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main self.module_api = homeserver.get_module_api() self.event_creation_handler = homeserver.get_event_creation_handler() self.sync_handler = homeserver.get_sync_handler() diff --git a/tests/push/test_email.py b/tests/push/test_email.py index f8cba7b64..7a3b0d675 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -102,13 +102,13 @@ class EmailPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(self.access_token) + self.hs.get_datastores().main.get_user_by_access_token(self.access_token) ) self.token_id = user_tuple.token_id # We need to add email to account before we can create a pusher. self.get_success( - hs.get_datastore().user_add_threepid( + hs.get_datastores().main.user_add_threepid( self.user_id, "email", "a@example.com", 0, 0 ) ) @@ -128,7 +128,7 @@ class EmailPusherTests(HomeserverTestCase): ) self.auth_handler = hs.get_auth_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def test_need_validated_email(self): """Test that we can only add an email pusher if the user has validated @@ -375,7 +375,7 @@ class EmailPusherTests(HomeserverTestCase): # check that the pusher for that email address has been deleted pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 0) @@ -388,14 +388,14 @@ class EmailPusherTests(HomeserverTestCase): # This resembles the old behaviour, which the background update below is intended # to clean up. self.get_success( - self.hs.get_datastore().user_delete_threepid( + self.hs.get_datastores().main.user_delete_threepid( self.user_id, "email", "a@example.com" ) ) # Run the "remove_deleted_email_pushers" background job self.get_success( - self.hs.get_datastore().db_pool.simple_insert( + self.hs.get_datastores().main.db_pool.simple_insert( table="background_updates", values={ "update_name": "remove_deleted_email_pushers", @@ -406,14 +406,14 @@ class EmailPusherTests(HomeserverTestCase): ) # ... and tell the DataStore that it hasn't finished all updates yet - self.hs.get_datastore().db_pool.updates._all_done = False + self.hs.get_datastores().main.db_pool.updates._all_done = False # Now let's actually drive the updates to completion self.wait_for_background_updates() # Check that all pushers with unlinked addresses were deleted pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 0) @@ -428,7 +428,7 @@ class EmailPusherTests(HomeserverTestCase): """ # Get the stream ordering before it gets sent pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -439,7 +439,7 @@ class EmailPusherTests(HomeserverTestCase): # It hasn't succeeded yet, so the stream ordering shouldn't have moved pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -458,7 +458,7 @@ class EmailPusherTests(HomeserverTestCase): # The stream ordering has increased pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) diff --git a/tests/push/test_http.py b/tests/push/test_http.py index e1e3fb97c..c284beb37 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -62,7 +62,7 @@ class HTTPPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -108,7 +108,7 @@ class HTTPPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -138,7 +138,7 @@ class HTTPPusherTests(HomeserverTestCase): # Get the stream ordering before it gets sent pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -149,7 +149,7 @@ class HTTPPusherTests(HomeserverTestCase): # It hasn't succeeded yet, so the stream ordering shouldn't have moved pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -170,7 +170,7 @@ class HTTPPusherTests(HomeserverTestCase): # The stream ordering has increased pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -192,7 +192,7 @@ class HTTPPusherTests(HomeserverTestCase): # The stream ordering has increased, again pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -224,7 +224,7 @@ class HTTPPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -344,7 +344,7 @@ class HTTPPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -430,7 +430,7 @@ class HTTPPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -507,7 +507,7 @@ class HTTPPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -613,7 +613,7 @@ class HTTPPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 9fc50f885..a7a05a564 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -68,7 +68,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): # Since we use sqlite in memory databases we need to make sure the # databases objects are the same. - self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool + self.worker_hs.get_datastores().main.db_pool = hs.get_datastores().main.db_pool # Normally we'd pass in the handler to `setup_test_homeserver`, which would # eventually hit "Install @cache_in_self attributes" in tests/utils.py. @@ -233,7 +233,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): # We may have an attempt to connect to redis for the external cache already. self.connect_any_redis_attempts() - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.database_pool = store.db_pool self.reactor.lookups["testserv"] = "1.2.3.4" @@ -332,7 +332,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): lambda: self._handle_http_replication_attempt(worker_hs, port), ) - store = worker_hs.get_datastore() + store = worker_hs.get_datastores().main store.db_pool._db_pool = self.database_pool._db_pool # Set up TCP replication between master and the new worker if we don't diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 83e89383f..85be79d19 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -30,8 +30,8 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase): self.reconnect() - self.master_store = hs.get_datastore() - self.slaved_store = self.worker_hs.get_datastore() + self.master_store = hs.get_datastores().main + self.slaved_store = self.worker_hs.get_datastores().main self.storage = hs.get_storage() def replicate(self): diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py index cdd052001..50fbff5f3 100644 --- a/tests/replication/tcp/streams/test_account_data.py +++ b/tests/replication/tcp/streams/test_account_data.py @@ -23,7 +23,7 @@ from tests.replication._base import BaseStreamTestCase class AccountDataStreamTestCase(BaseStreamTestCase): def test_update_function_room_account_data_limit(self): """Test replication with many room account data updates""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # generate lots of account data updates updates = [] @@ -69,7 +69,7 @@ class AccountDataStreamTestCase(BaseStreamTestCase): def test_update_function_global_account_data_limit(self): """Test replication with many global account data updates""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # generate lots of account data updates updates = [] diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index f198a9488..f9d5da723 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -136,7 +136,7 @@ class EventsStreamTestCase(BaseStreamTestCase): # this is the point in the DAG where we make a fork fork_point: List[str] = self.get_success( - self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) + self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) ) events = [ @@ -291,7 +291,7 @@ class EventsStreamTestCase(BaseStreamTestCase): # this is the point in the DAG where we make a fork fork_point: List[str] = self.get_success( - self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) + self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) ) events: List[EventBase] = [] diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index 38e292c1a..eb0011784 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -32,7 +32,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): # tell the master to send a new receipt self.get_success( - self.hs.get_datastore().insert_receipt( + self.hs.get_datastores().main.insert_receipt( "!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1} ) ) @@ -56,7 +56,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): self.test_handler.on_rdata.reset_mock() self.get_success( - self.hs.get_datastore().insert_receipt( + self.hs.get_datastores().main.insert_receipt( "!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2} ) ) diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 92a5b53e1..ba1a63c0d 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -204,7 +204,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): def create_room_with_remote_server(self, user, token, remote_server="other_server"): room = self.helper.create_room_as(user, tok=token) - store = self.hs.get_datastore() + store = self.hs.get_datastores().main federation = self.hs.get_federation_event_handler() prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room)) diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py index 4094a75f3..8f4f6688c 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py @@ -50,7 +50,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): # Register a pusher user_dict = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_dict.token_id diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index 596ba5a0c..5f142e84c 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -47,7 +47,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): self.other_access_token = self.login("otheruser", "pass") self.room_creator = self.hs.get_room_creation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def default_config(self): conf = super().default_config() @@ -99,7 +99,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): persisted_on_1 = False persisted_on_2 = False - store = self.hs.get_datastore() + store = self.hs.get_datastores().main user_id = self.register_user("user", "pass") access_token = self.login("user", "pass") @@ -166,7 +166,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): user_id = self.register_user("user", "pass") access_token = self.login("user", "pass") - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Create two room on the different workers. self._create_room(room_id1, user_id, access_token) @@ -194,7 +194,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): # # Worker2's event stream position will not advance until we call # __aexit__ again. - worker_store2 = worker_hs2.get_datastore() + worker_store2 = worker_hs2.get_datastores().main assert isinstance(worker_store2._stream_id_gen, MultiWriterIdGenerator) actx = worker_store2._stream_id_gen.get_next() diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py index 1e3fe9c62..fb36aa994 100644 --- a/tests/rest/admin/test_background_updates.py +++ b/tests/rest/admin/test_background_updates.py @@ -36,7 +36,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py index 71068d16c..929bbdc37 100644 --- a/tests/rest/admin/test_federation.py +++ b/tests/rest/admin/test_federation.py @@ -35,7 +35,7 @@ class FederationTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -537,7 +537,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 86aff7575..0d47dd0af 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -634,7 +634,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: media_repo = hs.get_media_repository_resource() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.server_name = hs.hostname self.admin_user = self.register_user("admin", "pass", admin=True) @@ -767,7 +767,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: media_repo = hs.get_media_repository_resource() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py index 8513b1d2d..8354250ec 100644 --- a/tests/rest/admin/test_registration_tokens.py +++ b/tests/rest/admin/test_registration_tokens.py @@ -34,7 +34,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 23da0ad73..09c48e85c 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -50,7 +50,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" self.event_creation_handler._consent_uri_builder = consent_uri_builder - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -465,7 +465,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" self.event_creation_handler._consent_uri_builder = consent_uri_builder - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -2239,7 +2239,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py index 3c59f5f76..2c855bff9 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py @@ -38,7 +38,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room_shutdown_handler = hs.get_room_shutdown_handler() self.pagination_handler = hs.get_pagination_handler() self.server_notices_manager = self.hs.get_server_notices_manager() diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 272637e96..a60ea0a56 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -410,7 +410,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): even if the MAU limit is reached. """ handler = self.hs.get_registration_handler() - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Set monthly active users to the limit store.get_monthly_active_count = Mock( @@ -455,7 +455,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v2/users" def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -913,7 +913,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -1167,7 +1167,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth_handler = hs.get_auth_handler() # create users and get access tokens @@ -2609,7 +2609,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -2737,7 +2737,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.media_repo = hs.get_media_repository_resource() self.filepaths = MediaFilePaths(hs.config.media.media_store_path) @@ -3317,7 +3317,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -3609,7 +3609,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -3687,7 +3687,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -3913,7 +3913,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index afaa597f6..aa019c9a4 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -77,7 +77,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): return hs def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.submit_token_resource = PasswordResetSubmitTokenResource(hs) def test_basic_password_reset(self): @@ -398,7 +398,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase): self.deactivate(user_id, tok) - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Check that the user has been marked as deactivated. self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id))) @@ -409,7 +409,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase): def test_pending_invites(self): """Tests that deactivating a user rejects every pending invite for them.""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main inviter_id = self.register_user("inviter", "test") inviter_tok = self.login("inviter", "test") @@ -527,7 +527,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase): namespaces={"users": [{"regex": user_id, "exclusive": True}]}, sender=user_id, ) - self.hs.get_datastore().services_cache.append(appservice) + self.hs.get_datastores().main.services_cache.append(appservice) whoami = self._whoami(as_token) self.assertEqual( @@ -586,7 +586,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): return self.hs def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_id = self.register_user("kermit", "test") self.user_id_tok = self.login("kermit", "test") diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py index 475c6bed3..a573cc3c2 100644 --- a/tests/rest/client/test_filter.py +++ b/tests/rest/client/test_filter.py @@ -32,7 +32,7 @@ class FilterTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.filtering = hs.get_filtering() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def test_add_filter(self): channel = self.make_request( diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 19f5e4653..26d0d83e0 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -1101,8 +1101,8 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): }, ) - self.hs.get_datastore().services_cache.append(self.service) - self.hs.get_datastore().services_cache.append(self.another_service) + self.hs.get_datastores().main.services_cache.append(self.service) + self.hs.get_datastores().main.services_cache.append(self.another_service) return self.hs def test_login_appservice_user(self): diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py index ead883ded..b9647d5bd 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py @@ -292,7 +292,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): properties are "mimetype" (for the file's type) and "size" (for the file's size). """ - store = self.hs.get_datastore() + store = self.hs.get_datastores().main for name, props in names_and_props.items(): self.get_success( diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 0f1c47dcb..2835d86e5 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -56,7 +56,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): sender="@as:test", ) - self.hs.get_datastore().services_cache.append(appservice) + self.hs.get_datastores().main.services_cache.append(appservice) request_data = json.dumps( {"username": "as_user_kermit", "type": APP_SERVICE_REGISTRATION_TYPE} ) @@ -80,7 +80,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): sender="@as:test", ) - self.hs.get_datastore().services_cache.append(appservice) + self.hs.get_datastores().main.services_cache.append(appservice) request_data = json.dumps({"username": "as_user_kermit"}) channel = self.make_request( @@ -210,7 +210,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): username = "kermit" device_id = "frogfone" token = "abcd" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( "registration_tokens", @@ -316,7 +316,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): @override_config({"registration_requires_token": True}) def test_POST_registration_token_limit_uses(self): token = "abcd" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Create token that can be used once self.get_success( store.db_pool.simple_insert( @@ -391,7 +391,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_POST_registration_token_expiry(self): token = "abcd" now = self.hs.get_clock().time_msec() - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Create token that expired yesterday self.get_success( store.db_pool.simple_insert( @@ -439,7 +439,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_POST_registration_token_session_expiry(self): """Test `pending` is decremented when an uncompleted session expires.""" token = "abcd" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( "registration_tokens", @@ -530,7 +530,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): 3. Expire the session """ token = "abcd" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( "registration_tokens", @@ -657,7 +657,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): # Add a threepid self.get_success( - self.hs.get_datastore().user_add_threepid( + self.hs.get_datastores().main.user_add_threepid( user_id=user_id, medium="email", address=email, @@ -941,7 +941,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.email_attempts = [] self.hs.get_send_email_handler()._sendmail = sendmail - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main return self.hs @@ -1126,10 +1126,12 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): # We need to set these directly, instead of in the homeserver config dict above. # This is due to account validity-related config options not being read by # Synapse when account_validity.enabled is False. - self.hs.get_datastore()._account_validity_period = self.validity_period - self.hs.get_datastore()._account_validity_startup_job_max_delta = self.max_delta + self.hs.get_datastores().main._account_validity_period = self.validity_period + self.hs.get_datastores().main._account_validity_startup_job_max_delta = ( + self.max_delta + ) - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main return self.hs @@ -1163,7 +1165,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): def test_GET_token_valid(self): token = "abcd" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( "registration_tokens", diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index dfd9ffcb9..5687dea48 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -53,7 +53,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): return config def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_id, self.user_token = self._create_user("alice") self.user2_id, self.user2_token = self._create_user("bob") @@ -107,7 +107,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): # Unless that event is referenced from another event! self.get_success( - self.hs.get_datastore().db_pool.simple_insert( + self.hs.get_datastores().main.db_pool.simple_insert( table="event_relations", values={ "event_id": "bar", diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index fe5b536d9..c41a1c14a 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -51,7 +51,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): self.user_id = self.register_user("user", "password") self.token = self.login("user", "password") - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.serializer = self.hs.get_event_client_serializer() self.clock = self.hs.get_clock() @@ -114,7 +114,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): """Tests that synapse.visibility.filter_events_for_client correctly filters out outdated events """ - store = self.hs.get_datastore() + store = self.hs.get_datastores().main storage = self.hs.get_storage() room_id = self.helper.create_room_as(self.user_id, tok=self.token) events = [] diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index b7f086927..1afd96b8f 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -65,7 +65,7 @@ class RoomBase(unittest.HomeserverTestCase): async def _insert_client_ip(*args, **kwargs): return None - self.hs.get_datastore().insert_client_ip = _insert_client_ip + self.hs.get_datastores().main.insert_client_ip = _insert_client_ip return self.hs @@ -667,7 +667,7 @@ class RoomsCreateTestCase(RoomBase): # Add the current user to the ratelimit overrides, allowing them no ratelimiting. self.get_success( - self.hs.get_datastore().set_ratelimit_for_user(self.user_id, 0, 0) + self.hs.get_datastores().main.set_ratelimit_for_user(self.user_id, 0, 0) ) # Test that the invites aren't ratelimited anymore. @@ -1060,7 +1060,9 @@ class RoomJoinRatelimitTestCase(RoomBase): user_id = self.register_user("testuser", "password") # Check that the new user successfully joined the four rooms - rooms = self.get_success(self.hs.get_datastore().get_rooms_for_user(user_id)) + rooms = self.get_success( + self.hs.get_datastores().main.get_rooms_for_user(user_id) + ) self.assertEqual(len(rooms), 4) @@ -1184,7 +1186,7 @@ class RoomMessageListTestCase(RoomBase): self.assertTrue("end" in channel.json_body) def test_room_messages_purge(self): - store = self.hs.get_datastore() + store = self.hs.get_datastores().main pagination_handler = self.hs.get_pagination_handler() # Send a first message in the room, which will be removed by the purge. diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index b0c44af03..7d0e66b53 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -34,7 +34,7 @@ class _ShadowBannedBase(unittest.HomeserverTestCase): self.banned_user_id = self.register_user("banned", "test") self.banned_access_token = self.login("banned", "test") - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.get_success( self.store.set_shadow_banned(UserID.from_string(self.banned_user_id), True) diff --git a/tests/rest/client/test_shared_rooms.py b/tests/rest/client/test_shared_rooms.py index 283eccd53..c42c8aff6 100644 --- a/tests/rest/client/test_shared_rooms.py +++ b/tests/rest/client/test_shared_rooms.py @@ -36,7 +36,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): return self.setup_test_homeserver(config=config) def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = hs.get_user_directory_handler() def _get_shared_rooms(self, token, other_user) -> FakeChannel: diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index cd4af2b1f..e06256136 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -299,7 +299,7 @@ class SyncKnockTestCase( ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.url = "/sync?since=%s" self.next_batch = "s0" diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py index ee0abd529..de312cb63 100644 --- a/tests/rest/client/test_typing.py +++ b/tests/rest/client/test_typing.py @@ -57,7 +57,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): async def _insert_client_ip(*args, **kwargs): return None - hs.get_datastore().insert_client_ip = _insert_client_ip + hs.get_datastores().main.insert_client_ip = _insert_client_ip return hs diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py index a42388b26..7f79336ab 100644 --- a/tests/rest/client/test_upgrade_room.py +++ b/tests/rest/client/test_upgrade_room.py @@ -32,7 +32,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.creator = self.register_user("creator", "pass") self.creator_token = self.login(self.creator, "pass") diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 4cf1ed5dd..6878ccddb 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -243,7 +243,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): media_resource = hs.get_media_repository_resource() self.download_resource = media_resource.children[b"download"] self.thumbnail_resource = media_resource.children[b"thumbnail"] - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.media_repo = hs.get_media_repository() self.media_id = "example.com/12345" diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 36c495954..02b96c9e6 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -242,7 +242,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): return c def prepare(self, reactor, clock, hs): - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.server_notices_sender = self.hs.get_server_notices_sender() self.server_notices_manager = self.hs.get_server_notices_manager() self.event_source = self.hs.get_event_sources() diff --git a/tests/storage/databases/main/test_deviceinbox.py b/tests/storage/databases/main/test_deviceinbox.py index 36c933b9e..50c20c5b9 100644 --- a/tests/storage/databases/main/test_deviceinbox.py +++ b/tests/storage/databases/main/test_deviceinbox.py @@ -26,7 +26,7 @@ class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_id = self.register_user("foo", "pass") def test_background_remove_deleted_devices_from_device_inbox(self): diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 5ae491ff5..59def6e59 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -37,7 +37,7 @@ from tests import unittest class HaveSeenEventsTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store: EventsWorkerStore = hs.get_datastore() + self.store: EventsWorkerStore = hs.get_datastores().main # insert some test data for rid in ("room1", "room2"): @@ -122,7 +122,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store: EventsWorkerStore = hs.get_datastore() + self.store: EventsWorkerStore = hs.get_datastores().main self.user = self.register_user("user", "pass") self.token = self.login(self.user, "pass") @@ -163,7 +163,7 @@ class DatabaseOutageTestCase(unittest.HomeserverTestCase): """Test event fetching during a database outage.""" def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): - self.store: EventsWorkerStore = hs.get_datastore() + self.store: EventsWorkerStore = hs.get_datastores().main self.room_id = f"!room:{hs.hostname}" self.event_ids = [f"event{i}" for i in range(20)] diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py index d326a1d6a..3ac464696 100644 --- a/tests/storage/databases/main/test_lock.py +++ b/tests/storage/databases/main/test_lock.py @@ -20,7 +20,7 @@ from tests import unittest class LockTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs: HomeServer): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def test_simple_lock(self): """Test that we can take out a lock and that while we hold it nobody diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py index 7496974da..9abd0cb44 100644 --- a/tests/storage/databases/main/test_room.py +++ b/tests/storage/databases/main/test_room.py @@ -28,7 +28,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_id = self.register_user("foo", "pass") self.token = self.login("foo", "pass") diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 200b9198f..4899cd5c3 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -20,7 +20,7 @@ from tests import unittest class UpsertManyTests(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.storage = hs.get_datastore() + self.storage = hs.get_datastores().main self.table_name = "table_" + secrets.token_hex(6) self.get_success( diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py index d697d2bc1..272cd3540 100644 --- a/tests/storage/test_account_data.py +++ b/tests/storage/test_account_data.py @@ -21,7 +21,7 @@ from tests import unittest class IgnoredUsersTestCase(unittest.HomeserverTestCase): def prepare(self, hs, reactor, clock): - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.user = "@user:test" def _update_ignore_list( diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index ddcb7f554..50703ccae 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -467,7 +467,7 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer ) -> None: self.service = Mock(id="foo") - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.get_success( self.store.set_appservice_state(self.service, ApplicationServiceState.UP) ) diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 6156dfac4..39dcc094b 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -24,7 +24,7 @@ from tests.test_utils import make_awaitable, simple_async_mock class BackgroundUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates + self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates # the base test class should have run the real bg updates for us self.assertTrue( self.get_success(self.updates.has_completed_background_updates()) @@ -42,7 +42,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): # the target runtime for each bg update target_background_update_duration_ms = 100 - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( "background_updates", @@ -102,7 +102,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates + self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates # the base test class should have run the real bg updates for us self.assertTrue( self.get_success(self.updates.has_completed_background_updates()) @@ -138,7 +138,7 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): ) def test_controller(self): - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( "background_updates", diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index a59c28f89..ce89c9691 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -30,7 +30,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): """ def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main self.room_creator = homeserver.get_room_creation_handler() # Create a test user and room @@ -242,7 +242,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): return self.setup_test_homeserver(config=config) def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main self.room_creator = homeserver.get_room_creation_handler() self.event_creator_handler = homeserver.get_event_creation_handler() diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index c8ac67e35..49ad3c132 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -35,7 +35,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): return hs def prepare(self, hs, reactor, clock): - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main def test_insert_new_client_ip(self): self.reactor.advance(12345678) @@ -666,7 +666,7 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase): return hs def prepare(self, hs, reactor, clock): - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.user_id = self.register_user("bob", "abc123", True) def test_request_with_xforwarded(self): diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index b547bf8d9..21ffc5a90 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -19,7 +19,7 @@ from tests.unittest import HomeserverTestCase class DeviceStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def test_store_new_device(self): self.get_success( diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index 43628ce44..7b72a9242 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -19,7 +19,7 @@ from tests.unittest import HomeserverTestCase class DirectoryStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room = RoomID.from_string("!abcde:test") self.alias = RoomAlias.from_string("#my-room:test") diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py index 7556171d8..fb96ab3a2 100644 --- a/tests/storage/test_e2e_room_keys.py +++ b/tests/storage/test_e2e_room_keys.py @@ -28,7 +28,7 @@ room_key: RoomKey = { class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver("server", federation_http_client=None) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main return hs def test_room_keys_version_delete(self): diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 3bf6e337f..0f04493ad 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -17,7 +17,7 @@ from tests.unittest import HomeserverTestCase class EndToEndKeyStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def test_key_without_device_name(self): now = 1470174257070 diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index e3273a93f..401020fd6 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -30,7 +30,7 @@ from tests.unittest import HomeserverTestCase class EventChainStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._next_stream_ordering = 1 def test_simple(self): @@ -492,7 +492,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_id = self.register_user("foo", "pass") self.token = self.login("foo", "pass") self.requester = create_requester(self.user_id) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 667ca90a4..645d564d1 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -31,7 +31,7 @@ import tests.utils class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def test_get_prev_events_for_room(self): room_id = "@ROOM:local" diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 738f3ad1d..c9e3b9fa7 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -30,7 +30,7 @@ HIGHLIGHT = [ class EventPushActionsStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.persist_events_store = hs.get_datastores().persist_events def test_get_unread_push_actions_for_user_in_range_for_http(self): diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index a8639d8f8..ef5e25873 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -32,7 +32,7 @@ class ExtremPruneTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.state = self.hs.get_state_handler() self.persistence = self.hs.get_storage().persistence - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.register_user("user", "pass") self.token = self.login("user", "pass") @@ -341,7 +341,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.state = self.hs.get_state_handler() self.persistence = self.hs.get_storage().persistence - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main def test_remote_user_rooms_cache_invalidated(self): """Test that if the server leaves a room the `get_rooms_for_user` cache diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 748607828..6ac4b93f9 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -26,7 +26,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): skip = "Requires Postgres" def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) @@ -459,7 +459,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): skip = "Requires Postgres" def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) @@ -585,7 +585,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): skip = "Requires Postgres" def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py index a94b5fd72..905909552 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py @@ -37,7 +37,7 @@ KEY_2 = decode_verify_key_base64( class KeyStoreTestCase(tests.unittest.HomeserverTestCase): def test_get_server_verify_keys(self): - store = self.hs.get_datastore() + store = self.hs.get_datastores().main key_id_1 = "ed25519:key1" key_id_2 = "ed25519:KEY_ID_2" @@ -74,7 +74,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): def test_cache(self): """Check that updates correctly invalidate the cache.""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main key_id_1 = "ed25519:key1" key_id_2 = "ed25519:key2" diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py index f8d11bac4..4ca212fd1 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py @@ -22,7 +22,7 @@ class DataStoreTestCase(unittest.HomeserverTestCase): def setUp(self) -> None: super(DataStoreTestCase, self).setUp() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.user = UserID.from_string("@abcde:test") self.displayname = "Frank" diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index d6b4cdd78..79648d45d 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -45,7 +45,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): return config def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main # Advance the clock a bit reactor.advance(FORTY_DAYS) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index d37736edf..b6f99af2f 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -22,7 +22,7 @@ from tests import unittest class ProfileStoreTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.u_frank = UserID.from_string("@frank:test") diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 22a77c3cc..08cc60237 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -30,7 +30,7 @@ class PurgeTests(HomeserverTestCase): def prepare(self, reactor, clock, hs): self.room_id = self.helper.create_room_as(self.user_id) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = self.hs.get_storage() def test_purge_history(self): @@ -47,7 +47,7 @@ class PurgeTests(HomeserverTestCase): token = self.get_success( self.store.get_topological_token_for_event(last["event_id"]) ) - token_str = self.get_success(token.to_string(self.hs.get_datastore())) + token_str = self.get_success(token.to_string(self.hs.get_datastores().main)) # Purge everything before this topological token self.get_success( diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 8c95a0a2f..03e9cc7d4 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -30,7 +30,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): return config def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 974806528..1fa495f77 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -20,7 +20,7 @@ from tests.unittest import HomeserverTestCase class RegistrationStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_id = "@my-user:test" self.tokens = ["AbCdEfGhIjKlMnOpQrStUvWxYz", "BcDeFgHiJkLmNoPqRsTuVwXyZa"] diff --git a/tests/storage/test_rollback_worker.py b/tests/storage/test_rollback_worker.py index cfc8098af..0baa54312 100644 --- a/tests/storage/test_rollback_worker.py +++ b/tests/storage/test_rollback_worker.py @@ -56,7 +56,7 @@ class WorkerSchemaTests(HomeserverTestCase): def test_rolling_back(self): """Test that workers can start if the DB is a newer schema version""" - db_pool = self.hs.get_datastore().db_pool + db_pool = self.hs.get_datastores().main.db_pool db_conn = LoggingDatabaseConnection( db_pool._db_pool.connect(), db_pool.engine, @@ -72,7 +72,7 @@ class WorkerSchemaTests(HomeserverTestCase): def test_not_upgraded_old_schema_version(self): """Test that workers don't start if the DB has an older schema version""" - db_pool = self.hs.get_datastore().db_pool + db_pool = self.hs.get_datastores().main.db_pool db_conn = LoggingDatabaseConnection( db_pool._db_pool.connect(), db_pool.engine, @@ -92,7 +92,7 @@ class WorkerSchemaTests(HomeserverTestCase): Test that workers don't start if the DB is on the current schema version, but there are still outstanding delta migrations to run. """ - db_pool = self.hs.get_datastore().db_pool + db_pool = self.hs.get_datastores().main.db_pool db_conn = LoggingDatabaseConnection( db_pool._db_pool.connect(), db_pool.engine, diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 31ce7f625..42bfca2a8 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -23,7 +23,7 @@ class RoomStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): # We can't test RoomStore on its own without the DirectoryStore, for # management of the 'room_aliases' table - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room = RoomID.from_string("!abcde:test") self.alias = RoomAlias.from_string("#a-room-name:test") @@ -71,7 +71,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): # Room events need the full datastore, for persist_event() and # get_room_state() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.event_factory = hs.get_event_factory() diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py index 8971ecccb..befaa0fce 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py @@ -46,7 +46,7 @@ class NullByteInsertionTest(HomeserverTestCase): self.assertIn("event_id", response) # Check that search works for the message where the null byte was replaced - store = self.hs.get_datastore() + store = self.hs.get_datastores().main result = self.get_success( store.search_msgs([room_id], "hi bob", ["content.body"]) ) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 5cfdfe9b8..7028f0dfb 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -35,7 +35,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # We can't test the RoomMemberStore on its own without the other event # storage logic - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.u_alice = self.register_user("alice", "pass") self.t_alice = self.login("alice", "pass") @@ -212,7 +212,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main self.room_creator = homeserver.get_room_creation_handler() def test_can_rerun_update(self): diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 28c767ecf..f88f1c55f 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) class StateStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_datastore = self.storage.state.stores.state self.event_builder_factory = hs.get_event_builder_factory() diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index ce782c7e1..6a1cf3305 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -115,7 +115,7 @@ class PaginationTestCase(HomeserverTestCase): ) events, next_key = self.get_success( - self.hs.get_datastore().paginate_room_events( + self.hs.get_datastores().main.paginate_room_events( room_id=self.room_id, from_key=from_token.room_key, to_key=None, diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py index bea9091d3..e05daa285 100644 --- a/tests/storage/test_transactions.py +++ b/tests/storage/test_transactions.py @@ -20,7 +20,7 @@ from tests.unittest import HomeserverTestCase class TransactionStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main def test_get_set_transactions(self): """Tests that we can successfully get a non-existent entry for diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 48f1e9d84..7f1964eb6 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -149,7 +149,7 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_dir_helper = GetUserDirectoryTables(self.store) def _purge_and_rebuild_user_dir(self) -> None: @@ -415,7 +415,7 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): class UserDirectoryStoreTestCase(HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main # alice and bob are both in !room_id. bobby is not but shares # a homeserver with alice. diff --git a/tests/test_federation.py b/tests/test_federation.py index 2b9804aba..c39816de8 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -52,11 +52,13 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) )[0]["room_id"] - self.store = self.homeserver.get_datastore() + self.store = self.homeserver.get_datastores().main # Figure out what the most recent event is most_recent = self.get_success( - self.homeserver.get_datastore().get_latest_event_ids_in_room(self.room_id) + self.homeserver.get_datastores().main.get_latest_event_ids_in_room( + self.room_id + ) )[0] join_event = make_event_from_dict( @@ -185,7 +187,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Register a mock on the store so that the incoming update doesn't fail because # we don't share a room with the user. - store = self.homeserver.get_datastore() + store = self.homeserver.get_datastores().main store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"])) # Manually inject a fake device list update. We need this update to include at diff --git a/tests/test_mau.py b/tests/test_mau.py index 80ab40e25..46bd3075d 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -52,7 +52,7 @@ class TestMauLimit(unittest.HomeserverTestCase): return config def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main def test_simple_deny_mau(self): # Create and sync so that the MAU counts get updated diff --git a/tests/test_state.py b/tests/test_state.py index 76e0e8ca7..90800421f 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -162,7 +162,7 @@ class StateTestCase(unittest.TestCase): hs = Mock( spec_set=[ "config", - "get_datastore", + "get_datastores", "get_storage", "get_auth", "get_state_handler", @@ -173,7 +173,7 @@ class StateTestCase(unittest.TestCase): ] ) hs.config = default_config("tesths", True) - hs.get_datastore.return_value = self.store + hs.get_datastores.return_value = Mock(main=self.store) hs.get_state_handler.return_value = None hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index e9ec9e085..c654e36ee 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -85,7 +85,9 @@ async def create_event( **kwargs, ) -> Tuple[EventBase, EventContext]: if room_version is None: - room_version = await hs.get_datastore().get_room_version_id(kwargs["room_id"]) + room_version = await hs.get_datastores().main.get_room_version_id( + kwargs["room_id"] + ) builder = hs.get_event_builder_factory().for_room_version( KNOWN_ROOM_VERSIONS[room_version], kwargs diff --git a/tests/test_visibility.py b/tests/test_visibility.py index e0b08d67d..219b5660b 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -93,7 +93,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): events_to_filter.append(evt) # the erasey user gets erased - self.get_success(self.hs.get_datastore().mark_user_erased("@erased:local_hs")) + self.get_success( + self.hs.get_datastores().main.mark_user_erased("@erased:local_hs") + ) # ... and the filtering happens. filtered = self.get_success( diff --git a/tests/unittest.py b/tests/unittest.py index 7983c1e8b..0caa8e7a4 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -280,7 +280,7 @@ class HomeserverTestCase(TestCase): # We need a valid token ID to satisfy foreign key constraints. token_id = self.get_success( - self.hs.get_datastore().add_access_token_to_user( + self.hs.get_datastores().main.add_access_token_to_user( self.helper.auth_user_id, "some_fake_token", None, @@ -337,7 +337,7 @@ class HomeserverTestCase(TestCase): def wait_for_background_updates(self) -> None: """Block until all background database updates have completed.""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main while not self.get_success( store.db_pool.updates.has_completed_background_updates() ): @@ -504,7 +504,7 @@ class HomeserverTestCase(TestCase): self.get_success(stor.db_pool.updates.run_background_updates(False)) hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) - stor = hs.get_datastore() + stor = hs.get_datastores().main # Run the database background updates, when running against "master". if hs.__class__.__name__ == "TestHomeServer": @@ -722,14 +722,16 @@ class HomeserverTestCase(TestCase): Add the given event as an extremity to the room. """ self.get_success( - self.hs.get_datastore().db_pool.simple_insert( + self.hs.get_datastores().main.db_pool.simple_insert( table="event_forward_extremities", values={"room_id": room_id, "event_id": event_id}, desc="test_add_extremity", ) ) - self.hs.get_datastore().get_latest_event_ids_in_room.invalidate((room_id,)) + self.hs.get_datastores().main.get_latest_event_ids_in_room.invalidate( + (room_id,) + ) def attempt_wrong_password_login(self, username, password): """Attempts to login as the user with the given password, asserting @@ -775,7 +777,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase): verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version) self.get_success( - hs.get_datastore().store_server_verify_keys( + hs.get_datastores().main.store_server_verify_keys( from_server=self.OTHER_SERVER_NAME, ts_added_ms=clock.time_msec(), verify_keys=[ diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py index 9e1bebdc8..26cb71c64 100644 --- a/tests/util/test_retryutils.py +++ b/tests/util/test_retryutils.py @@ -24,7 +24,7 @@ from tests.unittest import HomeserverTestCase class RetryLimiterTestCase(HomeserverTestCase): def test_new_destination(self): """A happy-path case with a new destination and a successful operation""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) # advance the clock a bit before making the request @@ -38,7 +38,7 @@ class RetryLimiterTestCase(HomeserverTestCase): def test_limiter(self): """General test case which walks through the process of a failing request""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) diff --git a/tests/utils.py b/tests/utils.py index c06fc320f..ef99c72e0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -367,7 +367,7 @@ async def create_room(hs, room_id: str, creator_id: str): """Creates and persist a creation event for the given room""" persistence_store = hs.get_storage().persistence - store = hs.get_datastore() + store = hs.get_datastores().main event_builder_factory = hs.get_event_builder_factory() event_creation_handler = hs.get_event_creation_handler()