From 6f440fd859c411af8c32478b4353f7550619c3bd Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 9 Feb 2022 16:06:51 +0100 Subject: [PATCH 01/84] Recommend upgrading treq alongside twisted (#11943) --- CHANGES.md | 2 +- docs/upgrade.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 958024ff0..cd62e5256 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -7,7 +7,7 @@ Note that [Twisted 22.1.0](https://github.com/twisted/twisted/releases/tag/twist has recently been released, which fixes a [security issue](https://github.com/twisted/twisted/security/advisories/GHSA-92x2-jw7w-xvvx) within the Twisted library. We do not believe Synapse is affected by this vulnerability, though we advise server administrators who installed Synapse via pip to upgrade Twisted -with `pip install --upgrade Twisted` as a matter of good practice. The Docker image +with `pip install --upgrade Twisted treq` as a matter of good practice. The Docker image `matrixdotorg/synapse` and the Debian packages from `packages.matrix.org` are using the updated library. diff --git a/docs/upgrade.md b/docs/upgrade.md index 0105f87f9..df873e531 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -93,7 +93,7 @@ Note that [Twisted 22.1.0](https://github.com/twisted/twisted/releases/tag/twist has recently been released, which fixes a [security issue](https://github.com/twisted/twisted/security/advisories/GHSA-92x2-jw7w-xvvx) within the Twisted library. We do not believe Synapse is affected by this vulnerability, though we advise server administrators who installed Synapse via pip to upgrade Twisted -with `pip install --upgrade Twisted` as a matter of good practice. The Docker image +with `pip install --upgrade Twisted treq` as a matter of good practice. The Docker image `matrixdotorg/synapse` and the Debian packages from `packages.matrix.org` are using the updated library. From 87f200571386b131f841f8abc8bb08efb6c3be52 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 15 Feb 2022 11:27:56 +0000 Subject: [PATCH 02/84] Add some tests for propagation of device list changes between local users (#11972) --- changelog.d/11972.misc | 1 + synapse/notifier.py | 4 +- tests/rest/client/test_device_lists.py | 155 +++++++++++++++++++++++++ tests/rest/client/utils.py | 6 +- 4 files changed, 163 insertions(+), 3 deletions(-) create mode 100644 changelog.d/11972.misc create mode 100644 tests/rest/client/test_device_lists.py diff --git a/changelog.d/11972.misc b/changelog.d/11972.misc new file mode 100644 index 000000000..29c38bfd8 --- /dev/null +++ b/changelog.d/11972.misc @@ -0,0 +1 @@ +Add tests for device list changes between local users. \ No newline at end of file diff --git a/synapse/notifier.py b/synapse/notifier.py index e0fad2da6..753dd6b6a 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -138,7 +138,7 @@ class _NotifierUserStream: self.current_token = self.current_token.copy_and_advance(stream_key, stream_id) self.last_notified_token = self.current_token self.last_notified_ms = time_now_ms - noify_deferred = self.notify_deferred + notify_deferred = self.notify_deferred log_kv( { @@ -153,7 +153,7 @@ class _NotifierUserStream: with PreserveLoggingContext(): self.notify_deferred = ObservableDeferred(defer.Deferred()) - noify_deferred.callback(self.current_token) + notify_deferred.callback(self.current_token) def remove(self, notifier: "Notifier") -> None: """Remove this listener from all the indexes in the Notifier diff --git a/tests/rest/client/test_device_lists.py b/tests/rest/client/test_device_lists.py new file mode 100644 index 000000000..16070cf02 --- /dev/null +++ b/tests/rest/client/test_device_lists.py @@ -0,0 +1,155 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.rest import admin, devices, room, sync +from synapse.rest.client import account, login, register + +from tests import unittest + + +class DeviceListsTestCase(unittest.HomeserverTestCase): + """Tests regarding device list changes.""" + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + account.register_servlets, + room.register_servlets, + sync.register_servlets, + devices.register_servlets, + ] + + def test_receiving_local_device_list_changes(self): + """Tests that a local users that share a room receive each other's device list + changes. + """ + # Register two users + test_device_id = "TESTDEVICE" + alice_user_id = self.register_user("alice", "correcthorse") + alice_access_token = self.login( + alice_user_id, "correcthorse", device_id=test_device_id + ) + + bob_user_id = self.register_user("bob", "ponyponypony") + bob_access_token = self.login(bob_user_id, "ponyponypony") + + # Create a room for them to coexist peacefully in + new_room_id = self.helper.create_room_as( + alice_user_id, is_public=True, tok=alice_access_token + ) + self.assertIsNotNone(new_room_id) + + # Have Bob join the room + self.helper.invite( + new_room_id, alice_user_id, bob_user_id, tok=alice_access_token + ) + self.helper.join(new_room_id, bob_user_id, tok=bob_access_token) + + # Now have Bob initiate an initial sync (in order to get a since token) + channel = self.make_request( + "GET", + "/sync", + access_token=bob_access_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + next_batch_token = channel.json_body["next_batch"] + + # ...and then an incremental sync. This should block until the sync stream is woken up, + # which we hope will happen as a result of Alice updating their device list. + bob_sync_channel = self.make_request( + "GET", + f"/sync?since={next_batch_token}&timeout=30000", + access_token=bob_access_token, + # Start the request, then continue on. + await_result=False, + ) + + # Have alice update their device list + channel = self.make_request( + "PUT", + f"/devices/{test_device_id}", + { + "display_name": "New Device Name", + }, + access_token=alice_access_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that bob's incremental sync contains the updated device list. + # If not, the client would only receive the device list update on the + # *next* sync. + bob_sync_channel.await_result() + self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body) + + changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get( + "changed", [] + ) + self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body) + + def test_not_receiving_local_device_list_changes(self): + """Tests a local users DO NOT receive device updates from each other if they do not + share a room. + """ + # Register two users + test_device_id = "TESTDEVICE" + alice_user_id = self.register_user("alice", "correcthorse") + alice_access_token = self.login( + alice_user_id, "correcthorse", device_id=test_device_id + ) + + bob_user_id = self.register_user("bob", "ponyponypony") + bob_access_token = self.login(bob_user_id, "ponyponypony") + + # These users do not share a room. They are lonely. + + # Have Bob initiate an initial sync (in order to get a since token) + channel = self.make_request( + "GET", + "/sync", + access_token=bob_access_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + next_batch_token = channel.json_body["next_batch"] + + # ...and then an incremental sync. This should block until the sync stream is woken up, + # which we hope will happen as a result of Alice updating their device list. + bob_sync_channel = self.make_request( + "GET", + f"/sync?since={next_batch_token}&timeout=1000", + access_token=bob_access_token, + # Start the request, then continue on. + await_result=False, + ) + + # Have alice update their device list + channel = self.make_request( + "PUT", + f"/devices/{test_device_id}", + { + "display_name": "New Device Name", + }, + access_token=alice_access_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that bob's incremental sync does not contain the updated device list. + bob_sync_channel.await_result() + self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body) + + changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get( + "changed", [] + ) + self.assertNotIn( + alice_user_id, changed_device_lists, bob_sync_channel.json_body + ) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 1c0cb0cf4..2b3fdadff 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -106,9 +106,13 @@ class RestHelper: default room version. tok: The access token to use in the request. expect_code: The expected HTTP response code. + extra_content: Extra keys to include in the body of the /createRoom request. + Note that if is_public is set, the "visibility" key will be overridden. + If room_version is set, the "room_version" key will be overridden. + custom_headers: HTTP headers to include in the request. Returns: - The ID of the newly created room. + The ID of the newly created room, or None if the request failed. """ temp_id = self.auth_user_id self.auth_user_id = room_creator From 45f45404de2d0c4d68954eddc2dc905e50dfafe9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 15 Feb 2022 08:26:57 -0500 Subject: [PATCH 03/84] Fix incorrect thread summaries when the latest event is edited. (#11992) If the latest event in a thread was edited than the original event content was included in bundled aggregation for threads instead of the edited event content. --- changelog.d/11992.bugfix | 1 + synapse/events/utils.py | 69 ++++++++++++------- .../storage/databases/main/events_worker.py | 2 +- synapse/storage/databases/main/relations.py | 24 +++++-- tests/rest/client/test_relations.py | 42 +++++++++++ 5 files changed, 107 insertions(+), 31 deletions(-) create mode 100644 changelog.d/11992.bugfix diff --git a/changelog.d/11992.bugfix b/changelog.d/11992.bugfix new file mode 100644 index 000000000..f73c86bb2 --- /dev/null +++ b/changelog.d/11992.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse v1.48.0 where an edit of the latest event in a thread would not be properly applied to the thread summary. diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 243696b35..9386fa29d 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -425,6 +425,33 @@ class EventClientSerializer: return serialized_event + def _apply_edit( + self, orig_event: EventBase, serialized_event: JsonDict, edit: EventBase + ) -> None: + """Replace the content, preserving existing relations of the serialized event. + + Args: + orig_event: The original event. + serialized_event: The original event, serialized. This is modified. + edit: The event which edits the above. + """ + + # Ensure we take copies of the edit content, otherwise we risk modifying + # the original event. + edit_content = edit.content.copy() + + # Unfreeze the event content if necessary, so that we may modify it below + edit_content = unfreeze(edit_content) + serialized_event["content"] = edit_content.get("m.new_content", {}) + + # Check for existing relations + relates_to = orig_event.content.get("m.relates_to") + if relates_to: + # Keep the relations, ensuring we use a dict copy of the original + serialized_event["content"]["m.relates_to"] = relates_to.copy() + else: + serialized_event["content"].pop("m.relates_to", None) + def _inject_bundled_aggregations( self, event: EventBase, @@ -450,26 +477,11 @@ class EventClientSerializer: serialized_aggregations[RelationTypes.REFERENCE] = aggregations.references if aggregations.replace: - # If there is an edit replace the content, preserving existing - # relations. + # If there is an edit, apply it to the event. edit = aggregations.replace + self._apply_edit(event, serialized_event, edit) - # Ensure we take copies of the edit content, otherwise we risk modifying - # the original event. - edit_content = edit.content.copy() - - # Unfreeze the event content if necessary, so that we may modify it below - edit_content = unfreeze(edit_content) - serialized_event["content"] = edit_content.get("m.new_content", {}) - - # Check for existing relations - relates_to = event.content.get("m.relates_to") - if relates_to: - # Keep the relations, ensuring we use a dict copy of the original - serialized_event["content"]["m.relates_to"] = relates_to.copy() - else: - serialized_event["content"].pop("m.relates_to", None) - + # Include information about it in the relations dict. serialized_aggregations[RelationTypes.REPLACE] = { "event_id": edit.event_id, "origin_server_ts": edit.origin_server_ts, @@ -478,13 +490,22 @@ class EventClientSerializer: # If this event is the start of a thread, include a summary of the replies. if aggregations.thread: + thread = aggregations.thread + + # Don't bundle aggregations as this could recurse forever. + serialized_latest_event = self.serialize_event( + thread.latest_event, time_now, bundle_aggregations=None + ) + # Manually apply an edit, if one exists. + if thread.latest_edit: + self._apply_edit( + thread.latest_event, serialized_latest_event, thread.latest_edit + ) + serialized_aggregations[RelationTypes.THREAD] = { - # Don't bundle aggregations as this could recurse forever. - "latest_event": self.serialize_event( - aggregations.thread.latest_event, time_now, bundle_aggregations=None - ), - "count": aggregations.thread.count, - "current_user_participated": aggregations.thread.current_user_participated, + "latest_event": serialized_latest_event, + "count": thread.count, + "current_user_participated": thread.current_user_participated, } # Include the bundled aggregations in the event. diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 8d4287045..712b8ce20 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -408,7 +408,7 @@ class EventsWorkerStore(SQLBaseStore): include the previous states content in the unsigned field. allow_rejected: If True, return rejected events. Otherwise, - omits rejeted events from the response. + omits rejected events from the response. Returns: A mapping from event_id to event. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index e2c27e594..5582029f9 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -53,8 +53,13 @@ logger = logging.getLogger(__name__) @attr.s(slots=True, frozen=True, auto_attribs=True) class _ThreadAggregation: + # The latest event in the thread. latest_event: EventBase + # The latest edit to the latest event in the thread. + latest_edit: Optional[EventBase] + # The total number of events in the thread. count: int + # True if the current user has sent an event to the thread. current_user_participated: bool @@ -461,8 +466,8 @@ class RelationsWorkerStore(SQLBaseStore): @cachedList(cached_method_name="get_thread_summary", list_name="event_ids") async def _get_thread_summaries( self, event_ids: Collection[str] - ) -> Dict[str, Optional[Tuple[int, EventBase]]]: - """Get the number of threaded replies and the latest reply (if any) for the given event. + ) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]: + """Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event. Args: event_ids: Summarize the thread related to this event ID. @@ -471,8 +476,10 @@ class RelationsWorkerStore(SQLBaseStore): A map of the thread summary each event. A missing event implies there are no threaded replies. - Each summary includes the number of items in the thread and the most - recent response. + Each summary is a tuple of: + The number of events in the thread. + The most recent event in the thread. + The most recent edit to the most recent event in the thread, if applicable. """ def _get_thread_summaries_txn( @@ -558,6 +565,9 @@ class RelationsWorkerStore(SQLBaseStore): latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined] + # Check to see if any of those events are edited. + latest_edits = await self._get_applicable_edits(latest_event_ids.values()) + # Map to the event IDs to the thread summary. # # There might not be a summary due to there not being a thread or @@ -568,7 +578,8 @@ class RelationsWorkerStore(SQLBaseStore): summary = None if latest_event: - summary = (counts[parent_event_id], latest_event) + latest_edit = latest_edits.get(latest_event_id) + summary = (counts[parent_event_id], latest_event, latest_edit) summaries[parent_event_id] = summary return summaries @@ -828,11 +839,12 @@ class RelationsWorkerStore(SQLBaseStore): ) for event_id, summary in summaries.items(): if summary: - thread_count, latest_thread_event = summary + thread_count, latest_thread_event, edit = summary results.setdefault( event_id, BundledAggregations() ).thread = _ThreadAggregation( latest_event=latest_thread_event, + latest_edit=edit, count=thread_count, # If there's a thread summary it must also exist in the # participated dictionary. diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index de80aca03..dfd9ffcb9 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1123,6 +1123,48 @@ class RelationsTestCase(unittest.HomeserverTestCase): {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) + @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) + def test_edit_thread(self): + """Test that editing a thread works.""" + + # Create a thread and edit the last event. + channel = self._send_relation( + RelationTypes.THREAD, + "m.room.message", + content={"msgtype": "m.text", "body": "A threaded reply!"}, + ) + self.assertEquals(200, channel.code, channel.json_body) + threaded_event_id = channel.json_body["event_id"] + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + parent_id=threaded_event_id, + ) + self.assertEquals(200, channel.code, channel.json_body) + + # Fetch the thread root, to get the bundled aggregation for the thread. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + # We expect that the edit message appears in the thread summary in the + # unsigned relations section. + relations_dict = channel.json_body["unsigned"].get("m.relations") + self.assertIn(RelationTypes.THREAD, relations_dict) + + thread_summary = relations_dict[RelationTypes.THREAD] + self.assertIn("latest_event", thread_summary) + latest_event_in_thread = thread_summary["latest_event"] + self.assertEquals( + latest_event_in_thread["content"]["body"], "I've been edited!" + ) + def test_edit_edit(self): """Test that an edit cannot be edited.""" new_body = {"msgtype": "m.text", "body": "Initial edit"} From e44f91d678e22936b7e2f0d8bf4890159507533b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 15 Feb 2022 08:47:05 -0500 Subject: [PATCH 04/84] Refactor search code to reduce function size. (#11991) Splits the search code into a few logical functions instead of a single unreadable function. There are also a few additional changes for readability. After refactoring it was clear to see there were some unused and unnecessary variables, which were simplified. --- changelog.d/11991.misc | 1 + synapse/handlers/search.py | 645 +++++++++++++++-------- synapse/storage/databases/main/search.py | 17 +- 3 files changed, 436 insertions(+), 227 deletions(-) create mode 100644 changelog.d/11991.misc diff --git a/changelog.d/11991.misc b/changelog.d/11991.misc new file mode 100644 index 000000000..34a3b3a6b --- /dev/null +++ b/changelog.d/11991.misc @@ -0,0 +1 @@ +Refactor the search code for improved readability. diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 41cb80907..afd14da11 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -14,8 +14,9 @@ import itertools import logging -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional +from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple +import attr from unpaddedbase64 import decode_base64, encode_base64 from synapse.api.constants import EventTypes, Membership @@ -32,6 +33,20 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _SearchResult: + # The count of results. + count: int + # A mapping of event ID to the rank of that event. + rank_map: Dict[str, int] + # A list of the resulting events. + allowed_events: List[EventBase] + # A map of room ID to results. + room_groups: Dict[str, JsonDict] + # A set of event IDs to highlight. + highlights: Set[str] + + class SearchHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -100,7 +115,7 @@ class SearchHandler: """Performs a full text search for a user. Args: - user + user: The user performing the search. content: Search parameters batch: The next_batch parameter. Used for pagination. @@ -156,6 +171,8 @@ class SearchHandler: # Include context around each event? event_context = room_cat.get("event_context", None) + before_limit = after_limit = None + include_profile = False # Group results together? May allow clients to paginate within a # group @@ -182,6 +199,73 @@ class SearchHandler: % (set(group_keys) - {"room_id", "sender"},), ) + return await self._search( + user, + batch_group, + batch_group_key, + batch_token, + search_term, + keys, + filter_dict, + order_by, + include_state, + group_keys, + event_context, + before_limit, + after_limit, + include_profile, + ) + + async def _search( + self, + user: UserID, + batch_group: Optional[str], + batch_group_key: Optional[str], + batch_token: Optional[str], + search_term: str, + keys: List[str], + filter_dict: JsonDict, + order_by: str, + include_state: bool, + group_keys: List[str], + event_context: Optional[bool], + before_limit: Optional[int], + after_limit: Optional[int], + include_profile: bool, + ) -> JsonDict: + """Performs a full text search for a user. + + Args: + user: The user performing the search. + batch_group: Pagination information. + batch_group_key: Pagination information. + batch_token: Pagination information. + search_term: Search term to search for + keys: List of keys to search in, currently supports + "content.body", "content.name", "content.topic" + filter_dict: The JSON to build a filter out of. + order_by: How to order the results. Valid values ore "rank" and "recent". + include_state: True if the state of the room at each result should + be included. + group_keys: A list of ways to group the results. Valid values are + "room_id" and "sender". + event_context: True to include contextual events around results. + before_limit: + The number of events before a result to include as context. + + Only used if event_context is True. + after_limit: + The number of events after a result to include as context. + + Only used if event_context is True. + include_profile: True if historical profile information should be + included in the event context. + + Only used if event_context is True. + + Returns: + dict to be returned to the client with results of search + """ search_filter = Filter(self.hs, filter_dict) # TODO: Search through left rooms too @@ -216,209 +300,57 @@ class SearchHandler: } } - rank_map = {} # event_id -> rank of event - allowed_events = [] - # Holds result of grouping by room, if applicable - room_groups: Dict[str, JsonDict] = {} - # Holds result of grouping by sender, if applicable - sender_group: Dict[str, JsonDict] = {} - - # Holds the next_batch for the entire result set if one of those exists - global_next_batch = None - - highlights = set() - - count = None + sender_group: Optional[Dict[str, JsonDict]] if order_by == "rank": - search_result = await self.store.search_msgs(room_ids, search_term, keys) - - count = search_result["count"] - - if search_result["highlights"]: - highlights.update(search_result["highlights"]) - - results = search_result["results"] - - rank_map.update({r["event"].event_id: r["rank"] for r in results}) - - filtered_events = await search_filter.filter([r["event"] for r in results]) - - events = await filter_events_for_client( - self.storage, user.to_string(), filtered_events + search_result, sender_group = await self._search_by_rank( + user, room_ids, search_term, keys, search_filter ) - - events.sort(key=lambda e: -rank_map[e.event_id]) - allowed_events = events[: search_filter.limit] - - for e in allowed_events: - rm = room_groups.setdefault( - e.room_id, {"results": [], "order": rank_map[e.event_id]} - ) - rm["results"].append(e.event_id) - - s = sender_group.setdefault( - e.sender, {"results": [], "order": rank_map[e.event_id]} - ) - s["results"].append(e.event_id) - + # Unused return values for rank search. + global_next_batch = None elif order_by == "recent": - room_events: List[EventBase] = [] - i = 0 - - pagination_token = batch_token - - # We keep looping and we keep filtering until we reach the limit - # or we run out of things. - # But only go around 5 times since otherwise synapse will be sad. - while len(room_events) < search_filter.limit and i < 5: - i += 1 - search_result = await self.store.search_rooms( - room_ids, - search_term, - keys, - search_filter.limit * 2, - pagination_token=pagination_token, - ) - - if search_result["highlights"]: - highlights.update(search_result["highlights"]) - - count = search_result["count"] - - results = search_result["results"] - - results_map = {r["event"].event_id: r for r in results} - - rank_map.update({r["event"].event_id: r["rank"] for r in results}) - - filtered_events = await search_filter.filter( - [r["event"] for r in results] - ) - - events = await filter_events_for_client( - self.storage, user.to_string(), filtered_events - ) - - room_events.extend(events) - room_events = room_events[: search_filter.limit] - - if len(results) < search_filter.limit * 2: - pagination_token = None - break - else: - pagination_token = results[-1]["pagination_token"] - - for event in room_events: - group = room_groups.setdefault(event.room_id, {"results": []}) - group["results"].append(event.event_id) - - if room_events and len(room_events) >= search_filter.limit: - last_event_id = room_events[-1].event_id - pagination_token = results_map[last_event_id]["pagination_token"] - - # We want to respect the given batch group and group keys so - # that if people blindly use the top level `next_batch` token - # it returns more from the same group (if applicable) rather - # than reverting to searching all results again. - if batch_group and batch_group_key: - global_next_batch = encode_base64( - ( - "%s\n%s\n%s" - % (batch_group, batch_group_key, pagination_token) - ).encode("ascii") - ) - else: - global_next_batch = encode_base64( - ("%s\n%s\n%s" % ("all", "", pagination_token)).encode("ascii") - ) - - for room_id, group in room_groups.items(): - group["next_batch"] = encode_base64( - ("%s\n%s\n%s" % ("room_id", room_id, pagination_token)).encode( - "ascii" - ) - ) - - allowed_events.extend(room_events) - + search_result, global_next_batch = await self._search_by_recent( + user, + room_ids, + search_term, + keys, + search_filter, + batch_group, + batch_group_key, + batch_token, + ) + # Unused return values for recent search. + sender_group = None else: # We should never get here due to the guard earlier. raise NotImplementedError() - logger.info("Found %d events to return", len(allowed_events)) + logger.info("Found %d events to return", len(search_result.allowed_events)) # If client has asked for "context" for each event (i.e. some surrounding # events and state), fetch that if event_context is not None: - now_token = self.hs.get_event_sources().get_current_token() + # Note that before and after limit must be set in this case. + assert before_limit is not None + assert after_limit is not None - contexts = {} - for event in allowed_events: - res = await self.store.get_events_around( - event.room_id, event.event_id, before_limit, after_limit - ) - - logger.info( - "Context for search returned %d and %d events", - len(res.events_before), - len(res.events_after), - ) - - events_before = await filter_events_for_client( - self.storage, user.to_string(), res.events_before - ) - - events_after = await filter_events_for_client( - self.storage, user.to_string(), res.events_after - ) - - context = { - "events_before": events_before, - "events_after": events_after, - "start": await now_token.copy_and_replace( - "room_key", res.start - ).to_string(self.store), - "end": await now_token.copy_and_replace( - "room_key", res.end - ).to_string(self.store), - } - - if include_profile: - senders = { - ev.sender - for ev in itertools.chain(events_before, [event], events_after) - } - - if events_after: - last_event_id = events_after[-1].event_id - else: - last_event_id = event.event_id - - state_filter = StateFilter.from_types( - [(EventTypes.Member, sender) for sender in senders] - ) - - state = await self.state_store.get_state_for_event( - last_event_id, state_filter - ) - - context["profile_info"] = { - s.state_key: { - "displayname": s.content.get("displayname", None), - "avatar_url": s.content.get("avatar_url", None), - } - for s in state.values() - if s.type == EventTypes.Member and s.state_key in senders - } - - contexts[event.event_id] = context + contexts = await self._calculate_event_contexts( + user, + search_result.allowed_events, + before_limit, + after_limit, + include_profile, + ) else: contexts = {} # TODO: Add a limit - time_now = self.clock.time_msec() + state_results = {} + if include_state: + for room_id in {e.room_id for e in search_result.allowed_events}: + state = await self.state_handler.get_current_state(room_id) + state_results[room_id] = list(state.values()) aggregations = None if self._msc3666_enabled: @@ -432,11 +364,16 @@ class SearchHandler: for context in contexts.values() ), # The returned events. - allowed_events, + search_result.allowed_events, ), user.to_string(), ) + # We're now about to serialize the events. We should not make any + # blocking calls after this. Otherwise, the 'age' will be wrong. + + time_now = self.clock.time_msec() + for context in contexts.values(): context["events_before"] = self._event_serializer.serialize_events( context["events_before"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type] @@ -445,44 +382,33 @@ class SearchHandler: context["events_after"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type] ) - state_results = {} - if include_state: - for room_id in {e.room_id for e in allowed_events}: - state = await self.state_handler.get_current_state(room_id) - state_results[room_id] = list(state.values()) + results = [ + { + "rank": search_result.rank_map[e.event_id], + "result": self._event_serializer.serialize_event( + e, time_now, bundle_aggregations=aggregations + ), + "context": contexts.get(e.event_id, {}), + } + for e in search_result.allowed_events + ] - # We're now about to serialize the events. We should not make any - # blocking calls after this. Otherwise the 'age' will be wrong - - results = [] - for e in allowed_events: - results.append( - { - "rank": rank_map[e.event_id], - "result": self._event_serializer.serialize_event( - e, time_now, bundle_aggregations=aggregations - ), - "context": contexts.get(e.event_id, {}), - } - ) - - rooms_cat_res = { + rooms_cat_res: JsonDict = { "results": results, - "count": count, - "highlights": list(highlights), + "count": search_result.count, + "highlights": list(search_result.highlights), } if state_results: - s = {} - for room_id, state_events in state_results.items(): - s[room_id] = self._event_serializer.serialize_events( - state_events, time_now - ) + rooms_cat_res["state"] = { + room_id: self._event_serializer.serialize_events(state_events, time_now) + for room_id, state_events in state_results.items() + } - rooms_cat_res["state"] = s - - if room_groups and "room_id" in group_keys: - rooms_cat_res.setdefault("groups", {})["room_id"] = room_groups + if search_result.room_groups and "room_id" in group_keys: + rooms_cat_res.setdefault("groups", {})[ + "room_id" + ] = search_result.room_groups if sender_group and "sender" in group_keys: rooms_cat_res.setdefault("groups", {})["sender"] = sender_group @@ -491,3 +417,282 @@ class SearchHandler: rooms_cat_res["next_batch"] = global_next_batch return {"search_categories": {"room_events": rooms_cat_res}} + + async def _search_by_rank( + self, + user: UserID, + room_ids: Collection[str], + search_term: str, + keys: Iterable[str], + search_filter: Filter, + ) -> Tuple[_SearchResult, Dict[str, JsonDict]]: + """ + Performs a full text search for a user ordering by rank. + + Args: + user: The user performing the search. + room_ids: List of room ids to search in + search_term: Search term to search for + keys: List of keys to search in, currently supports + "content.body", "content.name", "content.topic" + search_filter: The event filter to use. + + Returns: + A tuple of: + The search results. + A map of sender ID to results. + """ + rank_map = {} # event_id -> rank of event + # Holds result of grouping by room, if applicable + room_groups: Dict[str, JsonDict] = {} + # Holds result of grouping by sender, if applicable + sender_group: Dict[str, JsonDict] = {} + + search_result = await self.store.search_msgs(room_ids, search_term, keys) + + if search_result["highlights"]: + highlights = search_result["highlights"] + else: + highlights = set() + + results = search_result["results"] + + # event_id -> rank of event + rank_map = {r["event"].event_id: r["rank"] for r in results} + + filtered_events = await search_filter.filter([r["event"] for r in results]) + + events = await filter_events_for_client( + self.storage, user.to_string(), filtered_events + ) + + events.sort(key=lambda e: -rank_map[e.event_id]) + allowed_events = events[: search_filter.limit] + + for e in allowed_events: + rm = room_groups.setdefault( + e.room_id, {"results": [], "order": rank_map[e.event_id]} + ) + rm["results"].append(e.event_id) + + s = sender_group.setdefault( + e.sender, {"results": [], "order": rank_map[e.event_id]} + ) + s["results"].append(e.event_id) + + return ( + _SearchResult( + search_result["count"], + rank_map, + allowed_events, + room_groups, + highlights, + ), + sender_group, + ) + + async def _search_by_recent( + self, + user: UserID, + room_ids: Collection[str], + search_term: str, + keys: Iterable[str], + search_filter: Filter, + batch_group: Optional[str], + batch_group_key: Optional[str], + batch_token: Optional[str], + ) -> Tuple[_SearchResult, Optional[str]]: + """ + Performs a full text search for a user ordering by recent. + + Args: + user: The user performing the search. + room_ids: List of room ids to search in + search_term: Search term to search for + keys: List of keys to search in, currently supports + "content.body", "content.name", "content.topic" + search_filter: The event filter to use. + batch_group: Pagination information. + batch_group_key: Pagination information. + batch_token: Pagination information. + + Returns: + A tuple of: + The search results. + Optionally, a pagination token. + """ + rank_map = {} # event_id -> rank of event + # Holds result of grouping by room, if applicable + room_groups: Dict[str, JsonDict] = {} + + # Holds the next_batch for the entire result set if one of those exists + global_next_batch = None + + highlights = set() + + room_events: List[EventBase] = [] + i = 0 + + pagination_token = batch_token + + # We keep looping and we keep filtering until we reach the limit + # or we run out of things. + # But only go around 5 times since otherwise synapse will be sad. + while len(room_events) < search_filter.limit and i < 5: + i += 1 + search_result = await self.store.search_rooms( + room_ids, + search_term, + keys, + search_filter.limit * 2, + pagination_token=pagination_token, + ) + + if search_result["highlights"]: + highlights.update(search_result["highlights"]) + + count = search_result["count"] + + results = search_result["results"] + + results_map = {r["event"].event_id: r for r in results} + + rank_map.update({r["event"].event_id: r["rank"] for r in results}) + + filtered_events = await search_filter.filter([r["event"] for r in results]) + + events = await filter_events_for_client( + self.storage, user.to_string(), filtered_events + ) + + room_events.extend(events) + room_events = room_events[: search_filter.limit] + + if len(results) < search_filter.limit * 2: + break + else: + pagination_token = results[-1]["pagination_token"] + + for event in room_events: + group = room_groups.setdefault(event.room_id, {"results": []}) + group["results"].append(event.event_id) + + if room_events and len(room_events) >= search_filter.limit: + last_event_id = room_events[-1].event_id + pagination_token = results_map[last_event_id]["pagination_token"] + + # We want to respect the given batch group and group keys so + # that if people blindly use the top level `next_batch` token + # it returns more from the same group (if applicable) rather + # than reverting to searching all results again. + if batch_group and batch_group_key: + global_next_batch = encode_base64( + ( + "%s\n%s\n%s" % (batch_group, batch_group_key, pagination_token) + ).encode("ascii") + ) + else: + global_next_batch = encode_base64( + ("%s\n%s\n%s" % ("all", "", pagination_token)).encode("ascii") + ) + + for room_id, group in room_groups.items(): + group["next_batch"] = encode_base64( + ("%s\n%s\n%s" % ("room_id", room_id, pagination_token)).encode( + "ascii" + ) + ) + + return ( + _SearchResult(count, rank_map, room_events, room_groups, highlights), + global_next_batch, + ) + + async def _calculate_event_contexts( + self, + user: UserID, + allowed_events: List[EventBase], + before_limit: int, + after_limit: int, + include_profile: bool, + ) -> Dict[str, JsonDict]: + """ + Calculates the contextual events for any search results. + + Args: + user: The user performing the search. + allowed_events: The search results. + before_limit: + The number of events before a result to include as context. + after_limit: + The number of events after a result to include as context. + include_profile: True if historical profile information should be + included in the event context. + + Returns: + A map of event ID to contextual information. + """ + now_token = self.hs.get_event_sources().get_current_token() + + contexts = {} + for event in allowed_events: + res = await self.store.get_events_around( + event.room_id, event.event_id, before_limit, after_limit + ) + + logger.info( + "Context for search returned %d and %d events", + len(res.events_before), + len(res.events_after), + ) + + events_before = await filter_events_for_client( + self.storage, user.to_string(), res.events_before + ) + + events_after = await filter_events_for_client( + self.storage, user.to_string(), res.events_after + ) + + context = { + "events_before": events_before, + "events_after": events_after, + "start": await now_token.copy_and_replace( + "room_key", res.start + ).to_string(self.store), + "end": await now_token.copy_and_replace("room_key", res.end).to_string( + self.store + ), + } + + if include_profile: + senders = { + ev.sender + for ev in itertools.chain(events_before, [event], events_after) + } + + if events_after: + last_event_id = events_after[-1].event_id + else: + last_event_id = event.event_id + + state_filter = StateFilter.from_types( + [(EventTypes.Member, sender) for sender in senders] + ) + + state = await self.state_store.get_state_for_event( + last_event_id, state_filter + ) + + context["profile_info"] = { + s.state_key: { + "displayname": s.content.get("displayname", None), + "avatar_url": s.content.get("avatar_url", None), + } + for s in state.values() + if s.type == EventTypes.Member and s.state_key in senders + } + + contexts[event.event_id] = context + + return contexts diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 2d085a576..acea300ed 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -28,6 +28,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer @@ -381,17 +382,19 @@ class SearchStore(SearchBackgroundUpdateStore): ): super().__init__(database, db_conn, hs) - async def search_msgs(self, room_ids, search_term, keys): + async def search_msgs( + self, room_ids: Collection[str], search_term: str, keys: Iterable[str] + ) -> JsonDict: """Performs a full text search over events with given keys. Args: - room_ids (list): List of room ids to search in - search_term (str): Search term to search for - keys (list): List of keys to search in, currently supports + room_ids: List of room ids to search in + search_term: Search term to search for + keys: List of keys to search in, currently supports "content.body", "content.name", "content.topic" Returns: - list of dicts + Dictionary of results """ clauses = [] @@ -499,10 +502,10 @@ class SearchStore(SearchBackgroundUpdateStore): self, room_ids: Collection[str], search_term: str, - keys: List[str], + keys: Iterable[str], limit, pagination_token: Optional[str] = None, - ) -> List[dict]: + ) -> JsonDict: """Performs a full text search over events with given keys. Args: From 5598556b776e3c8cc93b8ede4af9f2f5b21ff935 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 15 Feb 2022 13:59:15 +0000 Subject: [PATCH 05/84] Docker: remove `VOLUME` directive (#11997) The driver for this is to stop Complement complaining about it, but as far as I can tell it was pointless and needed to go away anyway. I'm a bit unclear about what exactly VOLUME does, but I think what it means is that, if you don't override it with an explicit -v argument, then docker run will create a temporary volume, and copy things into it. The temporary volume is then deleted when the container finishes. That only sounds useful if your image has something to copy into it (otherwise you may as well just use the default root filesystem), and our image notably doesn't copy anything into /data. So... this wasn't doing anything, except annoying Complement? --- changelog.d/11997.docker | 1 + docker/Dockerfile | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) create mode 100644 changelog.d/11997.docker diff --git a/changelog.d/11997.docker b/changelog.d/11997.docker new file mode 100644 index 000000000..1b3271457 --- /dev/null +++ b/changelog.d/11997.docker @@ -0,0 +1 @@ +The docker image no longer automatically creates a temporary volume at `/data`. This is not expected to affect normal usage. diff --git a/docker/Dockerfile b/docker/Dockerfile index 306f75ae5..e4c1c19b8 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -98,8 +98,6 @@ COPY --from=builder /install /usr/local COPY ./docker/start.py /start.py COPY ./docker/conf /conf -VOLUME ["/data"] - EXPOSE 8008/tcp 8009/tcp 8448/tcp ENTRYPOINT ["/start.py"] From dc9fe61050deb5eda59b161e4c0404ff36e3ac59 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 15 Feb 2022 14:26:28 +0000 Subject: [PATCH 06/84] Fix incorrect `get_rooms_for_user` for remote user (#11999) When the server leaves a room the `get_rooms_for_user` cache is not correctly invalidated for the remote users in the room. This means that subsequent calls to `get_rooms_for_user` for the remote users would incorrectly include the room (it shouldn't be included because the server no longer knows anything about the room). --- changelog.d/11999.bugfix | 1 + synapse/storage/databases/main/events.py | 27 +++--- tests/storage/test_events.py | 107 +++++++++++++++++++++++ 3 files changed, 124 insertions(+), 11 deletions(-) create mode 100644 changelog.d/11999.bugfix diff --git a/changelog.d/11999.bugfix b/changelog.d/11999.bugfix new file mode 100644 index 000000000..fd8409590 --- /dev/null +++ b/changelog.d/11999.bugfix @@ -0,0 +1 @@ +Fix long standing bug where `get_rooms_for_user` was not correctly invalidated for remote users when the server left a room. diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 5246fccad..a1d7a9b41 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -975,6 +975,17 @@ class PersistEventsStore: to_delete = delta_state.to_delete to_insert = delta_state.to_insert + # Figure out the changes of membership to invalidate the + # `get_rooms_for_user` cache. + # We find out which membership events we may have deleted + # and which we have added, then we invalidate the caches for all + # those users. + members_changed = { + state_key + for ev_type, state_key in itertools.chain(to_delete, to_insert) + if ev_type == EventTypes.Member + } + if delta_state.no_longer_in_room: # Server is no longer in the room so we delete the room from # current_state_events, being careful we've already updated the @@ -993,6 +1004,11 @@ class PersistEventsStore: """ txn.execute(sql, (stream_id, self._instance_name, room_id)) + # We also want to invalidate the membership caches for users + # that were in the room. + users_in_room = self.store.get_users_in_room_txn(txn, room_id) + members_changed.update(users_in_room) + self.db_pool.simple_delete_txn( txn, table="current_state_events", @@ -1102,17 +1118,6 @@ class PersistEventsStore: # Invalidate the various caches - # Figure out the changes of membership to invalidate the - # `get_rooms_for_user` cache. - # We find out which membership events we may have deleted - # and which we have added, then we invalidate the caches for all - # those users. - members_changed = { - state_key - for ev_type, state_key in itertools.chain(to_delete, to_insert) - if ev_type == EventTypes.Member - } - for member in members_changed: txn.call_after( self.store.get_rooms_for_user_with_stream_ordering.invalidate, diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index f462a8b1c..a8639d8f8 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -329,3 +329,110 @@ class ExtremPruneTestCase(HomeserverTestCase): # Check the new extremity is just the new remote event. self.assert_extremities([local_message_event_id, remote_event_2.event_id]) + + +class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.state = self.hs.get_state_handler() + self.persistence = self.hs.get_storage().persistence + self.store = self.hs.get_datastore() + + def test_remote_user_rooms_cache_invalidated(self): + """Test that if the server leaves a room the `get_rooms_for_user` cache + is invalidated for remote users. + """ + + # Set up a room with a local and remote user in it. + user_id = self.register_user("user", "pass") + token = self.login("user", "pass") + + room_id = self.helper.create_room_as( + "user", room_version=RoomVersions.V6.identifier, tok=token + ) + + body = self.helper.send(room_id, body="Test", tok=token) + local_message_event_id = body["event_id"] + + # Fudge a join event for a remote user. + remote_user = "@user:other" + remote_event_1 = event_from_pdu_json( + { + "type": EventTypes.Member, + "state_key": remote_user, + "content": {"membership": Membership.JOIN}, + "room_id": room_id, + "sender": remote_user, + "depth": 5, + "prev_events": [local_message_event_id], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + context = self.get_success(self.state.compute_event_context(remote_event_1)) + self.get_success(self.persistence.persist_event(remote_event_1, context)) + + # Call `get_rooms_for_user` to add the remote user to the cache + rooms = self.get_success(self.store.get_rooms_for_user(remote_user)) + self.assertEqual(set(rooms), {room_id}) + + # Now we have the local server leave the room, and check that calling + # `get_user_in_room` for the remote user no longer includes the room. + self.helper.leave(room_id, user_id, tok=token) + + rooms = self.get_success(self.store.get_rooms_for_user(remote_user)) + self.assertEqual(set(rooms), set()) + + def test_room_remote_user_cache_invalidated(self): + """Test that if the server leaves a room the `get_users_in_room` cache + is invalidated for remote users. + """ + + # Set up a room with a local and remote user in it. + user_id = self.register_user("user", "pass") + token = self.login("user", "pass") + + room_id = self.helper.create_room_as( + "user", room_version=RoomVersions.V6.identifier, tok=token + ) + + body = self.helper.send(room_id, body="Test", tok=token) + local_message_event_id = body["event_id"] + + # Fudge a join event for a remote user. + remote_user = "@user:other" + remote_event_1 = event_from_pdu_json( + { + "type": EventTypes.Member, + "state_key": remote_user, + "content": {"membership": Membership.JOIN}, + "room_id": room_id, + "sender": remote_user, + "depth": 5, + "prev_events": [local_message_event_id], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + context = self.get_success(self.state.compute_event_context(remote_event_1)) + self.get_success(self.persistence.persist_event(remote_event_1, context)) + + # Call `get_users_in_room` to add the remote user to the cache + users = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual(set(users), {user_id, remote_user}) + + # Now we have the local server leave the room, and check that calling + # `get_user_in_room` for the remote user no longer includes the room. + self.helper.leave(room_id, user_id, tok=token) + + users = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual(users, []) From 0dbbe33a65c17cdb1ad41d6109b5629029dce886 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 15 Feb 2022 14:31:04 +0000 Subject: [PATCH 07/84] Track cache invalidations (#12000) Currently we only track evictions due to size or time constraints. --- changelog.d/12000.feature | 1 + synapse/util/caches/__init__.py | 1 + synapse/util/caches/expiringcache.py | 5 +++++ synapse/util/caches/lrucache.py | 4 +++- 4 files changed, 10 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12000.feature diff --git a/changelog.d/12000.feature b/changelog.d/12000.feature new file mode 100644 index 000000000..246cc87f0 --- /dev/null +++ b/changelog.d/12000.feature @@ -0,0 +1 @@ +Track cache invalidations in Prometheus metrics, as already happens for cache eviction based on size or time. diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 15debd6c4..1cbc180ed 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -56,6 +56,7 @@ response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["n class EvictionReason(Enum): size = auto() time = auto() + invalidation = auto() @attr.s(slots=True, auto_attribs=True) diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 67ee4c693..c6a5d0dfc 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -133,6 +133,11 @@ class ExpiringCache(Generic[KT, VT]): raise KeyError(key) return default + if self.iterable: + self.metrics.inc_evictions(EvictionReason.invalidation, len(value.value)) + else: + self.metrics.inc_evictions(EvictionReason.invalidation) + return value.value def __contains__(self, key: KT) -> bool: diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 7548b3854..45ff0de63 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -560,8 +560,10 @@ class LruCache(Generic[KT, VT]): def cache_pop(key: KT, default: Optional[T] = None) -> Union[None, T, VT]: node = cache.get(key, None) if node: - delete_node(node) + evicted_len = delete_node(node) cache.pop(node.key, None) + if metrics: + metrics.inc_evictions(EvictionReason.invalidation, evicted_len) return node.value else: return default From bab2394aa9b514f51feb2c378a39d92143e9cec8 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 15 Feb 2022 14:33:28 +0000 Subject: [PATCH 08/84] `_auth_and_persist_outliers`: drop events we have already seen (#11994) We already have two copies of this code, in 2/3 of the callers of `_auth_and_persist_outliers`. Before I add a third, let's push it down. --- changelog.d/11994.misc | 1 + synapse/handlers/federation_event.py | 44 +++++++++++++--------------- 2 files changed, 21 insertions(+), 24 deletions(-) create mode 100644 changelog.d/11994.misc diff --git a/changelog.d/11994.misc b/changelog.d/11994.misc new file mode 100644 index 000000000..d64297dd7 --- /dev/null +++ b/changelog.d/11994.misc @@ -0,0 +1 @@ +Move common deduplication code down into `_auth_and_persist_outliers`. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 9edc7369d..6dc27a38f 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -419,8 +419,6 @@ class FederationEventHandler: Raises: SynapseError if the response is in some way invalid. """ - event_map = {e.event_id: e for e in itertools.chain(auth_events, state)} - create_event = None for e in auth_events: if (e.type, e.state_key) == (EventTypes.Create, ""): @@ -439,11 +437,6 @@ class FederationEventHandler: if room_version.identifier != room_version_id: raise SynapseError(400, "Room version mismatch") - # filter out any events we have already seen - seen_remotes = await self._store.have_seen_events(room_id, event_map.keys()) - for s in seen_remotes: - event_map.pop(s, None) - # persist the auth chain and state events. # # any invalid events here will be marked as rejected, and we'll carry on. @@ -455,7 +448,9 @@ class FederationEventHandler: # signatures right now doesn't mean that we will *never* be able to, so it # is premature to reject them. # - await self._auth_and_persist_outliers(room_id, event_map.values()) + await self._auth_and_persist_outliers( + room_id, itertools.chain(auth_events, state) + ) # and now persist the join event itself. logger.info("Peristing join-via-remote %s", event) @@ -1245,6 +1240,16 @@ class FederationEventHandler: """ event_map = {event.event_id: event for event in events} + # filter out any events we have already seen. This might happen because + # the events were eagerly pushed to us (eg, during a room join), or because + # another thread has raced against us since we decided to request the event. + # + # This is just an optimisation, so it doesn't need to be watertight - the event + # persister does another round of deduplication. + seen_remotes = await self._store.have_seen_events(room_id, event_map.keys()) + for s in seen_remotes: + event_map.pop(s, None) + # XXX: it might be possible to kick this process off in parallel with fetching # the events. while event_map: @@ -1717,31 +1722,22 @@ class FederationEventHandler: event_id: the event for which we are lacking auth events """ try: - remote_event_map = { - e.event_id: e - for e in await self._federation_client.get_event_auth( - destination, room_id, event_id - ) - } + remote_events = await self._federation_client.get_event_auth( + destination, room_id, event_id + ) + except RequestSendFailed as e1: # The other side isn't around or doesn't implement the # endpoint, so lets just bail out. logger.info("Failed to get event auth from remote: %s", e1) return - logger.info("/event_auth returned %i events", len(remote_event_map)) + logger.info("/event_auth returned %i events", len(remote_events)) # `event` may be returned, but we should not yet process it. - remote_event_map.pop(event_id, None) + remote_auth_events = (e for e in remote_events if e.event_id != event_id) - # nor should we reprocess any events we have already seen. - seen_remotes = await self._store.have_seen_events( - room_id, remote_event_map.keys() - ) - for s in seen_remotes: - remote_event_map.pop(s, None) - - await self._auth_and_persist_outliers(room_id, remote_event_map.values()) + await self._auth_and_persist_outliers(room_id, remote_auth_events) async def _update_context_for_auth_events( self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase] From 2b5643b3afa4cddc7809c8b51fe813d2f0987235 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 15 Feb 2022 15:01:00 +0000 Subject: [PATCH 09/84] Optimise calculating device_list changes in `/sync`. (#11974) For users with large accounts it is inefficient to calculate the set of users they share a room with (and takes a lot of space in the cache). Instead we can look at users whose devices have changed since the last sync and check if they share a room with the syncing user. --- changelog.d/11974.misc | 1 + synapse/handlers/sync.py | 66 +++++++++++++++----- synapse/storage/databases/main/devices.py | 10 +++ synapse/storage/databases/main/roommember.py | 62 ++++++++++++++++++ 4 files changed, 125 insertions(+), 14 deletions(-) create mode 100644 changelog.d/11974.misc diff --git a/changelog.d/11974.misc b/changelog.d/11974.misc new file mode 100644 index 000000000..1debad236 --- /dev/null +++ b/changelog.d/11974.misc @@ -0,0 +1 @@ +Optimise calculating device_list changes in `/sync`. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index aa9a76f8a..e6050cbce 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1289,23 +1289,54 @@ class SyncHandler: # room with by looking at all users that have left a room plus users # that were in a room we've left. - users_who_share_room = await self.store.get_users_who_share_room_with_user( - user_id + users_that_have_changed = set() + + joined_rooms = sync_result_builder.joined_room_ids + + # Step 1a, check for changes in devices of users we share a room + # with + # + # We do this in two different ways depending on what we have cached. + # If we already have a list of all the user that have changed since + # the last sync then it's likely more efficient to compare the rooms + # they're in with the rooms the syncing user is in. + # + # If we don't have that info cached then we get all the users that + # share a room with our user and check if those users have changed. + changed_users = self.store.get_cached_device_list_changes( + since_token.device_list_key ) + if changed_users is not None: + result = await self.store.get_rooms_for_users_with_stream_ordering( + changed_users + ) - # Always tell the user about their own devices. We check as the user - # ID is almost certainly already included (unless they're not in any - # rooms) and taking a copy of the set is relatively expensive. - if user_id not in users_who_share_room: - users_who_share_room = set(users_who_share_room) - users_who_share_room.add(user_id) + for changed_user_id, entries in result.items(): + # Check if the changed user shares any rooms with the user, + # or if the changed user is the syncing user (as we always + # want to include device list updates of their own devices). + if user_id == changed_user_id or any( + e.room_id in joined_rooms for e in entries + ): + users_that_have_changed.add(changed_user_id) + else: + users_who_share_room = ( + await self.store.get_users_who_share_room_with_user(user_id) + ) - tracked_users = users_who_share_room + # Always tell the user about their own devices. We check as the user + # ID is almost certainly already included (unless they're not in any + # rooms) and taking a copy of the set is relatively expensive. + if user_id not in users_who_share_room: + users_who_share_room = set(users_who_share_room) + users_who_share_room.add(user_id) - # Step 1a, check for changes in devices of users we share a room with - users_that_have_changed = await self.store.get_users_whose_devices_changed( - since_token.device_list_key, tracked_users - ) + tracked_users = users_who_share_room + users_that_have_changed = ( + await self.store.get_users_whose_devices_changed( + since_token.device_list_key, tracked_users + ) + ) # Step 1b, check for newly joined rooms for room_id in newly_joined_rooms: @@ -1329,7 +1360,14 @@ class SyncHandler: newly_left_users.update(left_users) # Remove any users that we still share a room with. - newly_left_users -= users_who_share_room + left_users_rooms = ( + await self.store.get_rooms_for_users_with_stream_ordering( + newly_left_users + ) + ) + for user_id, entries in left_users_rooms.items(): + if any(e.room_id in joined_rooms for e in entries): + newly_left_users.discard(user_id) return DeviceLists(changed=users_that_have_changed, left=newly_left_users) else: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 8d845fe95..3b3a089b7 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -670,6 +670,16 @@ class DeviceWorkerStore(SQLBaseStore): device["device_id"]: db_to_json(device["content"]) for device in devices } + def get_cached_device_list_changes( + self, + from_key: int, + ) -> Optional[Set[str]]: + """Get set of users whose devices have changed since `from_key`, or None + if that information is not in our cache. + """ + + return self._device_list_stream_cache.get_all_entities_changed(from_key) + async def get_users_whose_devices_changed( self, from_key: int, user_ids: Iterable[str] ) -> Set[str]: diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 4489732fd..e48ec5f49 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -504,6 +504,68 @@ class RoomMemberWorkerStore(EventsWorkerStore): for room_id, instance, stream_id in txn ) + @cachedList( + cached_method_name="get_rooms_for_user_with_stream_ordering", + list_name="user_ids", + ) + async def get_rooms_for_users_with_stream_ordering( + self, user_ids: Collection[str] + ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]: + """A batched version of `get_rooms_for_user_with_stream_ordering`. + + Returns: + Map from user_id to set of rooms that is currently in. + """ + return await self.db_pool.runInteraction( + "get_rooms_for_users_with_stream_ordering", + self._get_rooms_for_users_with_stream_ordering_txn, + user_ids, + ) + + def _get_rooms_for_users_with_stream_ordering_txn( + self, txn, user_ids: Collection[str] + ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]: + + clause, args = make_in_list_sql_clause( + self.database_engine, + "c.state_key", + user_ids, + ) + + if self._current_state_events_membership_up_to_date: + sql = f""" + SELECT c.state_key, room_id, e.instance_name, e.stream_ordering + FROM current_state_events AS c + INNER JOIN events AS e USING (room_id, event_id) + WHERE + c.type = 'm.room.member' + AND c.membership = ? + AND {clause} + """ + else: + sql = f""" + SELECT c.state_key, room_id, e.instance_name, e.stream_ordering + FROM current_state_events AS c + INNER JOIN room_memberships AS m USING (room_id, event_id) + INNER JOIN events AS e USING (room_id, event_id) + WHERE + c.type = 'm.room.member' + AND m.membership = ? + AND {clause} + """ + + txn.execute(sql, [Membership.JOIN] + args) + + result = {user_id: set() for user_id in user_ids} + for user_id, room_id, instance, stream_id in txn: + result[user_id].add( + GetRoomsForUserWithStreamOrdering( + room_id, PersistedEventPosition(instance, stream_id) + ) + ) + + return {user_id: frozenset(v) for user_id, v in result.items()} + async def get_users_server_still_shares_room_with( self, user_ids: Collection[str] ) -> Set[str]: From 130fd45393d39425a4aa35a388b1e5a74ba6f6f4 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 16 Feb 2022 12:16:48 +0100 Subject: [PATCH 10/84] Limit concurrent AS joins (#11996) Initially introduced in matrix-org-hotfixes by e5537cf (and tweaked by later commits). Fixes #11995 See also #4826 --- changelog.d/11996.misc | 1 + synapse/handlers/room_member.py | 46 +++++++++++++++++++-------------- 2 files changed, 28 insertions(+), 19 deletions(-) create mode 100644 changelog.d/11996.misc diff --git a/changelog.d/11996.misc b/changelog.d/11996.misc new file mode 100644 index 000000000..6c675fd19 --- /dev/null +++ b/changelog.d/11996.misc @@ -0,0 +1 @@ +Limit concurrent joins from applications services. \ No newline at end of file diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index bf1a47efb..b2adc0f48 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -82,6 +82,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.event_auth_handler = hs.get_event_auth_handler() self.member_linearizer: Linearizer = Linearizer(name="member") + self.member_as_limiter = Linearizer(max_count=10, name="member_as_limiter") self.clock = hs.get_clock() self.spam_checker = hs.get_spam_checker() @@ -500,25 +501,32 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): key = (room_id,) - with (await self.member_linearizer.queue(key)): - result = await self.update_membership_locked( - requester, - target, - room_id, - action, - txn_id=txn_id, - remote_room_hosts=remote_room_hosts, - third_party_signed=third_party_signed, - ratelimit=ratelimit, - content=content, - new_room=new_room, - require_consent=require_consent, - outlier=outlier, - historical=historical, - allow_no_prev_events=allow_no_prev_events, - prev_event_ids=prev_event_ids, - auth_event_ids=auth_event_ids, - ) + as_id = object() + if requester.app_service: + as_id = requester.app_service.id + + # We first linearise by the application service (to try to limit concurrent joins + # by application services), and then by room ID. + with (await self.member_as_limiter.queue(as_id)): + with (await self.member_linearizer.queue(key)): + result = await self.update_membership_locked( + requester, + target, + room_id, + action, + txn_id=txn_id, + remote_room_hosts=remote_room_hosts, + third_party_signed=third_party_signed, + ratelimit=ratelimit, + content=content, + new_room=new_room, + require_consent=require_consent, + outlier=outlier, + historical=historical, + allow_no_prev_events=allow_no_prev_events, + prev_event_ids=prev_event_ids, + auth_event_ids=auth_event_ids, + ) return result From 7a92d68441f8182c6eac467ff2c23e71112659f0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 16 Feb 2022 06:53:21 -0500 Subject: [PATCH 11/84] Fix a typo in a comment. --- synapse/storage/databases/main/relations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 5582029f9..36aa1092f 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -489,7 +489,7 @@ class RelationsWorkerStore(SQLBaseStore): # TODO Should this only allow m.room.message events. if isinstance(self.database_engine, PostgresEngine): # The `DISTINCT ON` clause will pick the *first* row it encounters, - # so ordering by topologica ordering + stream ordering desc will + # so ordering by topological ordering + stream ordering desc will # ensure we get the latest event in the thread. sql = """ SELECT DISTINCT ON (parent.event_id) parent.event_id, child.event_id FROM events AS child From 73fc4887834b33afaf23ff48a68772f8ef9b924c Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 16 Feb 2022 12:25:43 +0000 Subject: [PATCH 12/84] Explain the meaning of spam checker callbacks' return values (#12003) Co-authored-by: Patrick Cloke --- changelog.d/12003.doc | 1 + docs/modules/spam_checker_callbacks.md | 40 +++++++++++++++++--------- 2 files changed, 28 insertions(+), 13 deletions(-) create mode 100644 changelog.d/12003.doc diff --git a/changelog.d/12003.doc b/changelog.d/12003.doc new file mode 100644 index 000000000..1ac816355 --- /dev/null +++ b/changelog.d/12003.doc @@ -0,0 +1 @@ +Explain the meaning of spam checker callbacks' return values. diff --git a/docs/modules/spam_checker_callbacks.md b/docs/modules/spam_checker_callbacks.md index 2eb9032f4..2b672b78f 100644 --- a/docs/modules/spam_checker_callbacks.md +++ b/docs/modules/spam_checker_callbacks.md @@ -16,10 +16,12 @@ _First introduced in Synapse v1.37.0_ async def check_event_for_spam(event: "synapse.events.EventBase") -> Union[bool, str] ``` -Called when receiving an event from a client or via federation. The module can return -either a `bool` to indicate whether the event must be rejected because of spam, or a `str` -to indicate the event must be rejected because of spam and to give a rejection reason to -forward to clients. +Called when receiving an event from a client or via federation. The callback must return +either: +- an error message string, to indicate the event must be rejected because of spam and + give a rejection reason to forward to clients; +- the boolean `True`, to indicate that the event is spammy, but not provide further details; or +- the booelan `False`, to indicate that the event is not considered spammy. If multiple modules implement this callback, they will be considered in order. If a callback returns `False`, Synapse falls through to the next one. The value of the first @@ -35,7 +37,10 @@ async def user_may_join_room(user: str, room: str, is_invited: bool) -> bool ``` Called when a user is trying to join a room. The module must return a `bool` to indicate -whether the user can join the room. The user is represented by their Matrix user ID (e.g. +whether the user can join the room. Return `False` to prevent the user from joining the +room; otherwise return `True` to permit the joining. + +The user is represented by their Matrix user ID (e.g. `@alice:example.com`) and the room is represented by its Matrix ID (e.g. `!room:example.com`). The module is also given a boolean to indicate whether the user currently has a pending invite in the room. @@ -58,7 +63,8 @@ async def user_may_invite(inviter: str, invitee: str, room_id: str) -> bool Called when processing an invitation. The module must return a `bool` indicating whether the inviter can invite the invitee to the given room. Both inviter and invitee are -represented by their Matrix user ID (e.g. `@alice:example.com`). +represented by their Matrix user ID (e.g. `@alice:example.com`). Return `False` to prevent +the invitation; otherwise return `True` to permit it. If multiple modules implement this callback, they will be considered in order. If a callback returns `True`, Synapse falls through to the next one. The value of the first @@ -80,7 +86,8 @@ async def user_may_send_3pid_invite( Called when processing an invitation using a third-party identifier (also called a 3PID, e.g. an email address or a phone number). The module must return a `bool` indicating -whether the inviter can invite the invitee to the given room. +whether the inviter can invite the invitee to the given room. Return `False` to prevent +the invitation; otherwise return `True` to permit it. The inviter is represented by their Matrix user ID (e.g. `@alice:example.com`), and the invitee is represented by its medium (e.g. "email") and its address @@ -117,6 +124,7 @@ async def user_may_create_room(user: str) -> bool Called when processing a room creation request. The module must return a `bool` indicating whether the given user (represented by their Matrix user ID) is allowed to create a room. +Return `False` to prevent room creation; otherwise return `True` to permit it. If multiple modules implement this callback, they will be considered in order. If a callback returns `True`, Synapse falls through to the next one. The value of the first @@ -133,7 +141,8 @@ async def user_may_create_room_alias(user: str, room_alias: "synapse.types.RoomA Called when trying to associate an alias with an existing room. The module must return a `bool` indicating whether the given user (represented by their Matrix user ID) is allowed -to set the given alias. +to set the given alias. Return `False` to prevent the alias creation; otherwise return +`True` to permit it. If multiple modules implement this callback, they will be considered in order. If a callback returns `True`, Synapse falls through to the next one. The value of the first @@ -150,7 +159,8 @@ async def user_may_publish_room(user: str, room_id: str) -> bool Called when trying to publish a room to the homeserver's public rooms directory. The module must return a `bool` indicating whether the given user (represented by their -Matrix user ID) is allowed to publish the given room. +Matrix user ID) is allowed to publish the given room. Return `False` to prevent the +room from being published; otherwise return `True` to permit its publication. If multiple modules implement this callback, they will be considered in order. If a callback returns `True`, Synapse falls through to the next one. The value of the first @@ -166,8 +176,11 @@ async def check_username_for_spam(user_profile: Dict[str, str]) -> bool ``` Called when computing search results in the user directory. The module must return a -`bool` indicating whether the given user profile can appear in search results. The profile -is represented as a dictionary with the following keys: +`bool` indicating whether the given user should be excluded from user directory +searches. Return `True` to indicate that the user is spammy and exclude them from +search results; otherwise return `False`. + +The profile is represented as a dictionary with the following keys: * `user_id`: The Matrix ID for this user. * `display_name`: The user's display name. @@ -225,8 +238,9 @@ async def check_media_file_for_spam( ) -> bool ``` -Called when storing a local or remote file. The module must return a boolean indicating -whether the given file can be stored in the homeserver's media store. +Called when storing a local or remote file. The module must return a `bool` indicating +whether the given file should be excluded from the homeserver's media store. Return +`True` to prevent this file from being stored; otherwise return `False`. If multiple modules implement this callback, they will be considered in order. If a callback returns `False`, Synapse falls through to the next one. The value of the first From 40771773909cb03d9296e3f0505e4e32372f10aa Mon Sep 17 00:00:00 2001 From: lukasdenk <63459921+lukasdenk@users.noreply.github.com> Date: Thu, 17 Feb 2022 11:23:54 +0100 Subject: [PATCH 13/84] Prevent duplicate push notifications for room reads (#11835) --- changelog.d/11835.feature | 1 + synapse/push/httppusher.py | 7 +- tests/push/test_http.py | 133 ++++++++++++++++++------------------- 3 files changed, 73 insertions(+), 68 deletions(-) create mode 100644 changelog.d/11835.feature diff --git a/changelog.d/11835.feature b/changelog.d/11835.feature new file mode 100644 index 000000000..7cee39b08 --- /dev/null +++ b/changelog.d/11835.feature @@ -0,0 +1 @@ +Make a `POST` to `/rooms//receipt/m.read/` only trigger a push notification if the count of unread messages is different to the one in the last successfully sent push. diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 96559081d..49bcc06e0 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -109,6 +109,7 @@ class HttpPusher(Pusher): self.data_minus_url = {} self.data_minus_url.update(self.data) del self.data_minus_url["url"] + self.badge_count_last_call: Optional[int] = None def on_started(self, should_check_for_notifs: bool) -> None: """Called when this pusher has been started. @@ -136,7 +137,9 @@ class HttpPusher(Pusher): self.user_id, group_by_room=self._group_unread_count_by_room, ) - await self._send_badge(badge) + if self.badge_count_last_call is None or self.badge_count_last_call != badge: + self.badge_count_last_call = badge + await self._send_badge(badge) def on_timer(self) -> None: self._start_processing() @@ -402,6 +405,8 @@ class HttpPusher(Pusher): rejected = [] if "rejected" in resp: rejected = resp["rejected"] + else: + self.badge_count_last_call = badge return rejected async def _send_badge(self, badge: int) -> None: diff --git a/tests/push/test_http.py b/tests/push/test_http.py index c068d329a..e1e3fb97c 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -571,9 +571,7 @@ class HTTPPusherTests(HomeserverTestCase): # Carry out our option-value specific test # # This push should still only contain an unread count of 1 (for 1 unread room) - self.assertEqual( - self.push_attempts[5][2]["notification"]["counts"]["unread"], 1 - ) + self._check_push_attempt(6, 1) @override_config({"push": {"group_unread_count_by_room": False}}) def test_push_unread_count_message_count(self): @@ -585,11 +583,9 @@ class HTTPPusherTests(HomeserverTestCase): # Carry out our option-value specific test # - # We're counting every unread message, so there should now be 4 since the + # We're counting every unread message, so there should now be 3 since the # last read receipt - self.assertEqual( - self.push_attempts[5][2]["notification"]["counts"]["unread"], 4 - ) + self._check_push_attempt(6, 3) def _test_push_unread_count(self): """ @@ -597,8 +593,9 @@ class HTTPPusherTests(HomeserverTestCase): Note that: * Sending messages will cause push notifications to go out to relevant users - * Sending a read receipt will cause a "badge update" notification to go out to - the user that sent the receipt + * Sending a read receipt will cause the HTTP pusher to check whether the unread + count has changed since the last push notification. If so, a "badge update" + notification goes out to the user that sent the receipt """ # Register the user who gets notified user_id = self.register_user("user", "pass") @@ -642,24 +639,74 @@ class HTTPPusherTests(HomeserverTestCase): # position in the room. We'll set the read position to this event in a moment first_message_event_id = response["event_id"] - # Advance time a bit (so the pusher will register something has happened) and - # make the push succeed - self.push_attempts[0][0].callback({}) - self.pump() + expected_push_attempts = 1 + self._check_push_attempt(expected_push_attempts, 0) - # Check our push made it + self._send_read_request(access_token, first_message_event_id, room_id) + + # Unread count has not changed. Therefore, ensure that read request does not + # trigger a push notification. self.assertEqual(len(self.push_attempts), 1) - self.assertEqual( - self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify" - ) + # Send another message + response2 = self.helper.send( + room_id, body="How's the weather today?", tok=other_access_token + ) + second_message_event_id = response2["event_id"] + + expected_push_attempts += 1 + + self._check_push_attempt(expected_push_attempts, 1) + + self._send_read_request(access_token, second_message_event_id, room_id) + expected_push_attempts += 1 + + self._check_push_attempt(expected_push_attempts, 0) + + # If we're grouping by room, sending more messages shouldn't increase the + # unread count, as they're all being sent in the same room. Otherwise, it + # should. Therefore, the last call to _check_push_attempt is done in the + # caller method. + self.helper.send(room_id, body="Hello?", tok=other_access_token) + expected_push_attempts += 1 + + self._advance_time_and_make_push_succeed(expected_push_attempts) + + self.helper.send(room_id, body="Hello??", tok=other_access_token) + expected_push_attempts += 1 + + self._advance_time_and_make_push_succeed(expected_push_attempts) + + self.helper.send(room_id, body="HELLO???", tok=other_access_token) + + def _advance_time_and_make_push_succeed(self, expected_push_attempts): + self.pump() + self.push_attempts[expected_push_attempts - 1][0].callback({}) + + def _check_push_attempt( + self, expected_push_attempts: int, expected_unread_count_last_push: int + ) -> None: + """ + Makes sure that the last expected push attempt succeeds and checks whether + it contains the expected unread count. + """ + self._advance_time_and_make_push_succeed(expected_push_attempts) + # Check our push made it + self.assertEqual(len(self.push_attempts), expected_push_attempts) + _, push_url, push_body = self.push_attempts[expected_push_attempts - 1] + self.assertEqual( + push_url, + "http://example.com/_matrix/push/v1/notify", + ) # Check that the unread count for the room is 0 # # The unread count is zero as the user has no read receipt in the room yet self.assertEqual( - self.push_attempts[0][2]["notification"]["counts"]["unread"], 0 + push_body["notification"]["counts"]["unread"], + expected_unread_count_last_push, ) + def _send_read_request(self, access_token, message_event_id, room_id): # Now set the user's read receipt position to the first event # # This will actually trigger a new notification to be sent out so that @@ -667,56 +714,8 @@ class HTTPPusherTests(HomeserverTestCase): # count goes down channel = self.make_request( "POST", - "/rooms/%s/receipt/m.read/%s" % (room_id, first_message_event_id), + "/rooms/%s/receipt/m.read/%s" % (room_id, message_event_id), {}, access_token=access_token, ) self.assertEqual(channel.code, 200, channel.json_body) - - # Advance time and make the push succeed - self.push_attempts[1][0].callback({}) - self.pump() - - # Unread count is still zero as we've read the only message in the room - self.assertEqual(len(self.push_attempts), 2) - self.assertEqual( - self.push_attempts[1][2]["notification"]["counts"]["unread"], 0 - ) - - # Send another message - self.helper.send( - room_id, body="How's the weather today?", tok=other_access_token - ) - - # Advance time and make the push succeed - self.push_attempts[2][0].callback({}) - self.pump() - - # This push should contain an unread count of 1 as there's now been one - # message since our last read receipt - self.assertEqual(len(self.push_attempts), 3) - self.assertEqual( - self.push_attempts[2][2]["notification"]["counts"]["unread"], 1 - ) - - # Since we're grouping by room, sending more messages shouldn't increase the - # unread count, as they're all being sent in the same room - self.helper.send(room_id, body="Hello?", tok=other_access_token) - - # Advance time and make the push succeed - self.pump() - self.push_attempts[3][0].callback({}) - - self.helper.send(room_id, body="Hello??", tok=other_access_token) - - # Advance time and make the push succeed - self.pump() - self.push_attempts[4][0].callback({}) - - self.helper.send(room_id, body="HELLO???", tok=other_access_token) - - # Advance time and make the push succeed - self.pump() - self.push_attempts[5][0].callback({}) - - self.assertEqual(len(self.push_attempts), 6) From 696acd35151ff32bb69e555baee4e584c504d4d6 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 17 Feb 2022 11:59:26 +0000 Subject: [PATCH 14/84] `send_join` response: get create event from `state`, not `auth_chain` (#12005) msc3706 proposes changing the `/send_join` response: > Any events returned within `state` can be omitted from `auth_chain`. Currently, we rely on `m.room.create` being returned in `auth_chain`, but since the `m.room.create` event must necessarily be part of the state, the above change will break this. In short, let's look for `m.room.create` in `state` rather than `auth_chain`. --- changelog.d/12005.misc | 1 + synapse/handlers/federation_event.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12005.misc diff --git a/changelog.d/12005.misc b/changelog.d/12005.misc new file mode 100644 index 000000000..45e21dbe5 --- /dev/null +++ b/changelog.d/12005.misc @@ -0,0 +1 @@ +Preparation for faster-room-join work: when parsing the `send_join` response, get the `m.room.create` event from `state`, not `auth_chain`. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 6dc27a38f..7683246be 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -420,7 +420,7 @@ class FederationEventHandler: SynapseError if the response is in some way invalid. """ create_event = None - for e in auth_events: + for e in state: if (e.type, e.state_key) == (EventTypes.Create, ""): create_event = e break From e69f8f0a8e3b3f3b647efc42fe67b6b9270d26e0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 17 Feb 2022 08:32:18 -0500 Subject: [PATCH 15/84] Remove support for the legacy structured logging configuration. (#12008) --- changelog.d/12008.removal | 1 + docs/structured_logging.md | 12 +-- docs/upgrade.md | 9 ++ synapse/config/logger.py | 12 ++- synapse/logging/_structured.py | 163 --------------------------------- 5 files changed, 23 insertions(+), 174 deletions(-) create mode 100644 changelog.d/12008.removal delete mode 100644 synapse/logging/_structured.py diff --git a/changelog.d/12008.removal b/changelog.d/12008.removal new file mode 100644 index 000000000..57599d9ee --- /dev/null +++ b/changelog.d/12008.removal @@ -0,0 +1 @@ +Remove support for the legacy structured logging configuration (please see the the [upgrade notes](https://matrix-org.github.io/synapse/develop/upgrade#legacy-structured-logging-configuration-removal) if you are using `structured: true` in the Synapse configuration). diff --git a/docs/structured_logging.md b/docs/structured_logging.md index 14db85f58..805c86765 100644 --- a/docs/structured_logging.md +++ b/docs/structured_logging.md @@ -81,14 +81,12 @@ remote endpoint at 10.1.2.3:9999. ## Upgrading from legacy structured logging configuration -Versions of Synapse prior to v1.23.0 included a custom structured logging -configuration which is deprecated. It used a `structured: true` flag and -configured `drains` instead of ``handlers`` and `formatters`. +Versions of Synapse prior to v1.54.0 automatically converted the legacy +structured logging configuration, which was deprecated in v1.23.0, to the standard +library logging configuration. -Synapse currently automatically converts the old configuration to the new -configuration, but this will be removed in a future version of Synapse. The -following reference can be used to update your configuration. Based on the drain -`type`, we can pick a new handler: +The following reference can be used to update your configuration. Based on the +drain `type`, we can pick a new handler: 1. For a type of `console`, `console_json`, or `console_json_terse`: a handler with a class of `logging.StreamHandler` and a `stream` of `ext://sys.stdout` diff --git a/docs/upgrade.md b/docs/upgrade.md index 477d7d0e8..9860ae97b 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -85,6 +85,15 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` +# Upgrading to v1.54.0 + +## Legacy structured logging configuration removal + +This release removes support for the `structured: true` logging configuration +which was deprecated in Synapse v1.23.0. If your logging configuration contains +`structured: true` then it should be modified based on the +[structured logging documentation](structured_logging.md). + # Upgrading to v1.53.0 ## Dropping support for `webclient` listeners and non-HTTP(S) `web_client_location` diff --git a/synapse/config/logger.py b/synapse/config/logger.py index b7145a44a..cbbe22196 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -33,7 +33,6 @@ from twisted.logger import ( globalLogBeginner, ) -from synapse.logging._structured import setup_structured_logging from synapse.logging.context import LoggingContextFilter from synapse.logging.filter import MetadataFilter @@ -138,6 +137,12 @@ Support for the log_file configuration option and --log-file command-line option removed in Synapse 1.3.0. You should instead set up a separate log configuration file. """ +STRUCTURED_ERROR = """\ +Support for the structured configuration option was removed in Synapse 1.54.0. +You should instead use the standard logging configuration. See +https://matrix-org.github.io/synapse/v1.54/structured_logging.html +""" + class LoggingConfig(Config): section = "logging" @@ -292,10 +297,9 @@ def _load_logging_config(log_config_path: str) -> None: if not log_config: logging.warning("Loaded a blank logging config?") - # If the old structured logging configuration is being used, convert it to - # the new style configuration. + # If the old structured logging configuration is being used, raise an error. if "structured" in log_config and log_config.get("structured"): - log_config = setup_structured_logging(log_config) + raise ConfigError(STRUCTURED_ERROR) logging.config.dictConfig(log_config) diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py deleted file mode 100644 index b9933a152..000000000 --- a/synapse/logging/_structured.py +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os.path -from typing import Any, Dict, Generator, Optional, Tuple - -from constantly import NamedConstant, Names - -from synapse.config._base import ConfigError - - -class DrainType(Names): - CONSOLE = NamedConstant() - CONSOLE_JSON = NamedConstant() - CONSOLE_JSON_TERSE = NamedConstant() - FILE = NamedConstant() - FILE_JSON = NamedConstant() - NETWORK_JSON_TERSE = NamedConstant() - - -DEFAULT_LOGGERS = {"synapse": {"level": "info"}} - - -def parse_drain_configs( - drains: dict, -) -> Generator[Tuple[str, Dict[str, Any]], None, None]: - """ - Parse the drain configurations. - - Args: - drains (dict): A list of drain configurations. - - Yields: - dict instances representing a logging handler. - - Raises: - ConfigError: If any of the drain configuration items are invalid. - """ - - for name, config in drains.items(): - if "type" not in config: - raise ConfigError("Logging drains require a 'type' key.") - - try: - logging_type = DrainType.lookupByName(config["type"].upper()) - except ValueError: - raise ConfigError( - "%s is not a known logging drain type." % (config["type"],) - ) - - # Either use the default formatter or the tersejson one. - if logging_type in ( - DrainType.CONSOLE_JSON, - DrainType.FILE_JSON, - ): - formatter: Optional[str] = "json" - elif logging_type in ( - DrainType.CONSOLE_JSON_TERSE, - DrainType.NETWORK_JSON_TERSE, - ): - formatter = "tersejson" - else: - # A formatter of None implies using the default formatter. - formatter = None - - if logging_type in [ - DrainType.CONSOLE, - DrainType.CONSOLE_JSON, - DrainType.CONSOLE_JSON_TERSE, - ]: - location = config.get("location") - if location is None or location not in ["stdout", "stderr"]: - raise ConfigError( - ( - "The %s drain needs the 'location' key set to " - "either 'stdout' or 'stderr'." - ) - % (logging_type,) - ) - - yield name, { - "class": "logging.StreamHandler", - "formatter": formatter, - "stream": "ext://sys." + location, - } - - elif logging_type in [DrainType.FILE, DrainType.FILE_JSON]: - if "location" not in config: - raise ConfigError( - "The %s drain needs the 'location' key set." % (logging_type,) - ) - - location = config.get("location") - if os.path.abspath(location) != location: - raise ConfigError( - "File paths need to be absolute, '%s' is a relative path" - % (location,) - ) - - yield name, { - "class": "logging.FileHandler", - "formatter": formatter, - "filename": location, - } - - elif logging_type in [DrainType.NETWORK_JSON_TERSE]: - host = config.get("host") - port = config.get("port") - maximum_buffer = config.get("maximum_buffer", 1000) - - yield name, { - "class": "synapse.logging.RemoteHandler", - "formatter": formatter, - "host": host, - "port": port, - "maximum_buffer": maximum_buffer, - } - - else: - raise ConfigError( - "The %s drain type is currently not implemented." - % (config["type"].upper(),) - ) - - -def setup_structured_logging( - log_config: dict, -) -> dict: - """ - Convert a legacy structured logging configuration (from Synapse < v1.23.0) - to one compatible with the new standard library handlers. - """ - if "drains" not in log_config: - raise ConfigError("The logging configuration requires a list of drains.") - - new_config = { - "version": 1, - "formatters": { - "json": {"class": "synapse.logging.JsonFormatter"}, - "tersejson": {"class": "synapse.logging.TerseJsonFormatter"}, - }, - "handlers": {}, - "loggers": log_config.get("loggers", DEFAULT_LOGGERS), - "root": {"handlers": []}, - } - - for handler_name, handler in parse_drain_configs(log_config["drains"]): - new_config["handlers"][handler_name] = handler - - # Add each handler to the root logger. - new_config["root"]["handlers"].append(handler_name) - - return new_config From 6127c4b9f19702909938ddc95b7e4219d20ac349 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 17 Feb 2022 15:55:14 +0000 Subject: [PATCH 16/84] Configure `tox` to use `venv` (#12015) As the comment says, virtualenv is a pile of fail. --- .ci/scripts/test_old_deps.sh | 4 +++- changelog.d/12015.misc | 1 + tox.ini | 5 +++++ 3 files changed, 9 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12015.misc diff --git a/.ci/scripts/test_old_deps.sh b/.ci/scripts/test_old_deps.sh index 54ec3c8b0..b2859f752 100755 --- a/.ci/scripts/test_old_deps.sh +++ b/.ci/scripts/test_old_deps.sh @@ -8,7 +8,9 @@ export DEBIAN_FRONTEND=noninteractive set -ex apt-get update -apt-get install -y python3 python3-dev python3-pip libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox libjpeg-dev libwebp-dev +apt-get install -y \ + python3 python3-dev python3-pip python3-venv \ + libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox libjpeg-dev libwebp-dev export LANG="C.UTF-8" diff --git a/changelog.d/12015.misc b/changelog.d/12015.misc new file mode 100644 index 000000000..3aa32ab4c --- /dev/null +++ b/changelog.d/12015.misc @@ -0,0 +1 @@ +Configure `tox` to use `venv` rather than `virtualenv`. diff --git a/tox.ini b/tox.ini index 32679e910..2b3d39e03 100644 --- a/tox.ini +++ b/tox.ini @@ -4,6 +4,11 @@ envlist = packaging, py37, py38, py39, py310, check_codestyle, check_isort # we require tox>=2.3.2 for the fix to https://github.com/tox-dev/tox/issues/208 minversion = 2.3.2 +# the tox-venv plugin makes tox use python's built-in `venv` module rather than +# the legacy `virtualenv` tool. `virtualenv` embeds its own `pip`, `setuptools`, +# etc, and ends up being rather unreliable. +requires = tox-venv + [base] deps = python-subunit From da0e9f8efdac1571eab35ad2cc842073ca54769c Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 17 Feb 2022 16:11:59 +0000 Subject: [PATCH 17/84] Faster joins: parse msc3706 fields in send_join response (#12011) Part of my work on #11249: add code to handle the new fields added in MSC3706. --- changelog.d/12011.misc | 1 + synapse/config/experimental.py | 4 + synapse/federation/federation_client.py | 15 ++- synapse/federation/transport/client.py | 118 ++++++++++++++++------ synapse/python_dependencies.py | 3 +- tests/federation/transport/test_client.py | 32 ++++++ 6 files changed, 140 insertions(+), 33 deletions(-) create mode 100644 changelog.d/12011.misc diff --git a/changelog.d/12011.misc b/changelog.d/12011.misc new file mode 100644 index 000000000..258b0e389 --- /dev/null +++ b/changelog.d/12011.misc @@ -0,0 +1 @@ +Preparation for faster-room-join work: parse msc3706 fields in send_join response. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 09d692d9a..12b5638cf 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -64,3 +64,7 @@ class ExperimentalConfig(Config): # MSC3706 (server-side support for partial state in /send_join responses) self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False) + + # experimental support for faster joins over federation (msc2775, msc3706) + # requires a target server with msc3706_enabled enabled. + self.faster_joins_enabled: bool = experimental.get("faster_joins", False) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 74f17aa4d..9f56f97d9 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1,4 +1,4 @@ -# Copyright 2015-2021 The Matrix.org Foundation C.I.C. +# Copyright 2015-2022 The Matrix.org Foundation C.I.C. # Copyright 2020 Sorunome # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -89,6 +89,12 @@ class SendJoinResult: state: List[EventBase] auth_chain: List[EventBase] + # True if 'state' elides non-critical membership events + partial_state: bool + + # if 'partial_state' is set, a list of the servers in the room (otherwise empty) + servers_in_room: List[str] + class FederationClient(FederationBase): def __init__(self, hs: "HomeServer"): @@ -876,11 +882,18 @@ class FederationClient(FederationBase): % (auth_chain_create_events,) ) + if response.partial_state and not response.servers_in_room: + raise InvalidResponseError( + "partial_state was set, but no servers were listed in the room" + ) + return SendJoinResult( event=event, state=signed_state, auth_chain=signed_auth, origin=destination, + partial_state=response.partial_state, + servers_in_room=response.servers_in_room or [], ) # MSC3083 defines additional error codes for room joins. diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 8782586cd..dca6e5c45 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -1,4 +1,4 @@ -# Copyright 2014-2021 The Matrix.org Foundation C.I.C. +# Copyright 2014-2022 The Matrix.org Foundation C.I.C. # Copyright 2020 Sorunome # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -60,6 +60,7 @@ class TransportLayerClient: def __init__(self, hs): self.server_name = hs.hostname self.client = hs.get_federation_http_client() + self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled async def get_room_state_ids( self, destination: str, room_id: str, event_id: str @@ -336,10 +337,15 @@ class TransportLayerClient: content: JsonDict, ) -> "SendJoinResponse": path = _create_v2_path("/send_join/%s/%s", room_id, event_id) + query_params: Dict[str, str] = {} + if self._faster_joins_enabled: + # lazy-load state on join + query_params["org.matrix.msc3706.partial_state"] = "true" return await self.client.put_json( destination=destination, path=path, + args=query_params, data=content, parser=SendJoinParser(room_version, v1_api=False), max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN, @@ -1271,6 +1277,12 @@ class SendJoinResponse: # "event" is not included in the response. event: Optional[EventBase] = None + # The room state is incomplete + partial_state: bool = False + + # List of servers in the room + servers_in_room: Optional[List[str]] = None + @ijson.coroutine def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]: @@ -1297,6 +1309,32 @@ def _event_list_parser( events.append(event) +@ijson.coroutine +def _partial_state_parser(response: SendJoinResponse) -> Generator[None, Any, None]: + """Helper function for use with `ijson.items_coro` + + Parses the partial_state field in send_join responses + """ + while True: + val = yield + if not isinstance(val, bool): + raise TypeError("partial_state must be a boolean") + response.partial_state = val + + +@ijson.coroutine +def _servers_in_room_parser(response: SendJoinResponse) -> Generator[None, Any, None]: + """Helper function for use with `ijson.items_coro` + + Parses the servers_in_room field in send_join responses + """ + while True: + val = yield + if not isinstance(val, list) or any(not isinstance(x, str) for x in val): + raise TypeError("servers_in_room must be a list of strings") + response.servers_in_room = val + + class SendJoinParser(ByteParser[SendJoinResponse]): """A parser for the response to `/send_join` requests. @@ -1308,44 +1346,62 @@ class SendJoinParser(ByteParser[SendJoinResponse]): CONTENT_TYPE = "application/json" def __init__(self, room_version: RoomVersion, v1_api: bool): - self._response = SendJoinResponse([], [], {}) + self._response = SendJoinResponse([], [], event_dict={}) self._room_version = room_version + self._coros = [] # The V1 API has the shape of `[200, {...}]`, which we handle by # prefixing with `item.*`. prefix = "item." if v1_api else "" - self._coro_state = ijson.items_coro( - _event_list_parser(room_version, self._response.state), - prefix + "state.item", - use_float=True, - ) - self._coro_auth = ijson.items_coro( - _event_list_parser(room_version, self._response.auth_events), - prefix + "auth_chain.item", - use_float=True, - ) - # TODO Remove the unstable prefix when servers have updated. - # - # By re-using the same event dictionary this will cause the parsing of - # org.matrix.msc3083.v2.event and event to stomp over each other. - # Generally this should be fine. - self._coro_unstable_event = ijson.kvitems_coro( - _event_parser(self._response.event_dict), - prefix + "org.matrix.msc3083.v2.event", - use_float=True, - ) - self._coro_event = ijson.kvitems_coro( - _event_parser(self._response.event_dict), - prefix + "event", - use_float=True, - ) + self._coros = [ + ijson.items_coro( + _event_list_parser(room_version, self._response.state), + prefix + "state.item", + use_float=True, + ), + ijson.items_coro( + _event_list_parser(room_version, self._response.auth_events), + prefix + "auth_chain.item", + use_float=True, + ), + # TODO Remove the unstable prefix when servers have updated. + # + # By re-using the same event dictionary this will cause the parsing of + # org.matrix.msc3083.v2.event and event to stomp over each other. + # Generally this should be fine. + ijson.kvitems_coro( + _event_parser(self._response.event_dict), + prefix + "org.matrix.msc3083.v2.event", + use_float=True, + ), + ijson.kvitems_coro( + _event_parser(self._response.event_dict), + prefix + "event", + use_float=True, + ), + ] + + if not v1_api: + self._coros.append( + ijson.items_coro( + _partial_state_parser(self._response), + "org.matrix.msc3706.partial_state", + use_float="True", + ) + ) + + self._coros.append( + ijson.items_coro( + _servers_in_room_parser(self._response), + "org.matrix.msc3706.servers_in_room", + use_float="True", + ) + ) def write(self, data: bytes) -> int: - self._coro_state.send(data) - self._coro_auth.send(data) - self._coro_unstable_event.send(data) - self._coro_event.send(data) + for c in self._coros: + c.send(data) return len(data) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 86162e0f2..f43fbb584 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -87,7 +87,8 @@ REQUIREMENTS = [ # We enforce that we have a `cryptography` version that bundles an `openssl` # with the latest security patches. "cryptography>=3.4.7", - "ijson>=3.1", + # ijson 3.1.4 fixes a bug with "." in property names + "ijson>=3.1.4", "matrix-common~=1.1.0", ] diff --git a/tests/federation/transport/test_client.py b/tests/federation/transport/test_client.py index a7031a55f..c2320ce13 100644 --- a/tests/federation/transport/test_client.py +++ b/tests/federation/transport/test_client.py @@ -62,3 +62,35 @@ class SendJoinParserTestCase(TestCase): self.assertEqual(len(parsed_response.state), 1, parsed_response) self.assertEqual(parsed_response.event_dict, {}, parsed_response) self.assertIsNone(parsed_response.event, parsed_response) + self.assertFalse(parsed_response.partial_state, parsed_response) + self.assertEqual(parsed_response.servers_in_room, None, parsed_response) + + def test_partial_state(self) -> None: + """Check that the partial_state flag is correctly parsed""" + parser = SendJoinParser(RoomVersions.V1, False) + response = { + "org.matrix.msc3706.partial_state": True, + } + + serialised_response = json.dumps(response).encode() + + # Send data to the parser + parser.write(serialised_response) + + # Retrieve and check the parsed SendJoinResponse + parsed_response = parser.finish() + self.assertTrue(parsed_response.partial_state) + + def test_servers_in_room(self) -> None: + """Check that the servers_in_room field is correctly parsed""" + parser = SendJoinParser(RoomVersions.V1, False) + response = {"org.matrix.msc3706.servers_in_room": ["hs1", "hs2"]} + + serialised_response = json.dumps(response).encode() + + # Send data to the parser + parser.write(serialised_response) + + # Retrieve and check the parsed SendJoinResponse + parsed_response = parser.finish() + self.assertEqual(parsed_response.servers_in_room, ["hs1", "hs2"]) From 707049c6ff61193ffdfba909b4f17e9158c1d3e1 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 17 Feb 2022 17:54:16 +0100 Subject: [PATCH 18/84] Allow modules to set a display name on registration (#12009) Co-authored-by: Patrick Cloke --- changelog.d/12009.feature | 1 + .../password_auth_provider_callbacks.md | 35 ++++- synapse/handlers/auth.py | 58 +++++++++ synapse/module_api/__init__.py | 5 + synapse/rest/client/register.py | 7 + tests/handlers/test_password_providers.py | 123 +++++++++++++----- 6 files changed, 195 insertions(+), 34 deletions(-) create mode 100644 changelog.d/12009.feature diff --git a/changelog.d/12009.feature b/changelog.d/12009.feature new file mode 100644 index 000000000..c8a531481 --- /dev/null +++ b/changelog.d/12009.feature @@ -0,0 +1 @@ +Enable modules to set a custom display name when registering a user. diff --git a/docs/modules/password_auth_provider_callbacks.md b/docs/modules/password_auth_provider_callbacks.md index 88b59bb09..ec810fd29 100644 --- a/docs/modules/password_auth_provider_callbacks.md +++ b/docs/modules/password_auth_provider_callbacks.md @@ -85,7 +85,7 @@ If the authentication is unsuccessful, the module must return `None`. If multiple modules implement this callback, they will be considered in order. If a callback returns `None`, Synapse falls through to the next one. The value of the first callback that does not return `None` will be used. If this happens, Synapse will not call -any of the subsequent implementations of this callback. If every callback return `None`, +any of the subsequent implementations of this callback. If every callback returns `None`, the authentication is denied. ### `on_logged_out` @@ -162,10 +162,38 @@ return `None`. If multiple modules implement this callback, they will be considered in order. If a callback returns `None`, Synapse falls through to the next one. The value of the first callback that does not return `None` will be used. If this happens, Synapse will not call -any of the subsequent implementations of this callback. If every callback return `None`, +any of the subsequent implementations of this callback. If every callback returns `None`, the username provided by the user is used, if any (otherwise one is automatically generated). +### `get_displayname_for_registration` + +_First introduced in Synapse v1.54.0_ + +```python +async def get_displayname_for_registration( + uia_results: Dict[str, Any], + params: Dict[str, Any], +) -> Optional[str] +``` + +Called when registering a new user. The module can return a display name to set for the +user being registered by returning it as a string, or `None` if it doesn't wish to force a +display name for this user. + +This callback is called once [User-Interactive Authentication](https://spec.matrix.org/latest/client-server-api/#user-interactive-authentication-api) +has been completed by the user. It is not called when registering a user via SSO. It is +passed two dictionaries, which include the information that the user has provided during +the registration process. These dictionaries are identical to the ones passed to +[`get_username_for_registration`](#get_username_for_registration), so refer to the +documentation of this callback for more information about them. + +If multiple modules implement this callback, they will be considered in order. If a +callback returns `None`, Synapse falls through to the next one. The value of the first +callback that does not return `None` will be used. If this happens, Synapse will not call +any of the subsequent implementations of this callback. If every callback returns `None`, +the username will be used (e.g. `alice` if the user being registered is `@alice:example.com`). + ## `is_3pid_allowed` _First introduced in Synapse v1.53.0_ @@ -194,8 +222,7 @@ The example module below implements authentication checkers for two different lo - Is checked by the method: `self.check_my_login` - `m.login.password` (defined in [the spec](https://matrix.org/docs/spec/client_server/latest#password-based)) - Expects a `password` field to be sent to `/login` - - Is checked by the method: `self.check_pass` - + - Is checked by the method: `self.check_pass` ```python from typing import Awaitable, Callable, Optional, Tuple diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 6959d1aa7..572f54b1e 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -2064,6 +2064,10 @@ GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[ [JsonDict, JsonDict], Awaitable[Optional[str]], ] +GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK = Callable[ + [JsonDict, JsonDict], + Awaitable[Optional[str]], +] IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]] @@ -2080,6 +2084,9 @@ class PasswordAuthProvider: self.get_username_for_registration_callbacks: List[ GET_USERNAME_FOR_REGISTRATION_CALLBACK ] = [] + self.get_displayname_for_registration_callbacks: List[ + GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK + ] = [] self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = [] # Mapping from login type to login parameters @@ -2099,6 +2106,9 @@ class PasswordAuthProvider: get_username_for_registration: Optional[ GET_USERNAME_FOR_REGISTRATION_CALLBACK ] = None, + get_displayname_for_registration: Optional[ + GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK + ] = None, ) -> None: # Register check_3pid_auth callback if check_3pid_auth is not None: @@ -2148,6 +2158,11 @@ class PasswordAuthProvider: get_username_for_registration, ) + if get_displayname_for_registration is not None: + self.get_displayname_for_registration_callbacks.append( + get_displayname_for_registration, + ) + if is_3pid_allowed is not None: self.is_3pid_allowed_callbacks.append(is_3pid_allowed) @@ -2350,6 +2365,49 @@ class PasswordAuthProvider: return None + async def get_displayname_for_registration( + self, + uia_results: JsonDict, + params: JsonDict, + ) -> Optional[str]: + """Defines the display name to use when registering the user, using the + credentials and parameters provided during the UIA flow. + + Stops at the first callback that returns a tuple containing at least one string. + + Args: + uia_results: The credentials provided during the UIA flow. + params: The parameters provided by the registration request. + + Returns: + A tuple which first element is the display name, and the second is an MXC URL + to the user's avatar. + """ + for callback in self.get_displayname_for_registration_callbacks: + try: + res = await callback(uia_results, params) + + if isinstance(res, str): + return res + elif res is not None: + # mypy complains that this line is unreachable because it assumes the + # data returned by the module fits the expected type. We just want + # to make sure this is the case. + logger.warning( # type: ignore[unreachable] + "Ignoring non-string value returned by" + " get_displayname_for_registration callback %s: %s", + callback, + res, + ) + except Exception as e: + logger.error( + "Module raised an exception in get_displayname_for_registration: %s", + e, + ) + raise SynapseError(code=500, msg="Internal Server Error") + + return None + async def is_3pid_allowed( self, medium: str, diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index d4fca3692..8a17b912d 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -70,6 +70,7 @@ from synapse.handlers.account_validity import ( from synapse.handlers.auth import ( CHECK_3PID_AUTH_CALLBACK, CHECK_AUTH_CALLBACK, + GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK, GET_USERNAME_FOR_REGISTRATION_CALLBACK, IS_3PID_ALLOWED_CALLBACK, ON_LOGGED_OUT_CALLBACK, @@ -317,6 +318,9 @@ class ModuleApi: get_username_for_registration: Optional[ GET_USERNAME_FOR_REGISTRATION_CALLBACK ] = None, + get_displayname_for_registration: Optional[ + GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK + ] = None, ) -> None: """Registers callbacks for password auth provider capabilities. @@ -328,6 +332,7 @@ class ModuleApi: is_3pid_allowed=is_3pid_allowed, auth_checkers=auth_checkers, get_username_for_registration=get_username_for_registration, + get_displayname_for_registration=get_displayname_for_registration, ) def register_background_update_controller_callbacks( diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index c965e2bda..b8a5135e0 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -694,11 +694,18 @@ class RegisterRestServlet(RestServlet): session_id ) + display_name = await ( + self.password_auth_provider.get_displayname_for_registration( + auth_result, params + ) + ) + registered_user_id = await self.registration_handler.register_user( localpart=desired_username, password_hash=password_hash, guest_access_token=guest_access_token, threepid=threepid, + default_display_name=display_name, address=client_addr, user_agent_ips=entries, ) diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 4740dd0a6..49d832de8 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -84,7 +84,7 @@ class CustomAuthProvider: def __init__(self, config, api: ModuleApi): api.register_password_auth_provider_callbacks( - auth_checkers={("test.login_type", ("test_field",)): self.check_auth}, + auth_checkers={("test.login_type", ("test_field",)): self.check_auth} ) def check_auth(self, *args): @@ -122,7 +122,7 @@ class PasswordCustomAuthProvider: auth_checkers={ ("test.login_type", ("test_field",)): self.check_auth, ("m.login.password", ("password",)): self.check_auth, - }, + } ) pass @@ -163,6 +163,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): account.register_servlets, ] + CALLBACK_USERNAME = "get_username_for_registration" + CALLBACK_DISPLAYNAME = "get_displayname_for_registration" + def setUp(self): # we use a global mock device, so make sure we are starting with a clean slate mock_password_provider.reset_mock() @@ -754,7 +757,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): """Tests that the get_username_for_registration callback can define the username of a user when registering. """ - self._setup_get_username_for_registration() + self._setup_get_name_for_registration( + callback_name=self.CALLBACK_USERNAME, + ) username = "rin" channel = self.make_request( @@ -777,30 +782,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): """Tests that the get_username_for_registration callback is only called at the end of the UIA flow. """ - m = self._setup_get_username_for_registration() + m = self._setup_get_name_for_registration( + callback_name=self.CALLBACK_USERNAME, + ) - # Initiate the UIA flow. username = "rin" - channel = self.make_request( - "POST", - "register", - {"username": username, "type": "m.login.password", "password": "bar"}, - ) - self.assertEqual(channel.code, 401) - self.assertIn("session", channel.json_body) + res = self._do_uia_assert_mock_not_called(username, m) - # Check that the callback hasn't been called yet. - m.assert_not_called() - - # Finish the UIA flow. - session = channel.json_body["session"] - channel = self.make_request( - "POST", - "register", - {"auth": {"session": session, "type": LoginType.DUMMY}}, - ) - self.assertEqual(channel.code, 200, channel.json_body) - mxid = channel.json_body["user_id"] + mxid = res["user_id"] self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo") # Check that the callback has been called. @@ -817,6 +806,56 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self._test_3pid_allowed("rin", False) self._test_3pid_allowed("kitay", True) + def test_displayname(self): + """Tests that the get_displayname_for_registration callback can define the + display name of a user when registering. + """ + self._setup_get_name_for_registration( + callback_name=self.CALLBACK_DISPLAYNAME, + ) + + username = "rin" + channel = self.make_request( + "POST", + "/register", + { + "username": username, + "password": "bar", + "auth": {"type": LoginType.DUMMY}, + }, + ) + self.assertEqual(channel.code, 200) + + # Our callback takes the username and appends "-foo" to it, check that's what we + # have. + user_id = UserID.from_string(channel.json_body["user_id"]) + display_name = self.get_success( + self.hs.get_profile_handler().get_displayname(user_id) + ) + + self.assertEqual(display_name, username + "-foo") + + def test_displayname_uia(self): + """Tests that the get_displayname_for_registration callback is only called at the + end of the UIA flow. + """ + m = self._setup_get_name_for_registration( + callback_name=self.CALLBACK_DISPLAYNAME, + ) + + username = "rin" + res = self._do_uia_assert_mock_not_called(username, m) + + user_id = UserID.from_string(res["user_id"]) + display_name = self.get_success( + self.hs.get_profile_handler().get_displayname(user_id) + ) + + self.assertEqual(display_name, username + "-foo") + + # Check that the callback has been called. + m.assert_called_once() + def _test_3pid_allowed(self, username: str, registration: bool): """Tests that the "is_3pid_allowed" module callback is called correctly, using either /register or /account URLs depending on the arguments. @@ -877,23 +916,47 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): m.assert_called_once_with("email", "bar@test.com", registration) - def _setup_get_username_for_registration(self) -> Mock: - """Registers a get_username_for_registration callback that appends "-foo" to the - username the client is trying to register. + def _setup_get_name_for_registration(self, callback_name: str) -> Mock: + """Registers either a get_username_for_registration callback or a + get_displayname_for_registration callback that appends "-foo" to the username the + client is trying to register. """ - async def get_username_for_registration(uia_results, params): + async def callback(uia_results, params): self.assertIn(LoginType.DUMMY, uia_results) username = params["username"] return username + "-foo" - m = Mock(side_effect=get_username_for_registration) + m = Mock(side_effect=callback) password_auth_provider = self.hs.get_password_auth_provider() - password_auth_provider.get_username_for_registration_callbacks.append(m) + getattr(password_auth_provider, callback_name + "_callbacks").append(m) return m + def _do_uia_assert_mock_not_called(self, username: str, m: Mock) -> JsonDict: + # Initiate the UIA flow. + channel = self.make_request( + "POST", + "register", + {"username": username, "type": "m.login.password", "password": "bar"}, + ) + self.assertEqual(channel.code, 401) + self.assertIn("session", channel.json_body) + + # Check that the callback hasn't been called yet. + m.assert_not_called() + + # Finish the UIA flow. + session = channel.json_body["session"] + channel = self.make_request( + "POST", + "register", + {"auth": {"session": session, "type": LoginType.DUMMY}}, + ) + self.assertEqual(channel.code, 200, channel.json_body) + return channel.json_body + def _get_login_flows(self) -> JsonDict: channel = self.make_request("GET", "/_matrix/client/r0/login") self.assertEqual(channel.code, 200, channel.result) From 3f4d25a48ba17a1e2bcc6510f6cf6de7dd496a11 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 17 Feb 2022 17:22:55 +0000 Subject: [PATCH 19/84] Remove unstable MSC3283 flags (#12018) Fixes #11962 --- changelog.d/12018.removal | 1 + synapse/config/experimental.py | 3 --- synapse/rest/client/capabilities.py | 14 -------------- 3 files changed, 1 insertion(+), 17 deletions(-) create mode 100644 changelog.d/12018.removal diff --git a/changelog.d/12018.removal b/changelog.d/12018.removal new file mode 100644 index 000000000..e940b6222 --- /dev/null +++ b/changelog.d/12018.removal @@ -0,0 +1 @@ +Drop support for [MSC3283](https://github.com/matrix-org/matrix-doc/pull/3283) unstable flags now that the stable flags are supported. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 12b5638cf..bcdeb9ee2 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -41,9 +41,6 @@ class ExperimentalConfig(Config): # MSC3244 (room version capabilities) self.msc3244_enabled: bool = experimental.get("msc3244_enabled", True) - # MSC3283 (set displayname, avatar_url and change 3pid capabilities) - self.msc3283_enabled: bool = experimental.get("msc3283_enabled", False) - # MSC3266 (room summary api) self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False) diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py index 6682da077..e05c926b6 100644 --- a/synapse/rest/client/capabilities.py +++ b/synapse/rest/client/capabilities.py @@ -72,20 +72,6 @@ class CapabilitiesRestServlet(RestServlet): "org.matrix.msc3244.room_capabilities" ] = MSC3244_CAPABILITIES - # Must be removed in later versions. - # Is only included for migration. - # Also the parts in `synapse/config/experimental.py`. - if self.config.experimental.msc3283_enabled: - response["capabilities"]["org.matrix.msc3283.set_displayname"] = { - "enabled": self.config.registration.enable_set_displayname - } - response["capabilities"]["org.matrix.msc3283.set_avatar_url"] = { - "enabled": self.config.registration.enable_set_avatar_url - } - response["capabilities"]["org.matrix.msc3283.3pid_changes"] = { - "enabled": self.config.registration.enable_3pid_changes - } - if self.config.experimental.msc3440_enabled: response["capabilities"]["io.element.thread"] = {"enabled": True} From 40e256e7aa31a6b90e665b340858abbd3a2999c9 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 18 Feb 2022 12:38:48 +0100 Subject: [PATCH 20/84] Update the olddeps CI check to use an old version of markupsafe (#12025) --- changelog.d/12025.misc | 1 + tox.ini | 3 +++ 2 files changed, 4 insertions(+) create mode 100644 changelog.d/12025.misc diff --git a/changelog.d/12025.misc b/changelog.d/12025.misc new file mode 100644 index 000000000..d9475a771 --- /dev/null +++ b/changelog.d/12025.misc @@ -0,0 +1 @@ +Update the `olddeps` CI job to use an old version of `markupsafe`. diff --git a/tox.ini b/tox.ini index 2b3d39e03..41678aa38 100644 --- a/tox.ini +++ b/tox.ini @@ -124,6 +124,9 @@ usedevelop = false deps = Automat == 0.8.0 lxml + # markupsafe 2.1 introduced a change that breaks Jinja 2.x. Since we depend on + # Jinja >= 2.9, it means this test suite will fail if markupsafe >= 2.1 is installed. + markupsafe < 2.1 {[base]deps} commands = From 5a6911598ad2d3dea96b9f8c1cffccd4f4840bf7 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 18 Feb 2022 06:11:18 -0600 Subject: [PATCH 21/84] Fix 500 error with Postgres when looking backwards with the MSC3030 `/timestamp_to_event` endpoint (#12024) --- changelog.d/12024.bugfix | 1 + synapse/storage/databases/main/events_worker.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12024.bugfix diff --git a/changelog.d/12024.bugfix b/changelog.d/12024.bugfix new file mode 100644 index 000000000..59bcdb93a --- /dev/null +++ b/changelog.d/12024.bugfix @@ -0,0 +1 @@ +Fix 500 error with Postgres when looking backwards with the [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) `/timestamp_to_event?dir=b` endpoint. diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 712b8ce20..2a255d103 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1854,7 +1854,7 @@ class EventsWorkerStore(SQLBaseStore): forward_edge_query = """ SELECT 1 FROM event_edges /* Check to make sure the event referencing our event in question is not rejected */ - LEFT JOIN rejections ON event_edges.event_id == rejections.event_id + LEFT JOIN rejections ON event_edges.event_id = rejections.event_id WHERE event_edges.room_id = ? AND event_edges.prev_event_id = ? From 19bd9cff1afba2149ae58cd5b5648470c48d51a5 Mon Sep 17 00:00:00 2001 From: Travis Ralston Date: Fri, 18 Feb 2022 05:48:23 -0700 Subject: [PATCH 22/84] Use stable MSC3069 `is_guest` flag on `/whoami`. (#12021) Keeping backwards compatibility with the unstable flag for now. --- changelog.d/12021.feature | 1 + synapse/rest/client/account.py | 2 ++ tests/rest/client/test_account.py | 9 ++++++--- 3 files changed, 9 insertions(+), 3 deletions(-) create mode 100644 changelog.d/12021.feature diff --git a/changelog.d/12021.feature b/changelog.d/12021.feature new file mode 100644 index 000000000..01378df8c --- /dev/null +++ b/changelog.d/12021.feature @@ -0,0 +1 @@ +Support only the stable identifier for [MSC3069](https://github.com/matrix-org/matrix-doc/pull/3069)'s `is_guest` on `/_matrix/client/v3/account/whoami`. \ No newline at end of file diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index cfa2aee76..efe299e69 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -883,7 +883,9 @@ class WhoamiRestServlet(RestServlet): response = { "user_id": requester.user.to_string(), # MSC: https://github.com/matrix-org/matrix-doc/pull/3069 + # Entered spec in Matrix 1.2 "org.matrix.msc3069.is_guest": bool(requester.is_guest), + "is_guest": bool(requester.is_guest), } # Appservices and similar accounts do not have device IDs diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index 89d85b0a1..51146c471 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -486,8 +486,9 @@ class WhoamiTestCase(unittest.HomeserverTestCase): { "user_id": user_id, "device_id": device_id, - # Unstable until MSC3069 enters spec + # MSC3069 entered spec in Matrix 1.2 but maintained compatibility "org.matrix.msc3069.is_guest": False, + "is_guest": False, }, ) @@ -505,8 +506,9 @@ class WhoamiTestCase(unittest.HomeserverTestCase): { "user_id": user_id, "device_id": device_id, - # Unstable until MSC3069 enters spec + # MSC3069 entered spec in Matrix 1.2 but maintained compatibility "org.matrix.msc3069.is_guest": True, + "is_guest": True, }, ) @@ -528,8 +530,9 @@ class WhoamiTestCase(unittest.HomeserverTestCase): whoami, { "user_id": user_id, - # Unstable until MSC3069 enters spec + # MSC3069 entered spec in Matrix 1.2 but maintained compatibility "org.matrix.msc3069.is_guest": False, + "is_guest": False, }, ) self.assertFalse(hasattr(whoami, "device_id")) From 31a298fec792ec1d1efb5f47763e4b0a16f24e6d Mon Sep 17 00:00:00 2001 From: Travis Ralston Date: Fri, 18 Feb 2022 05:49:53 -0700 Subject: [PATCH 23/84] Advertise Matrix 1.1 in `/_matrix/client/versions` (#12020) --- changelog.d/12020.feature | 1 + synapse/rest/client/versions.py | 1 + 2 files changed, 2 insertions(+) create mode 100644 changelog.d/12020.feature diff --git a/changelog.d/12020.feature b/changelog.d/12020.feature new file mode 100644 index 000000000..1ac9d2060 --- /dev/null +++ b/changelog.d/12020.feature @@ -0,0 +1 @@ +Advertise Matrix 1.1 support on `/_matrix/client/versions`. \ No newline at end of file diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 2290c57c1..35b88e9bb 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -73,6 +73,7 @@ class VersionsRestServlet(RestServlet): "r0.5.0", "r0.6.0", "r0.6.1", + "v1.1", ], # as per MSC1497: "unstable_features": { From eb609c65d0794dd49efcd924bdc8743fd4253a93 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Fri, 18 Feb 2022 14:54:31 +0000 Subject: [PATCH 24/84] Fix bug in `StateFilter.return_expanded()` and add some tests. (#12016) --- changelog.d/12016.misc | 1 + synapse/storage/state.py | 8 ++- tests/storage/test_state.py | 109 ++++++++++++++++++++++++++++++++++++ 3 files changed, 117 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12016.misc diff --git a/changelog.d/12016.misc b/changelog.d/12016.misc new file mode 100644 index 000000000..8856ef46a --- /dev/null +++ b/changelog.d/12016.misc @@ -0,0 +1 @@ +Fix bug in `StateFilter.return_expanded()` and add some tests. \ No newline at end of file diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 913448f0f..e79ecf64a 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -204,13 +204,16 @@ class StateFilter: if get_all_members: # We want to return everything. return StateFilter.all() - else: + elif EventTypes.Member in self.types: # We want to return all non-members, but only particular # memberships return StateFilter( types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}), include_others=True, ) + else: + # We want to return all non-members + return _ALL_NON_MEMBER_STATE_FILTER def make_sql_filter_clause(self) -> Tuple[str, List[str]]: """Converts the filter to an SQL clause. @@ -528,6 +531,9 @@ class StateFilter: _ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True) +_ALL_NON_MEMBER_STATE_FILTER = StateFilter( + types=frozendict({EventTypes.Member: frozenset()}), include_others=True +) _NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 70d52b088..28c767ecf 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -992,3 +992,112 @@ class StateFilterDifferenceTestCase(TestCase): StateFilter.none(), StateFilter.all(), ) + + +class StateFilterTestCase(TestCase): + def test_return_expanded(self): + """ + Tests the behaviour of the return_expanded() function that expands + StateFilters to include more state types (for the sake of cache hit rate). + """ + + self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all()) + + self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none()) + + # Concrete-only state filters stay the same + # (Case: mixed filter) + self.assertEqual( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + "some.other.state.type": {""}, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + "some.other.state.type": {""}, + }, + include_others=False, + ), + ) + + # Concrete-only state filters stay the same + # (Case: non-member-only filter) + self.assertEqual( + StateFilter.freeze( + {"some.other.state.type": {""}}, include_others=False + ).return_expanded(), + StateFilter.freeze({"some.other.state.type": {""}}, include_others=False), + ) + + # Concrete-only state filters stay the same + # (Case: member-only filter) + self.assertEqual( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + }, + include_others=False, + ), + ) + + # Wildcard member-only state filters stay the same + self.assertEqual( + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + ) + + # If there is a wildcard in the non-member portion of the filter, + # it's expanded to include ALL non-member events. + # (Case: mixed filter) + self.assertEqual( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + "some.other.state.type": None, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + {EventTypes.Member: {"@wombat:test", "@alicia:test"}}, + include_others=True, + ), + ) + + # If there is a wildcard in the non-member portion of the filter, + # it's expanded to include ALL non-member events. + # (Case: non-member-only filter) + self.assertEqual( + StateFilter.freeze( + { + "some.other.state.type": None, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze({EventTypes.Member: set()}, include_others=True), + ) + self.assertEqual( + StateFilter.freeze( + { + "some.other.state.type": None, + "yet.another.state.type": {"wombat"}, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze({EventTypes.Member: set()}, include_others=True), + ) From e6acd3cf4fe6910e99683c2ebe3dd917f0a3ae14 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 18 Feb 2022 15:57:26 +0000 Subject: [PATCH 25/84] Upgrade mypy to version 0.931 (#12030) Upgrade mypy to 0.931, mypy-zope to 0.3.5 and fix new complaints. --- changelog.d/12030.misc | 1 + setup.py | 4 ++-- stubs/sortedcontainers/sorteddict.pyi | 13 +++++++++---- synapse/handlers/search.py | 2 +- synapse/push/baserules.py | 10 ++++++---- synapse/push/httppusher.py | 2 +- synapse/streams/events.py | 6 ++++-- synapse/util/daemonize.py | 8 +++++--- synapse/util/patch_inline_callbacks.py | 6 ++++-- 9 files changed, 33 insertions(+), 19 deletions(-) create mode 100644 changelog.d/12030.misc diff --git a/changelog.d/12030.misc b/changelog.d/12030.misc new file mode 100644 index 000000000..607ee97ce --- /dev/null +++ b/changelog.d/12030.misc @@ -0,0 +1 @@ +Upgrade mypy to version 0.931. diff --git a/setup.py b/setup.py index d0511c767..c80cb6f20 100755 --- a/setup.py +++ b/setup.py @@ -103,8 +103,8 @@ CONDITIONAL_REQUIREMENTS["lint"] = [ ] CONDITIONAL_REQUIREMENTS["mypy"] = [ - "mypy==0.910", - "mypy-zope==0.3.2", + "mypy==0.931", + "mypy-zope==0.3.5", "types-bleach>=4.1.0", "types-jsonschema>=3.2.0", "types-opentracing>=2.4.2", diff --git a/stubs/sortedcontainers/sorteddict.pyi b/stubs/sortedcontainers/sorteddict.pyi index 0eaef0049..344d55cce 100644 --- a/stubs/sortedcontainers/sorteddict.pyi +++ b/stubs/sortedcontainers/sorteddict.pyi @@ -66,13 +66,18 @@ class SortedDict(Dict[_KT, _VT]): def __copy__(self: _SD) -> _SD: ... @classmethod @overload - def fromkeys(cls, seq: Iterable[_T_h]) -> SortedDict[_T_h, None]: ... + def fromkeys( + cls, seq: Iterable[_T_h], value: None = ... + ) -> SortedDict[_T_h, None]: ... @classmethod @overload def fromkeys(cls, seq: Iterable[_T_h], value: _S) -> SortedDict[_T_h, _S]: ... - def keys(self) -> SortedKeysView[_KT]: ... - def items(self) -> SortedItemsView[_KT, _VT]: ... - def values(self) -> SortedValuesView[_VT]: ... + # As of Python 3.10, `dict_{keys,items,values}` have an extra `mapping` attribute and so + # `Sorted{Keys,Items,Values}View` are no longer compatible with them. + # See https://github.com/python/typeshed/issues/6837 + def keys(self) -> SortedKeysView[_KT]: ... # type: ignore[override] + def items(self) -> SortedItemsView[_KT, _VT]: ... # type: ignore[override] + def values(self) -> SortedValuesView[_VT]: ... # type: ignore[override] @overload def pop(self, key: _KT) -> _VT: ... @overload diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index afd14da11..0e0e58de0 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -654,7 +654,7 @@ class SearchHandler: self.storage, user.to_string(), res.events_after ) - context = { + context: JsonDict = { "events_before": events_before, "events_after": events_after, "start": await now_token.copy_and_replace( diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 910b05c0d..832eaa34e 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -130,7 +130,9 @@ def make_base_prepend_rules( return rules -BASE_APPEND_CONTENT_RULES = [ +# We have to annotate these types, otherwise mypy infers them as +# `List[Dict[str, Sequence[Collection[str]]]]`. +BASE_APPEND_CONTENT_RULES: List[Dict[str, Any]] = [ { "rule_id": "global/content/.m.rule.contains_user_name", "conditions": [ @@ -149,7 +151,7 @@ BASE_APPEND_CONTENT_RULES = [ ] -BASE_PREPEND_OVERRIDE_RULES = [ +BASE_PREPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [ { "rule_id": "global/override/.m.rule.master", "enabled": False, @@ -159,7 +161,7 @@ BASE_PREPEND_OVERRIDE_RULES = [ ] -BASE_APPEND_OVERRIDE_RULES = [ +BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [ { "rule_id": "global/override/.m.rule.suppress_notices", "conditions": [ @@ -278,7 +280,7 @@ BASE_APPEND_OVERRIDE_RULES = [ ] -BASE_APPEND_UNDERRIDE_RULES = [ +BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [ { "rule_id": "global/underride/.m.rule.call", "conditions": [ diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 49bcc06e0..52c7ff357 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -325,7 +325,7 @@ class HttpPusher(Pusher): # This was checked in the __init__, but mypy doesn't seem to know that. assert self.data is not None if self.data.get("format") == "event_id_only": - d = { + d: Dict[str, Any] = { "notification": { "event_id": event.event_id, "room_id": event.room_id, diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 21591d0bf..4ec2a713c 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -37,14 +37,16 @@ class _EventSourcesInner: account_data: AccountDataEventSource def get_sources(self) -> Iterator[Tuple[str, EventSource]]: - for attribute in _EventSourcesInner.__attrs_attrs__: # type: ignore[attr-defined] + for attribute in attr.fields(_EventSourcesInner): yield attribute.name, getattr(self, attribute.name) class EventSources: def __init__(self, hs: "HomeServer"): self.sources = _EventSourcesInner( - *(attribute.type(hs) for attribute in _EventSourcesInner.__attrs_attrs__) # type: ignore[attr-defined] + # mypy thinks attribute.type is `Optional`, but we know it's never `None` here since + # all the attributes of `_EventSourcesInner` are annotated. + *(attribute.type(hs) for attribute in attr.fields(_EventSourcesInner)) # type: ignore[misc] ) self.store = hs.get_datastore() diff --git a/synapse/util/daemonize.py b/synapse/util/daemonize.py index de04f34e4..031880ec3 100644 --- a/synapse/util/daemonize.py +++ b/synapse/util/daemonize.py @@ -20,7 +20,7 @@ import os import signal import sys from types import FrameType, TracebackType -from typing import NoReturn, Type +from typing import NoReturn, Optional, Type def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -> None: @@ -100,7 +100,9 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") - # also catch any other uncaught exceptions before we get that far.) def excepthook( - type_: Type[BaseException], value: BaseException, traceback: TracebackType + type_: Type[BaseException], + value: BaseException, + traceback: Optional[TracebackType], ) -> None: logger.critical("Unhanded exception", exc_info=(type_, value, traceback)) @@ -123,7 +125,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") - sys.exit(1) # write a log line on SIGTERM. - def sigterm(signum: signal.Signals, frame: FrameType) -> NoReturn: + def sigterm(signum: int, frame: Optional[FrameType]) -> NoReturn: logger.warning("Caught signal %s. Stopping daemon." % signum) sys.exit(0) diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py index 1f18654d4..6d4b0b7c5 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py @@ -14,7 +14,7 @@ import functools import sys -from typing import Any, Callable, Generator, List, TypeVar +from typing import Any, Callable, Generator, List, TypeVar, cast from twisted.internet import defer from twisted.internet.defer import Deferred @@ -174,7 +174,9 @@ def _check_yield_points( ) ) changes.append(err) - return getattr(e, "value", None) + # The `StopIteration` or `_DefGen_Return` contains the return value from the + # generator. + return cast(T, e.value) frame = gen.gi_frame From 284ea2025a27e2ceb6dacdd1f6f0b0fff3814dde Mon Sep 17 00:00:00 2001 From: reivilibre Date: Fri, 18 Feb 2022 17:23:31 +0000 Subject: [PATCH 26/84] Track and deduplicate in-flight requests to `_get_state_for_groups`. (#10870) Co-authored-by: Patrick Cloke --- changelog.d/10870.misc | 1 + synapse/storage/databases/state/store.py | 203 +++++++++++++++++--- tests/storage/databases/test_state_store.py | 133 +++++++++++++ 3 files changed, 312 insertions(+), 25 deletions(-) create mode 100644 changelog.d/10870.misc create mode 100644 tests/storage/databases/test_state_store.py diff --git a/changelog.d/10870.misc b/changelog.d/10870.misc new file mode 100644 index 000000000..3af049b96 --- /dev/null +++ b/changelog.d/10870.misc @@ -0,0 +1 @@ +Deduplicate in-flight requests in `_get_state_for_groups`. diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 7614d76ac..3af69a207 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -13,11 +13,23 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + Iterable, + Optional, + Sequence, + Set, + Tuple, +) import attr +from twisted.internet import defer + from synapse.api.constants import EventTypes +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -29,6 +41,12 @@ from synapse.storage.state import StateFilter from synapse.storage.types import Cursor from synapse.storage.util.sequence import build_sequence_generator from synapse.types import MutableStateMap, StateKey, StateMap +from synapse.util import unwrapFirstError +from synapse.util.async_helpers import ( + AbstractObservableDeferred, + ObservableDeferred, + yieldable_gather_results, +) from synapse.util.caches.descriptors import cached from synapse.util.caches.dictionary_cache import DictionaryCache @@ -37,7 +55,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) - MAX_STATE_DELTA_HOPS = 100 @@ -106,6 +123,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): 500000, ) + # Current ongoing get_state_for_groups in-flight requests + # {group ID -> {StateFilter -> ObservableDeferred}} + self._state_group_inflight_requests: Dict[ + int, Dict[StateFilter, AbstractObservableDeferred[StateMap[str]]] + ] = {} + def get_max_state_group_txn(txn: Cursor) -> int: txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") return txn.fetchone()[0] # type: ignore @@ -157,7 +180,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): ) async def _get_state_groups_from_groups( - self, groups: List[int], state_filter: StateFilter + self, groups: Sequence[int], state_filter: StateFilter ) -> Dict[int, StateMap[str]]: """Returns the state groups for a given set of groups from the database, filtering on types of state events. @@ -228,6 +251,150 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return state_filter.filter_state(state_dict_ids), not missing_types + def _get_state_for_group_gather_inflight_requests( + self, group: int, state_filter_left_over: StateFilter + ) -> Tuple[Sequence[AbstractObservableDeferred[StateMap[str]]], StateFilter]: + """ + Attempts to gather in-flight requests and re-use them to retrieve state + for the given state group, filtered with the given state filter. + + Used as part of _get_state_for_group_using_inflight_cache. + + Returns: + Tuple of two values: + A sequence of ObservableDeferreds to observe + A StateFilter representing what else needs to be requested to fulfill the request + """ + + inflight_requests = self._state_group_inflight_requests.get(group) + if inflight_requests is None: + # no requests for this group, need to retrieve it all ourselves + return (), state_filter_left_over + + # The list of ongoing requests which will help narrow the current request. + reusable_requests = [] + for (request_state_filter, request_deferred) in inflight_requests.items(): + new_state_filter_left_over = state_filter_left_over.approx_difference( + request_state_filter + ) + if new_state_filter_left_over == state_filter_left_over: + # Reusing this request would not gain us anything, so don't bother. + continue + + reusable_requests.append(request_deferred) + state_filter_left_over = new_state_filter_left_over + if state_filter_left_over == StateFilter.none(): + # we have managed to collect enough of the in-flight requests + # to cover our StateFilter and give us the state we need. + break + + return reusable_requests, state_filter_left_over + + async def _get_state_for_group_fire_request( + self, group: int, state_filter: StateFilter + ) -> StateMap[str]: + """ + Fires off a request to get the state at a state group, + potentially filtering by type and/or state key. + + This request will be tracked in the in-flight request cache and automatically + removed when it is finished. + + Used as part of _get_state_for_group_using_inflight_cache. + + Args: + group: ID of the state group for which we want to get state + state_filter: the state filter used to fetch state from the database + """ + cache_sequence_nm = self._state_group_cache.sequence + cache_sequence_m = self._state_group_members_cache.sequence + + # Help the cache hit ratio by expanding the filter a bit + db_state_filter = state_filter.return_expanded() + + async def _the_request() -> StateMap[str]: + group_to_state_dict = await self._get_state_groups_from_groups( + (group,), state_filter=db_state_filter + ) + + # Now let's update the caches + self._insert_into_cache( + group_to_state_dict, + db_state_filter, + cache_seq_num_members=cache_sequence_m, + cache_seq_num_non_members=cache_sequence_nm, + ) + + # Remove ourselves from the in-flight cache + group_request_dict = self._state_group_inflight_requests[group] + del group_request_dict[db_state_filter] + if not group_request_dict: + # If there are no more requests in-flight for this group, + # clean up the cache by removing the empty dictionary + del self._state_group_inflight_requests[group] + + return group_to_state_dict[group] + + # We don't immediately await the result, so must use run_in_background + # But we DO await the result before the current log context (request) + # finishes, so don't need to run it as a background process. + request_deferred = run_in_background(_the_request) + observable_deferred = ObservableDeferred(request_deferred, consumeErrors=True) + + # Insert the ObservableDeferred into the cache + group_request_dict = self._state_group_inflight_requests.setdefault(group, {}) + group_request_dict[db_state_filter] = observable_deferred + + return await make_deferred_yieldable(observable_deferred.observe()) + + async def _get_state_for_group_using_inflight_cache( + self, group: int, state_filter: StateFilter + ) -> MutableStateMap[str]: + """ + Gets the state at a state group, potentially filtering by type and/or + state key. + + 1. Calls _get_state_for_group_gather_inflight_requests to gather any + ongoing requests which might overlap with the current request. + 2. Fires a new request, using _get_state_for_group_fire_request, + for any state which cannot be gathered from ongoing requests. + + Args: + group: ID of the state group for which we want to get state + state_filter: the state filter used to fetch state from the database + Returns: + state map + """ + + # first, figure out whether we can re-use any in-flight requests + # (and if so, what would be left over) + ( + reusable_requests, + state_filter_left_over, + ) = self._get_state_for_group_gather_inflight_requests(group, state_filter) + + if state_filter_left_over != StateFilter.none(): + # Fetch remaining state + remaining = await self._get_state_for_group_fire_request( + group, state_filter_left_over + ) + assembled_state: MutableStateMap[str] = dict(remaining) + else: + assembled_state = {} + + gathered = await make_deferred_yieldable( + defer.gatherResults( + (r.observe() for r in reusable_requests), consumeErrors=True + ) + ).addErrback(unwrapFirstError) + + # assemble our result. + for result_piece in gathered: + assembled_state.update(result_piece) + + # Filter out any state that may be more than what we asked for. + return state_filter.filter_state(assembled_state) + async def _get_state_for_groups( self, groups: Iterable[int], state_filter: Optional[StateFilter] = None ) -> Dict[int, MutableStateMap[str]]: @@ -269,31 +436,17 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): if not incomplete_groups: return state - cache_sequence_nm = self._state_group_cache.sequence - cache_sequence_m = self._state_group_members_cache.sequence + async def get_from_cache(group: int, state_filter: StateFilter) -> None: + state[group] = await self._get_state_for_group_using_inflight_cache( + group, state_filter + ) - # Help the cache hit ratio by expanding the filter a bit - db_state_filter = state_filter.return_expanded() - - group_to_state_dict = await self._get_state_groups_from_groups( - list(incomplete_groups), state_filter=db_state_filter + await yieldable_gather_results( + get_from_cache, + incomplete_groups, + state_filter, ) - # Now lets update the caches - self._insert_into_cache( - group_to_state_dict, - db_state_filter, - cache_seq_num_members=cache_sequence_m, - cache_seq_num_non_members=cache_sequence_nm, - ) - - # And finally update the result dict, by filtering out any extra - # stuff we pulled out of the database. - for group, group_state_dict in group_to_state_dict.items(): - # We just replace any existing entries, as we will have loaded - # everything we need from the database anyway. - state[group] = state_filter.filter_state(group_state_dict) - return state def _get_state_for_groups_using_cache( diff --git a/tests/storage/databases/test_state_store.py b/tests/storage/databases/test_state_store.py new file mode 100644 index 000000000..cf126ee62 --- /dev/null +++ b/tests/storage/databases/test_state_store.py @@ -0,0 +1,133 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import typing +from typing import Dict, List, Sequence, Tuple +from unittest.mock import patch + +from twisted.internet.defer import Deferred, ensureDeferred +from twisted.test.proto_helpers import MemoryReactor + +from synapse.storage.state import StateFilter +from synapse.types import MutableStateMap, StateMap +from synapse.util import Clock + +from tests.unittest import HomeserverTestCase + +if typing.TYPE_CHECKING: + from synapse.server import HomeServer + + +class StateGroupInflightCachingTestCase(HomeserverTestCase): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: "HomeServer" + ) -> None: + self.state_storage = homeserver.get_storage().state + self.state_datastore = homeserver.get_datastores().state + # Patch out the `_get_state_groups_from_groups`. + # This is useful because it lets us pretend we have a slow database. + get_state_groups_patch = patch.object( + self.state_datastore, + "_get_state_groups_from_groups", + self._fake_get_state_groups_from_groups, + ) + get_state_groups_patch.start() + + self.addCleanup(get_state_groups_patch.stop) + self.get_state_group_calls: List[ + Tuple[Tuple[int, ...], StateFilter, Deferred[Dict[int, StateMap[str]]]] + ] = [] + + def _fake_get_state_groups_from_groups( + self, groups: Sequence[int], state_filter: StateFilter + ) -> "Deferred[Dict[int, StateMap[str]]]": + d: Deferred[Dict[int, StateMap[str]]] = Deferred() + self.get_state_group_calls.append((tuple(groups), state_filter, d)) + return d + + def _complete_request_fake( + self, + groups: Tuple[int, ...], + state_filter: StateFilter, + d: "Deferred[Dict[int, StateMap[str]]]", + ) -> None: + """ + Assemble a fake database response and complete the database request. + """ + + result: Dict[int, StateMap[str]] = {} + + for group in groups: + group_result: MutableStateMap[str] = {} + result[group] = group_result + + for state_type, state_keys in state_filter.types.items(): + if state_keys is None: + group_result[(state_type, "a")] = "xyz" + group_result[(state_type, "b")] = "xyz" + else: + for state_key in state_keys: + group_result[(state_type, state_key)] = "abc" + + if state_filter.include_others: + group_result[("other.event.type", "state.key")] = "123" + + d.callback(result) + + def test_duplicate_requests_deduplicated(self) -> None: + """ + Tests that duplicate requests for state are deduplicated. + + This test: + - requests some state (state group 42, 'all' state filter) + - requests it again, before the first request finishes + - checks to see that only one database query was made + - completes the database query + - checks that both requests see the same retrieved state + """ + req1 = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, StateFilter.all() + ) + ) + self.pump(by=0.1) + + # This should have gone to the database + self.assertEqual(len(self.get_state_group_calls), 1) + self.assertFalse(req1.called) + + req2 = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, StateFilter.all() + ) + ) + self.pump(by=0.1) + + # No more calls should have gone to the database + self.assertEqual(len(self.get_state_group_calls), 1) + self.assertFalse(req1.called) + self.assertFalse(req2.called) + + groups, sf, d = self.get_state_group_calls[0] + self.assertEqual(groups, (42,)) + self.assertEqual(sf, StateFilter.all()) + + # Now we can complete the request + self._complete_request_fake(groups, sf, d) + + self.assertEqual( + self.get_success(req1), {("other.event.type", "state.key"): "123"} + ) + self.assertEqual( + self.get_success(req2), {("other.event.type", "state.key"): "123"} + ) From 444b04058b497da15812d7f14858e6270d54abb5 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 18 Feb 2022 12:24:25 -0500 Subject: [PATCH 27/84] Document why auth providers aren't validated in the admin API. (#12004) Since it is reasonable to give a future or past auth provider, which might not be in the current configuration. --- changelog.d/12004.doc | 1 + docs/admin_api/user_admin_api.md | 3 ++- synapse/module_api/__init__.py | 6 +++++- .../storage/databases/main/registration.py | 21 +++++++++++++++++++ 4 files changed, 29 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12004.doc diff --git a/changelog.d/12004.doc b/changelog.d/12004.doc new file mode 100644 index 000000000..0b4baef21 --- /dev/null +++ b/changelog.d/12004.doc @@ -0,0 +1 @@ +Clarify information about external Identity Provider IDs. diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 1bbe23708..4076fcab6 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -126,7 +126,8 @@ Body parameters: [Sample Configuration File](../usage/configuration/homeserver_sample_config.html) section `sso` and `oidc_providers`. - `auth_provider` - string. ID of the external identity provider. Value of `idp_id` - in homeserver configuration. + in the homeserver configuration. Note that no error is raised if the provided + value is not in the homeserver configuration. - `external_id` - string, user ID in the external identity provider. - `avatar_url` - string, optional, must be a [MXC URI](https://matrix.org/docs/spec/client_server/r0.6.0#matrix-content-mxc-uris). diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 8a17b912d..07020bfb8 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -653,7 +653,11 @@ class ModuleApi: Added in Synapse v1.9.0. Args: - auth_provider: identifier for the remote auth provider + auth_provider: identifier for the remote auth provider, see `sso` and + `oidc_providers` in the homeserver configuration. + + Note that no error is raised if the provided value is not in the + homeserver configuration. external_id: id on that system user_id: complete mxid that it is mapped to """ diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index aac94fa46..17110bb03 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -622,10 +622,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): ) -> None: """Record a mapping from an external user id to a mxid + See notes in _record_user_external_id_txn about what constitutes valid data. + Args: auth_provider: identifier for the remote auth provider external_id: id on that system user_id: complete mxid that it is mapped to + Raises: ExternalIDReuseException if the new external_id could not be mapped. """ @@ -648,6 +651,21 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): external_id: str, user_id: str, ) -> None: + """ + Record a mapping from an external user id to a mxid. + + Note that the auth provider IDs (and the external IDs) are not validated + against configured IdPs as Synapse does not know its relationship to + external systems. For example, it might be useful to pre-configure users + before enabling a new IdP or an IdP might be temporarily offline, but + still valid. + + Args: + txn: The database transaction. + auth_provider: identifier for the remote auth provider + external_id: id on that system + user_id: complete mxid that it is mapped to + """ self.db_pool.simple_insert_txn( txn, @@ -687,10 +705,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): """Replace mappings from external user ids to a mxid in a single transaction. All mappings are deleted and the new ones are created. + See notes in _record_user_external_id_txn about what constitutes valid data. + Args: record_external_ids: List with tuple of auth_provider and external_id to record user_id: complete mxid that it is mapped to + Raises: ExternalIDReuseException if the new external_id could not be mapped. """ From 99f6d79fe17b2f96c6c3cf85c4f0fee255758300 Mon Sep 17 00:00:00 2001 From: Travis Ralston Date: Mon, 21 Feb 2022 08:59:29 -0700 Subject: [PATCH 28/84] Advertise Matrix 1.2 in `/_matrix/client/versions` (#12022) Co-authored-by: Patrick Cloke --- changelog.d/12022.feature | 1 + synapse/rest/client/versions.py | 1 + 2 files changed, 2 insertions(+) create mode 100644 changelog.d/12022.feature diff --git a/changelog.d/12022.feature b/changelog.d/12022.feature new file mode 100644 index 000000000..188fb1257 --- /dev/null +++ b/changelog.d/12022.feature @@ -0,0 +1 @@ +Advertise Matrix 1.2 support on `/_matrix/client/versions`. \ No newline at end of file diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 35b88e9bb..00f29344a 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -74,6 +74,7 @@ class VersionsRestServlet(RestServlet): "r0.6.0", "r0.6.1", "v1.1", + "v1.2", ], # as per MSC1497: "unstable_features": { From 7c82da27aa6acff3ac40343719b440133955c207 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 21 Feb 2022 17:03:06 +0100 Subject: [PATCH 29/84] Add type hints to `synapse/storage/databases/main` (#11984) --- changelog.d/11984.misc | 1 + mypy.ini | 3 - synapse/handlers/presence.py | 26 ++++---- synapse/storage/databases/main/presence.py | 61 +++++++++++++------ .../storage/databases/main/purge_events.py | 13 ++-- .../storage/databases/main/user_directory.py | 22 +++---- synapse/types.py | 6 +- 7 files changed, 79 insertions(+), 53 deletions(-) create mode 100644 changelog.d/11984.misc diff --git a/changelog.d/11984.misc b/changelog.d/11984.misc new file mode 100644 index 000000000..8e405b922 --- /dev/null +++ b/changelog.d/11984.misc @@ -0,0 +1 @@ +Add missing type hints to storage classes. \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index 63848d664..610660b9b 100644 --- a/mypy.ini +++ b/mypy.ini @@ -31,14 +31,11 @@ exclude = (?x) |synapse/storage/databases/main/group_server.py |synapse/storage/databases/main/metrics.py |synapse/storage/databases/main/monthly_active_users.py - |synapse/storage/databases/main/presence.py - |synapse/storage/databases/main/purge_events.py |synapse/storage/databases/main/push_rule.py |synapse/storage/databases/main/receipts.py |synapse/storage/databases/main/roommember.py |synapse/storage/databases/main/search.py |synapse/storage/databases/main/state.py - |synapse/storage/databases/main/user_directory.py |synapse/storage/schema/ |tests/api/test_auth.py diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 067c43ae4..b223b7262 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -204,25 +204,27 @@ class BasePresenceHandler(abc.ABC): Returns: dict: `user_id` -> `UserPresenceState` """ - states = { - user_id: self.user_to_current_state.get(user_id, None) - for user_id in user_ids - } + states = {} + missing = [] + for user_id in user_ids: + state = self.user_to_current_state.get(user_id, None) + if state: + states[user_id] = state + else: + missing.append(user_id) - missing = [user_id for user_id, state in states.items() if not state] if missing: # There are things not in our in memory cache. Lets pull them out of # the database. res = await self.store.get_presence_for_users(missing) states.update(res) - missing = [user_id for user_id, state in states.items() if not state] - if missing: - new = { - user_id: UserPresenceState.default(user_id) for user_id in missing - } - states.update(new) - self.user_to_current_state.update(new) + for user_id in missing: + # if user has no state in database, create the state + if not res.get(user_id, None): + new_state = UserPresenceState.default(user_id) + states[user_id] = new_state + self.user_to_current_state[user_id] = new_state return states diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 4f05811a7..d3c461168 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -12,15 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, cast from synapse.api.presence import PresenceState, UserPresenceState from synapse.replication.tcp.streams import PresenceStream from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.engines import PostgresEngine from synapse.storage.types import Connection -from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator +from synapse.storage.util.id_generators import ( + AbstractStreamIdGenerator, + MultiWriterIdGenerator, + StreamIdGenerator, +) from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.iterutils import batch_iter @@ -35,7 +43,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore): database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", - ): + ) -> None: super().__init__(database, db_conn, hs) # Used by `PresenceStore._get_active_presence()` @@ -54,11 +62,14 @@ class PresenceStore(PresenceBackgroundUpdateStore): database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", - ): + ) -> None: super().__init__(database, db_conn, hs) + self._instance_name = hs.get_instance_name() + self._presence_id_gen: AbstractStreamIdGenerator + self._can_persist_presence = ( - hs.get_instance_name() in hs.config.worker.writers.presence + self._instance_name in hs.config.worker.writers.presence ) if isinstance(database.engine, PostgresEngine): @@ -109,7 +120,9 @@ class PresenceStore(PresenceBackgroundUpdateStore): return stream_orderings[-1], self._presence_id_gen.get_current_token() - def _update_presence_txn(self, txn, stream_orderings, presence_states): + def _update_presence_txn( + self, txn: LoggingTransaction, stream_orderings, presence_states + ) -> None: for stream_id, state in zip(stream_orderings, presence_states): txn.call_after( self.presence_stream_cache.entity_has_changed, state.user_id, stream_id @@ -183,19 +196,23 @@ class PresenceStore(PresenceBackgroundUpdateStore): if last_id == current_id: return [], current_id, False - def get_all_presence_updates_txn(txn): + def get_all_presence_updates_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, list]], int, bool]: sql = """ SELECT stream_id, user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, - status_msg, - currently_active + status_msg, currently_active FROM presence_stream WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC LIMIT ? """ txn.execute(sql, (last_id, current_id, limit)) - updates = [(row[0], row[1:]) for row in txn] + updates = cast( + List[Tuple[int, list]], + [(row[0], row[1:]) for row in txn], + ) upper_bound = current_id limited = False @@ -210,7 +227,7 @@ class PresenceStore(PresenceBackgroundUpdateStore): ) @cached() - def _get_presence_for_user(self, user_id): + def _get_presence_for_user(self, user_id: str) -> None: raise NotImplementedError() @cachedList( @@ -218,7 +235,9 @@ class PresenceStore(PresenceBackgroundUpdateStore): list_name="user_ids", num_args=1, ) - async def get_presence_for_users(self, user_ids): + async def get_presence_for_users( + self, user_ids: Iterable[str] + ) -> Dict[str, UserPresenceState]: rows = await self.db_pool.simple_select_many_batch( table="presence_stream", column="user_id", @@ -257,7 +276,9 @@ class PresenceStore(PresenceBackgroundUpdateStore): True if the user should have full presence sent to them, False otherwise. """ - def _should_user_receive_full_presence_with_token_txn(txn): + def _should_user_receive_full_presence_with_token_txn( + txn: LoggingTransaction, + ) -> bool: sql = """ SELECT 1 FROM users_to_send_full_presence_to WHERE user_id = ? @@ -271,7 +292,7 @@ class PresenceStore(PresenceBackgroundUpdateStore): _should_user_receive_full_presence_with_token_txn, ) - async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]): + async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]) -> None: """Adds to the list of users who should receive a full snapshot of presence upon their next sync. @@ -353,10 +374,10 @@ class PresenceStore(PresenceBackgroundUpdateStore): return users_to_state - def get_current_presence_token(self): + def get_current_presence_token(self) -> int: return self._presence_id_gen.get_current_token() - def _get_active_presence(self, db_conn: Connection): + def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]: """Fetch non-offline presence from the database so that we can register the appropriate time outs. """ @@ -379,12 +400,12 @@ class PresenceStore(PresenceBackgroundUpdateStore): return [UserPresenceState(**row) for row in rows] - def take_presence_startup_info(self): + def take_presence_startup_info(self) -> List[UserPresenceState]: active_on_startup = self._presence_on_startup - self._presence_on_startup = None + self._presence_on_startup = [] return active_on_startup - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows) -> None: if stream_name == PresenceStream.NAME: self._presence_id_gen.advance(instance_name, token) for row in rows: diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index e87a8fb85..2e3818e43 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -13,9 +13,10 @@ # limitations under the License. import logging -from typing import Any, List, Set, Tuple +from typing import Any, List, Set, Tuple, cast from synapse.api.errors import SynapseError +from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main import CacheInvalidationWorkerStore from synapse.storage.databases.main.state import StateGroupWorkerStore from synapse.types import RoomStreamToken @@ -55,7 +56,11 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): ) def _purge_history_txn( - self, txn, room_id: str, token: RoomStreamToken, delete_local_events: bool + self, + txn: LoggingTransaction, + room_id: str, + token: RoomStreamToken, + delete_local_events: bool, ) -> Set[int]: # Tables that should be pruned: # event_auth @@ -273,7 +278,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): """, (room_id,), ) - (min_depth,) = txn.fetchone() + (min_depth,) = cast(Tuple[int], txn.fetchone()) logger.info("[purge] updating room_depth to %d", min_depth) @@ -318,7 +323,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): "purge_room", self._purge_room_txn, room_id ) - def _purge_room_txn(self, txn, room_id: str) -> List[int]: + def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]: # First we fetch all the state groups that should be deleted, before # we delete that information. txn.execute( diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index f7c778bdf..e7fddd242 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -58,7 +58,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", - ): + ) -> None: super().__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -234,10 +234,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): processed_event_count = 0 for room_id, event_count in rooms_to_work_on: - is_in_room = await self.is_host_joined(room_id, self.server_name) + is_in_room = await self.is_host_joined(room_id, self.server_name) # type: ignore[attr-defined] if is_in_room: - users_with_profile = await self.get_users_in_room_with_profiles(room_id) + users_with_profile = await self.get_users_in_room_with_profiles(room_id) # type: ignore[attr-defined] # Throw away users excluded from the directory. users_with_profile = { user_id: profile @@ -368,7 +368,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): for user_id in users_to_work_on: if await self.should_include_local_user_in_dir(user_id): - profile = await self.get_profileinfo(get_localpart_from_id(user_id)) + profile = await self.get_profileinfo(get_localpart_from_id(user_id)) # type: ignore[attr-defined] await self.update_profile_in_user_dir( user_id, profile.display_name, profile.avatar_url ) @@ -397,7 +397,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # technically it could be DM-able. In the future, this could potentially # be configurable per-appservice whether the appservice sender can be # contacted. - if self.get_app_service_by_user_id(user) is not None: + if self.get_app_service_by_user_id(user) is not None: # type: ignore[attr-defined] return False # We're opting to exclude appservice users (anyone matching the user @@ -405,17 +405,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # they could be DM-able. In the future, this could potentially # be configurable per-appservice whether the appservice users can be # contacted. - if self.get_if_app_services_interested_in_user(user): + if self.get_if_app_services_interested_in_user(user): # type: ignore[attr-defined] # TODO we might want to make this configurable for each app service return False # Support users are for diagnostics and should not appear in the user directory. - if await self.is_support_user(user): + if await self.is_support_user(user): # type: ignore[attr-defined] return False # Deactivated users aren't contactable, so should not appear in the user directory. try: - if await self.get_user_deactivated_status(user): + if await self.get_user_deactivated_status(user): # type: ignore[attr-defined] return False except StoreError: # No such user in the users table. No need to do this when calling @@ -433,20 +433,20 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): (EventTypes.RoomHistoryVisibility, ""), ) - current_state_ids = await self.get_filtered_current_state_ids( + current_state_ids = await self.get_filtered_current_state_ids( # type: ignore[attr-defined] room_id, StateFilter.from_types(types_to_filter) ) join_rules_id = current_state_ids.get((EventTypes.JoinRules, "")) if join_rules_id: - join_rule_ev = await self.get_event(join_rules_id, allow_none=True) + join_rule_ev = await self.get_event(join_rules_id, allow_none=True) # type: ignore[attr-defined] if join_rule_ev: if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC: return True hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, "")) if hist_vis_id: - hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True) + hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True) # type: ignore[attr-defined] if hist_vis_ev: if ( hist_vis_ev.content.get("history_visibility") diff --git a/synapse/types.py b/synapse/types.py index f89fb216a..53be3583a 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -51,7 +51,7 @@ from synapse.util.stringutils import parse_and_validate_server_name if TYPE_CHECKING: from synapse.appservice.api import ApplicationService - from synapse.storage.databases.main import DataStore + from synapse.storage.databases.main import DataStore, PurgeEventsStore # Define a state map type from type/state_key to T (usually an event ID or # event) @@ -485,7 +485,7 @@ class RoomStreamToken: ) @classmethod - async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken": + async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken": try: if string[0] == "s": return cls(topological=None, stream=int(string[1:])) @@ -502,7 +502,7 @@ class RoomStreamToken: instance_id = int(key) pos = int(value) - instance_name = await store.get_name_from_instance_id(instance_id) + instance_name = await store.get_name_from_instance_id(instance_id) # type: ignore[attr-defined] instance_map[instance_name] = pos return cls( From a85dde34459451f2ccece2374ca6280b5f55335f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 21 Feb 2022 18:37:04 +0000 Subject: [PATCH 30/84] Minor typing fixes (#12034) These started failing in https://github.com/matrix-org/synapse/pull/12031... I'm a bit mystified by how they ever worked. --- changelog.d/12034.misc | 1 + .../federation/sender/per_destination_queue.py | 18 +++++++++--------- synapse/handlers/message.py | 10 ++++++---- synapse/handlers/register.py | 6 +++--- 4 files changed, 19 insertions(+), 16 deletions(-) create mode 100644 changelog.d/12034.misc diff --git a/changelog.d/12034.misc b/changelog.d/12034.misc new file mode 100644 index 000000000..8374a6322 --- /dev/null +++ b/changelog.d/12034.misc @@ -0,0 +1 @@ +Minor typing fixes. diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 8152e80b8..c3132f731 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -381,7 +381,9 @@ class PerDestinationQueue: ) ) - if self._last_successful_stream_ordering is None: + last_successful_stream_ordering = self._last_successful_stream_ordering + + if last_successful_stream_ordering is None: # if it's still None, then this means we don't have the information # in our database ­ we haven't successfully sent a PDU to this server # (at least since the introduction of the feature tracking @@ -394,8 +396,7 @@ class PerDestinationQueue: # get at most 50 catchup room/PDUs while True: event_ids = await self._store.get_catch_up_room_event_ids( - self._destination, - self._last_successful_stream_ordering, + self._destination, last_successful_stream_ordering ) if not event_ids: @@ -403,7 +404,7 @@ class PerDestinationQueue: # of a race condition, so we check that no new events have been # skipped due to us being in catch-up mode - if self._catchup_last_skipped > self._last_successful_stream_ordering: + if self._catchup_last_skipped > last_successful_stream_ordering: # another event has been skipped because we were in catch-up mode continue @@ -470,7 +471,7 @@ class PerDestinationQueue: # offline if ( p.internal_metadata.stream_ordering - < self._last_successful_stream_ordering + < last_successful_stream_ordering ): continue @@ -513,12 +514,11 @@ class PerDestinationQueue: # from the *original* PDU, rather than the PDU(s) we actually # send. This is because we use it to mark our position in the # queue of missed PDUs to process. - self._last_successful_stream_ordering = ( - pdu.internal_metadata.stream_ordering - ) + last_successful_stream_ordering = pdu.internal_metadata.stream_ordering + self._last_successful_stream_ordering = last_successful_stream_ordering await self._store.set_destination_last_successful_stream_ordering( - self._destination, self._last_successful_stream_ordering + self._destination, last_successful_stream_ordering ) def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 9267e586a..4d0da8428 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -550,10 +550,11 @@ class EventCreationHandler: if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "": room_version_id = event_dict["content"]["room_version"] - room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id) - if not room_version_obj: + maybe_room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id) + if not maybe_room_version_obj: # this can happen if support is withdrawn for a room version raise UnsupportedRoomVersionError(room_version_id) + room_version_obj = maybe_room_version_obj else: try: room_version_obj = await self.store.get_room_version( @@ -1145,12 +1146,13 @@ class EventCreationHandler: room_version_id = event.content.get( "room_version", RoomVersions.V1.identifier ) - room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id) - if not room_version_obj: + maybe_room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id) + if not maybe_room_version_obj: raise UnsupportedRoomVersionError( "Attempt to create a room with unsupported room version %s" % (room_version_id,) ) + room_version_obj = maybe_room_version_obj else: room_version_obj = await self.store.get_room_version(event.room_id) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index a719d5eef..80320d2c0 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -320,12 +320,12 @@ class RegistrationHandler: if fail_count > 10: raise SynapseError(500, "Unable to find a suitable guest user ID") - localpart = await self.store.generate_user_id() - user = UserID(localpart, self.hs.hostname) + generated_localpart = await self.store.generate_user_id() + user = UserID(generated_localpart, self.hs.hostname) user_id = user.to_string() self.check_user_id_not_appservice_exclusive(user_id) if generate_display_name: - default_display_name = localpart + default_display_name = generated_localpart try: await self.register_with_store( user_id=user_id, From 3070af4809016f547adf55fb02dbe9e569590f7e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 21 Feb 2022 19:27:35 +0000 Subject: [PATCH 31/84] remote join processing: get create event from state, not auth_chain (#12039) A follow-up to #12005, in which I apparently missed that there are a bunch of other places that assume the create event is in the auth chain. --- changelog.d/12039.misc | 1 + synapse/federation/federation_client.py | 6 ++++-- synapse/handlers/federation.py | 2 +- synapse/storage/databases/main/room.py | 4 ++-- 4 files changed, 8 insertions(+), 5 deletions(-) create mode 100644 changelog.d/12039.misc diff --git a/changelog.d/12039.misc b/changelog.d/12039.misc new file mode 100644 index 000000000..45e21dbe5 --- /dev/null +++ b/changelog.d/12039.misc @@ -0,0 +1 @@ +Preparation for faster-room-join work: when parsing the `send_join` response, get the `m.room.create` event from `state`, not `auth_chain`. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 9f56f97d9..48c90bf0b 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -870,13 +870,15 @@ class FederationClient(FederationBase): for s in signed_state: s.internal_metadata = copy.deepcopy(s.internal_metadata) - # double-check that the same create event has ended up in the auth chain + # double-check that the auth chain doesn't include a different create event auth_chain_create_events = [ e.event_id for e in signed_auth if (e.type, e.state_key) == (EventTypes.Create, "") ] - if auth_chain_create_events != [create_event.event_id]: + if auth_chain_create_events and auth_chain_create_events != [ + create_event.event_id + ]: raise InvalidResponseError( "Unexpected create event(s) in auth chain: %s" % (auth_chain_create_events,) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index c0f642005..c8356f233 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -516,7 +516,7 @@ class FederationHandler: await self.store.upsert_room_on_join( room_id=room_id, room_version=room_version_obj, - auth_events=auth_chain, + state_events=state, ) max_stream_id = await self._federation_event_handler.process_remote_join( diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 95167116c..0416df64c 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1498,7 +1498,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") async def upsert_room_on_join( - self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase] + self, room_id: str, room_version: RoomVersion, state_events: List[EventBase] ) -> None: """Ensure that the room is stored in the table @@ -1511,7 +1511,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): has_auth_chain_index = await self.has_auth_chain_index(room_id) create_event = None - for e in auth_events: + for e in state_events: if (e.type, e.state_key) == (EventTypes.Create, ""): create_event = e break From d7cb0dcbaa59c9e16f9b46415dbc645735dce8f4 Mon Sep 17 00:00:00 2001 From: Travis Ralston Date: Tue, 22 Feb 2022 04:20:45 -0700 Subject: [PATCH 32/84] Use v3 endpoints for fallback auth (Matrix 1.1) (#12019) --- changelog.d/12019.misc | 1 + synapse/rest/client/auth.py | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) create mode 100644 changelog.d/12019.misc diff --git a/changelog.d/12019.misc b/changelog.d/12019.misc new file mode 100644 index 000000000..b2186320e --- /dev/null +++ b/changelog.d/12019.misc @@ -0,0 +1 @@ +Use Matrix v1.1 endpoints (`/_matrix/client/v3/auth/...`) in fallback auth HTML forms. \ No newline at end of file diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index 9c15a0433..e0b2b80e5 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -62,7 +62,7 @@ class AuthRestServlet(RestServlet): if stagetype == LoginType.RECAPTCHA: html = self.recaptcha_template.render( session=session, - myurl="%s/r0/auth/%s/fallback/web" + myurl="%s/v3/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), sitekey=self.hs.config.captcha.recaptcha_public_key, ) @@ -74,7 +74,7 @@ class AuthRestServlet(RestServlet): self.hs.config.server.public_baseurl, self.hs.config.consent.user_consent_version, ), - myurl="%s/r0/auth/%s/fallback/web" + myurl="%s/v3/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), ) @@ -118,7 +118,7 @@ class AuthRestServlet(RestServlet): # Authentication failed, let user try again html = self.recaptcha_template.render( session=session, - myurl="%s/r0/auth/%s/fallback/web" + myurl="%s/v3/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), sitekey=self.hs.config.captcha.recaptcha_public_key, error=e.msg, @@ -143,7 +143,7 @@ class AuthRestServlet(RestServlet): self.hs.config.server.public_baseurl, self.hs.config.consent.user_consent_version, ), - myurl="%s/r0/auth/%s/fallback/web" + myurl="%s/v3/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), error=e.msg, ) From 1ae492c8c09edef4a6d2af65588895305eaedec3 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 22 Feb 2022 11:30:19 +0000 Subject: [PATCH 33/84] Move isort config to `pyproject.toml` (#12052) --- changelog.d/12052.misc | 1 + pyproject.toml | 12 ++++++++++++ setup.cfg | 11 ----------- tox.ini | 2 +- 4 files changed, 14 insertions(+), 12 deletions(-) create mode 100644 changelog.d/12052.misc diff --git a/changelog.d/12052.misc b/changelog.d/12052.misc new file mode 100644 index 000000000..fbaff67e9 --- /dev/null +++ b/changelog.d/12052.misc @@ -0,0 +1 @@ +Move `isort` configuration to `pyproject.toml`. diff --git a/pyproject.toml b/pyproject.toml index 963f149c6..c9cd0cf6e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,3 +54,15 @@ exclude = ''' )/ ) ''' + +[tool.isort] +line_length = 88 +sections = ["FUTURE", "STDLIB", "THIRDPARTY", "TWISTED", "FIRSTPARTY", "TESTS", "LOCALFOLDER"] +default_section = "THIRDPARTY" +known_first_party = ["synapse"] +known_tests = ["tests"] +known_twisted = ["twisted", "OpenSSL"] +multi_line_output = 3 +include_trailing_comma = true +combine_as_imports = true + diff --git a/setup.cfg b/setup.cfg index e5ceb7ed1..a0506572d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,14 +19,3 @@ ignore = # E731: do not assign a lambda expression, use a def # E501: Line too long (black enforces this for us) ignore=W503,W504,E203,E731,E501 - -[isort] -line_length = 88 -sections=FUTURE,STDLIB,THIRDPARTY,TWISTED,FIRSTPARTY,TESTS,LOCALFOLDER -default_section=THIRDPARTY -known_first_party = synapse -known_tests=tests -known_twisted=twisted,OpenSSL -multi_line_output=3 -include_trailing_comma=true -combine_as_imports=true diff --git a/tox.ini b/tox.ini index 41678aa38..436ecf755 100644 --- a/tox.ini +++ b/tox.ini @@ -166,7 +166,7 @@ commands = [testenv:check_isort] extras = lint -commands = isort -c --df --sp setup.cfg {[base]lint_targets} +commands = isort -c --df {[base]lint_targets} [testenv:check-newsfragment] skip_install = true From af2c1e3d2a56c4042db27e70b72409ce8f4b406e Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 22 Feb 2022 11:33:37 +0000 Subject: [PATCH 34/84] Tidy the building of sdists and wheels (#12051) * Don't build distribution pkgs in tests.yml * Run `release-artifacts` on release branches * Use backend-meta workflow for packaging --- .github/workflows/release-artifacts.yml | 14 ++------------ .github/workflows/tests.yml | 17 +---------------- changelog.d/12051.misc | 1 + 3 files changed, 4 insertions(+), 28 deletions(-) create mode 100644 changelog.d/12051.misc diff --git a/.github/workflows/release-artifacts.yml b/.github/workflows/release-artifacts.yml index eb294f161..eee3633d5 100644 --- a/.github/workflows/release-artifacts.yml +++ b/.github/workflows/release-artifacts.yml @@ -7,7 +7,7 @@ on: # of things breaking (but only build one set of debs) pull_request: push: - branches: ["develop"] + branches: ["develop", "release-*"] # we do the full build on tags. tags: ["v*"] @@ -91,17 +91,7 @@ jobs: build-sdist: name: "Build pypi distribution files" - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 - - run: pip install wheel - - run: | - python setup.py sdist bdist_wheel - - uses: actions/upload-artifact@v2 - with: - name: python-dist - path: dist/* + uses: "matrix-org/backend-meta/.github/workflows/packaging.yml@v1" # if it's a tag, create a release and attach the artifacts to it attach-assets: diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 75ac1304b..bbf1033bd 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -48,24 +48,10 @@ jobs: env: PULL_REQUEST_NUMBER: ${{ github.event.number }} - lint-sdist: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 - with: - python-version: "3.x" - - run: pip install wheel - - run: python setup.py sdist bdist_wheel - - uses: actions/upload-artifact@v2 - with: - name: Python Distributions - path: dist/* - # Dummy step to gate other tests on without repeating the whole list linting-done: if: ${{ !cancelled() }} # Run this even if prior jobs were skipped - needs: [lint, lint-crlf, lint-newsfile, lint-sdist] + needs: [lint, lint-crlf, lint-newsfile] runs-on: ubuntu-latest steps: - run: "true" @@ -397,7 +383,6 @@ jobs: - lint - lint-crlf - lint-newsfile - - lint-sdist - trial - trial-olddeps - sytest diff --git a/changelog.d/12051.misc b/changelog.d/12051.misc new file mode 100644 index 000000000..995919135 --- /dev/null +++ b/changelog.d/12051.misc @@ -0,0 +1 @@ +Tidy up GitHub Actions config which builds distributions for PyPI. \ No newline at end of file From 546b9c9e648f5e2b25bb7c8350570787ff9befae Mon Sep 17 00:00:00 2001 From: reivilibre Date: Tue, 22 Feb 2022 11:44:11 +0000 Subject: [PATCH 35/84] Add more tests for in-flight state query duplication. (#12033) --- changelog.d/12033.misc | 1 + tests/storage/databases/test_state_store.py | 196 +++++++++++++++++--- 2 files changed, 174 insertions(+), 23 deletions(-) create mode 100644 changelog.d/12033.misc diff --git a/changelog.d/12033.misc b/changelog.d/12033.misc new file mode 100644 index 000000000..3af049b96 --- /dev/null +++ b/changelog.d/12033.misc @@ -0,0 +1 @@ +Deduplicate in-flight requests in `_get_state_for_groups`. diff --git a/tests/storage/databases/test_state_store.py b/tests/storage/databases/test_state_store.py index cf126ee62..3a4a4a3a2 100644 --- a/tests/storage/databases/test_state_store.py +++ b/tests/storage/databases/test_state_store.py @@ -18,8 +18,9 @@ from unittest.mock import patch from twisted.internet.defer import Deferred, ensureDeferred from twisted.test.proto_helpers import MemoryReactor +from synapse.api.constants import EventTypes from synapse.storage.state import StateFilter -from synapse.types import MutableStateMap, StateMap +from synapse.types import StateMap from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -27,6 +28,21 @@ from tests.unittest import HomeserverTestCase if typing.TYPE_CHECKING: from synapse.server import HomeServer +# StateFilter for ALL non-m.room.member state events +ALL_NON_MEMBERS_STATE_FILTER = StateFilter.freeze( + types={EventTypes.Member: set()}, + include_others=True, +) + +FAKE_STATE = { + (EventTypes.Member, "@alice:test"): "join", + (EventTypes.Member, "@bob:test"): "leave", + (EventTypes.Member, "@charlie:test"): "invite", + ("test.type", "a"): "AAA", + ("test.type", "b"): "BBB", + ("other.event.type", "state.key"): "123", +} + class StateGroupInflightCachingTestCase(HomeserverTestCase): def prepare( @@ -65,24 +81,8 @@ class StateGroupInflightCachingTestCase(HomeserverTestCase): Assemble a fake database response and complete the database request. """ - result: Dict[int, StateMap[str]] = {} - - for group in groups: - group_result: MutableStateMap[str] = {} - result[group] = group_result - - for state_type, state_keys in state_filter.types.items(): - if state_keys is None: - group_result[(state_type, "a")] = "xyz" - group_result[(state_type, "b")] = "xyz" - else: - for state_key in state_keys: - group_result[(state_type, state_key)] = "abc" - - if state_filter.include_others: - group_result[("other.event.type", "state.key")] = "123" - - d.callback(result) + # Return a filtered copy of the fake state + d.callback({group: state_filter.filter_state(FAKE_STATE) for group in groups}) def test_duplicate_requests_deduplicated(self) -> None: """ @@ -125,9 +125,159 @@ class StateGroupInflightCachingTestCase(HomeserverTestCase): # Now we can complete the request self._complete_request_fake(groups, sf, d) - self.assertEqual( - self.get_success(req1), {("other.event.type", "state.key"): "123"} + self.assertEqual(self.get_success(req1), FAKE_STATE) + self.assertEqual(self.get_success(req2), FAKE_STATE) + + def test_smaller_request_deduplicated(self) -> None: + """ + Tests that duplicate requests for state are deduplicated. + + This test: + - requests some state (state group 42, 'all' state filter) + - requests a subset of that state, before the first request finishes + - checks to see that only one database query was made + - completes the database query + - checks that both requests see the correct retrieved state + """ + req1 = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, StateFilter.from_types((("test.type", None),)) + ) ) - self.assertEqual( - self.get_success(req2), {("other.event.type", "state.key"): "123"} + self.pump(by=0.1) + + # This should have gone to the database + self.assertEqual(len(self.get_state_group_calls), 1) + self.assertFalse(req1.called) + + req2 = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, StateFilter.from_types((("test.type", "b"),)) + ) ) + self.pump(by=0.1) + + # No more calls should have gone to the database, because the second + # request was already in the in-flight cache! + self.assertEqual(len(self.get_state_group_calls), 1) + self.assertFalse(req1.called) + self.assertFalse(req2.called) + + groups, sf, d = self.get_state_group_calls[0] + self.assertEqual(groups, (42,)) + # The state filter is expanded internally for increased cache hit rate, + # so we the database sees a wider state filter than requested. + self.assertEqual(sf, ALL_NON_MEMBERS_STATE_FILTER) + + # Now we can complete the request + self._complete_request_fake(groups, sf, d) + + self.assertEqual( + self.get_success(req1), + {("test.type", "a"): "AAA", ("test.type", "b"): "BBB"}, + ) + self.assertEqual(self.get_success(req2), {("test.type", "b"): "BBB"}) + + def test_partially_overlapping_request_deduplicated(self) -> None: + """ + Tests that partially-overlapping requests are partially deduplicated. + + This test: + - requests a single type of wildcard state + (This is internally expanded to be all non-member state) + - requests the entire state in parallel + - checks to see that two database queries were made, but that the second + one is only for member state. + - completes the database queries + - checks that both requests have the correct result. + """ + + req1 = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, StateFilter.from_types((("test.type", None),)) + ) + ) + self.pump(by=0.1) + + # This should have gone to the database + self.assertEqual(len(self.get_state_group_calls), 1) + self.assertFalse(req1.called) + + req2 = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, StateFilter.all() + ) + ) + self.pump(by=0.1) + + # Because it only partially overlaps, this also went to the database + self.assertEqual(len(self.get_state_group_calls), 2) + self.assertFalse(req1.called) + self.assertFalse(req2.called) + + # First request: + groups, sf, d = self.get_state_group_calls[0] + self.assertEqual(groups, (42,)) + # The state filter is expanded internally for increased cache hit rate, + # so we the database sees a wider state filter than requested. + self.assertEqual(sf, ALL_NON_MEMBERS_STATE_FILTER) + self._complete_request_fake(groups, sf, d) + + # Second request: + groups, sf, d = self.get_state_group_calls[1] + self.assertEqual(groups, (42,)) + # The state filter is narrowed to only request membership state, because + # the remainder of the state is already being queried in the first request! + self.assertEqual( + sf, StateFilter.freeze({EventTypes.Member: None}, include_others=False) + ) + self._complete_request_fake(groups, sf, d) + + # Check the results are correct + self.assertEqual( + self.get_success(req1), + {("test.type", "a"): "AAA", ("test.type", "b"): "BBB"}, + ) + self.assertEqual(self.get_success(req2), FAKE_STATE) + + def test_in_flight_requests_stop_being_in_flight(self) -> None: + """ + Tests that in-flight request deduplication doesn't somehow 'hold on' + to completed requests: once they're done, they're taken out of the + in-flight cache. + """ + req1 = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, StateFilter.all() + ) + ) + self.pump(by=0.1) + + # This should have gone to the database + self.assertEqual(len(self.get_state_group_calls), 1) + self.assertFalse(req1.called) + + # Complete the request right away. + self._complete_request_fake(*self.get_state_group_calls[0]) + self.assertTrue(req1.called) + + # Send off another request + req2 = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, StateFilter.all() + ) + ) + self.pump(by=0.1) + + # It should have gone to the database again, because the previous request + # isn't in-flight and therefore isn't available for deduplication. + self.assertEqual(len(self.get_state_group_calls), 2) + self.assertFalse(req2.called) + + # Complete the request right away. + self._complete_request_fake(*self.get_state_group_calls[1]) + self.assertTrue(req2.called) + groups, sf, d = self.get_state_group_calls[0] + + self.assertEqual(self.get_success(req1), FAKE_STATE) + self.assertEqual(self.get_success(req2), FAKE_STATE) From 45e2c04f78abea88224191fdc424646e1bdd14f4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 22 Feb 2022 12:00:05 +0000 Subject: [PATCH 36/84] Update changelog --- CHANGES.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 8d91a7921..4bbb50b75 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -11,7 +11,7 @@ Features -------- - Add experimental support for sending to-device messages to application services, as specified by [MSC2409](https://github.com/matrix-org/matrix-doc/pull/2409). ([\#11215](https://github.com/matrix-org/synapse/issues/11215), [\#11966](https://github.com/matrix-org/synapse/issues/11966)) -- Remove account data (including client config, push rules and ignored users) upon user deactivation. ([\#11655](https://github.com/matrix-org/synapse/issues/11655)) +- Add a background database update to purge account data for deactivated users. ([\#11655](https://github.com/matrix-org/synapse/issues/11655)) - Experimental support for [MSC3666](https://github.com/matrix-org/matrix-doc/pull/3666): including bundled aggregations in server side search results. ([\#11837](https://github.com/matrix-org/synapse/issues/11837)) - Enable cache time-based expiry by default. The `expiry_time` config flag has been superseded by `expire_caches` and `cache_entry_ttl`. ([\#11849](https://github.com/matrix-org/synapse/issues/11849)) - Add a callback to allow modules to allow or forbid a 3PID (email address, phone number) from being associated to a local account. ([\#11854](https://github.com/matrix-org/synapse/issues/11854)) @@ -273,7 +273,7 @@ Bugfixes Synapse 1.50.0 (2022-01-18) =========================== -**This release contains a critical bug that may prevent clients from being able to connect. +**This release contains a critical bug that may prevent clients from being able to connect. As such, it is not recommended to upgrade to 1.50.0. Instead, please upgrade straight to to 1.50.1. Further details are available in [this issue](https://github.com/matrix-org/synapse/issues/11763).** From 1bf9cbbf75f14ef7746b595f0a9167f0d4dd210f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 22 Feb 2022 12:00:46 +0000 Subject: [PATCH 37/84] Update changelog --- CHANGES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index 4bbb50b75..367adff5d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,7 +1,7 @@ Synapse 1.53.0 (2022-02-22) =========================== -No significant changes. +No significant changes since 1.53.0rc1. Synapse 1.53.0rc1 (2022-02-15) From 066171643b5812b05dd9352ee650f524567de877 Mon Sep 17 00:00:00 2001 From: AndrewRyanChama <89478935+AndrewRyanChama@users.noreply.github.com> Date: Tue, 22 Feb 2022 04:11:39 -0800 Subject: [PATCH 38/84] Fetch images when previewing Twitter URLs. (#11985) By including "bot" in the User-Agent, which some sites use to decide whether to include additional Open Graph information. --- changelog.d/11985.feature | 1 + synapse/res/providers.json | 4 +--- synapse/rest/media/v1/preview_url_resource.py | 10 +++++++++- 3 files changed, 11 insertions(+), 4 deletions(-) create mode 100644 changelog.d/11985.feature diff --git a/changelog.d/11985.feature b/changelog.d/11985.feature new file mode 100644 index 000000000..120d888a4 --- /dev/null +++ b/changelog.d/11985.feature @@ -0,0 +1 @@ +Fetch images when previewing Twitter URLs. Contributed by @AndrewRyanChama. diff --git a/synapse/res/providers.json b/synapse/res/providers.json index f1838f955..7b9958e45 100644 --- a/synapse/res/providers.json +++ b/synapse/res/providers.json @@ -5,8 +5,6 @@ "endpoints": [ { "schemes": [ - "https://twitter.com/*/status/*", - "https://*.twitter.com/*/status/*", "https://twitter.com/*/moments/*", "https://*.twitter.com/*/moments/*" ], @@ -14,4 +12,4 @@ } ] } -] \ No newline at end of file +] diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 8d3d1e54d..c08b60d10 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -402,7 +402,15 @@ class PreviewUrlResource(DirectServeJsonResource): url, output_stream=output_stream, max_size=self.max_spider_size, - headers={"Accept-Language": self.url_preview_accept_language}, + headers={ + b"Accept-Language": self.url_preview_accept_language, + # Use a custom user agent for the preview because some sites will only return + # Open Graph metadata to crawler user agents. Omit the Synapse version + # string to avoid leaking information. + b"User-Agent": [ + "Synapse (bot; +https://github.com/matrix-org/synapse)" + ], + }, is_allowed_content_type=_is_previewable, ) except SynapseError: From 7273011f60afbb1c9754ec73ee3661b19dca6bbd Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 22 Feb 2022 12:17:10 +0000 Subject: [PATCH 39/84] Faster joins: Support for calling `/federation/v1/state` (#12013) This is an endpoint that we have server-side support for, but no client-side support. It's going to be useful for resyncing partial-stated rooms, so let's introduce it. --- changelog.d/12013.misc | 1 + synapse/federation/federation_base.py | 10 +- synapse/federation/federation_client.py | 93 +++++++++++-- synapse/federation/transport/client.py | 70 +++++++++- synapse/http/matrixfederationclient.py | 50 ++++++- tests/federation/test_federation_client.py | 149 +++++++++++++++++++++ tests/unittest.py | 21 +++ 7 files changed, 377 insertions(+), 17 deletions(-) create mode 100644 changelog.d/12013.misc create mode 100644 tests/federation/test_federation_client.py diff --git a/changelog.d/12013.misc b/changelog.d/12013.misc new file mode 100644 index 000000000..c0fca8dcc --- /dev/null +++ b/changelog.d/12013.misc @@ -0,0 +1 @@ +Preparation for faster-room-join work: Support for calling `/federation/v1/state` on a remote server. diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 896168c05..fab6da3c0 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -47,6 +47,11 @@ class FederationBase: ) -> EventBase: """Checks that event is correctly signed by the sending server. + Also checks the content hash, and redacts the event if there is a mismatch. + + Also runs the event through the spam checker; if it fails, redacts the event + and flags it as soft-failed. + Args: room_version: The room version of the PDU pdu: the event to be checked @@ -55,7 +60,10 @@ class FederationBase: * the original event if the checks pass * a redacted version of the event (if the signature matched but the hash did not) - * throws a SynapseError if the signature check failed.""" + + Raises: + SynapseError if the signature check failed. + """ try: await _check_sigs_on_pdu(self.keyring, room_version, pdu) except SynapseError as e: diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 48c90bf0b..c2997997d 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -419,26 +419,90 @@ class FederationClient(FederationBase): return state_event_ids, auth_event_ids + async def get_room_state( + self, + destination: str, + room_id: str, + event_id: str, + room_version: RoomVersion, + ) -> Tuple[List[EventBase], List[EventBase]]: + """Calls the /state endpoint to fetch the state at a particular point + in the room. + + Any invalid events (those with incorrect or unverifiable signatures or hashes) + are filtered out from the response, and any duplicate events are removed. + + (Size limits and other event-format checks are *not* performed.) + + Note that the result is not ordered, so callers must be careful to process + the events in an order that handles dependencies. + + Returns: + a tuple of (state events, auth events) + """ + result = await self.transport_layer.get_room_state( + room_version, + destination, + room_id, + event_id, + ) + state_events = result.state + auth_events = result.auth_events + + # we may as well filter out any duplicates from the response, to save + # processing them multiple times. (In particular, events may be present in + # `auth_events` as well as `state`, which is redundant). + # + # We don't rely on the sort order of the events, so we can just stick them + # in a dict. + state_event_map = {event.event_id: event for event in state_events} + auth_event_map = { + event.event_id: event + for event in auth_events + if event.event_id not in state_event_map + } + + logger.info( + "Processing from /state: %d state events, %d auth events", + len(state_event_map), + len(auth_event_map), + ) + + valid_auth_events = await self._check_sigs_and_hash_and_fetch( + destination, auth_event_map.values(), room_version + ) + + valid_state_events = await self._check_sigs_and_hash_and_fetch( + destination, state_event_map.values(), room_version + ) + + return valid_state_events, valid_auth_events + async def _check_sigs_and_hash_and_fetch( self, origin: str, pdus: Collection[EventBase], room_version: RoomVersion, ) -> List[EventBase]: - """Takes a list of PDUs and checks the signatures and hashes of each - one. If a PDU fails its signature check then we check if we have it in - the database and if not then request if from the originating server of - that PDU. + """Checks the signatures and hashes of a list of events. + + If a PDU fails its signature check then we check if we have it in + the database, and if not then request it from the sender's server (if that + is different from `origin`). If that still fails, the event is omitted from + the returned list. If a PDU fails its content hash check then it is redacted. - The given list of PDUs are not modified, instead the function returns + Also runs each event through the spam checker; if it fails, redacts the event + and flags it as soft-failed. + + The given list of PDUs are not modified; instead the function returns a new list. Args: - origin - pdu - room_version + origin: The server that sent us these events + pdus: The events to be checked + room_version: the version of the room these events are in Returns: A list of PDUs that have valid signatures and hashes. @@ -469,11 +533,16 @@ class FederationClient(FederationBase): origin: str, room_version: RoomVersion, ) -> Optional[EventBase]: - """Takes a PDU and checks its signatures and hashes. If the PDU fails - its signature check then we check if we have it in the database and if - not then request if from the originating server of that PDU. + """Takes a PDU and checks its signatures and hashes. - If then PDU fails its content hash check then it is redacted. + If the PDU fails its signature check then we check if we have it in the + database; if not, we then request it from sender's server (if that is not the + same as `origin`). If that still fails, we return None. + + If the PDU fails its content hash check, it is redacted. + + Also runs the event through the spam checker; if it fails, redacts the event + and flags it as soft-failed. Args: origin diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index dca6e5c45..7e510e224 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -65,13 +65,12 @@ class TransportLayerClient: async def get_room_state_ids( self, destination: str, room_id: str, event_id: str ) -> JsonDict: - """Requests all state for a given room from the given server at the - given event. Returns the state's event_id's + """Requests the IDs of all state for a given room at the given event. Args: destination: The host name of the remote homeserver we want to get the state from. - context: The name of the context we want the state of + room_id: the room we want the state of event_id: The event we want the context at. Returns: @@ -87,6 +86,29 @@ class TransportLayerClient: try_trailing_slash_on_400=True, ) + async def get_room_state( + self, room_version: RoomVersion, destination: str, room_id: str, event_id: str + ) -> "StateRequestResponse": + """Requests the full state for a given room at the given event. + + Args: + room_version: the version of the room (required to build the event objects) + destination: The host name of the remote homeserver we want + to get the state from. + room_id: the room we want the state of + event_id: The event we want the context at. + + Returns: + Results in a dict received from the remote homeserver. + """ + path = _create_v1_path("/state/%s", room_id) + return await self.client.get_json( + destination, + path=path, + args={"event_id": event_id}, + parser=_StateParser(room_version), + ) + async def get_event( self, destination: str, event_id: str, timeout: Optional[int] = None ) -> JsonDict: @@ -1284,6 +1306,14 @@ class SendJoinResponse: servers_in_room: Optional[List[str]] = None +@attr.s(slots=True, auto_attribs=True) +class StateRequestResponse: + """The parsed response of a `/state` request.""" + + auth_events: List[EventBase] + state: List[EventBase] + + @ijson.coroutine def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]: """Helper function for use with `ijson.kvitems_coro` to parse key-value pairs @@ -1411,3 +1441,37 @@ class SendJoinParser(ByteParser[SendJoinResponse]): self._response.event_dict, self._room_version ) return self._response + + +class _StateParser(ByteParser[StateRequestResponse]): + """A parser for the response to `/state` requests. + + Args: + room_version: The version of the room. + """ + + CONTENT_TYPE = "application/json" + + def __init__(self, room_version: RoomVersion): + self._response = StateRequestResponse([], []) + self._room_version = room_version + self._coros = [ + ijson.items_coro( + _event_list_parser(room_version, self._response.state), + "pdus.item", + use_float=True, + ), + ijson.items_coro( + _event_list_parser(room_version, self._response.auth_events), + "auth_chain.item", + use_float=True, + ), + ] + + def write(self, data: bytes) -> int: + for c in self._coros: + c.send(data) + return len(data) + + def finish(self) -> StateRequestResponse: + return self._response diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index c5f8fcbb2..e7656fbb9 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -958,6 +958,7 @@ class MatrixFederationHttpClient: ) return body + @overload async def get_json( self, destination: str, @@ -967,7 +968,38 @@ class MatrixFederationHttpClient: timeout: Optional[int] = None, ignore_backoff: bool = False, try_trailing_slash_on_400: bool = False, + parser: Literal[None] = None, + max_response_size: Optional[int] = None, ) -> Union[JsonDict, list]: + ... + + @overload + async def get_json( + self, + destination: str, + path: str, + args: Optional[QueryArgs] = ..., + retry_on_dns_fail: bool = ..., + timeout: Optional[int] = ..., + ignore_backoff: bool = ..., + try_trailing_slash_on_400: bool = ..., + parser: ByteParser[T] = ..., + max_response_size: Optional[int] = ..., + ) -> T: + ... + + async def get_json( + self, + destination: str, + path: str, + args: Optional[QueryArgs] = None, + retry_on_dns_fail: bool = True, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + try_trailing_slash_on_400: bool = False, + parser: Optional[ByteParser] = None, + max_response_size: Optional[int] = None, + ): """GETs some json from the given host homeserver and path Args: @@ -992,6 +1024,13 @@ class MatrixFederationHttpClient: try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED response we should try appending a trailing slash to the end of the request. Workaround for #3622 in Synapse <= v0.99.3. + + parser: The parser to use to decode the response. Defaults to + parsing as JSON. + + max_response_size: The maximum size to read from the response. If None, + uses the default. + Returns: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. @@ -1026,8 +1065,17 @@ class MatrixFederationHttpClient: else: _sec_timeout = self.default_timeout + if parser is None: + parser = JsonParser() + body = await _handle_response( - self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser() + self.reactor, + _sec_timeout, + request, + response, + start_ms, + parser=parser, + max_response_size=max_response_size, ) return body diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py new file mode 100644 index 000000000..ec8864daf --- /dev/null +++ b/tests/federation/test_federation_client.py @@ -0,0 +1,149 @@ +# Copyright 2022 Matrix.org Federation C.I.C +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from unittest import mock + +import twisted.web.client +from twisted.internet import defer +from twisted.internet.protocol import Protocol +from twisted.python.failure import Failure +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.room_versions import RoomVersions +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock + +from tests.unittest import FederatingHomeserverTestCase + + +class FederationClientTest(FederatingHomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): + super().prepare(reactor, clock, homeserver) + + # mock out the Agent used by the federation client, which is easier than + # catching the HTTPS connection and do the TLS stuff. + self._mock_agent = mock.create_autospec(twisted.web.client.Agent, spec_set=True) + homeserver.get_federation_http_client().agent = self._mock_agent + + def test_get_room_state(self): + creator = f"@creator:{self.OTHER_SERVER_NAME}" + test_room_id = "!room_id" + + # mock up some events to use in the response. + # In real life, these would have things in `prev_events` and `auth_events`, but that's + # a bit annoying to mock up, and the code under test doesn't care, so we don't bother. + create_event_dict = self.add_hashes_and_signatures( + { + "room_id": test_room_id, + "type": "m.room.create", + "state_key": "", + "sender": creator, + "content": {"creator": creator}, + "prev_events": [], + "auth_events": [], + "origin_server_ts": 500, + } + ) + member_event_dict = self.add_hashes_and_signatures( + { + "room_id": test_room_id, + "type": "m.room.member", + "sender": creator, + "state_key": creator, + "content": {"membership": "join"}, + "prev_events": [], + "auth_events": [], + "origin_server_ts": 600, + } + ) + pl_event_dict = self.add_hashes_and_signatures( + { + "room_id": test_room_id, + "type": "m.room.power_levels", + "sender": creator, + "state_key": "", + "content": {}, + "prev_events": [], + "auth_events": [], + "origin_server_ts": 700, + } + ) + + # mock up the response, and have the agent return it + self._mock_agent.request.return_value = defer.succeed( + _mock_response( + { + "pdus": [ + create_event_dict, + member_event_dict, + pl_event_dict, + ], + "auth_chain": [ + create_event_dict, + member_event_dict, + ], + } + ) + ) + + # now fire off the request + state_resp, auth_resp = self.get_success( + self.hs.get_federation_client().get_room_state( + "yet_another_server", + test_room_id, + "event_id", + RoomVersions.V9, + ) + ) + + # check the right call got made to the agent + self._mock_agent.request.assert_called_once_with( + b"GET", + b"matrix://yet_another_server/_matrix/federation/v1/state/%21room_id?event_id=event_id", + headers=mock.ANY, + bodyProducer=None, + ) + + # ... and that the response is correct. + + # the auth_resp should be empty because all the events are also in state + self.assertEqual(auth_resp, []) + + # all of the events should be returned in state_resp, though not necessarily + # in the same order. We just check the type on the assumption that if the type + # is right, so is the rest of the event. + self.assertCountEqual( + [e.type for e in state_resp], + ["m.room.create", "m.room.member", "m.room.power_levels"], + ) + + +def _mock_response(resp: JsonDict): + body = json.dumps(resp).encode("utf-8") + + def deliver_body(p: Protocol): + p.dataReceived(body) + p.connectionLost(Failure(twisted.web.client.ResponseDone())) + + response = mock.Mock( + code=200, + phrase=b"OK", + headers=twisted.web.client.Headers({"content-Type": ["application/json"]}), + length=len(body), + deliverBody=deliver_body, + ) + mock.seal(response) + return response diff --git a/tests/unittest.py b/tests/unittest.py index a71892cb9..7983c1e8b 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -51,7 +51,10 @@ from twisted.web.server import Request from synapse import events from synapse.api.constants import EventTypes, Membership +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.config.homeserver import HomeServerConfig +from synapse.config.server import DEFAULT_ROOM_VERSION +from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.federation.transport.server import TransportLayerServer from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest, SynapseSite @@ -839,6 +842,24 @@ class FederatingHomeserverTestCase(HomeserverTestCase): client_ip=client_ip, ) + def add_hashes_and_signatures( + self, + event_dict: JsonDict, + room_version: RoomVersion = KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION], + ) -> JsonDict: + """Adds hashes and signatures to the given event dict + + Returns: + The modified event dict, for convenience + """ + add_hashes_and_signatures( + room_version, + event_dict, + signature_name=self.OTHER_SERVER_NAME, + signing_key=self.OTHER_SERVER_SIGNATURE_KEY, + ) + return event_dict + def _auth_header_for_request( origin: str, From 235d2916ceb0c9a8e874ea8ac6994d604d743444 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Tue, 22 Feb 2022 13:29:04 +0000 Subject: [PATCH 40/84] Fix slow performance of `/logout` in some cases where refresh tokens are in use. The slowness existed since the initial implementation of refresh tokens. (#12056) --- changelog.d/12056.bugfix | 1 + .../storage/databases/main/registration.py | 18 +++- .../04_refresh_tokens_index_next_token_id.sql | 28 ++++++ tests/rest/client/test_auth.py | 93 ++++++++++++++++++- 4 files changed, 136 insertions(+), 4 deletions(-) create mode 100644 changelog.d/12056.bugfix create mode 100644 synapse/storage/schema/main/delta/68/04_refresh_tokens_index_next_token_id.sql diff --git a/changelog.d/12056.bugfix b/changelog.d/12056.bugfix new file mode 100644 index 000000000..210e30c63 --- /dev/null +++ b/changelog.d/12056.bugfix @@ -0,0 +1 @@ +Fix slow performance of `/logout` in some cases where refresh tokens are in use. The slowness existed since the initial implementation of refresh tokens. \ No newline at end of file diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 17110bb03..dc6665237 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1681,7 +1681,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): user_id=row[1], device_id=row[2], next_token_id=row[3], - has_next_refresh_token_been_refreshed=row[4], + # SQLite returns 0 or 1 for false/true, so convert to a bool. + has_next_refresh_token_been_refreshed=bool(row[4]), # This column is nullable, ensure it's a boolean has_next_access_token_been_used=(row[5] or False), expiry_ts=row[6], @@ -1697,12 +1698,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): Set the successor of a refresh token, removing the existing successor if any. + This also deletes the predecessor refresh and access tokens, + since they cannot be valid anymore. + Args: token_id: ID of the refresh token to update. next_token_id: ID of its successor. """ - def _replace_refresh_token_txn(txn) -> None: + def _replace_refresh_token_txn(txn: LoggingTransaction) -> None: # First check if there was an existing refresh token old_next_token_id = self.db_pool.simple_select_one_onecol_txn( txn, @@ -1728,6 +1732,16 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): {"id": old_next_token_id}, ) + # Delete the previous refresh token, since we only want to keep the + # last 2 refresh tokens in the database. + # (The predecessor of the latest refresh token is still useful in + # case the refresh was interrupted and the client re-uses the old + # one.) + # This cascades to delete the associated access token. + self.db_pool.simple_delete_txn( + txn, "refresh_tokens", {"next_token_id": token_id} + ) + await self.db_pool.runInteraction( "replace_refresh_token", _replace_refresh_token_txn ) diff --git a/synapse/storage/schema/main/delta/68/04_refresh_tokens_index_next_token_id.sql b/synapse/storage/schema/main/delta/68/04_refresh_tokens_index_next_token_id.sql new file mode 100644 index 000000000..09305638e --- /dev/null +++ b/synapse/storage/schema/main/delta/68/04_refresh_tokens_index_next_token_id.sql @@ -0,0 +1,28 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- next_token_id is a foreign key reference, so previously required a table scan +-- when a row in the referenced table was deleted. +-- As it was self-referential and cascaded deletes, this led to O(t*n) time to +-- delete a row, where t: number of rows in the table and n: number of rows in +-- the ancestral 'chain' of access tokens. +-- +-- This index is partial since we only require it for rows which reference +-- another. +-- Performance was tested to be the same regardless of whether the index was +-- full or partial, but a partial index can be smaller. +CREATE INDEX refresh_tokens_next_token_id + ON refresh_tokens(next_token_id) + WHERE next_token_id IS NOT NULL; diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 27cb856b0..4a68d6657 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -13,15 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. from http import HTTPStatus -from typing import Optional, Union +from typing import Optional, Tuple, Union from twisted.internet.defer import succeed import synapse.rest.admin from synapse.api.constants import LoginType from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker -from synapse.rest.client import account, auth, devices, login, register +from synapse.rest.client import account, auth, devices, login, logout, register from synapse.rest.synapse.client import build_synapse_client_resource_tree +from synapse.storage.database import LoggingTransaction from synapse.types import JsonDict, UserID from tests import unittest @@ -527,6 +528,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): auth.register_servlets, account.register_servlets, login.register_servlets, + logout.register_servlets, synapse.rest.admin.register_servlets_for_client_rest_resource, register.register_servlets, ] @@ -984,3 +986,90 @@ class RefreshAuthTests(unittest.HomeserverTestCase): self.assertEqual( fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result ) + + def test_many_token_refresh(self): + """ + If a refresh is performed many times during a session, there shouldn't be + extra 'cruft' built up over time. + + This test was written specifically to troubleshoot a case where logout + was very slow if a lot of refreshes had been performed for the session. + """ + + def _refresh(refresh_token: str) -> Tuple[str, str]: + """ + Performs one refresh, returning the next refresh token and access token. + """ + refresh_response = self.use_refresh_token(refresh_token) + self.assertEqual( + refresh_response.code, HTTPStatus.OK, refresh_response.result + ) + return ( + refresh_response.json_body["refresh_token"], + refresh_response.json_body["access_token"], + ) + + def _table_length(table_name: str) -> int: + """ + Helper to get the size of a table, in rows. + For testing only; trivially vulnerable to SQL injection. + """ + + def _txn(txn: LoggingTransaction) -> int: + txn.execute(f"SELECT COUNT(1) FROM {table_name}") + row = txn.fetchone() + # Query is infallible + assert row is not None + return row[0] + + return self.get_success( + self.hs.get_datastores().main.db_pool.runInteraction( + "_table_length", _txn + ) + ) + + # Before we log in, there are no access tokens. + self.assertEqual(_table_length("access_tokens"), 0) + self.assertEqual(_table_length("refresh_tokens"), 0) + + body = { + "type": "m.login.password", + "user": "test", + "password": self.user_pass, + "refresh_token": True, + } + login_response = self.make_request( + "POST", + "/_matrix/client/v3/login", + body, + ) + self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) + + access_token = login_response.json_body["access_token"] + refresh_token = login_response.json_body["refresh_token"] + + # Now that we have logged in, there should be one access token and one + # refresh token + self.assertEqual(_table_length("access_tokens"), 1) + self.assertEqual(_table_length("refresh_tokens"), 1) + + for _ in range(5): + refresh_token, access_token = _refresh(refresh_token) + + # After 5 sequential refreshes, there should only be the latest two + # refresh/access token pairs. + # (The last one is preserved because it's in use! + # The one before that is preserved because it can still be used to + # replace the last token pair, in case of e.g. a network interruption.) + self.assertEqual(_table_length("access_tokens"), 2) + self.assertEqual(_table_length("refresh_tokens"), 2) + + logout_response = self.make_request( + "POST", "/_matrix/client/v3/logout", {}, access_token=access_token + ) + self.assertEqual(logout_response.code, HTTPStatus.OK, logout_response.result) + + # Now that we have logged in, there should be no access token + # and no refresh token + self.assertEqual(_table_length("access_tokens"), 0) + self.assertEqual(_table_length("refresh_tokens"), 0) From 81364db49b7778021edcd5912555dd3e1583c1b8 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 22 Feb 2022 13:33:22 +0000 Subject: [PATCH 41/84] Run `_handle_queued_pdus` as a background process (#12041) ... to ensure it gets a proper log context, mostly. --- changelog.d/12041.misc | 1 + synapse/handlers/federation.py | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12041.misc diff --git a/changelog.d/12041.misc b/changelog.d/12041.misc new file mode 100644 index 000000000..e56dc093d --- /dev/null +++ b/changelog.d/12041.misc @@ -0,0 +1 @@ +After joining a room, create a dedicated logcontext to process the queued events. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index c8356f233..e9ac920bc 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -49,8 +49,8 @@ from synapse.logging.context import ( make_deferred_yieldable, nested_logging_context, preserve_fn, - run_in_background, ) +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.federation import ( ReplicationCleanRoomRestServlet, ReplicationStoreRoomOnOutlierMembershipRestServlet, @@ -559,7 +559,9 @@ class FederationHandler: # lots of requests for missing prev_events which we do actually # have. Hence we fire off the background task, but don't wait for it. - run_in_background(self._handle_queued_pdus, room_queue) + run_as_background_process( + "handle_queued_pdus", self._handle_queued_pdus, room_queue + ) async def do_knock( self, From 7bcc28f82fe160dc8dca1d70f8dc3667c94fb738 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 22 Feb 2022 09:09:40 -0500 Subject: [PATCH 42/84] Use room version 9 as the default room version (per MSC3589). (#12058) --- changelog.d/12058.feature | 1 + docs/sample_config.yaml | 2 +- synapse/config/server.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12058.feature diff --git a/changelog.d/12058.feature b/changelog.d/12058.feature new file mode 100644 index 000000000..7b7169222 --- /dev/null +++ b/changelog.d/12058.feature @@ -0,0 +1 @@ +Use room version 9 as the default room version (per [MSC3589](https://github.com/matrix-org/matrix-doc/pull/3589)). diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index d2bb3d420..6f3623c88 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -163,7 +163,7 @@ presence: # For example, for room version 1, default_room_version should be set # to "1". # -#default_room_version: "6" +#default_room_version: "9" # The GC threshold parameters to pass to `gc.set_threshold`, if defined # diff --git a/synapse/config/server.py b/synapse/config/server.py index 7bc962454..49cd0a4f1 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -146,7 +146,7 @@ DEFAULT_IP_RANGE_BLACKLIST = [ "fec0::/10", ] -DEFAULT_ROOM_VERSION = "6" +DEFAULT_ROOM_VERSION = "9" ROOM_COMPLEXITY_TOO_GREAT = ( "Your homeserver is unable to join rooms this large or complex. " From dcb6a378372d08a46a76c591704be4dc15c68df3 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Tue, 22 Feb 2022 14:24:31 +0000 Subject: [PATCH 43/84] Cap the number of in-flight requests for state from a single group (#11608) --- changelog.d/11608.misc | 1 + synapse/storage/databases/state/store.py | 16 +++++ tests/storage/databases/test_state_store.py | 69 +++++++++++++++++++++ 3 files changed, 86 insertions(+) create mode 100644 changelog.d/11608.misc diff --git a/changelog.d/11608.misc b/changelog.d/11608.misc new file mode 100644 index 000000000..3af049b96 --- /dev/null +++ b/changelog.d/11608.misc @@ -0,0 +1 @@ +Deduplicate in-flight requests in `_get_state_for_groups`. diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 3af69a207..b8016f679 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -56,6 +56,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) MAX_STATE_DELTA_HOPS = 100 +MAX_INFLIGHT_REQUESTS_PER_GROUP = 5 @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -258,6 +259,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): Attempts to gather in-flight requests and re-use them to retrieve state for the given state group, filtered with the given state filter. + If there are more than MAX_INFLIGHT_REQUESTS_PER_GROUP in-flight requests, + and there *still* isn't enough information to complete the request by solely + reusing others, a full state filter will be requested to ensure that subsequent + requests can reuse this request. + Used as part of _get_state_for_group_using_inflight_cache. Returns: @@ -288,6 +294,16 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): # to cover our StateFilter and give us the state we need. break + if ( + state_filter_left_over != StateFilter.none() + and len(inflight_requests) >= MAX_INFLIGHT_REQUESTS_PER_GROUP + ): + # There are too many requests for this group. + # To prevent even more from building up, we request the whole + # state filter to guarantee that we can be reused by any subsequent + # requests for this state group. + return (), StateFilter.all() + return reusable_requests, state_filter_left_over async def _get_state_for_group_fire_request( diff --git a/tests/storage/databases/test_state_store.py b/tests/storage/databases/test_state_store.py index 3a4a4a3a2..076b66080 100644 --- a/tests/storage/databases/test_state_store.py +++ b/tests/storage/databases/test_state_store.py @@ -19,6 +19,7 @@ from twisted.internet.defer import Deferred, ensureDeferred from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes +from synapse.storage.databases.state.store import MAX_INFLIGHT_REQUESTS_PER_GROUP from synapse.storage.state import StateFilter from synapse.types import StateMap from synapse.util import Clock @@ -281,3 +282,71 @@ class StateGroupInflightCachingTestCase(HomeserverTestCase): self.assertEqual(self.get_success(req1), FAKE_STATE) self.assertEqual(self.get_success(req2), FAKE_STATE) + + def test_inflight_requests_capped(self) -> None: + """ + Tests that the number of in-flight requests is capped to 5. + + - requests several pieces of state separately + (5 to hit the limit, 1 to 'shunt out', another that comes after the + group has been 'shunted out') + - checks to see that the torrent of requests is shunted out by + rewriting one of the filters as the 'all' state filter + - requests after that one do not cause any additional queries + """ + # 5 at the time of writing. + CAP_COUNT = MAX_INFLIGHT_REQUESTS_PER_GROUP + + reqs = [] + + # Request 7 different keys (1 to 7) of the `some.state` type. + for req_id in range(CAP_COUNT + 2): + reqs.append( + ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, + StateFilter.freeze( + {"some.state": {str(req_id + 1)}}, include_others=False + ), + ) + ) + ) + self.pump(by=0.1) + + # There should only be 6 calls to the database, not 7. + self.assertEqual(len(self.get_state_group_calls), CAP_COUNT + 1) + + # Assert that the first 5 are exact requests for the individual pieces + # wanted + for req_id in range(CAP_COUNT): + groups, sf, d = self.get_state_group_calls[req_id] + self.assertEqual( + sf, + StateFilter.freeze( + {"some.state": {str(req_id + 1)}}, include_others=False + ), + ) + + # The 6th request should be the 'all' state filter + groups, sf, d = self.get_state_group_calls[CAP_COUNT] + self.assertEqual(sf, StateFilter.all()) + + # Complete the queries and check which requests complete as a result + for req_id in range(CAP_COUNT): + # This request should not have been completed yet + self.assertFalse(reqs[req_id].called) + + groups, sf, d = self.get_state_group_calls[req_id] + self._complete_request_fake(groups, sf, d) + + # This should have only completed this one request + self.assertTrue(reqs[req_id].called) + + # Now complete the final query; the last 2 requests should complete + # as a result + self.assertFalse(reqs[CAP_COUNT].called) + self.assertFalse(reqs[CAP_COUNT + 1].called) + groups, sf, d = self.get_state_group_calls[CAP_COUNT] + self._complete_request_fake(groups, sf, d) + self.assertTrue(reqs[CAP_COUNT].called) + self.assertTrue(reqs[CAP_COUNT + 1].called) From 94a396e7c4b4488d7f0ca08672114a4a586cf42c Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 22 Feb 2022 14:52:56 +0000 Subject: [PATCH 44/84] Prune setup.cfg some more (#12059) * Remove `trial` section from setup.cfg This was added in the initial commit from 2014. I can't see that it does anything. Maybe it's there so that you can run `trial` without any extra args, but if I do that then I just get the `--help` message. * Move flake8's config to its own file --- .flake8 | 11 +++++++++++ MANIFEST.in | 1 + changelog.d/12052.misc | 2 +- changelog.d/12059.misc | 1 + setup.cfg | 12 ------------ 5 files changed, 14 insertions(+), 13 deletions(-) create mode 100644 .flake8 create mode 100644 changelog.d/12059.misc diff --git a/.flake8 b/.flake8 new file mode 100644 index 000000000..acb118c86 --- /dev/null +++ b/.flake8 @@ -0,0 +1,11 @@ +# TODO: incorporate this into pyproject.toml if flake8 supports it in the future. +# See https://github.com/PyCQA/flake8/issues/234 +[flake8] +# see https://pycodestyle.readthedocs.io/en/latest/intro.html#error-codes +# for error codes. The ones we ignore are: +# W503: line break before binary operator +# W504: line break after binary operator +# E203: whitespace before ':' (which is contrary to pep8?) +# E731: do not assign a lambda expression, use a def +# E501: Line too long (black enforces this for us) +ignore=W503,W504,E203,E731,E501 diff --git a/MANIFEST.in b/MANIFEST.in index c24786c3b..76d14eb64 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -45,6 +45,7 @@ include book.toml include pyproject.toml recursive-include changelog.d * +include .flake8 prune .circleci prune .github prune .ci diff --git a/changelog.d/12052.misc b/changelog.d/12052.misc index fbaff67e9..11755ae61 100644 --- a/changelog.d/12052.misc +++ b/changelog.d/12052.misc @@ -1 +1 @@ -Move `isort` configuration to `pyproject.toml`. +Move configuration out of `setup.cfg`. diff --git a/changelog.d/12059.misc b/changelog.d/12059.misc new file mode 100644 index 000000000..9ba4759d9 --- /dev/null +++ b/changelog.d/12059.misc @@ -0,0 +1 @@ +Move configuration out of `setup.cfg`. \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index a0506572d..6213f3265 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,3 @@ -[trial] -test_suite = tests - [check-manifest] ignore = .git-blame-ignore-revs @@ -10,12 +7,3 @@ ignore = pylint.cfg tox.ini -[flake8] -# see https://pycodestyle.readthedocs.io/en/latest/intro.html#error-codes -# for error codes. The ones we ignore are: -# W503: line break before binary operator -# W504: line break after binary operator -# E203: whitespace before ':' (which is contrary to pep8?) -# E731: do not assign a lambda expression, use a def -# E501: Line too long (black enforces this for us) -ignore=W503,W504,E203,E731,E501 From 250104d357c17a1c87fa46af35bbf3612f4ef171 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 22 Feb 2022 16:10:10 +0100 Subject: [PATCH 45/84] Implement account status endpoints (MSC3720) (#12001) See matrix-org/matrix-doc#3720 Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --- changelog.d/12001.feature | 1 + synapse/config/experimental.py | 3 + synapse/federation/federation_client.py | 60 +++++- synapse/federation/transport/client.py | 19 +- .../federation/transport/server/__init__.py | 8 + .../federation/transport/server/federation.py | 35 +++ synapse/handlers/account.py | 144 +++++++++++++ synapse/rest/client/account.py | 33 +++ synapse/rest/client/capabilities.py | 5 + synapse/server.py | 5 + tests/rest/client/test_account.py | 204 +++++++++++++++++- 11 files changed, 511 insertions(+), 6 deletions(-) create mode 100644 changelog.d/12001.feature create mode 100644 synapse/handlers/account.py diff --git a/changelog.d/12001.feature b/changelog.d/12001.feature new file mode 100644 index 000000000..dc1153c49 --- /dev/null +++ b/changelog.d/12001.feature @@ -0,0 +1 @@ +Implement experimental support for [MSC3720](https://github.com/matrix-org/matrix-doc/pull/3720) (account status endpoints). diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index bcdeb9ee2..772eb3501 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -65,3 +65,6 @@ class ExperimentalConfig(Config): # experimental support for faster joins over federation (msc2775, msc3706) # requires a target server with msc3706_enabled enabled. self.faster_joins_enabled: bool = experimental.get("faster_joins", False) + + # MSC3720 (Account status endpoint) + self.msc3720_enabled: bool = experimental.get("msc3720_enabled", False) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index c2997997d..2121e92e3 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -56,7 +56,7 @@ from synapse.api.room_versions import ( from synapse.events import EventBase, builder from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.transport.client import SendJoinResponse -from synapse.types import JsonDict, get_domain_from_id +from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination @@ -1610,6 +1610,64 @@ class FederationClient(FederationBase): except ValueError as e: raise InvalidResponseError(str(e)) + async def get_account_status( + self, destination: str, user_ids: List[str] + ) -> Tuple[JsonDict, List[str]]: + """Retrieves account statuses for a given list of users on a given remote + homeserver. + + If the request fails for any reason, all user IDs for this destination are marked + as failed. + + Args: + destination: the destination to contact + user_ids: the user ID(s) for which to request account status(es) + + Returns: + The account statuses, as well as the list of user IDs for which it was not + possible to retrieve a status. + """ + try: + res = await self.transport_layer.get_account_status(destination, user_ids) + except Exception: + # If the query failed for any reason, mark all the users as failed. + return {}, user_ids + + statuses = res.get("account_statuses", {}) + failures = res.get("failures", []) + + if not isinstance(statuses, dict) or not isinstance(failures, list): + # Make sure we're not feeding back malformed data back to the caller. + logger.warning( + "Destination %s responded with malformed data to account_status query", + destination, + ) + return {}, user_ids + + for user_id in user_ids: + # Any account whose status is missing is a user we failed to receive the + # status of. + if user_id not in statuses and user_id not in failures: + failures.append(user_id) + + # Filter out any user ID that doesn't belong to the remote server that sent its + # status (or failure). + def filter_user_id(user_id: str) -> bool: + try: + return UserID.from_string(user_id).domain == destination + except SynapseError: + # If the user ID doesn't parse, ignore it. + return False + + filtered_statuses = dict( + # item is a (key, value) tuple, so item[0] is the user ID. + filter(lambda item: filter_user_id(item[0]), statuses.items()) + ) + + filtered_failures = list(filter(filter_user_id, failures)) + + return filtered_statuses, filtered_failures + @attr.s(frozen=True, slots=True, auto_attribs=True) class TimestampToEventResponse: diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 7e510e224..69998de52 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -258,8 +258,9 @@ class TransportLayerClient: args: dict, retry_on_dns_fail: bool, ignore_backoff: bool = False, + prefix: str = FEDERATION_V1_PREFIX, ) -> JsonDict: - path = _create_v1_path("/query/%s", query_type) + path = _create_path(prefix, "/query/%s", query_type) return await self.client.get_json( destination=destination, @@ -1247,6 +1248,22 @@ class TransportLayerClient: args={"suggested_only": "true" if suggested_only else "false"}, ) + async def get_account_status( + self, destination: str, user_ids: List[str] + ) -> JsonDict: + """ + Args: + destination: The remote server. + user_ids: The user ID(s) for which to request account status(es). + """ + path = _create_path( + FEDERATION_UNSTABLE_PREFIX, "/org.matrix.msc3720/account_status" + ) + + return await self.client.post_json( + destination=destination, path=path, data={"user_ids": user_ids} + ) + def _create_path(federation_prefix: str, path: str, *args: str) -> str: """ diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py index db4fe2c79..67a634790 100644 --- a/synapse/federation/transport/server/__init__.py +++ b/synapse/federation/transport/server/__init__.py @@ -24,6 +24,7 @@ from synapse.federation.transport.server._base import ( ) from synapse.federation.transport.server.federation import ( FEDERATION_SERVLET_CLASSES, + FederationAccountStatusServlet, FederationTimestampLookupServlet, ) from synapse.federation.transport.server.groups_local import GROUP_LOCAL_SERVLET_CLASSES @@ -336,6 +337,13 @@ def register_servlets( ): continue + # Only allow the `/account_status` servlet if msc3720 is enabled + if ( + servletclass == FederationAccountStatusServlet + and not hs.config.experimental.msc3720_enabled + ): + continue + servletclass( hs=hs, authenticator=authenticator, diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index e85a8eda5..4d75e58bf 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -766,6 +766,40 @@ class RoomComplexityServlet(BaseFederationServlet): return 200, complexity +class FederationAccountStatusServlet(BaseFederationServerServlet): + PATH = "/query/account_status" + PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc3720" + + def __init__( + self, + hs: "HomeServer", + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self._account_handler = hs.get_account_handler() + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Mapping[bytes, Sequence[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + if "user_ids" not in content: + raise SynapseError( + 400, "Required parameter 'user_ids' is missing", Codes.MISSING_PARAM + ) + + statuses, failures = await self._account_handler.get_account_statuses( + content["user_ids"], + allow_remote=False, + ) + + return 200, {"account_statuses": statuses, "failures": failures} + + FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( FederationSendServlet, FederationEventServlet, @@ -797,4 +831,5 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( FederationRoomHierarchyUnstableServlet, FederationV1SendKnockServlet, FederationMakeKnockServlet, + FederationAccountStatusServlet, ) diff --git a/synapse/handlers/account.py b/synapse/handlers/account.py new file mode 100644 index 000000000..f8cfe9f6d --- /dev/null +++ b/synapse/handlers/account.py @@ -0,0 +1,144 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Dict, List, Tuple + +from synapse.api.errors import Codes, SynapseError +from synapse.types import JsonDict, UserID + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +class AccountHandler: + def __init__(self, hs: "HomeServer"): + self._store = hs.get_datastore() + self._is_mine = hs.is_mine + self._federation_client = hs.get_federation_client() + + async def get_account_statuses( + self, + user_ids: List[str], + allow_remote: bool, + ) -> Tuple[JsonDict, List[str]]: + """Get account statuses for a list of user IDs. + + If one or more account(s) belong to remote homeservers, retrieve their status(es) + over federation if allowed. + + Args: + user_ids: The list of accounts to retrieve the status of. + allow_remote: Whether to try to retrieve the status of remote accounts, if + any. + + Returns: + The account statuses as well as the list of users whose statuses could not be + retrieved. + + Raises: + SynapseError if a required parameter is missing or malformed, or if one of + the accounts isn't local to this homeserver and allow_remote is False. + """ + statuses = {} + failures = [] + remote_users: List[UserID] = [] + + for raw_user_id in user_ids: + try: + user_id = UserID.from_string(raw_user_id) + except SynapseError: + raise SynapseError( + 400, + f"Not a valid Matrix user ID: {raw_user_id}", + Codes.INVALID_PARAM, + ) + + if self._is_mine(user_id): + status = await self._get_local_account_status(user_id) + statuses[user_id.to_string()] = status + else: + if not allow_remote: + raise SynapseError( + 400, + f"Not a local user: {raw_user_id}", + Codes.INVALID_PARAM, + ) + + remote_users.append(user_id) + + if allow_remote and len(remote_users) > 0: + remote_statuses, remote_failures = await self._get_remote_account_statuses( + remote_users, + ) + + statuses.update(remote_statuses) + failures += remote_failures + + return statuses, failures + + async def _get_local_account_status(self, user_id: UserID) -> JsonDict: + """Retrieve the status of a local account. + + Args: + user_id: The account to retrieve the status of. + + Returns: + The account's status. + """ + status = {"exists": False} + + userinfo = await self._store.get_userinfo_by_id(user_id.to_string()) + + if userinfo is not None: + status = { + "exists": True, + "deactivated": userinfo.is_deactivated, + } + + return status + + async def _get_remote_account_statuses( + self, remote_users: List[UserID] + ) -> Tuple[JsonDict, List[str]]: + """Send out federation requests to retrieve the statuses of remote accounts. + + Args: + remote_users: The accounts to retrieve the statuses of. + + Returns: + The statuses of the accounts, and a list of accounts for which no status + could be retrieved. + """ + # Group remote users by destination, so we only send one request per remote + # homeserver. + by_destination: Dict[str, List[str]] = {} + for user in remote_users: + if user.domain not in by_destination: + by_destination[user.domain] = [] + + by_destination[user.domain].append(user.to_string()) + + # Retrieve the statuses and failures for remote accounts. + final_statuses: JsonDict = {} + final_failures: List[str] = [] + for destination, users in by_destination.items(): + statuses, failures = await self._federation_client.get_account_status( + destination, + users, + ) + + final_statuses.update(statuses) + final_failures += failures + + return final_statuses, final_failures diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index efe299e69..5802de5b7 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -896,6 +896,36 @@ class WhoamiRestServlet(RestServlet): return 200, response +class AccountStatusRestServlet(RestServlet): + PATTERNS = client_patterns( + "/org.matrix.msc3720/account_status$", unstable=True, releases=() + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._auth = hs.get_auth() + self._store = hs.get_datastore() + self._is_mine = hs.is_mine + self._federation_client = hs.get_federation_client() + self._account_handler = hs.get_account_handler() + + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await self._auth.get_user_by_req(request) + + body = parse_json_object_from_request(request) + if "user_ids" not in body: + raise SynapseError( + 400, "Required parameter 'user_ids' is missing", Codes.MISSING_PARAM + ) + + statuses, failures = await self._account_handler.get_account_statuses( + body["user_ids"], + allow_remote=True, + ) + + return 200, {"account_statuses": statuses, "failures": failures} + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: EmailPasswordRequestTokenRestServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server) @@ -910,3 +940,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ThreepidUnbindRestServlet(hs).register(http_server) ThreepidDeleteRestServlet(hs).register(http_server) WhoamiRestServlet(hs).register(http_server) + + if hs.config.experimental.msc3720_enabled: + AccountStatusRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py index e05c926b6..b80fdd371 100644 --- a/synapse/rest/client/capabilities.py +++ b/synapse/rest/client/capabilities.py @@ -75,6 +75,11 @@ class CapabilitiesRestServlet(RestServlet): if self.config.experimental.msc3440_enabled: response["capabilities"]["io.element.thread"] = {"enabled": True} + if self.config.experimental.msc3720_enabled: + response["capabilities"]["org.matrix.msc3720.account_status"] = { + "enabled": True, + } + return HTTPStatus.OK, response diff --git a/synapse/server.py b/synapse/server.py index 564afdcb9..4c07f2101 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -62,6 +62,7 @@ from synapse.federation.sender import AbstractFederationSender, FederationSender from synapse.federation.transport.client import TransportLayerClient from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler +from synapse.handlers.account import AccountHandler from synapse.handlers.account_data import AccountDataHandler from synapse.handlers.account_validity import AccountValidityHandler from synapse.handlers.admin import AdminHandler @@ -807,6 +808,10 @@ class HomeServer(metaclass=abc.ABCMeta): def get_external_cache(self) -> ExternalCache: return ExternalCache(self) + @cache_in_self + def get_account_handler(self) -> AccountHandler: + return AccountHandler(self) + @cache_in_self def get_outbound_redis_connection(self) -> "RedisProtocol": """ diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index 51146c471..afaa597f6 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -1,6 +1,4 @@ -# Copyright 2015-2016 OpenMarket Ltd -# Copyright 2017-2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2022 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,16 +15,22 @@ import json import os import re from email.parser import Parser -from typing import Optional +from typing import Dict, List, Optional +from unittest.mock import Mock import pkg_resources +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.constants import LoginType, Membership from synapse.api.errors import Codes, HttpResponseException from synapse.appservice import ApplicationService +from synapse.rest import admin from synapse.rest.client import account, login, register, room from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.server import FakeSite, make_request @@ -1040,3 +1044,195 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): threepids = {threepid["address"] for threepid in channel.json_body["threepids"]} self.assertIn(expected_email, threepids) + + +class AccountStatusTestCase(unittest.HomeserverTestCase): + servlets = [ + account.register_servlets, + admin.register_servlets, + login.register_servlets, + ] + + url = "/_matrix/client/unstable/org.matrix.msc3720/account_status" + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["experimental_features"] = {"msc3720_enabled": True} + + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): + self.requester = self.register_user("requester", "password") + self.requester_tok = self.login("requester", "password") + self.server_name = homeserver.config.server.server_name + + def test_missing_mxid(self): + """Tests that not providing any MXID raises an error.""" + self._test_status( + users=None, + expected_status_code=400, + expected_errcode=Codes.MISSING_PARAM, + ) + + def test_invalid_mxid(self): + """Tests that providing an invalid MXID raises an error.""" + self._test_status( + users=["bad:test"], + expected_status_code=400, + expected_errcode=Codes.INVALID_PARAM, + ) + + def test_local_user_not_exists(self): + """Tests that the account status endpoints correctly reports that a user doesn't + exist. + """ + user = "@unknown:" + self.hs.config.server.server_name + + self._test_status( + users=[user], + expected_statuses={ + user: { + "exists": False, + }, + }, + expected_failures=[], + ) + + def test_local_user_exists(self): + """Tests that the account status endpoint correctly reports that a user doesn't + exist. + """ + user = self.register_user("someuser", "password") + + self._test_status( + users=[user], + expected_statuses={ + user: { + "exists": True, + "deactivated": False, + }, + }, + expected_failures=[], + ) + + def test_local_user_deactivated(self): + """Tests that the account status endpoint correctly reports a deactivated user.""" + user = self.register_user("someuser", "password") + self.get_success( + self.hs.get_datastore().set_user_deactivated_status(user, deactivated=True) + ) + + self._test_status( + users=[user], + expected_statuses={ + user: { + "exists": True, + "deactivated": True, + }, + }, + expected_failures=[], + ) + + def test_mixed_local_and_remote_users(self): + """Tests that if some users are remote the account status endpoint correctly + merges the remote responses with the local result. + """ + # We use 3 users: one doesn't exist but belongs on the local homeserver, one is + # deactivated and belongs on one remote homeserver, and one belongs to another + # remote homeserver that didn't return any result (the federation code should + # mark that user as a failure). + users = [ + "@unknown:" + self.hs.config.server.server_name, + "@deactivated:remote", + "@failed:otherremote", + "@bad:badremote", + ] + + async def post_json(destination, path, data, *a, **kwa): + if destination == "remote": + return { + "account_statuses": { + users[1]: { + "exists": True, + "deactivated": True, + }, + } + } + if destination == "otherremote": + return {} + if destination == "badremote": + # badremote tries to overwrite the status of a user that doesn't belong + # to it (i.e. users[1]) with false data, which Synapse is expected to + # ignore. + return { + "account_statuses": { + users[3]: { + "exists": False, + }, + users[1]: { + "exists": False, + }, + } + } + + # Register a mock that will return the expected result depending on the remote. + self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) + + # Check that we've got the correct response from the client-side endpoint. + self._test_status( + users=users, + expected_statuses={ + users[0]: { + "exists": False, + }, + users[1]: { + "exists": True, + "deactivated": True, + }, + users[3]: { + "exists": False, + }, + }, + expected_failures=[users[2]], + ) + + def _test_status( + self, + users: Optional[List[str]], + expected_status_code: int = 200, + expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None, + expected_failures: Optional[List[str]] = None, + expected_errcode: Optional[str] = None, + ): + """Send a request to the account status endpoint and check that the response + matches with what's expected. + + Args: + users: The account(s) to request the status of, if any. If set to None, no + `user_id` query parameter will be included in the request. + expected_status_code: The expected HTTP status code. + expected_statuses: The expected account statuses, if any. + expected_failures: The expected failures, if any. + expected_errcode: The expected Matrix error code, if any. + """ + content = {} + if users is not None: + content["user_ids"] = users + + channel = self.make_request( + method="POST", + path=self.url, + content=content, + access_token=self.requester_tok, + ) + + self.assertEqual(channel.code, expected_status_code) + + if expected_statuses is not None: + self.assertEqual(channel.json_body["account_statuses"], expected_statuses) + + if expected_failures is not None: + self.assertEqual(channel.json_body["failures"], expected_failures) + + if expected_errcode is not None: + self.assertEqual(channel.json_body["errcode"], expected_errcode) From 6d14b3dabfe38c6ae487d0f663e294056b6cc056 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 22 Feb 2022 15:52:08 +0000 Subject: [PATCH 46/84] Better error message when failing to request from another process (#12060) --- changelog.d/12060.misc | 1 + synapse/replication/http/_base.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12060.misc diff --git a/changelog.d/12060.misc b/changelog.d/12060.misc new file mode 100644 index 000000000..d771e6a1b --- /dev/null +++ b/changelog.d/12060.misc @@ -0,0 +1 @@ +Fix error message when a worker process fails to talk to another worker process. diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index bc1d28dd1..2e697c74a 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -268,7 +268,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): raise e.to_synapse_error() except Exception as e: _outgoing_request_counter.labels(cls.NAME, "ERR").inc() - raise SynapseError(502, "Failed to talk to main process") from e + raise SynapseError( + 502, f"Failed to talk to {instance_name} process" + ) from e _outgoing_request_counter.labels(cls.NAME, 200).inc() return result From e3fe6347be1da930b6a0ed2005b565369800a327 Mon Sep 17 00:00:00 2001 From: Travis Ralston Date: Tue, 22 Feb 2022 11:35:01 -0700 Subject: [PATCH 47/84] Remove excess condition on `knock->leave` check (#11900) --- changelog.d/11900.misc | 1 + synapse/event_auth.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) create mode 100644 changelog.d/11900.misc diff --git a/changelog.d/11900.misc b/changelog.d/11900.misc new file mode 100644 index 000000000..edd2852fd --- /dev/null +++ b/changelog.d/11900.misc @@ -0,0 +1 @@ +Remove unnecessary condition on knock->leave auth rule check. \ No newline at end of file diff --git a/synapse/event_auth.py b/synapse/event_auth.py index eca00bc97..621a3efcc 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -374,9 +374,9 @@ def _is_membership_change_allowed( return # Require the user to be in the room for membership changes other than join/knock. - if Membership.JOIN != membership and ( - RoomVersion.msc2403_knocking and Membership.KNOCK != membership - ): + # Note that the room version check for knocking is done implicitly by `caller_knocked` + # and the ability to set a membership of `knock` in the first place. + if Membership.JOIN != membership and Membership.KNOCK != membership: # If the user has been invited or has knocked, they are allowed to change their # membership event to leave if ( From c1ac2a81350f3b5b86f4c53a585eccd17e3b8e75 Mon Sep 17 00:00:00 2001 From: Nicolas Werner <89468146+nico-famedly@users.noreply.github.com> Date: Wed, 23 Feb 2022 10:06:18 +0000 Subject: [PATCH 48/84] Rename default branch of complement.sh to main (#12063) The complement.sh script relies on the name of the ref matching the name of the unpacked folder. The branch redirect from renaming the default branch breaks that assumption. Signed-off-by: Nicolas Werner --- changelog.d/12063.misc | 1 + scripts-dev/complement.sh | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12063.misc diff --git a/changelog.d/12063.misc b/changelog.d/12063.misc new file mode 100644 index 000000000..e48c5dd08 --- /dev/null +++ b/changelog.d/12063.misc @@ -0,0 +1 @@ +Fix using the complement.sh script without specifying a dir or a branch. Contributed by Nico on behalf of Famedly. diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index e08ffedaf..0aecb3daf 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -5,7 +5,7 @@ # It makes a Synapse image which represents the current checkout, # builds a synapse-complement image on top, then runs tests with it. # -# By default the script will fetch the latest Complement master branch and +# By default the script will fetch the latest Complement main branch and # run tests with that. This can be overridden to use a custom Complement # checkout by setting the COMPLEMENT_DIR environment variable to the # filepath of a local Complement checkout or by setting the COMPLEMENT_REF @@ -32,7 +32,7 @@ cd "$(dirname $0)/.." # Check for a user-specified Complement checkout if [[ -z "$COMPLEMENT_DIR" ]]; then - COMPLEMENT_REF=${COMPLEMENT_REF:-master} + COMPLEMENT_REF=${COMPLEMENT_REF:-main} echo "COMPLEMENT_DIR not set. Fetching Complement checkout from ${COMPLEMENT_REF}..." wget -Nq https://github.com/matrix-org/complement/archive/${COMPLEMENT_REF}.tar.gz tar -xzf ${COMPLEMENT_REF}.tar.gz From e24ff8ebe3d4119d377355402245947f7de61c00 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 23 Feb 2022 11:04:02 +0000 Subject: [PATCH 49/84] Remove `HomeServer.get_datastore()` (#12031) The presence of this method was confusing, and mostly present for backwards compatibility. Let's get rid of it. Part of #11733 --- changelog.d/12031.misc | 1 + docs/manhole.md | 2 +- scripts/update_synapse_database | 2 +- synapse/api/auth.py | 2 +- synapse/api/auth_blocking.py | 2 +- synapse/api/filtering.py | 4 +-- synapse/app/_base.py | 2 +- synapse/app/generic_worker.py | 2 +- synapse/app/homeserver.py | 2 +- synapse/app/phone_stats_home.py | 14 +++++--- synapse/appservice/scheduler.py | 2 +- synapse/crypto/keyring.py | 4 +-- synapse/events/builder.py | 2 +- synapse/events/third_party_rules.py | 2 +- synapse/federation/federation_base.py | 2 +- synapse/federation/sender/__init__.py | 2 +- .../sender/per_destination_queue.py | 9 ++--- .../federation/sender/transaction_manager.py | 2 +- synapse/federation/transport/server/_base.py | 2 +- .../federation/transport/server/federation.py | 2 +- synapse/groups/attestations.py | 2 +- synapse/groups/groups_server.py | 2 +- synapse/handlers/account_data.py | 4 +-- synapse/handlers/account_validity.py | 2 +- synapse/handlers/admin.py | 2 +- synapse/handlers/appservice.py | 2 +- synapse/handlers/auth.py | 4 +-- synapse/handlers/cas.py | 2 +- synapse/handlers/deactivate_account.py | 2 +- synapse/handlers/device.py | 4 +-- synapse/handlers/devicemessage.py | 2 +- synapse/handlers/directory.py | 2 +- synapse/handlers/e2e_keys.py | 4 +-- synapse/handlers/e2e_room_keys.py | 2 +- synapse/handlers/event_auth.py | 2 +- synapse/handlers/events.py | 4 +-- synapse/handlers/federation.py | 2 +- synapse/handlers/federation_event.py | 2 +- synapse/handlers/groups_local.py | 2 +- synapse/handlers/identity.py | 2 +- synapse/handlers/initial_sync.py | 2 +- synapse/handlers/message.py | 4 +-- synapse/handlers/oidc.py | 2 +- synapse/handlers/pagination.py | 2 +- synapse/handlers/presence.py | 4 +-- synapse/handlers/profile.py | 2 +- synapse/handlers/read_marker.py | 2 +- synapse/handlers/receipts.py | 4 +-- synapse/handlers/register.py | 2 +- synapse/handlers/room.py | 10 +++--- synapse/handlers/room_batch.py | 2 +- synapse/handlers/room_list.py | 2 +- synapse/handlers/room_member.py | 2 +- synapse/handlers/room_summary.py | 2 +- synapse/handlers/saml.py | 2 +- synapse/handlers/search.py | 2 +- synapse/handlers/set_password.py | 2 +- synapse/handlers/sso.py | 2 +- synapse/handlers/state_deltas.py | 2 +- synapse/handlers/stats.py | 2 +- synapse/handlers/sync.py | 2 +- synapse/handlers/typing.py | 4 +-- synapse/handlers/ui_auth/checkers.py | 4 +-- synapse/handlers/user_directory.py | 2 +- synapse/http/matrixfederationclient.py | 2 +- synapse/module_api/__init__.py | 8 +++-- synapse/notifier.py | 2 +- synapse/push/__init__.py | 2 +- synapse/push/bulk_push_rule_evaluator.py | 4 +-- synapse/push/emailpusher.py | 2 +- synapse/push/httppusher.py | 4 +-- synapse/push/mailer.py | 2 +- synapse/push/pusherpool.py | 2 +- synapse/replication/http/devices.py | 2 +- synapse/replication/http/federation.py | 10 +++--- synapse/replication/http/membership.py | 10 +++--- synapse/replication/http/register.py | 4 +-- synapse/replication/http/send_event.py | 2 +- synapse/replication/tcp/client.py | 4 +-- synapse/replication/tcp/handler.py | 2 +- synapse/replication/tcp/resource.py | 2 +- synapse/replication/tcp/streams/_base.py | 24 ++++++------- synapse/replication/tcp/streams/events.py | 2 +- synapse/rest/admin/__init__.py | 2 +- synapse/rest/admin/background_updates.py | 2 +- synapse/rest/admin/devices.py | 6 ++-- synapse/rest/admin/event_reports.py | 4 +-- synapse/rest/admin/federation.py | 8 ++--- synapse/rest/admin/media.py | 20 +++++------ synapse/rest/admin/registration_tokens.py | 6 ++-- synapse/rest/admin/rooms.py | 16 ++++----- synapse/rest/admin/statistics.py | 2 +- synapse/rest/admin/users.py | 24 ++++++------- synapse/rest/client/account.py | 18 +++++----- synapse/rest/client/account_data.py | 4 +-- synapse/rest/client/directory.py | 6 ++-- synapse/rest/client/events.py | 2 +- synapse/rest/client/groups.py | 8 ++--- synapse/rest/client/initial_sync.py | 2 +- synapse/rest/client/keys.py | 2 +- synapse/rest/client/login.py | 4 +-- synapse/rest/client/notifications.py | 2 +- synapse/rest/client/openid.py | 2 +- synapse/rest/client/push_rule.py | 2 +- synapse/rest/client/pusher.py | 4 ++- synapse/rest/client/register.py | 10 +++--- synapse/rest/client/relations.py | 6 ++-- synapse/rest/client/report_event.py | 2 +- synapse/rest/client/room.py | 12 +++---- synapse/rest/client/room_batch.py | 2 +- synapse/rest/client/shared_rooms.py | 2 +- synapse/rest/client/sync.py | 2 +- synapse/rest/client/tags.py | 2 +- synapse/rest/consent/consent_resource.py | 2 +- synapse/rest/key/v2/remote_key_resource.py | 2 +- synapse/rest/media/v1/media_repository.py | 2 +- synapse/rest/media/v1/preview_url_resource.py | 2 +- synapse/rest/media/v1/thumbnail_resource.py | 2 +- synapse/rest/media/v1/upload_resource.py | 2 +- synapse/rest/synapse/client/password_reset.py | 2 +- synapse/server.py | 16 +++------ .../server_notices/consent_server_notices.py | 2 +- .../resource_limits_server_notices.py | 2 +- .../server_notices/server_notices_manager.py | 2 +- synapse/state/__init__.py | 2 +- .../databases/main/monthly_active_users.py | 2 +- synapse/streams/events.py | 2 +- tests/api/test_auth.py | 8 +++-- tests/api/test_filtering.py | 2 +- tests/api/test_ratelimiting.py | 18 +++++----- tests/app/test_phone_stats_home.py | 34 ++++++++++--------- tests/crypto/test_keyring.py | 10 +++--- tests/events/test_snapshot.py | 2 +- tests/federation/test_complexity.py | 4 +-- tests/federation/test_federation_catch_up.py | 26 +++++++------- tests/federation/test_federation_sender.py | 6 ++-- tests/federation/transport/test_knocking.py | 2 +- tests/handlers/test_appservice.py | 10 +++--- tests/handlers/test_auth.py | 12 +++---- tests/handlers/test_cas.py | 2 +- tests/handlers/test_deactivate_account.py | 2 +- tests/handlers/test_device.py | 4 +-- tests/handlers/test_directory.py | 6 ++-- tests/handlers/test_e2e_keys.py | 2 +- tests/handlers/test_federation.py | 2 +- tests/handlers/test_message.py | 2 +- tests/handlers/test_oidc.py | 6 ++-- tests/handlers/test_presence.py | 4 +-- tests/handlers/test_profile.py | 4 +-- tests/handlers/test_register.py | 6 ++-- tests/handlers/test_saml.py | 4 +-- tests/handlers/test_stats.py | 2 +- tests/handlers/test_sync.py | 4 +-- tests/handlers/test_typing.py | 2 +- tests/handlers/test_user_directory.py | 2 +- tests/module_api/test_api.py | 2 +- tests/push/test_email.py | 22 ++++++------ tests/push/test_http.py | 22 ++++++------ tests/replication/_base.py | 6 ++-- tests/replication/slave/storage/_base.py | 4 +-- .../tcp/streams/test_account_data.py | 4 +-- tests/replication/tcp/streams/test_events.py | 4 +-- .../replication/tcp/streams/test_receipts.py | 4 +-- .../test_federation_sender_shard.py | 2 +- tests/replication/test_pusher_shard.py | 2 +- .../test_sharded_event_persister.py | 8 ++--- tests/rest/admin/test_background_updates.py | 2 +- tests/rest/admin/test_federation.py | 4 +-- tests/rest/admin/test_media.py | 4 +-- tests/rest/admin/test_registration_tokens.py | 2 +- tests/rest/admin/test_room.py | 6 ++-- tests/rest/admin/test_server_notice.py | 2 +- tests/rest/admin/test_user.py | 20 +++++------ tests/rest/client/test_account.py | 10 +++--- tests/rest/client/test_filter.py | 2 +- tests/rest/client/test_login.py | 4 +-- tests/rest/client/test_profile.py | 2 +- tests/rest/client/test_register.py | 28 ++++++++------- tests/rest/client/test_relations.py | 4 +-- tests/rest/client/test_retention.py | 4 +-- tests/rest/client/test_rooms.py | 10 +++--- tests/rest/client/test_shadow_banned.py | 2 +- tests/rest/client/test_shared_rooms.py | 2 +- tests/rest/client/test_sync.py | 2 +- tests/rest/client/test_typing.py | 2 +- tests/rest/client/test_upgrade_room.py | 2 +- tests/rest/media/v1/test_media_storage.py | 2 +- .../test_resource_limits_server_notices.py | 2 +- .../databases/main/test_deviceinbox.py | 2 +- .../databases/main/test_events_worker.py | 6 ++-- tests/storage/databases/main/test_lock.py | 2 +- tests/storage/databases/main/test_room.py | 2 +- tests/storage/test__base.py | 2 +- tests/storage/test_account_data.py | 2 +- tests/storage/test_appservice.py | 2 +- tests/storage/test_background_update.py | 8 ++--- tests/storage/test_cleanup_extrems.py | 4 +-- tests/storage/test_client_ips.py | 4 +-- tests/storage/test_devices.py | 2 +- tests/storage/test_directory.py | 2 +- tests/storage/test_e2e_room_keys.py | 2 +- tests/storage/test_end_to_end_keys.py | 2 +- tests/storage/test_event_chain.py | 4 +-- tests/storage/test_event_federation.py | 2 +- tests/storage/test_event_push_actions.py | 2 +- tests/storage/test_events.py | 4 +-- tests/storage/test_id_generators.py | 6 ++-- tests/storage/test_keys.py | 4 +-- tests/storage/test_main.py | 2 +- tests/storage/test_monthly_active_users.py | 2 +- tests/storage/test_profile.py | 2 +- tests/storage/test_purge.py | 4 +-- tests/storage/test_redaction.py | 2 +- tests/storage/test_registration.py | 2 +- tests/storage/test_rollback_worker.py | 6 ++-- tests/storage/test_room.py | 4 +-- tests/storage/test_room_search.py | 2 +- tests/storage/test_roommember.py | 4 +-- tests/storage/test_state.py | 2 +- tests/storage/test_stream.py | 2 +- tests/storage/test_transactions.py | 2 +- tests/storage/test_user_directory.py | 4 +-- tests/test_federation.py | 8 +++-- tests/test_mau.py | 2 +- tests/test_state.py | 4 +-- tests/test_utils/event_injection.py | 4 ++- tests/test_visibility.py | 4 ++- tests/unittest.py | 14 ++++---- tests/util/test_retryutils.py | 4 +-- tests/utils.py | 2 +- 230 files changed, 526 insertions(+), 500 deletions(-) create mode 100644 changelog.d/12031.misc diff --git a/changelog.d/12031.misc b/changelog.d/12031.misc new file mode 100644 index 000000000..d4bedc6b9 --- /dev/null +++ b/changelog.d/12031.misc @@ -0,0 +1 @@ +Remove legacy `HomeServer.get_datastore()`. diff --git a/docs/manhole.md b/docs/manhole.md index 715ed840f..a82fad0f0 100644 --- a/docs/manhole.md +++ b/docs/manhole.md @@ -94,6 +94,6 @@ As a simple example, retrieving an event from the database: ```pycon >>> from twisted.internet import defer ->>> defer.ensureDeferred(hs.get_datastore().get_event('$1416420717069yeQaw:matrix.org')) +>>> defer.ensureDeferred(hs.get_datastores().main.get_event('$1416420717069yeQaw:matrix.org')) > ``` diff --git a/scripts/update_synapse_database b/scripts/update_synapse_database index 5c6453d77..f43676afa 100755 --- a/scripts/update_synapse_database +++ b/scripts/update_synapse_database @@ -44,7 +44,7 @@ class MockHomeserver(HomeServer): def run_background_updates(hs): - store = hs.get_datastore() + store = hs.get_datastores().main async def run_background_updates(): await store.db_pool.updates.run_background_updates(sleep=False) diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 683241201..01c32417d 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -60,7 +60,7 @@ class Auth: def __init__(self, hs: "HomeServer"): self.hs = hs self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state = hs.get_state_handler() self._account_validity_handler = hs.get_account_validity_handler() diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py index 08fe160c9..22348d2d8 100644 --- a/synapse/api/auth_blocking.py +++ b/synapse/api/auth_blocking.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) class AuthBlocking: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._server_notices_mxid = hs.config.servernotices.server_notices_mxid self._hs_disabled = hs.config.server.hs_disabled diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index d087c816d..fe4cc2e8e 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -150,7 +150,7 @@ def matrix_user_id_validator(user_id_str: str) -> UserID: class Filtering: def __init__(self, hs: "HomeServer"): self._hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {}) @@ -294,7 +294,7 @@ class FilterCollection: class Filter: def __init__(self, hs: "HomeServer", filter_json: JsonDict): self._hs = hs - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self.filter_json = filter_json self.limit = filter_json.get("limit", 10) diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 452c0c09d..3e59805ba 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -448,7 +448,7 @@ async def start(hs: "HomeServer") -> None: # It is now safe to start your Synapse. hs.start_listening() - hs.get_datastore().db_pool.start_profiling() + hs.get_datastores().main.db_pool.start_profiling() hs.get_pusherpool().start() # Log when we start the shut down process. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index aadc882bf..1536a4272 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -142,7 +142,7 @@ class KeyUploadServlet(RestServlet): """ super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.http_client = hs.get_simple_http_client() self.main_uri = hs.config.worker.worker_main_http_uri diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index bfb30003c..b9931001c 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -372,7 +372,7 @@ def setup(config_options: List[str]) -> SynapseHomeServer: await _base.start(hs) - hs.get_datastore().db_pool.updates.start_doing_background_updates() + hs.get_datastores().main.db_pool.updates.start_doing_background_updates() register_start(start) diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index 899dba5c3..40dbdace8 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -82,7 +82,7 @@ async def phone_stats_home( # General statistics # - store = hs.get_datastore() + store = hs.get_datastores().main stats["homeserver"] = hs.config.server.server_name stats["server_context"] = hs.config.server.server_context @@ -170,18 +170,22 @@ def start_phone_stats_home(hs: "HomeServer") -> None: # Rather than update on per session basis, batch up the requests. # If you increase the loop period, the accuracy of user_daily_visits # table will decrease - clock.looping_call(hs.get_datastore().generate_user_daily_visits, 5 * 60 * 1000) + clock.looping_call( + hs.get_datastores().main.generate_user_daily_visits, 5 * 60 * 1000 + ) # monthly active user limiting functionality - clock.looping_call(hs.get_datastore().reap_monthly_active_users, 1000 * 60 * 60) - hs.get_datastore().reap_monthly_active_users() + clock.looping_call( + hs.get_datastores().main.reap_monthly_active_users, 1000 * 60 * 60 + ) + hs.get_datastores().main.reap_monthly_active_users() @wrap_as_background_process("generate_monthly_active_users") async def generate_monthly_active_users() -> None: current_mau_count = 0 current_mau_count_by_service = {} reserved_users: Sized = () - store = hs.get_datastore() + store = hs.get_datastores().main if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only: current_mau_count = await store.get_monthly_active_count() current_mau_count_by_service = ( diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index c42fa32ff..b4e602e88 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -92,7 +92,7 @@ class ApplicationServiceScheduler: def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.as_api = hs.get_application_service_api() self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 72d4a69aa..93d56c077 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -476,7 +476,7 @@ class StoreKeyFetcher(KeyFetcher): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def _fetch_keys( self, keys_to_fetch: List[_FetchKeyRequest] @@ -498,7 +498,7 @@ class BaseV2KeyFetcher(KeyFetcher): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.config = hs.config async def process_v2_response( diff --git a/synapse/events/builder.py b/synapse/events/builder.py index eb39e0ae3..1ea1bb7d3 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -189,7 +189,7 @@ class EventBuilderFactory: self.hostname = hs.hostname self.signing_key = hs.signing_key - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state = hs.get_state_handler() self._event_auth_handler = hs.get_event_auth_handler() diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 1bb8ca714..71ec100a7 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -143,7 +143,7 @@ class ThirdPartyEventRules: def __init__(self, hs: "HomeServer"): self.third_party_rules = None - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = [] self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = [] diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index fab6da3c0..41ac49fdc 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -39,7 +39,7 @@ class FederationBase: self.server_name = hs.hostname self.keyring = hs.get_keyring() self.spam_checker = hs.get_spam_checker() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._clock = hs.get_clock() async def _check_sigs_and_hash( diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 720d7bd74..6106a486d 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -228,7 +228,7 @@ class FederationSender(AbstractFederationSender): self.hs = hs self.server_name = hs.hostname - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state = hs.get_state_handler() self.clock = hs.get_clock() diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index c3132f731..c8768f22b 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -76,7 +76,7 @@ class PerDestinationQueue: ): self._server_name = hs.hostname self._clock = hs.get_clock() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._transaction_manager = transaction_manager self._instance_name = hs.get_instance_name() self._federation_shard_config = hs.config.worker.federation_shard_config @@ -381,9 +381,8 @@ class PerDestinationQueue: ) ) - last_successful_stream_ordering = self._last_successful_stream_ordering - - if last_successful_stream_ordering is None: + _tmp_last_successful_stream_ordering = self._last_successful_stream_ordering + if _tmp_last_successful_stream_ordering is None: # if it's still None, then this means we don't have the information # in our database ­ we haven't successfully sent a PDU to this server # (at least since the introduction of the feature tracking @@ -393,6 +392,8 @@ class PerDestinationQueue: self._catching_up = False return + last_successful_stream_ordering: int = _tmp_last_successful_stream_ordering + # get at most 50 catchup room/PDUs while True: event_ids = await self._store.get_catch_up_room_event_ids( diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 742ee5725..0c1cad86a 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -53,7 +53,7 @@ class TransactionManager: def __init__(self, hs: "synapse.server.HomeServer"): self._server_name = hs.hostname self.clock = hs.get_clock() # nb must be called this for @measure_func - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._transaction_actions = TransactionActions(self._store) self._transport_layer = hs.get_federation_transport_client() diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py index dff2b6835..87e99c7dd 100644 --- a/synapse/federation/transport/server/_base.py +++ b/synapse/federation/transport/server/_base.py @@ -55,7 +55,7 @@ class Authenticator: self._clock = hs.get_clock() self.keyring = hs.get_keyring() self.server_name = hs.hostname - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.federation_domain_whitelist = ( hs.config.federation.federation_domain_whitelist ) diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 4d75e58bf..9cc9a7339 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -746,7 +746,7 @@ class RoomComplexityServlet(BaseFederationServlet): server_name: str, ): super().__init__(hs, authenticator, ratelimiter, server_name) - self._store = self.hs.get_datastore() + self._store = self.hs.get_datastores().main async def on_GET( self, diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index a87896e53..ed26d6a6c 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -140,7 +140,7 @@ class GroupAttestionRenewer: def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.assestations = hs.get_groups_attestation_signing() self.transport_client = hs.get_federation_transport_client() self.is_mine_id = hs.is_mine_id diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 449bbc700..4c3a5a6e2 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -45,7 +45,7 @@ MAX_LONG_DESC_LEN = 10000 class GroupsServerWorkerHandler: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room_list_handler = hs.get_room_list_handler() self.auth = hs.get_auth() self.clock = hs.get_clock() diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index bad48713b..177b4f899 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -30,7 +30,7 @@ if TYPE_CHECKING: class AccountDataHandler: def __init__(self, hs: "HomeServer"): - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._instance_name = hs.get_instance_name() self._notifier = hs.get_notifier() @@ -166,7 +166,7 @@ class AccountDataHandler: class AccountDataEventSource(EventSource[int, JsonDict]): def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def get_current_key(self, direction: str = "f") -> int: return self.store.get_max_account_data_stream_id() diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 87e415df7..9d0975f63 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -43,7 +43,7 @@ class AccountValidityHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.config = hs.config - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.send_email_handler = self.hs.get_send_email_handler() self.clock = self.hs.get_clock() diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 00ab5e79b..96376963f 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class AdminHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_store = self.storage.state diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index a42c3558e..e6461cc3c 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -47,7 +47,7 @@ events_processed_counter = Counter("synapse_handlers_appservice_events_processed class ApplicationServicesHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.is_mine_id = hs.is_mine_id self.appservice_api = hs.get_application_service_api() self.scheduler = hs.get_application_service_scheduler() diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 572f54b1e..3e29c96a4 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -194,7 +194,7 @@ class AuthHandler: SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.clock = hs.get_clock() self.checkers: Dict[str, UserInteractiveAuthChecker] = {} @@ -1183,7 +1183,7 @@ class AuthHandler: # No password providers were able to handle this 3pid # Check local store - user_id = await self.hs.get_datastore().get_user_id_by_threepid( + user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( medium, address ) if not user_id: diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py index 5d8f6c50a..7163af800 100644 --- a/synapse/handlers/cas.py +++ b/synapse/handlers/cas.py @@ -61,7 +61,7 @@ class CasHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self._hostname = hs.hostname - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._auth_handler = hs.get_auth_handler() self._registration_handler = hs.get_registration_handler() diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 7a13d76a6..e4eae0305 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -29,7 +29,7 @@ class DeactivateAccountHandler: """Handler which deals with deactivating user accounts.""" def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.hs = hs self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 36c05f836..934b5bd73 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -63,7 +63,7 @@ class DeviceWorkerHandler: def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self.state = hs.get_state_handler() self.state_store = hs.get_storage().state @@ -628,7 +628,7 @@ class DeviceListUpdater: "Handles incoming device list updates from federation and updates the DB" def __init__(self, hs: "HomeServer", device_handler: DeviceHandler): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.federation = hs.get_federation_client() self.clock = hs.get_clock() self.device_handler = device_handler diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index b582266af..4cb725d02 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -43,7 +43,7 @@ class DeviceMessageHandler: Args: hs: server """ - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self.is_mine = hs.is_mine diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 082f52179..b7064c662 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -44,7 +44,7 @@ class DirectoryHandler: self.state = hs.get_state_handler() self.appservice_handler = hs.get_application_service_handler() self.event_creation_handler = hs.get_event_creation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.config = hs.config self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.require_membership = hs.config.server.require_membership_for_aliases diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index d4dfddf63..d96456cd4 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -47,7 +47,7 @@ logger = logging.getLogger(__name__) class E2eKeysHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.federation = hs.get_federation_client() self.device_handler = hs.get_device_handler() self.is_mine = hs.is_mine @@ -1335,7 +1335,7 @@ class SigningKeyEduUpdater: """Handles incoming signing key updates from federation and updates the DB""" def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.federation = hs.get_federation_client() self.clock = hs.get_clock() self.e2e_keys_handler = e2e_keys_handler diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 12614b2c5..52e44a2d4 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -45,7 +45,7 @@ class E2eRoomKeysHandler: """ def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main # Used to lock whenever a client is uploading key data. This prevents collisions # between clients trying to upload the details of a new session, given all diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index 365063ebd..d441ebb0a 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -43,7 +43,7 @@ class EventAuthHandler: def __init__(self, hs: "HomeServer"): self._clock = hs.get_clock() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._server_name = hs.hostname async def check_auth_rules_from_context( diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index bac5de052..97e75e60c 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -33,7 +33,7 @@ logger = logging.getLogger(__name__) class EventStreamHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.hs = hs @@ -134,7 +134,7 @@ class EventStreamHandler: class EventHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() async def get_event( diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index e9ac920bc..c055c26ec 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -107,7 +107,7 @@ class FederationHandler: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_store = self.storage.state self.federation_client = hs.get_federation_client() diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 7683246be..09d0de1ea 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -95,7 +95,7 @@ class FederationEventHandler: """ def __init__(self, hs: "HomeServer"): - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._storage = hs.get_storage() self._state_store = self._storage.state diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index 9e270d461..e7a399787 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -63,7 +63,7 @@ def _create_rerouter(func_name: str) -> Callable[..., Awaitable[JsonDict]]: class GroupsLocalWorkerHandler: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room_list_handler = hs.get_room_list_handler() self.groups_server_handler = hs.get_groups_server_handler() self.transport_client = hs.get_federation_transport_client() diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index c83eaea35..57c9fdfe6 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -49,7 +49,7 @@ id_server_scheme = "https://" class IdentityHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main # An HTTP client for contacting trusted URLs. self.http_client = SimpleHttpClient(hs) # An HTTP client for contacting identity servers specified by clients. diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 346a06ff4..344f20f37 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -46,7 +46,7 @@ logger = logging.getLogger(__name__) class InitialSyncHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.state_handler = hs.get_state_handler() self.hs = hs diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 4d0da8428..a9c964cd7 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -75,7 +75,7 @@ class MessageHandler: self.auth = hs.get_auth() self.clock = hs.get_clock() self.state = hs.get_state_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_store = self.storage.state self._event_serializer = hs.get_event_client_serializer() @@ -397,7 +397,7 @@ class EventCreationHandler: self.hs = hs self.auth = hs.get_auth() self._event_auth_handler = hs.get_event_auth_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state = hs.get_state_handler() self.clock = hs.get_clock() diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 8f71d975e..593a2aac6 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -273,7 +273,7 @@ class OidcProvider: token_generator: "OidcSessionTokenGenerator", provider: OidcProviderConfig, ): - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._token_generator = token_generator diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 973f26296..5c01a426f 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -127,7 +127,7 @@ class PaginationHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_store = self.storage.state self.clock = hs.get_clock() diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index b223b7262..c155098be 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -133,7 +133,7 @@ class BasePresenceHandler(abc.ABC): def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.presence_router = hs.get_presence_router() self.state = hs.get_state_handler() self.is_mine_id = hs.is_mine_id @@ -1541,7 +1541,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): self.get_presence_handler = hs.get_presence_handler self.get_presence_router = hs.get_presence_router self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def get_new_events( self, diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 36e3ad2ba..dd27f0acc 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -54,7 +54,7 @@ class ProfileHandler: PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.hs = hs diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index 58593e570..bad1acc63 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) class ReadMarkerHandler: def __init__(self, hs: "HomeServer"): self.server_name = hs.config.server.server_name - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.account_data_handler = hs.get_account_data_handler() self.read_marker_linearizer = Linearizer(name="read_marker") diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 5cb1ff749..b4132c353 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -29,7 +29,7 @@ class ReceiptsHandler: def __init__(self, hs: "HomeServer"): self.notifier = hs.get_notifier() self.server_name = hs.config.server.server_name - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.event_auth_handler = hs.get_event_auth_handler() self.hs = hs @@ -163,7 +163,7 @@ class ReceiptsHandler: class ReceiptEventSource(EventSource[int, JsonDict]): def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.config = hs.config @staticmethod diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 80320d2c0..05bb1e022 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -86,7 +86,7 @@ class LoginDict(TypedDict): class RegistrationHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.hs = hs self.auth = hs.get_auth() diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index a990727fc..7b965b4b9 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -105,7 +105,7 @@ class EventContext: class RoomCreationHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.clock = hs.get_clock() self.hs = hs @@ -1115,7 +1115,7 @@ class RoomContextHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_store = self.storage.state @@ -1246,7 +1246,7 @@ class RoomContextHandler: class TimestampLookupHandler: def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state_handler = hs.get_state_handler() self.federation_client = hs.get_federation_client() @@ -1386,7 +1386,7 @@ class TimestampLookupHandler: class RoomEventSource(EventSource[RoomStreamToken, EventBase]): def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def get_new_events( self, @@ -1476,7 +1476,7 @@ class RoomShutdownHandler: self._room_creation_handler = hs.get_room_creation_handler() self._replication = hs.get_replication_data_handler() self.event_creation_handler = hs.get_event_creation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def shutdown_room( self, diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index f8137ec04..abbf7b7b2 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) class RoomBatchHandler: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state_store = hs.get_storage().state self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 1a33211a1..f3577b5d5 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -49,7 +49,7 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None) class RoomListHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.hs = hs self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.response_cache: ResponseCache[ diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index b2adc0f48..a582837cf 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -66,7 +66,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.state_handler = hs.get_state_handler() self.config = hs.config diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index 4844b69a0..2e61d1cbe 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -90,7 +90,7 @@ class RoomSummaryHandler: def __init__(self, hs: "HomeServer"): self._event_auth_handler = hs.get_event_auth_handler() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._event_serializer = hs.get_event_client_serializer() self._server_name = hs.hostname self._federation_client = hs.get_federation_client() diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py index 727d75a50..9602f0d0b 100644 --- a/synapse/handlers/saml.py +++ b/synapse/handlers/saml.py @@ -52,7 +52,7 @@ class Saml2SessionData: class SamlHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.server_name = hs.hostname self._saml_client = Saml2Client(hs.config.saml2.saml2_sp_config) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 0e0e58de0..aa16e417e 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -49,7 +49,7 @@ class _SearchResult: class SearchHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state_handler = hs.get_state_handler() self.clock = hs.get_clock() self.hs = hs diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 706ad7276..73861bbd4 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -27,7 +27,7 @@ class SetPasswordHandler: """Handler which deals with changing user account passwords""" def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 0bb8b0929..ff5b5169c 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -180,7 +180,7 @@ class SsoHandler: def __init__(self, hs: "HomeServer"): self._clock = hs.get_clock() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._server_name = hs.hostname self._registration_handler = hs.get_registration_handler() self._auth_handler = hs.get_auth_handler() diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py index d30ba2b72..2d197282e 100644 --- a/synapse/handlers/state_deltas.py +++ b/synapse/handlers/state_deltas.py @@ -30,7 +30,7 @@ class MatchChange(Enum): class StateDeltasHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def _get_key_change( self, diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 29e41a4c7..436cd971c 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -39,7 +39,7 @@ class StatsHandler: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state = hs.get_state_handler() self.server_name = hs.hostname self.clock = hs.get_clock() diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index e6050cbce..98eaad331 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -266,7 +266,7 @@ class SyncResult: class SyncHandler: def __init__(self, hs: "HomeServer"): self.hs_config = hs.config - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self.presence_handler = hs.get_presence_handler() self.event_sources = hs.get_event_sources() diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index e4bed1c93..843c68eb0 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -57,7 +57,7 @@ class FollowerTypingHandler: """ def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.server_name = hs.config.server.server_name self.clock = hs.get_clock() self.is_mine_id = hs.is_mine_id @@ -446,7 +446,7 @@ class TypingWriterHandler(FollowerTypingHandler): class TypingNotificationEventSource(EventSource[int, JsonDict]): def __init__(self, hs: "HomeServer"): - self._main_store = hs.get_datastore() + self._main_store = hs.get_datastores().main self.clock = hs.get_clock() # We can't call get_typing_handler here because there's a cycle: # diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 184730ebe..014754a63 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -139,7 +139,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker): class _BaseThreepidAuthChecker: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def _check_threepid(self, medium: str, authdict: dict) -> dict: if "threepid_creds" not in authdict: @@ -255,7 +255,7 @@ class RegistrationTokenAuthChecker(UserInteractiveAuthChecker): super().__init__(hs) self.hs = hs self._enabled = bool(hs.config.registration.registration_requires_token) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def is_enabled(self) -> bool: return self._enabled diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 1565e034c..d27ed2be6 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -55,7 +55,7 @@ class UserDirectoryHandler(StateDeltasHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.server_name = hs.hostname self.clock = hs.get_clock() self.notifier = hs.get_notifier() diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index e7656fbb9..40bf1e06d 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -351,7 +351,7 @@ class MatrixFederationHttpClient: ) self.clock = hs.get_clock() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self.version_string_bytes = hs.version_string.encode("ascii") self.default_timeout = 60 diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 07020bfb8..902916d80 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -172,7 +172,9 @@ class ModuleApi: # TODO: Fix this type hint once the types for the data stores have been ironed # out. - self._store: Union[DataStore, "GenericWorkerSlavedStore"] = hs.get_datastore() + self._store: Union[ + DataStore, "GenericWorkerSlavedStore" + ] = hs.get_datastores().main self._auth = hs.get_auth() self._auth_handler = auth_handler self._server_name = hs.hostname @@ -926,7 +928,7 @@ class ModuleApi: ) # Try to retrieve the resulting event. - event = await self._hs.get_datastore().get_event(event_id) + event = await self._hs.get_datastores().main.get_event(event_id) # update_membership is supposed to always return after the event has been # successfully persisted. @@ -1270,7 +1272,7 @@ class PublicRoomListManager: """ def __init__(self, hs: "HomeServer"): - self._store = hs.get_datastore() + self._store = hs.get_datastores().main async def room_is_in_public_room_list(self, room_id: str) -> bool: """Checks whether a room is in the public room list. diff --git a/synapse/notifier.py b/synapse/notifier.py index 753dd6b6a..16d15a1f3 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -222,7 +222,7 @@ class Notifier: self.hs = hs self.storage = hs.get_storage() self.event_sources = hs.get_event_sources() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.pending_new_room_events: List[_PendingRoomEventEntry] = [] # Called when there are new things to stream over replication diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index 5176a1c18..a1b771109 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -68,7 +68,7 @@ class ThrottleParams: class Pusher(metaclass=abc.ABCMeta): def __init__(self, hs: "HomeServer", pusher_config: PusherConfig): self.hs = hs - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.clock = self.hs.get_clock() self.pusher_id = pusher_config.id diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index bee660893..fecf86034 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -103,7 +103,7 @@ class BulkPushRuleEvaluator: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._event_auth_handler = hs.get_event_auth_handler() # Used by `RulesForRoom` to ensure only one thing mutates the cache at a @@ -366,7 +366,7 @@ class RulesForRoom: """ self.room_id = room_id self.is_mine_id = hs.is_mine_id - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room_push_rule_cache_metrics = room_push_rule_cache_metrics # Used to ensure only one thing mutates the cache at a time. Keyed off diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 39bb2acae..1710dd51b 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -66,7 +66,7 @@ class EmailPusher(Pusher): super().__init__(hs, pusher_config) self.mailer = mailer - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.email = pusher_config.pushkey self.timed_call: Optional[IDelayedCall] = None self.throttle_params: Dict[str, ThrottleParams] = {} diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 52c7ff357..581834452 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -133,7 +133,7 @@ class HttpPusher(Pusher): # XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems # to be largely redundant. perhaps we can remove it. badge = await push_tools.get_badge_count( - self.hs.get_datastore(), + self.hs.get_datastores().main, self.user_id, group_by_room=self._group_unread_count_by_room, ) @@ -283,7 +283,7 @@ class HttpPusher(Pusher): tweaks = push_rule_evaluator.tweaks_for_actions(push_action.actions) badge = await push_tools.get_badge_count( - self.hs.get_datastore(), + self.hs.get_datastores().main, self.user_id, group_by_room=self._group_unread_count_by_room, ) diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 3df8452ee..649a4f49d 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -112,7 +112,7 @@ class Mailer: self.template_text = template_text self.send_email_handler = hs.get_send_email_handler() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.state_store = self.hs.get_storage().state self.macaroon_gen = self.hs.get_macaroon_generator() self.state_handler = self.hs.get_state_handler() diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 7912311d2..d0cc657b4 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -59,7 +59,7 @@ class PusherPool: def __init__(self, hs: "HomeServer"): self.hs = hs self.pusher_factory = PusherFactory(hs) - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.clock = self.hs.get_clock() # We shard the handling of push notifications by user ID. diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index f2f40129f..3d6364572 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -63,7 +63,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): super().__init__(hs) self.device_list_updater = hs.get_device_handler().device_list_updater - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() @staticmethod diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index d529c8a19..3e7300b4a 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -68,7 +68,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.clock = hs.get_clock() self.federation_event_handler = hs.get_federation_event_handler() @@ -167,7 +167,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.registry = hs.get_federation_registry() @@ -214,7 +214,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.registry = hs.get_federation_registry() @@ -260,7 +260,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main @staticmethod async def _serialize_payload(room_id: str) -> JsonDict: # type: ignore[override] @@ -297,7 +297,7 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main @staticmethod async def _serialize_payload(room_id: str, room_version: RoomVersion) -> JsonDict: # type: ignore[override] diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 0145858e4..663bff573 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -50,7 +50,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): super().__init__(hs) self.federation_handler = hs.get_federation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() @staticmethod @@ -119,7 +119,7 @@ class ReplicationRemoteKnockRestServlet(ReplicationEndpoint): super().__init__(hs) self.federation_handler = hs.get_federation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() @staticmethod @@ -188,7 +188,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.member_handler = hs.get_room_member_handler() @@ -258,7 +258,7 @@ class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.member_handler = hs.get_room_member_handler() @@ -325,7 +325,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): super().__init__(hs) self.registeration_handler = hs.get_registration_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.distributor = hs.get_distributor() diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index c7f751b70..6c8f8388f 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -36,7 +36,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.registration_handler = hs.get_registration_handler() @staticmethod @@ -112,7 +112,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.registration_handler = hs.get_registration_handler() @staticmethod diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index 33e98daf8..ce7817683 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -69,7 +69,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.clock = hs.get_clock() diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index d59ce7ccf..1b8479b0b 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -111,7 +111,7 @@ class ReplicationDataHandler: """ def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self._reactor = hs.get_reactor() self._clock = hs.get_clock() @@ -340,7 +340,7 @@ class FederationSenderHandler: def __init__(self, hs: "HomeServer"): assert hs.should_send_federation() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._is_mine_id = hs.is_mine_id self._hs = hs diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 17e157239..0d2013a3c 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -95,7 +95,7 @@ class ReplicationCommandHandler: def __init__(self, hs: "HomeServer"): self._replication_data_handler = hs.get_replication_data_handler() self._presence_handler = hs.get_presence_handler() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._notifier = hs.get_notifier() self._clock = hs.get_clock() self._instance_id = hs.get_instance_id() diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index ecd6190f5..494e42a2b 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -72,7 +72,7 @@ class ReplicationStreamer: """ def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.notifier = hs.get_notifier() self._instance_name = hs.get_instance_name() diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 914b9eae8..23d631a76 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -239,7 +239,7 @@ class BackfillStream(Stream): ROW_TYPE = BackfillStreamRow def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), self._current_token, @@ -267,7 +267,7 @@ class PresenceStream(Stream): ROW_TYPE = PresenceStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main if hs.get_instance_name() in hs.config.worker.writers.presence: # on the presence writer, query the presence handler @@ -355,7 +355,7 @@ class ReceiptsStream(Stream): ROW_TYPE = ReceiptsStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_max_receipt_stream_id), @@ -374,7 +374,7 @@ class PushRulesStream(Stream): ROW_TYPE = PushRulesStreamRow def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), @@ -401,7 +401,7 @@ class PushersStream(Stream): ROW_TYPE = PushersStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), @@ -434,7 +434,7 @@ class CachesStream(Stream): ROW_TYPE = CachesStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), store.get_cache_stream_token_for_writer, @@ -455,7 +455,7 @@ class DeviceListsStream(Stream): ROW_TYPE = DeviceListsStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_device_stream_token), @@ -474,7 +474,7 @@ class ToDeviceStream(Stream): ROW_TYPE = ToDeviceStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_to_device_stream_token), @@ -495,7 +495,7 @@ class TagAccountDataStream(Stream): ROW_TYPE = TagAccountDataStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_max_account_data_stream_id), @@ -516,7 +516,7 @@ class AccountDataStream(Stream): ROW_TYPE = AccountDataStreamRow def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(self.store.get_max_account_data_stream_id), @@ -585,7 +585,7 @@ class GroupServerStream(Stream): ROW_TYPE = GroupsStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_group_stream_token), @@ -604,7 +604,7 @@ class UserSignatureStream(Stream): ROW_TYPE = UserSignatureStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_device_stream_token), diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 50c4a5ba0..26f4fa7cf 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -124,7 +124,7 @@ class EventsStream(Stream): NAME = "events" def __init__(self, hs: "HomeServer"): - self._store = hs.get_datastore() + self._store = hs.get_datastores().main super().__init__( hs.get_instance_name(), self._store._stream_id_gen.get_current_token_for_writer, diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index ba0d989d8..6de302f81 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -116,7 +116,7 @@ class PurgeHistoryRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.pagination_handler = hs.get_pagination_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( diff --git a/synapse/rest/admin/background_updates.py b/synapse/rest/admin/background_updates.py index e9bce22a3..93a78db81 100644 --- a/synapse/rest/admin/background_updates.py +++ b/synapse/rest/admin/background_updates.py @@ -112,7 +112,7 @@ class BackgroundUpdateStartJobRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self._auth, request) diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index d9905ff56..cef46ba0d 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -44,7 +44,7 @@ class DeviceRestServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.is_mine = hs.is_mine async def on_GET( @@ -113,7 +113,7 @@ class DevicesRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.is_mine = hs.is_mine async def on_GET( @@ -144,7 +144,7 @@ class DeleteDevicesRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.is_mine = hs.is_mine async def on_POST( diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py index 38477f8ea..6d634eef7 100644 --- a/synapse/rest/admin/event_reports.py +++ b/synapse/rest/admin/event_reports.py @@ -53,7 +53,7 @@ class EventReportsRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) @@ -115,7 +115,7 @@ class EventReportDetailRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, report_id: str diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py index d162e0081..023ed9214 100644 --- a/synapse/rest/admin/federation.py +++ b/synapse/rest/admin/federation.py @@ -48,7 +48,7 @@ class ListDestinationsRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self._auth, request) @@ -105,7 +105,7 @@ class DestinationRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, destination: str @@ -165,7 +165,7 @@ class DestinationMembershipRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, destination: str @@ -221,7 +221,7 @@ class DestinationResetConnectionRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._authenticator = Authenticator(hs) async def on_POST( diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 299f5c9eb..8ca57bdb2 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -47,7 +47,7 @@ class QuarantineMediaInRoom(RestServlet): ] def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( @@ -74,7 +74,7 @@ class QuarantineMediaByUser(RestServlet): PATTERNS = admin_patterns("/user/(?P[^/]*)/media/quarantine$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( @@ -103,7 +103,7 @@ class QuarantineMediaByID(RestServlet): ) def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( @@ -132,7 +132,7 @@ class UnquarantineMediaByID(RestServlet): ) def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( @@ -156,7 +156,7 @@ class ProtectMediaByID(RestServlet): PATTERNS = admin_patterns("/media/protect/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( @@ -178,7 +178,7 @@ class UnprotectMediaByID(RestServlet): PATTERNS = admin_patterns("/media/unprotect/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( @@ -200,7 +200,7 @@ class ListMediaInRoom(RestServlet): PATTERNS = admin_patterns("/room/(?P[^/]*)/media$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_GET( @@ -251,7 +251,7 @@ class DeleteMediaByID(RestServlet): PATTERNS = admin_patterns("/media/(?P[^/]*)/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.server_name = hs.hostname self.media_repository = hs.get_media_repository() @@ -283,7 +283,7 @@ class DeleteMediaByDateSize(RestServlet): PATTERNS = admin_patterns("/media/(?P[^/]*)/delete$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.server_name = hs.hostname self.media_repository = hs.get_media_repository() @@ -352,7 +352,7 @@ class UserMediaRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.media_repository = hs.get_media_repository() async def on_GET( diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py index 04948b640..af606e925 100644 --- a/synapse/rest/admin/registration_tokens.py +++ b/synapse/rest/admin/registration_tokens.py @@ -71,7 +71,7 @@ class ListRegistrationTokensRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) @@ -109,7 +109,7 @@ class NewRegistrationTokenRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() # A string of all the characters allowed to be in a registration_token self.allowed_chars = string.ascii_letters + string.digits + "._~-" @@ -260,7 +260,7 @@ class RegistrationTokenRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]: """Retrieve a registration token.""" diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 5b706efbc..f4736a3da 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -65,7 +65,7 @@ class RoomRestV2Servlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._pagination_handler = hs.get_pagination_handler() async def on_DELETE( @@ -188,7 +188,7 @@ class ListRoomRestServlet(RestServlet): PATTERNS = admin_patterns("/rooms$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() @@ -278,7 +278,7 @@ class RoomRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room_shutdown_handler = hs.get_room_shutdown_handler() self.pagination_handler = hs.get_pagination_handler() @@ -382,7 +382,7 @@ class RoomMembersRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, room_id: str @@ -408,7 +408,7 @@ class RoomStateRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() @@ -525,7 +525,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): def __init__(self, hs: "HomeServer"): super().__init__(hs) self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.event_creation_handler = hs.get_event_creation_handler() self.state_handler = hs.get_state_handler() self.is_mine_id = hs.is_mine_id @@ -670,7 +670,7 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet): def __init__(self, hs: "HomeServer"): super().__init__(hs) self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_DELETE( self, request: SynapseRequest, room_identifier: str @@ -781,7 +781,7 @@ class BlockRoomRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, room_id: str diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py index 7a6546372..3b142b840 100644 --- a/synapse/rest/admin/statistics.py +++ b/synapse/rest/admin/statistics.py @@ -38,7 +38,7 @@ class UserMediaStatisticsRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index c2617ee30..8e29ada8a 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -66,7 +66,7 @@ class UsersRestServletV2(RestServlet): """ def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() @@ -156,7 +156,7 @@ class UserRestServletV2(RestServlet): self.hs = hs self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth_handler = hs.get_auth_handler() self.profile_handler = hs.get_profile_handler() self.set_password_handler = hs.get_set_password_handler() @@ -588,7 +588,7 @@ class DeactivateAccountRestServlet(RestServlet): self._deactivate_account_handler = hs.get_deactivate_account_handler() self.auth = hs.get_auth() self.is_mine = hs.is_mine - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_POST( self, request: SynapseRequest, target_user_id: str @@ -674,7 +674,7 @@ class ResetPasswordRestServlet(RestServlet): PATTERNS = admin_patterns("/reset_password/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self._set_password_handler = hs.get_set_password_handler() @@ -717,7 +717,7 @@ class SearchUsersRestServlet(RestServlet): PATTERNS = admin_patterns("/search_users/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.is_mine = hs.is_mine @@ -775,7 +775,7 @@ class UserAdminServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/admin$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.is_mine = hs.is_mine @@ -835,7 +835,7 @@ class UserMembershipRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, user_id: str @@ -864,7 +864,7 @@ class PushersRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_GET( @@ -905,7 +905,7 @@ class UserTokenRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/login$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self.is_mine_id = hs.is_mine_id @@ -974,7 +974,7 @@ class ShadowBanRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/shadow_ban$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.is_mine_id = hs.is_mine_id @@ -1026,7 +1026,7 @@ class RateLimitRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/override_ratelimit$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.is_mine_id = hs.is_mine_id @@ -1129,7 +1129,7 @@ class AccountDataRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._is_mine_id = hs.is_mine_id async def on_GET( diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index 5802de5b7..4b217882e 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -60,7 +60,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs - self.datastore = hs.get_datastore() + self.datastore = hs.get_datastores().main self.config = hs.config self.identity_handler = hs.get_identity_handler() @@ -114,7 +114,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # This avoids a potential account hijack by requesting a password reset to # an email address which is controlled by the attacker but which, after # canonicalisation, matches the one in our database. - existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( + existing_user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( "email", email ) @@ -168,7 +168,7 @@ class PasswordRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() - self.datastore = self.hs.get_datastore() + self.datastore = self.hs.get_datastores().main self.password_policy_handler = hs.get_password_policy_handler() self._set_password_handler = hs.get_set_password_handler() @@ -347,7 +347,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): self.hs = hs self.config = hs.config self.identity_handler = hs.get_identity_handler() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: self.mailer = Mailer( @@ -450,7 +450,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.hs = hs super().__init__() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.identity_handler = hs.get_identity_handler() async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: @@ -533,7 +533,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): super().__init__() self.config = hs.config self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: self._failure_email_template = ( self.config.email.email_add_threepid_template_failure_html @@ -600,7 +600,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): super().__init__() self.config = hs.config self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.identity_handler = hs.get_identity_handler() async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: @@ -634,7 +634,7 @@ class ThreepidRestServlet(RestServlet): self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() - self.datastore = self.hs.get_datastore() + self.datastore = self.hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) @@ -768,7 +768,7 @@ class ThreepidUnbindRestServlet(RestServlet): self.hs = hs self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() - self.datastore = self.hs.get_datastore() + self.datastore = self.hs.get_datastores().main async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: """Unbind the given 3pid from a specific identity server, or identity servers that are diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py index 58b8adbd3..bfe985939 100644 --- a/synapse/rest/client/account_data.py +++ b/synapse/rest/client/account_data.py @@ -42,7 +42,7 @@ class AccountDataServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = hs.get_account_data_handler() async def on_PUT( @@ -90,7 +90,7 @@ class RoomAccountDataServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = hs.get_account_data_handler() async def on_PUT( diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py index ee247e3d1..e181a0dde 100644 --- a/synapse/rest/client/directory.py +++ b/synapse/rest/client/directory.py @@ -47,7 +47,7 @@ class ClientDirectoryServer(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.directory_handler = hs.get_directory_handler() self.auth = hs.get_auth() @@ -129,7 +129,7 @@ class ClientDirectoryListServer(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.directory_handler = hs.get_directory_handler() self.auth = hs.get_auth() @@ -173,7 +173,7 @@ class ClientAppserviceDirectoryListServer(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.directory_handler = hs.get_directory_handler() self.auth = hs.get_auth() diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py index 672c82106..916f5230f 100644 --- a/synapse/rest/client/events.py +++ b/synapse/rest/client/events.py @@ -39,7 +39,7 @@ class EventStreamRestServlet(RestServlet): super().__init__() self.event_stream_handler = hs.get_event_stream_handler() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) diff --git a/synapse/rest/client/groups.py b/synapse/rest/client/groups.py index a7e9aa3e9..7e1149c7f 100644 --- a/synapse/rest/client/groups.py +++ b/synapse/rest/client/groups.py @@ -705,7 +705,7 @@ class GroupAdminUsersInviteServlet(RestServlet): self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.is_mine_id = hs.is_mine_id @_validate_group_id @@ -854,7 +854,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main @_validate_group_id async def on_PUT( @@ -879,7 +879,7 @@ class PublicisedGroupsForUserServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.groups_handler = hs.get_groups_local_handler() async def on_GET( @@ -901,7 +901,7 @@ class PublicisedGroupsForUsersServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.groups_handler = hs.get_groups_local_handler() async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: diff --git a/synapse/rest/client/initial_sync.py b/synapse/rest/client/initial_sync.py index 49b1037b2..cfadcb8e5 100644 --- a/synapse/rest/client/initial_sync.py +++ b/synapse/rest/client/initial_sync.py @@ -33,7 +33,7 @@ class InitialSyncRestServlet(RestServlet): super().__init__() self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index 730c18f08..ce806e3c1 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -198,7 +198,7 @@ class KeyChangesServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index f9994658c..c9d44c596 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -104,13 +104,13 @@ class LoginRestServlet(RestServlet): self._well_known_builder = WellKnownBuilder(hs) self._address_ratelimiter = Ratelimiter( - store=hs.get_datastore(), + store=hs.get_datastores().main, clock=hs.get_clock(), rate_hz=self.hs.config.ratelimiting.rc_login_address.per_second, burst_count=self.hs.config.ratelimiting.rc_login_address.burst_count, ) self._account_ratelimiter = Ratelimiter( - store=hs.get_datastore(), + store=hs.get_datastores().main, clock=hs.get_clock(), rate_hz=self.hs.config.ratelimiting.rc_login_account.per_second, burst_count=self.hs.config.ratelimiting.rc_login_account.burst_count, diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index 8e427a96a..20377a9ac 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -35,7 +35,7 @@ class NotificationsServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() diff --git a/synapse/rest/client/openid.py b/synapse/rest/client/openid.py index add56d699..820682ec4 100644 --- a/synapse/rest/client/openid.py +++ b/synapse/rest/client/openid.py @@ -67,7 +67,7 @@ class IdTokenServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.server_name = hs.config.server.server_name diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py index 8fe75bd75..a93f6fd5e 100644 --- a/synapse/rest/client/push_rule.py +++ b/synapse/rest/client/push_rule.py @@ -57,7 +57,7 @@ class PushRuleRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self._is_worker = hs.config.worker.worker_app is not None diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py index 98604a938..d6487c31d 100644 --- a/synapse/rest/client/pusher.py +++ b/synapse/rest/client/pusher.py @@ -46,7 +46,9 @@ class PushersRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) user = requester.user - pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) + pushers = await self.hs.get_datastores().main.get_pushers_by_user_id( + user.to_string() + ) filtered_pushers = [p.as_dict() for p in pushers] diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index b8a5135e0..70baf50fa 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -123,7 +123,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): request, "email", email ) - existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( + existing_user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( "email", email ) @@ -203,7 +203,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): request, "msisdn", msisdn ) - existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( + existing_user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( "msisdn", msisdn ) @@ -258,7 +258,7 @@ class RegistrationSubmitTokenServlet(RestServlet): self.auth = hs.get_auth() self.config = hs.config self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: self._failure_email_template = ( @@ -385,7 +385,7 @@ class RegistrationTokenValidityRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.ratelimiter = Ratelimiter( store=self.store, clock=hs.get_clock(), @@ -415,7 +415,7 @@ class RegisterRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self.identity_handler = hs.get_identity_handler() diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 2cab83c4e..487ea38b5 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -85,7 +85,7 @@ class RelationPaginationServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() self.event_handler = hs.get_event_handler() @@ -190,7 +190,7 @@ class RelationAggregationPaginationServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.event_handler = hs.get_event_handler() async def on_GET( @@ -282,7 +282,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() self.event_handler = hs.get_event_handler() diff --git a/synapse/rest/client/report_event.py b/synapse/rest/client/report_event.py index d4a4adb50..6e962a453 100644 --- a/synapse/rest/client/report_event.py +++ b/synapse/rest/client/report_event.py @@ -38,7 +38,7 @@ class ReportEventRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_POST( self, request: SynapseRequest, room_id: str, event_id: str diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 90355e44b..5ccfe5a92 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -477,7 +477,7 @@ class RoomMemberListRestServlet(RestServlet): super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, room_id: str @@ -553,7 +553,7 @@ class RoomMessageListRestServlet(RestServlet): self._hs = hs self.pagination_handler = hs.get_pagination_handler() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, room_id: str @@ -621,7 +621,7 @@ class RoomInitialSyncRestServlet(RestServlet): super().__init__() self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, room_id: str @@ -642,7 +642,7 @@ class RoomEventServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.clock = hs.get_clock() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self.event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() self.auth = hs.get_auth() @@ -1027,7 +1027,7 @@ class JoinedRoomsRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: @@ -1116,7 +1116,7 @@ class TimestampLookupRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self.timestamp_lookup_handler = hs.get_timestamp_lookup_handler() async def on_GET( diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py index 4b6be3832..0048973e5 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py @@ -75,7 +75,7 @@ class RoomBatchSendEventRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() self.room_batch_handler = hs.get_room_batch_handler() diff --git a/synapse/rest/client/shared_rooms.py b/synapse/rest/client/shared_rooms.py index 09a46737d..e669fa789 100644 --- a/synapse/rest/client/shared_rooms.py +++ b/synapse/rest/client/shared_rooms.py @@ -41,7 +41,7 @@ class UserSharedRoomsServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_directory_active = hs.config.server.update_user_directory async def on_GET( diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index f9615da52..f3018ff69 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -103,7 +103,7 @@ class SyncRestServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.sync_handler = hs.get_sync_handler() self.clock = hs.get_clock() self.filtering = hs.get_filtering() diff --git a/synapse/rest/client/tags.py b/synapse/rest/client/tags.py index c88cb9367..ca638755c 100644 --- a/synapse/rest/client/tags.py +++ b/synapse/rest/client/tags.py @@ -39,7 +39,7 @@ class TagListServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, user_id: str, room_id: str diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index 3d2afacc5..25f9ea285 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -78,7 +78,7 @@ class ConsentResource(DirectServeHtmlResource): super().__init__() self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.registration_handler = hs.get_registration_handler() # this is required by the request_handler wrapper diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 3923ba843..3525d6ae5 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -94,7 +94,7 @@ class RemoteKey(DirectServeJsonResource): super().__init__() self.fetcher = ServerKeyFetcher(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.federation_domain_whitelist = ( hs.config.federation.federation_domain_whitelist diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 71b9a34b1..6c414402b 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -75,7 +75,7 @@ class MediaRepository: self.client = hs.get_federation_http_client() self.clock = hs.get_clock() self.server_name = hs.hostname - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.max_upload_size = hs.config.media.max_upload_size self.max_image_pixels = hs.config.media.max_image_pixels diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index c08b60d10..14ea88b24 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -134,7 +134,7 @@ class PreviewUrlResource(DirectServeJsonResource): self.filepaths = media_repo.filepaths self.max_spider_size = hs.config.media.max_spider_size self.server_name = hs.hostname - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.client = SimpleHttpClient( hs, treq_args={"browser_like_redirects": True}, diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index ed91ef5a4..53b156524 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -50,7 +50,7 @@ class ThumbnailResource(DirectServeJsonResource): ): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.media_repo = media_repo self.media_storage = media_storage self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index fde28d08c..e73e431dc 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -37,7 +37,7 @@ class UploadResource(DirectServeJsonResource): self.media_repo = media_repo self.filepaths = media_repo.filepaths - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.server_name = hs.hostname self.auth = hs.get_auth() diff --git a/synapse/rest/synapse/client/password_reset.py b/synapse/rest/synapse/client/password_reset.py index 28a67f04e..6ac9dbc7c 100644 --- a/synapse/rest/synapse/client/password_reset.py +++ b/synapse/rest/synapse/client/password_reset.py @@ -44,7 +44,7 @@ class PasswordResetSubmitTokenResource(DirectServeHtmlResource): super().__init__() self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._local_threepid_handling_disabled_due_to_email_config = ( hs.config.email.local_threepid_handling_disabled_due_to_email_config diff --git a/synapse/server.py b/synapse/server.py index 4c07f2101..b5e2a319b 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -17,7 +17,7 @@ # homeservers; either as a full homeserver as a real application, or a small # partial one for unit test mocking. -# Imports required for the default HomeServer() implementation + import abc import functools import logging @@ -134,7 +134,7 @@ from synapse.server_notices.worker_server_notices_sender import ( WorkerServerNoticesSender, ) from synapse.state import StateHandler, StateResolutionHandler -from synapse.storage import Databases, DataStore, Storage +from synapse.storage import Databases, Storage from synapse.streams.events import EventSources from synapse.types import DomainSpecificString, ISynapseReactor from synapse.util import Clock @@ -225,7 +225,7 @@ class HomeServer(metaclass=abc.ABCMeta): # This is overridden in derived application classes # (such as synapse.app.homeserver.SynapseHomeServer) and gives the class to be - # instantiated during setup() for future return by get_datastore() + # instantiated during setup() for future return by get_datastores() DATASTORE_CLASS = abc.abstractproperty() tls_server_context_factory: Optional[IOpenSSLContextFactory] @@ -355,12 +355,6 @@ class HomeServer(metaclass=abc.ABCMeta): def get_clock(self) -> Clock: return Clock(self._reactor) - def get_datastore(self) -> DataStore: - if not self.datastores: - raise Exception("HomeServer.setup must be called before getting datastores") - - return self.datastores.main - def get_datastores(self) -> Databases: if not self.datastores: raise Exception("HomeServer.setup must be called before getting datastores") @@ -374,7 +368,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_registration_ratelimiter(self) -> Ratelimiter: return Ratelimiter( - store=self.get_datastore(), + store=self.get_datastores().main, clock=self.get_clock(), rate_hz=self.config.ratelimiting.rc_registration.per_second, burst_count=self.config.ratelimiting.rc_registration.burst_count, @@ -847,7 +841,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_request_ratelimiter(self) -> RequestRatelimiter: return RequestRatelimiter( - self.get_datastore(), + self.get_datastores().main, self.get_clock(), self.config.ratelimiting.rc_message, self.config.ratelimiting.rc_admin_redaction, diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index e09a25591..698ca742e 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -32,7 +32,7 @@ class ConsentServerNotices: def __init__(self, hs: "HomeServer"): self._server_notices_manager = hs.get_server_notices_manager() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._users_in_progress: Set[str] = set() diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index 8522930b5..015dd08f0 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -36,7 +36,7 @@ class ResourceLimitsServerNotices: def __init__(self, hs: "HomeServer"): self._server_notices_manager = hs.get_server_notices_manager() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._auth = hs.get_auth() self._config = hs.config self._resouce_limited = False diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index 0cf60236f..7b4814e04 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -29,7 +29,7 @@ SERVER_NOTICE_ROOM_TAG = "m.server_notice" class ServerNoticesManager: def __init__(self, hs: "HomeServer"): - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._config = hs.config self._account_data_handler = hs.get_account_data_handler() self._room_creation_handler = hs.get_room_creation_handler() diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 67e8bc6ec..fcc24ad12 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -126,7 +126,7 @@ class StateHandler: def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state_store = hs.get_storage().state self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index 8f09dd8e8..e9a0cdc6b 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -112,7 +112,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): for tp in self.hs.config.server.mau_limits_reserved_threepids[ : self.hs.config.server.max_mau_value ]: - user_id = await self.hs.get_datastore().get_user_id_by_threepid( + user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( tp["medium"], canonicalise_email(tp["address"]) ) if user_id: diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 4ec2a713c..fb8fe1729 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -48,7 +48,7 @@ class EventSources: # all the attributes of `_EventSourcesInner` are annotated. *(attribute.type(hs) for attribute in attr.fields(_EventSourcesInner)) # type: ignore[misc] ) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def get_current_token(self) -> StreamToken: push_rules_key = self.store.get_max_push_rules_stream_id() diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 4b53b6d40..686d17c0d 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -16,6 +16,8 @@ from unittest.mock import Mock import pymacaroons +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.auth import Auth from synapse.api.constants import UserTypes from synapse.api.errors import ( @@ -26,8 +28,10 @@ from synapse.api.errors import ( ResourceLimitError, ) from synapse.appservice import ApplicationService +from synapse.server import HomeServer from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import Requester +from synapse.util import Clock from tests import unittest from tests.test_utils import simple_async_mock @@ -36,10 +40,10 @@ from tests.utils import mock_getRawHeaders class AuthTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): self.store = Mock() - hs.get_datastore = Mock(return_value=self.store) + hs.datastores.main = self.store hs.get_auth_handler().store = self.store self.auth = Auth(hs) diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index b7fc33dc9..973f0f7fa 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -40,7 +40,7 @@ def MockEvent(**kwargs): class FilteringTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.filtering = hs.get_filtering() - self.datastore = hs.get_datastore() + self.datastore = hs.get_datastores().main def test_errors_on_invalid_filters(self): invalid_filters = [ diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index dcf0110c1..4ef754a18 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -8,7 +8,7 @@ from tests import unittest class TestRatelimiter(unittest.HomeserverTestCase): def test_allowed_via_can_do_action(self): limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", _time_now_s=0) @@ -39,7 +39,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): as_requester = create_requester("@user:example.com", app_service=appservice) limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=0) @@ -70,7 +70,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): as_requester = create_requester("@user:example.com", app_service=appservice) limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=0) @@ -92,7 +92,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): def test_allowed_via_ratelimit(self): limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) # Shouldn't raise @@ -116,7 +116,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): """ # Create a Ratelimiter with a very low allowed rate_hz and burst_count limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) # First attempt should be allowed @@ -162,7 +162,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): """ # Create a Ratelimiter with a very low allowed rate_hz and burst_count limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) # First attempt should be allowed @@ -190,7 +190,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): def test_pruning(self): limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) self.get_success_or_raise( limiter.can_do_action(None, key="test_id_1", _time_now_s=0) @@ -208,7 +208,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): """Test that users that have ratelimiting disabled in the DB aren't ratelimited. """ - store = self.hs.get_datastore() + store = self.hs.get_datastores().main user_id = "@user:test" requester = create_requester(user_id) @@ -233,7 +233,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): def test_multiple_actions(self): limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=3 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3 ) # Test that 4 actions aren't allowed with a maximum burst of 3. allowed, time_allowed = self.get_success_or_raise( diff --git a/tests/app/test_phone_stats_home.py b/tests/app/test_phone_stats_home.py index 19eb4c79d..df731eb59 100644 --- a/tests/app/test_phone_stats_home.py +++ b/tests/app/test_phone_stats_home.py @@ -32,7 +32,7 @@ class PhoneHomeTestCase(HomeserverTestCase): self.helper.send(room_id, "message", tok=access_token) # Check the R30 results do not count that user. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) # Advance 30 days (+ 1 second, because strict inequality causes issues if we are @@ -40,7 +40,7 @@ class PhoneHomeTestCase(HomeserverTestCase): self.reactor.advance(30 * ONE_DAY_IN_SECONDS + 1) # (Make sure the user isn't somehow counted by this point.) - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) # Send a message (this counts as activity) @@ -51,21 +51,21 @@ class PhoneHomeTestCase(HomeserverTestCase): self.reactor.advance(2 * 60 * 60) # *Now* the user is counted. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 1, "unknown": 1}) # Advance 29 days. The user has now not posted for 29 days. self.reactor.advance(29 * ONE_DAY_IN_SECONDS) # The user is still counted. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 1, "unknown": 1}) # Advance another day. The user has now not posted for 30 days. self.reactor.advance(ONE_DAY_IN_SECONDS) # The user is now no longer counted in R30. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) def test_r30_minimum_usage_using_default_config(self): @@ -84,7 +84,7 @@ class PhoneHomeTestCase(HomeserverTestCase): self.helper.send(room_id, "message", tok=access_token) # Check the R30 results do not count that user. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) # Advance 30 days (+ 1 second, because strict inequality causes issues if we are @@ -92,7 +92,7 @@ class PhoneHomeTestCase(HomeserverTestCase): self.reactor.advance(30 * ONE_DAY_IN_SECONDS + 1) # (Make sure the user isn't somehow counted by this point.) - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) # Send a message (this counts as activity) @@ -103,14 +103,14 @@ class PhoneHomeTestCase(HomeserverTestCase): self.reactor.advance(2 * 60 * 60) # *Now* the user is counted. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 1, "unknown": 1}) # Advance 27 days. The user has now not posted for 27 days. self.reactor.advance(27 * ONE_DAY_IN_SECONDS) # The user is still counted. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 1, "unknown": 1}) # Advance another day. The user has now not posted for 28 days. @@ -119,7 +119,7 @@ class PhoneHomeTestCase(HomeserverTestCase): # The user is now no longer counted in R30. # (This is because the user_ips table has been pruned, which by default # only preserves the last 28 days of entries.) - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) def test_r30_user_must_be_retained_for_at_least_a_month(self): @@ -135,7 +135,7 @@ class PhoneHomeTestCase(HomeserverTestCase): self.helper.send(room_id, "message", tok=access_token) # Check the user does not contribute to R30 yet. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) for _ in range(30): @@ -144,14 +144,16 @@ class PhoneHomeTestCase(HomeserverTestCase): self.helper.send(room_id, "I'm still here", tok=access_token) # Notice that the user *still* does not contribute to R30! - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success( + self.hs.get_datastores().main.count_r30_users() + ) self.assertEqual(r30_results, {"all": 0}) self.reactor.advance(ONE_DAY_IN_SECONDS) self.helper.send(room_id, "Still here!", tok=access_token) # *Now* the user appears in R30. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 1, "unknown": 1}) @@ -196,7 +198,7 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase): # (user_daily_visits is updated every 5 minutes using a looping call.) self.reactor.advance(FIVE_MINUTES_IN_SECONDS) - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Check the R30 results do not count that user. r30_results = self.get_success(store.count_r30v2_users()) @@ -275,7 +277,7 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase): # (user_daily_visits is updated every 5 minutes using a looping call.) self.reactor.advance(FIVE_MINUTES_IN_SECONDS) - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Check the user does not contribute to R30 yet. r30_results = self.get_success(store.count_r30v2_users()) @@ -347,7 +349,7 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase): # (user_daily_visits is updated every 5 minutes using a looping call.) self.reactor.advance(FIVE_MINUTES_IN_SECONDS) - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Check that the user does not contribute to R30v2, even though it's been # more than 30 days since registration. diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 17a9fb63a..3a4d50271 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -179,7 +179,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): kr = keyring.Keyring(self.hs) key1 = signedjson.key.generate_signing_key(1) - r = self.hs.get_datastore().store_server_verify_keys( + r = self.hs.get_datastores().main.store_server_verify_keys( "server9", time.time() * 1000, [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))], @@ -272,7 +272,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): ) key1 = signedjson.key.generate_signing_key(1) - r = self.hs.get_datastore().store_server_verify_keys( + r = self.hs.get_datastores().main.store_server_verify_keys( "server9", time.time() * 1000, [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), None))], @@ -448,7 +448,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): # check that the perspectives store is correctly updated lookup_triplet = (SERVER_NAME, testverifykey_id, None) key_json = self.get_success( - self.hs.get_datastore().get_server_keys_json([lookup_triplet]) + self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) ) res = key_json[lookup_triplet] self.assertEqual(len(res), 1) @@ -564,7 +564,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): # check that the perspectives store is correctly updated lookup_triplet = (SERVER_NAME, testverifykey_id, None) key_json = self.get_success( - self.hs.get_datastore().get_server_keys_json([lookup_triplet]) + self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) ) res = key_json[lookup_triplet] self.assertEqual(len(res), 1) @@ -683,7 +683,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): # check that the perspectives store is correctly updated lookup_triplet = (SERVER_NAME, testverifykey_id, None) key_json = self.get_success( - self.hs.get_datastore().get_server_keys_json([lookup_triplet]) + self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) ) res = key_json[lookup_triplet] self.assertEqual(len(res), 1) diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py index ca27388ae..defbc68c1 100644 --- a/tests/events/test_snapshot.py +++ b/tests/events/test_snapshot.py @@ -28,7 +28,7 @@ class TestEventContext(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.user_id = self.register_user("u1", "pass") diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index e40ef9587..9336181c9 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -55,7 +55,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): self.assertTrue(complexity > 0, complexity) # Artificially raise the complexity - store = self.hs.get_datastore() + store = self.hs.get_datastores().main store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23) # Get the room complexity again -- make sure it's our artificial value @@ -149,7 +149,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): ) # Artificially raise the complexity - self.hs.get_datastore().get_current_state_event_counts = ( + self.hs.get_datastores().main.get_current_state_event_counts = ( lambda x: make_awaitable(600) ) diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index f0aa8ed9d..2873b4d43 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -64,7 +64,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): Dictionary of { event_id: str, stream_ordering: int } """ event_id, stream_ordering = self.get_success( - self.hs.get_datastore().db_pool.execute( + self.hs.get_datastores().main.db_pool.execute( "test:get_destination_rooms", None, """ @@ -125,7 +125,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): self.pump() lsso_1 = self.get_success( - self.hs.get_datastore().get_destination_last_successful_stream_ordering( + self.hs.get_datastores().main.get_destination_last_successful_stream_ordering( "host2" ) ) @@ -141,7 +141,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): event_id_2 = self.helper.send(room, "rabbits!", tok=u1_token)["event_id"] lsso_2 = self.get_success( - self.hs.get_datastore().get_destination_last_successful_stream_ordering( + self.hs.get_datastores().main.get_destination_last_successful_stream_ordering( "host2" ) ) @@ -216,7 +216,9 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): # let's also clear any backoffs self.get_success( - self.hs.get_datastore().set_destination_retry_timings("host2", None, 0, 0) + self.hs.get_datastores().main.set_destination_retry_timings( + "host2", None, 0, 0 + ) ) # bring the remote online and clear the received pdu list @@ -296,13 +298,13 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): # destination_rooms should already be populated, but let us pretend that we already # sent (successfully) up to and including event id 2 - event_2 = self.get_success(self.hs.get_datastore().get_event(event_id_2)) + event_2 = self.get_success(self.hs.get_datastores().main.get_event(event_id_2)) # also fetch event 5 so we know its last_successful_stream_ordering later - event_5 = self.get_success(self.hs.get_datastore().get_event(event_id_5)) + event_5 = self.get_success(self.hs.get_datastores().main.get_event(event_id_5)) self.get_success( - self.hs.get_datastore().set_destination_last_successful_stream_ordering( + self.hs.get_datastores().main.set_destination_last_successful_stream_ordering( "host2", event_2.internal_metadata.stream_ordering ) ) @@ -359,7 +361,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): # ASSERT: # - All servers are up to date so none should have outstanding catch-up outstanding_when_successful = self.get_success( - self.hs.get_datastore().get_catch_up_outstanding_destinations(None) + self.hs.get_datastores().main.get_catch_up_outstanding_destinations(None) ) self.assertEqual(outstanding_when_successful, []) @@ -370,7 +372,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): # - Mark zzzerver as being backed-off from now = self.clock.time_msec() self.get_success( - self.hs.get_datastore().set_destination_retry_timings( + self.hs.get_datastores().main.set_destination_retry_timings( "zzzerver", now, now, 24 * 60 * 60 * 1000 # retry in 1 day ) ) @@ -382,14 +384,14 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): # - all remotes are outstanding # - they are returned in batches of 25, in order outstanding_1 = self.get_success( - self.hs.get_datastore().get_catch_up_outstanding_destinations(None) + self.hs.get_datastores().main.get_catch_up_outstanding_destinations(None) ) self.assertEqual(len(outstanding_1), 25) self.assertEqual(outstanding_1, server_names[0:25]) outstanding_2 = self.get_success( - self.hs.get_datastore().get_catch_up_outstanding_destinations( + self.hs.get_datastores().main.get_catch_up_outstanding_destinations( outstanding_1[-1] ) ) @@ -457,7 +459,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): ) self.get_success( - self.hs.get_datastore().set_destination_last_successful_stream_ordering( + self.hs.get_datastores().main.set_destination_last_successful_stream_ordering( "host2", event_1.internal_metadata.stream_ordering ) ) diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index b2376e2db..60e0c31f4 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -176,7 +176,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): def get_users_who_share_room_with_user(user_id): return defer.succeed({"@user2:host2"}) - hs.get_datastore().get_users_who_share_room_with_user = ( + hs.get_datastores().main.get_users_who_share_room_with_user = ( get_users_who_share_room_with_user ) @@ -395,7 +395,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # run the prune job self.reactor.advance(10) self.get_success( - self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1) + self.hs.get_datastores().main._prune_old_outbound_device_pokes(prune_age=1) ) # recover the server @@ -445,7 +445,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # run the prune job self.reactor.advance(10) self.get_success( - self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1) + self.hs.get_datastores().main._prune_old_outbound_device_pokes(prune_age=1) ) # recover the server diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index 686f42ab4..adf0535d9 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -198,7 +198,7 @@ class FederationKnockingTestCase( ] def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main # We're not going to be properly signing events as our remote homeserver is fake, # therefore disable event signature checks. diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index fe57ff267..9918ff680 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -38,7 +38,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_as_api = Mock() self.mock_scheduler = Mock() hs = Mock() - hs.get_datastore.return_value = self.mock_store + hs.get_datastores.return_value = Mock(main=self.mock_store) self.mock_store.get_received_ts.return_value = make_awaitable(0) self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None) self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable( @@ -355,7 +355,9 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): # Mock out application services, and allow defining our own in tests self._services: List[ApplicationService] = [] - self.hs.get_datastore().get_app_services = Mock(return_value=self._services) + self.hs.get_datastores().main.get_app_services = Mock( + return_value=self._services + ) # A user on the homeserver. self.local_user_device_id = "local_device" @@ -494,7 +496,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): # Create a fake device per message. We can't send to-device messages to # a device that doesn't exist. self.get_success( - self.hs.get_datastore().db_pool.simple_insert_many( + self.hs.get_datastores().main.db_pool.simple_insert_many( desc="test_application_services_receive_burst_of_to_device", table="devices", keys=("user_id", "device_id"), @@ -510,7 +512,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): # Seed the device_inbox table with our fake messages self.get_success( - self.hs.get_datastore().add_messages_to_device_inbox(messages, {}) + self.hs.get_datastores().main.add_messages_to_device_inbox(messages, {}) ) # Now have local_user send a final to-device message to exclusive_as_user. All unsent diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 03b8b8615..0c6e55e72 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -129,7 +129,7 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_mau_limits_exceeded_large(self): self.auth_blocking._limit_usage_by_mau = True - self.hs.get_datastore().get_monthly_active_count = Mock( + self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.large_number_of_users) ) @@ -140,7 +140,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ResourceLimitError, ) - self.hs.get_datastore().get_monthly_active_count = Mock( + self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.large_number_of_users) ) self.get_failure( @@ -156,7 +156,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.auth_blocking._limit_usage_by_mau = True # Set the server to be at the edge of too many users. - self.hs.get_datastore().get_monthly_active_count = Mock( + self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.auth_blocking._max_mau_value) ) @@ -175,7 +175,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ) # If in monthly active cohort - self.hs.get_datastore().user_last_seen_monthly_active = Mock( + self.hs.get_datastores().main.user_last_seen_monthly_active = Mock( return_value=make_awaitable(self.clock.time_msec()) ) self.get_success( @@ -192,7 +192,7 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_mau_limits_not_exceeded(self): self.auth_blocking._limit_usage_by_mau = True - self.hs.get_datastore().get_monthly_active_count = Mock( + self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.small_number_of_users) ) # Ensure does not raise exception @@ -202,7 +202,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ) ) - self.hs.get_datastore().get_monthly_active_count = Mock( + self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.small_number_of_users) ) self.get_success( diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index 8705ff894..a26722884 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -77,7 +77,7 @@ class CasHandlerTestCase(HomeserverTestCase): def test_map_cas_user_to_existing_user(self): """Existing users can log in with CAS account.""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) ) diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py index 01096a158..ddda36c5a 100644 --- a/tests/handlers/test_deactivate_account.py +++ b/tests/handlers/test_deactivate_account.py @@ -34,7 +34,7 @@ class DeactivateAccountTestCase(HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self.user = self.register_user("user", "pass") self.token = self.login("user", "pass") diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 43031e07e..683677fd0 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -28,7 +28,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver("server", federation_http_client=None) self.handler = hs.get_device_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main return hs def prepare(self, reactor, clock, hs): @@ -263,7 +263,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase): self.handler = hs.get_device_handler() self.registration = hs.get_registration_handler() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main return hs def test_dehydrate_and_rehydrate_device(self): diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 0ea4e753e..65ab107d0 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -46,7 +46,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.handler = hs.get_directory_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.my_room = RoomAlias.from_string("#my-room:test") self.your_room = RoomAlias.from_string("#your-room:test") @@ -174,7 +174,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = hs.get_directory_handler() self.state_handler = hs.get_state_handler() @@ -289,7 +289,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = hs.get_directory_handler() self.state_handler = hs.get_state_handler() diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 734ed84d7..9338ab92e 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -34,7 +34,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.handler = hs.get_e2e_keys_handler() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main def test_query_local_devices_no_devices(self): """If the user has no devices, we expect an empty list.""" diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 496b58172..e8b4e39d1 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -45,7 +45,7 @@ class FederationTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver(federation_http_client=None) self.handler = hs.get_federation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state_store = hs.get_storage().state self._event_auth_handler = hs.get_event_auth_handler() return hs diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index 5816295d8..f4f7ab484 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -44,7 +44,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token) self.info = self.get_success( - self.hs.get_datastore().get_user_by_access_token( + self.hs.get_datastores().main.get_user_by_access_token( self.access_token, ) ) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index a552d8182..e8418b663 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -856,7 +856,7 @@ class OidcHandlerTestCase(HomeserverTestCase): auth_handler.complete_sso_login.reset_mock() # Test if the mxid is already taken - store = self.hs.get_datastore() + store = self.hs.get_datastores().main user3 = UserID.from_string("@test_user_3:test") self.get_success( store.register_user(user_id=user3.to_string(), password_hash=None) @@ -872,7 +872,7 @@ class OidcHandlerTestCase(HomeserverTestCase): @override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}}) def test_map_userinfo_to_existing_user(self): """Existing users can log in with OpenID Connect when allow_existing_users is True.""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main user = UserID.from_string("@test_user:test") self.get_success( store.register_user(user_id=user.to_string(), password_hash=None) @@ -996,7 +996,7 @@ class OidcHandlerTestCase(HomeserverTestCase): auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) ) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 671dc7d08..61d28603a 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -43,7 +43,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): servlets = [admin.register_servlets] def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main def test_offline_to_online(self): wheel_timer = Mock() @@ -891,7 +891,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): # self.event_builder_for_2 = EventBuilderFactory(hs) # self.event_builder_for_2.hostname = "test2" - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state = hs.get_state_handler() self._event_auth_handler = hs.get_event_auth_handler() diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 60235e569..69e299fc1 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -48,7 +48,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): return hs def prepare(self, reactor, clock, hs: HomeServer): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.frank = UserID.from_string("@1234abcd:test") self.bob = UserID.from_string("@4567:test") @@ -325,7 +325,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): properties are "mimetype" (for the file's type) and "size" (for the file's size). """ - store = self.hs.get_datastore() + store = self.hs.get_datastores().main for name, props in names_and_props.items(): self.get_success( diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index cd6f2c77a..51ee667ab 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -154,7 +154,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.handler = self.hs.get_registration_handler() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.lots_of_users = 100 self.small_number_of_users = 1 @@ -172,7 +172,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertGreater(len(result_token), 20) def test_if_user_exists(self): - store = self.hs.get_datastore() + store = self.hs.get_datastores().main frank = UserID.from_string("@frank:test") self.get_success( store.register_user(user_id=frank.to_string(), password_hash=None) @@ -760,7 +760,7 @@ class RemoteAutoJoinTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.handler = self.hs.get_registration_handler() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main @override_config({"auto_join_rooms": ["#room:remotetest"]}) def test_auto_create_auto_join_remote_room(self): diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 50551aa6e..23941abed 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -142,7 +142,7 @@ class SamlHandlerTestCase(HomeserverTestCase): @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) def test_map_saml_response_to_existing_user(self): """Existing users can log in with SAML account.""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) ) @@ -217,7 +217,7 @@ class SamlHandlerTestCase(HomeserverTestCase): sso_handler.render_error = Mock(return_value=None) # register a user to occupy the first-choice MXID - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) ) diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 56207f4db..ecd78fa36 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -33,7 +33,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = self.hs.get_stats_handler() def _add_background_updates(self): diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 07a760e91..66b0bd4d1 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -41,7 +41,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs: HomeServer): self.sync_handler = self.hs.get_sync_handler() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main # AuthBlocking reads from the hs' config on initialization. We need to # modify its config instead of the hs' @@ -248,7 +248,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): # the prev_events used when creating the join event, such that the ban does not # precede the join. mocked_get_prev_events = patch.object( - self.hs.get_datastore(), + self.hs.get_datastores().main, "get_prev_events_for_room", new_callable=MagicMock, return_value=make_awaitable([last_room_creation_event_id]), diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 000f9b9fd..e461e0359 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -91,7 +91,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.event_source = hs.get_event_sources().sources.typing - self.datastore = hs.get_datastore() + self.datastore = hs.get_datastores().main self.datastore.get_destination_retry_timings = Mock( return_value=defer.succeed(None) ) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 482c90ef6..e159169e2 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -77,7 +77,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = hs.get_user_directory_handler() self.event_builder_factory = self.hs.get_event_builder_factory() self.event_creation_handler = self.hs.get_event_creation_handler() diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index d16cd141a..c3f20f969 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -41,7 +41,7 @@ class ModuleApiTestCase(HomeserverTestCase): ] def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main self.module_api = homeserver.get_module_api() self.event_creation_handler = homeserver.get_event_creation_handler() self.sync_handler = homeserver.get_sync_handler() diff --git a/tests/push/test_email.py b/tests/push/test_email.py index f8cba7b64..7a3b0d675 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -102,13 +102,13 @@ class EmailPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(self.access_token) + self.hs.get_datastores().main.get_user_by_access_token(self.access_token) ) self.token_id = user_tuple.token_id # We need to add email to account before we can create a pusher. self.get_success( - hs.get_datastore().user_add_threepid( + hs.get_datastores().main.user_add_threepid( self.user_id, "email", "a@example.com", 0, 0 ) ) @@ -128,7 +128,7 @@ class EmailPusherTests(HomeserverTestCase): ) self.auth_handler = hs.get_auth_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def test_need_validated_email(self): """Test that we can only add an email pusher if the user has validated @@ -375,7 +375,7 @@ class EmailPusherTests(HomeserverTestCase): # check that the pusher for that email address has been deleted pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 0) @@ -388,14 +388,14 @@ class EmailPusherTests(HomeserverTestCase): # This resembles the old behaviour, which the background update below is intended # to clean up. self.get_success( - self.hs.get_datastore().user_delete_threepid( + self.hs.get_datastores().main.user_delete_threepid( self.user_id, "email", "a@example.com" ) ) # Run the "remove_deleted_email_pushers" background job self.get_success( - self.hs.get_datastore().db_pool.simple_insert( + self.hs.get_datastores().main.db_pool.simple_insert( table="background_updates", values={ "update_name": "remove_deleted_email_pushers", @@ -406,14 +406,14 @@ class EmailPusherTests(HomeserverTestCase): ) # ... and tell the DataStore that it hasn't finished all updates yet - self.hs.get_datastore().db_pool.updates._all_done = False + self.hs.get_datastores().main.db_pool.updates._all_done = False # Now let's actually drive the updates to completion self.wait_for_background_updates() # Check that all pushers with unlinked addresses were deleted pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 0) @@ -428,7 +428,7 @@ class EmailPusherTests(HomeserverTestCase): """ # Get the stream ordering before it gets sent pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -439,7 +439,7 @@ class EmailPusherTests(HomeserverTestCase): # It hasn't succeeded yet, so the stream ordering shouldn't have moved pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -458,7 +458,7 @@ class EmailPusherTests(HomeserverTestCase): # The stream ordering has increased pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) diff --git a/tests/push/test_http.py b/tests/push/test_http.py index e1e3fb97c..c284beb37 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -62,7 +62,7 @@ class HTTPPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -108,7 +108,7 @@ class HTTPPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -138,7 +138,7 @@ class HTTPPusherTests(HomeserverTestCase): # Get the stream ordering before it gets sent pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -149,7 +149,7 @@ class HTTPPusherTests(HomeserverTestCase): # It hasn't succeeded yet, so the stream ordering shouldn't have moved pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -170,7 +170,7 @@ class HTTPPusherTests(HomeserverTestCase): # The stream ordering has increased pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -192,7 +192,7 @@ class HTTPPusherTests(HomeserverTestCase): # The stream ordering has increased, again pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -224,7 +224,7 @@ class HTTPPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -344,7 +344,7 @@ class HTTPPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -430,7 +430,7 @@ class HTTPPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -507,7 +507,7 @@ class HTTPPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -613,7 +613,7 @@ class HTTPPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 9fc50f885..a7a05a564 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -68,7 +68,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): # Since we use sqlite in memory databases we need to make sure the # databases objects are the same. - self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool + self.worker_hs.get_datastores().main.db_pool = hs.get_datastores().main.db_pool # Normally we'd pass in the handler to `setup_test_homeserver`, which would # eventually hit "Install @cache_in_self attributes" in tests/utils.py. @@ -233,7 +233,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): # We may have an attempt to connect to redis for the external cache already. self.connect_any_redis_attempts() - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.database_pool = store.db_pool self.reactor.lookups["testserv"] = "1.2.3.4" @@ -332,7 +332,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): lambda: self._handle_http_replication_attempt(worker_hs, port), ) - store = worker_hs.get_datastore() + store = worker_hs.get_datastores().main store.db_pool._db_pool = self.database_pool._db_pool # Set up TCP replication between master and the new worker if we don't diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 83e89383f..85be79d19 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -30,8 +30,8 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase): self.reconnect() - self.master_store = hs.get_datastore() - self.slaved_store = self.worker_hs.get_datastore() + self.master_store = hs.get_datastores().main + self.slaved_store = self.worker_hs.get_datastores().main self.storage = hs.get_storage() def replicate(self): diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py index cdd052001..50fbff5f3 100644 --- a/tests/replication/tcp/streams/test_account_data.py +++ b/tests/replication/tcp/streams/test_account_data.py @@ -23,7 +23,7 @@ from tests.replication._base import BaseStreamTestCase class AccountDataStreamTestCase(BaseStreamTestCase): def test_update_function_room_account_data_limit(self): """Test replication with many room account data updates""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # generate lots of account data updates updates = [] @@ -69,7 +69,7 @@ class AccountDataStreamTestCase(BaseStreamTestCase): def test_update_function_global_account_data_limit(self): """Test replication with many global account data updates""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # generate lots of account data updates updates = [] diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index f198a9488..f9d5da723 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -136,7 +136,7 @@ class EventsStreamTestCase(BaseStreamTestCase): # this is the point in the DAG where we make a fork fork_point: List[str] = self.get_success( - self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) + self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) ) events = [ @@ -291,7 +291,7 @@ class EventsStreamTestCase(BaseStreamTestCase): # this is the point in the DAG where we make a fork fork_point: List[str] = self.get_success( - self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) + self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) ) events: List[EventBase] = [] diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index 38e292c1a..eb0011784 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -32,7 +32,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): # tell the master to send a new receipt self.get_success( - self.hs.get_datastore().insert_receipt( + self.hs.get_datastores().main.insert_receipt( "!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1} ) ) @@ -56,7 +56,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): self.test_handler.on_rdata.reset_mock() self.get_success( - self.hs.get_datastore().insert_receipt( + self.hs.get_datastores().main.insert_receipt( "!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2} ) ) diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 92a5b53e1..ba1a63c0d 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -204,7 +204,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): def create_room_with_remote_server(self, user, token, remote_server="other_server"): room = self.helper.create_room_as(user, tok=token) - store = self.hs.get_datastore() + store = self.hs.get_datastores().main federation = self.hs.get_federation_event_handler() prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room)) diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py index 4094a75f3..8f4f6688c 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py @@ -50,7 +50,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): # Register a pusher user_dict = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_dict.token_id diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index 596ba5a0c..5f142e84c 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -47,7 +47,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): self.other_access_token = self.login("otheruser", "pass") self.room_creator = self.hs.get_room_creation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def default_config(self): conf = super().default_config() @@ -99,7 +99,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): persisted_on_1 = False persisted_on_2 = False - store = self.hs.get_datastore() + store = self.hs.get_datastores().main user_id = self.register_user("user", "pass") access_token = self.login("user", "pass") @@ -166,7 +166,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): user_id = self.register_user("user", "pass") access_token = self.login("user", "pass") - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Create two room on the different workers. self._create_room(room_id1, user_id, access_token) @@ -194,7 +194,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): # # Worker2's event stream position will not advance until we call # __aexit__ again. - worker_store2 = worker_hs2.get_datastore() + worker_store2 = worker_hs2.get_datastores().main assert isinstance(worker_store2._stream_id_gen, MultiWriterIdGenerator) actx = worker_store2._stream_id_gen.get_next() diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py index 1e3fe9c62..fb36aa994 100644 --- a/tests/rest/admin/test_background_updates.py +++ b/tests/rest/admin/test_background_updates.py @@ -36,7 +36,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py index 71068d16c..929bbdc37 100644 --- a/tests/rest/admin/test_federation.py +++ b/tests/rest/admin/test_federation.py @@ -35,7 +35,7 @@ class FederationTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -537,7 +537,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 86aff7575..0d47dd0af 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -634,7 +634,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: media_repo = hs.get_media_repository_resource() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.server_name = hs.hostname self.admin_user = self.register_user("admin", "pass", admin=True) @@ -767,7 +767,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: media_repo = hs.get_media_repository_resource() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py index 8513b1d2d..8354250ec 100644 --- a/tests/rest/admin/test_registration_tokens.py +++ b/tests/rest/admin/test_registration_tokens.py @@ -34,7 +34,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 23da0ad73..09c48e85c 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -50,7 +50,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" self.event_creation_handler._consent_uri_builder = consent_uri_builder - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -465,7 +465,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" self.event_creation_handler._consent_uri_builder = consent_uri_builder - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -2239,7 +2239,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py index 3c59f5f76..2c855bff9 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py @@ -38,7 +38,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room_shutdown_handler = hs.get_room_shutdown_handler() self.pagination_handler = hs.get_pagination_handler() self.server_notices_manager = self.hs.get_server_notices_manager() diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 272637e96..a60ea0a56 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -410,7 +410,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): even if the MAU limit is reached. """ handler = self.hs.get_registration_handler() - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Set monthly active users to the limit store.get_monthly_active_count = Mock( @@ -455,7 +455,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v2/users" def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -913,7 +913,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -1167,7 +1167,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth_handler = hs.get_auth_handler() # create users and get access tokens @@ -2609,7 +2609,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -2737,7 +2737,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.media_repo = hs.get_media_repository_resource() self.filepaths = MediaFilePaths(hs.config.media.media_store_path) @@ -3317,7 +3317,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -3609,7 +3609,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -3687,7 +3687,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -3913,7 +3913,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index afaa597f6..aa019c9a4 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -77,7 +77,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): return hs def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.submit_token_resource = PasswordResetSubmitTokenResource(hs) def test_basic_password_reset(self): @@ -398,7 +398,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase): self.deactivate(user_id, tok) - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Check that the user has been marked as deactivated. self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id))) @@ -409,7 +409,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase): def test_pending_invites(self): """Tests that deactivating a user rejects every pending invite for them.""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main inviter_id = self.register_user("inviter", "test") inviter_tok = self.login("inviter", "test") @@ -527,7 +527,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase): namespaces={"users": [{"regex": user_id, "exclusive": True}]}, sender=user_id, ) - self.hs.get_datastore().services_cache.append(appservice) + self.hs.get_datastores().main.services_cache.append(appservice) whoami = self._whoami(as_token) self.assertEqual( @@ -586,7 +586,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): return self.hs def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_id = self.register_user("kermit", "test") self.user_id_tok = self.login("kermit", "test") diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py index 475c6bed3..a573cc3c2 100644 --- a/tests/rest/client/test_filter.py +++ b/tests/rest/client/test_filter.py @@ -32,7 +32,7 @@ class FilterTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.filtering = hs.get_filtering() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def test_add_filter(self): channel = self.make_request( diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 19f5e4653..26d0d83e0 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -1101,8 +1101,8 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): }, ) - self.hs.get_datastore().services_cache.append(self.service) - self.hs.get_datastore().services_cache.append(self.another_service) + self.hs.get_datastores().main.services_cache.append(self.service) + self.hs.get_datastores().main.services_cache.append(self.another_service) return self.hs def test_login_appservice_user(self): diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py index ead883ded..b9647d5bd 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py @@ -292,7 +292,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): properties are "mimetype" (for the file's type) and "size" (for the file's size). """ - store = self.hs.get_datastore() + store = self.hs.get_datastores().main for name, props in names_and_props.items(): self.get_success( diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 0f1c47dcb..2835d86e5 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -56,7 +56,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): sender="@as:test", ) - self.hs.get_datastore().services_cache.append(appservice) + self.hs.get_datastores().main.services_cache.append(appservice) request_data = json.dumps( {"username": "as_user_kermit", "type": APP_SERVICE_REGISTRATION_TYPE} ) @@ -80,7 +80,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): sender="@as:test", ) - self.hs.get_datastore().services_cache.append(appservice) + self.hs.get_datastores().main.services_cache.append(appservice) request_data = json.dumps({"username": "as_user_kermit"}) channel = self.make_request( @@ -210,7 +210,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): username = "kermit" device_id = "frogfone" token = "abcd" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( "registration_tokens", @@ -316,7 +316,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): @override_config({"registration_requires_token": True}) def test_POST_registration_token_limit_uses(self): token = "abcd" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Create token that can be used once self.get_success( store.db_pool.simple_insert( @@ -391,7 +391,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_POST_registration_token_expiry(self): token = "abcd" now = self.hs.get_clock().time_msec() - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Create token that expired yesterday self.get_success( store.db_pool.simple_insert( @@ -439,7 +439,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_POST_registration_token_session_expiry(self): """Test `pending` is decremented when an uncompleted session expires.""" token = "abcd" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( "registration_tokens", @@ -530,7 +530,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): 3. Expire the session """ token = "abcd" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( "registration_tokens", @@ -657,7 +657,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): # Add a threepid self.get_success( - self.hs.get_datastore().user_add_threepid( + self.hs.get_datastores().main.user_add_threepid( user_id=user_id, medium="email", address=email, @@ -941,7 +941,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.email_attempts = [] self.hs.get_send_email_handler()._sendmail = sendmail - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main return self.hs @@ -1126,10 +1126,12 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): # We need to set these directly, instead of in the homeserver config dict above. # This is due to account validity-related config options not being read by # Synapse when account_validity.enabled is False. - self.hs.get_datastore()._account_validity_period = self.validity_period - self.hs.get_datastore()._account_validity_startup_job_max_delta = self.max_delta + self.hs.get_datastores().main._account_validity_period = self.validity_period + self.hs.get_datastores().main._account_validity_startup_job_max_delta = ( + self.max_delta + ) - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main return self.hs @@ -1163,7 +1165,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): def test_GET_token_valid(self): token = "abcd" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( "registration_tokens", diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index dfd9ffcb9..5687dea48 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -53,7 +53,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): return config def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_id, self.user_token = self._create_user("alice") self.user2_id, self.user2_token = self._create_user("bob") @@ -107,7 +107,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): # Unless that event is referenced from another event! self.get_success( - self.hs.get_datastore().db_pool.simple_insert( + self.hs.get_datastores().main.db_pool.simple_insert( table="event_relations", values={ "event_id": "bar", diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index fe5b536d9..c41a1c14a 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -51,7 +51,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): self.user_id = self.register_user("user", "password") self.token = self.login("user", "password") - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.serializer = self.hs.get_event_client_serializer() self.clock = self.hs.get_clock() @@ -114,7 +114,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): """Tests that synapse.visibility.filter_events_for_client correctly filters out outdated events """ - store = self.hs.get_datastore() + store = self.hs.get_datastores().main storage = self.hs.get_storage() room_id = self.helper.create_room_as(self.user_id, tok=self.token) events = [] diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index b7f086927..1afd96b8f 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -65,7 +65,7 @@ class RoomBase(unittest.HomeserverTestCase): async def _insert_client_ip(*args, **kwargs): return None - self.hs.get_datastore().insert_client_ip = _insert_client_ip + self.hs.get_datastores().main.insert_client_ip = _insert_client_ip return self.hs @@ -667,7 +667,7 @@ class RoomsCreateTestCase(RoomBase): # Add the current user to the ratelimit overrides, allowing them no ratelimiting. self.get_success( - self.hs.get_datastore().set_ratelimit_for_user(self.user_id, 0, 0) + self.hs.get_datastores().main.set_ratelimit_for_user(self.user_id, 0, 0) ) # Test that the invites aren't ratelimited anymore. @@ -1060,7 +1060,9 @@ class RoomJoinRatelimitTestCase(RoomBase): user_id = self.register_user("testuser", "password") # Check that the new user successfully joined the four rooms - rooms = self.get_success(self.hs.get_datastore().get_rooms_for_user(user_id)) + rooms = self.get_success( + self.hs.get_datastores().main.get_rooms_for_user(user_id) + ) self.assertEqual(len(rooms), 4) @@ -1184,7 +1186,7 @@ class RoomMessageListTestCase(RoomBase): self.assertTrue("end" in channel.json_body) def test_room_messages_purge(self): - store = self.hs.get_datastore() + store = self.hs.get_datastores().main pagination_handler = self.hs.get_pagination_handler() # Send a first message in the room, which will be removed by the purge. diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index b0c44af03..7d0e66b53 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -34,7 +34,7 @@ class _ShadowBannedBase(unittest.HomeserverTestCase): self.banned_user_id = self.register_user("banned", "test") self.banned_access_token = self.login("banned", "test") - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.get_success( self.store.set_shadow_banned(UserID.from_string(self.banned_user_id), True) diff --git a/tests/rest/client/test_shared_rooms.py b/tests/rest/client/test_shared_rooms.py index 283eccd53..c42c8aff6 100644 --- a/tests/rest/client/test_shared_rooms.py +++ b/tests/rest/client/test_shared_rooms.py @@ -36,7 +36,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): return self.setup_test_homeserver(config=config) def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = hs.get_user_directory_handler() def _get_shared_rooms(self, token, other_user) -> FakeChannel: diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index cd4af2b1f..e06256136 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -299,7 +299,7 @@ class SyncKnockTestCase( ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.url = "/sync?since=%s" self.next_batch = "s0" diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py index ee0abd529..de312cb63 100644 --- a/tests/rest/client/test_typing.py +++ b/tests/rest/client/test_typing.py @@ -57,7 +57,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): async def _insert_client_ip(*args, **kwargs): return None - hs.get_datastore().insert_client_ip = _insert_client_ip + hs.get_datastores().main.insert_client_ip = _insert_client_ip return hs diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py index a42388b26..7f79336ab 100644 --- a/tests/rest/client/test_upgrade_room.py +++ b/tests/rest/client/test_upgrade_room.py @@ -32,7 +32,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.creator = self.register_user("creator", "pass") self.creator_token = self.login(self.creator, "pass") diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 4cf1ed5dd..6878ccddb 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -243,7 +243,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): media_resource = hs.get_media_repository_resource() self.download_resource = media_resource.children[b"download"] self.thumbnail_resource = media_resource.children[b"thumbnail"] - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.media_repo = hs.get_media_repository() self.media_id = "example.com/12345" diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 36c495954..02b96c9e6 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -242,7 +242,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): return c def prepare(self, reactor, clock, hs): - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.server_notices_sender = self.hs.get_server_notices_sender() self.server_notices_manager = self.hs.get_server_notices_manager() self.event_source = self.hs.get_event_sources() diff --git a/tests/storage/databases/main/test_deviceinbox.py b/tests/storage/databases/main/test_deviceinbox.py index 36c933b9e..50c20c5b9 100644 --- a/tests/storage/databases/main/test_deviceinbox.py +++ b/tests/storage/databases/main/test_deviceinbox.py @@ -26,7 +26,7 @@ class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_id = self.register_user("foo", "pass") def test_background_remove_deleted_devices_from_device_inbox(self): diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 5ae491ff5..59def6e59 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -37,7 +37,7 @@ from tests import unittest class HaveSeenEventsTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store: EventsWorkerStore = hs.get_datastore() + self.store: EventsWorkerStore = hs.get_datastores().main # insert some test data for rid in ("room1", "room2"): @@ -122,7 +122,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store: EventsWorkerStore = hs.get_datastore() + self.store: EventsWorkerStore = hs.get_datastores().main self.user = self.register_user("user", "pass") self.token = self.login(self.user, "pass") @@ -163,7 +163,7 @@ class DatabaseOutageTestCase(unittest.HomeserverTestCase): """Test event fetching during a database outage.""" def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): - self.store: EventsWorkerStore = hs.get_datastore() + self.store: EventsWorkerStore = hs.get_datastores().main self.room_id = f"!room:{hs.hostname}" self.event_ids = [f"event{i}" for i in range(20)] diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py index d326a1d6a..3ac464696 100644 --- a/tests/storage/databases/main/test_lock.py +++ b/tests/storage/databases/main/test_lock.py @@ -20,7 +20,7 @@ from tests import unittest class LockTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs: HomeServer): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def test_simple_lock(self): """Test that we can take out a lock and that while we hold it nobody diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py index 7496974da..9abd0cb44 100644 --- a/tests/storage/databases/main/test_room.py +++ b/tests/storage/databases/main/test_room.py @@ -28,7 +28,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_id = self.register_user("foo", "pass") self.token = self.login("foo", "pass") diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 200b9198f..4899cd5c3 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -20,7 +20,7 @@ from tests import unittest class UpsertManyTests(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.storage = hs.get_datastore() + self.storage = hs.get_datastores().main self.table_name = "table_" + secrets.token_hex(6) self.get_success( diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py index d697d2bc1..272cd3540 100644 --- a/tests/storage/test_account_data.py +++ b/tests/storage/test_account_data.py @@ -21,7 +21,7 @@ from tests import unittest class IgnoredUsersTestCase(unittest.HomeserverTestCase): def prepare(self, hs, reactor, clock): - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.user = "@user:test" def _update_ignore_list( diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index ddcb7f554..50703ccae 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -467,7 +467,7 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer ) -> None: self.service = Mock(id="foo") - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.get_success( self.store.set_appservice_state(self.service, ApplicationServiceState.UP) ) diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 6156dfac4..39dcc094b 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -24,7 +24,7 @@ from tests.test_utils import make_awaitable, simple_async_mock class BackgroundUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates + self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates # the base test class should have run the real bg updates for us self.assertTrue( self.get_success(self.updates.has_completed_background_updates()) @@ -42,7 +42,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): # the target runtime for each bg update target_background_update_duration_ms = 100 - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( "background_updates", @@ -102,7 +102,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates + self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates # the base test class should have run the real bg updates for us self.assertTrue( self.get_success(self.updates.has_completed_background_updates()) @@ -138,7 +138,7 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): ) def test_controller(self): - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( "background_updates", diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index a59c28f89..ce89c9691 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -30,7 +30,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): """ def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main self.room_creator = homeserver.get_room_creation_handler() # Create a test user and room @@ -242,7 +242,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): return self.setup_test_homeserver(config=config) def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main self.room_creator = homeserver.get_room_creation_handler() self.event_creator_handler = homeserver.get_event_creation_handler() diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index c8ac67e35..49ad3c132 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -35,7 +35,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): return hs def prepare(self, hs, reactor, clock): - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main def test_insert_new_client_ip(self): self.reactor.advance(12345678) @@ -666,7 +666,7 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase): return hs def prepare(self, hs, reactor, clock): - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.user_id = self.register_user("bob", "abc123", True) def test_request_with_xforwarded(self): diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index b547bf8d9..21ffc5a90 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -19,7 +19,7 @@ from tests.unittest import HomeserverTestCase class DeviceStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def test_store_new_device(self): self.get_success( diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index 43628ce44..7b72a9242 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -19,7 +19,7 @@ from tests.unittest import HomeserverTestCase class DirectoryStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room = RoomID.from_string("!abcde:test") self.alias = RoomAlias.from_string("#my-room:test") diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py index 7556171d8..fb96ab3a2 100644 --- a/tests/storage/test_e2e_room_keys.py +++ b/tests/storage/test_e2e_room_keys.py @@ -28,7 +28,7 @@ room_key: RoomKey = { class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver("server", federation_http_client=None) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main return hs def test_room_keys_version_delete(self): diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 3bf6e337f..0f04493ad 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -17,7 +17,7 @@ from tests.unittest import HomeserverTestCase class EndToEndKeyStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def test_key_without_device_name(self): now = 1470174257070 diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index e3273a93f..401020fd6 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -30,7 +30,7 @@ from tests.unittest import HomeserverTestCase class EventChainStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._next_stream_ordering = 1 def test_simple(self): @@ -492,7 +492,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_id = self.register_user("foo", "pass") self.token = self.login("foo", "pass") self.requester = create_requester(self.user_id) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 667ca90a4..645d564d1 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -31,7 +31,7 @@ import tests.utils class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def test_get_prev_events_for_room(self): room_id = "@ROOM:local" diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 738f3ad1d..c9e3b9fa7 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -30,7 +30,7 @@ HIGHLIGHT = [ class EventPushActionsStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.persist_events_store = hs.get_datastores().persist_events def test_get_unread_push_actions_for_user_in_range_for_http(self): diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index a8639d8f8..ef5e25873 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -32,7 +32,7 @@ class ExtremPruneTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.state = self.hs.get_state_handler() self.persistence = self.hs.get_storage().persistence - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.register_user("user", "pass") self.token = self.login("user", "pass") @@ -341,7 +341,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.state = self.hs.get_state_handler() self.persistence = self.hs.get_storage().persistence - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main def test_remote_user_rooms_cache_invalidated(self): """Test that if the server leaves a room the `get_rooms_for_user` cache diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 748607828..6ac4b93f9 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -26,7 +26,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): skip = "Requires Postgres" def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) @@ -459,7 +459,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): skip = "Requires Postgres" def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) @@ -585,7 +585,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): skip = "Requires Postgres" def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py index a94b5fd72..905909552 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py @@ -37,7 +37,7 @@ KEY_2 = decode_verify_key_base64( class KeyStoreTestCase(tests.unittest.HomeserverTestCase): def test_get_server_verify_keys(self): - store = self.hs.get_datastore() + store = self.hs.get_datastores().main key_id_1 = "ed25519:key1" key_id_2 = "ed25519:KEY_ID_2" @@ -74,7 +74,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): def test_cache(self): """Check that updates correctly invalidate the cache.""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main key_id_1 = "ed25519:key1" key_id_2 = "ed25519:key2" diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py index f8d11bac4..4ca212fd1 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py @@ -22,7 +22,7 @@ class DataStoreTestCase(unittest.HomeserverTestCase): def setUp(self) -> None: super(DataStoreTestCase, self).setUp() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.user = UserID.from_string("@abcde:test") self.displayname = "Frank" diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index d6b4cdd78..79648d45d 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -45,7 +45,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): return config def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main # Advance the clock a bit reactor.advance(FORTY_DAYS) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index d37736edf..b6f99af2f 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -22,7 +22,7 @@ from tests import unittest class ProfileStoreTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.u_frank = UserID.from_string("@frank:test") diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 22a77c3cc..08cc60237 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -30,7 +30,7 @@ class PurgeTests(HomeserverTestCase): def prepare(self, reactor, clock, hs): self.room_id = self.helper.create_room_as(self.user_id) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = self.hs.get_storage() def test_purge_history(self): @@ -47,7 +47,7 @@ class PurgeTests(HomeserverTestCase): token = self.get_success( self.store.get_topological_token_for_event(last["event_id"]) ) - token_str = self.get_success(token.to_string(self.hs.get_datastore())) + token_str = self.get_success(token.to_string(self.hs.get_datastores().main)) # Purge everything before this topological token self.get_success( diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 8c95a0a2f..03e9cc7d4 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -30,7 +30,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): return config def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 974806528..1fa495f77 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -20,7 +20,7 @@ from tests.unittest import HomeserverTestCase class RegistrationStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_id = "@my-user:test" self.tokens = ["AbCdEfGhIjKlMnOpQrStUvWxYz", "BcDeFgHiJkLmNoPqRsTuVwXyZa"] diff --git a/tests/storage/test_rollback_worker.py b/tests/storage/test_rollback_worker.py index cfc8098af..0baa54312 100644 --- a/tests/storage/test_rollback_worker.py +++ b/tests/storage/test_rollback_worker.py @@ -56,7 +56,7 @@ class WorkerSchemaTests(HomeserverTestCase): def test_rolling_back(self): """Test that workers can start if the DB is a newer schema version""" - db_pool = self.hs.get_datastore().db_pool + db_pool = self.hs.get_datastores().main.db_pool db_conn = LoggingDatabaseConnection( db_pool._db_pool.connect(), db_pool.engine, @@ -72,7 +72,7 @@ class WorkerSchemaTests(HomeserverTestCase): def test_not_upgraded_old_schema_version(self): """Test that workers don't start if the DB has an older schema version""" - db_pool = self.hs.get_datastore().db_pool + db_pool = self.hs.get_datastores().main.db_pool db_conn = LoggingDatabaseConnection( db_pool._db_pool.connect(), db_pool.engine, @@ -92,7 +92,7 @@ class WorkerSchemaTests(HomeserverTestCase): Test that workers don't start if the DB is on the current schema version, but there are still outstanding delta migrations to run. """ - db_pool = self.hs.get_datastore().db_pool + db_pool = self.hs.get_datastores().main.db_pool db_conn = LoggingDatabaseConnection( db_pool._db_pool.connect(), db_pool.engine, diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 31ce7f625..42bfca2a8 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -23,7 +23,7 @@ class RoomStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): # We can't test RoomStore on its own without the DirectoryStore, for # management of the 'room_aliases' table - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room = RoomID.from_string("!abcde:test") self.alias = RoomAlias.from_string("#a-room-name:test") @@ -71,7 +71,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): # Room events need the full datastore, for persist_event() and # get_room_state() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.event_factory = hs.get_event_factory() diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py index 8971ecccb..befaa0fce 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py @@ -46,7 +46,7 @@ class NullByteInsertionTest(HomeserverTestCase): self.assertIn("event_id", response) # Check that search works for the message where the null byte was replaced - store = self.hs.get_datastore() + store = self.hs.get_datastores().main result = self.get_success( store.search_msgs([room_id], "hi bob", ["content.body"]) ) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 5cfdfe9b8..7028f0dfb 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -35,7 +35,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # We can't test the RoomMemberStore on its own without the other event # storage logic - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.u_alice = self.register_user("alice", "pass") self.t_alice = self.login("alice", "pass") @@ -212,7 +212,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main self.room_creator = homeserver.get_room_creation_handler() def test_can_rerun_update(self): diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 28c767ecf..f88f1c55f 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) class StateStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_datastore = self.storage.state.stores.state self.event_builder_factory = hs.get_event_builder_factory() diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index ce782c7e1..6a1cf3305 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -115,7 +115,7 @@ class PaginationTestCase(HomeserverTestCase): ) events, next_key = self.get_success( - self.hs.get_datastore().paginate_room_events( + self.hs.get_datastores().main.paginate_room_events( room_id=self.room_id, from_key=from_token.room_key, to_key=None, diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py index bea9091d3..e05daa285 100644 --- a/tests/storage/test_transactions.py +++ b/tests/storage/test_transactions.py @@ -20,7 +20,7 @@ from tests.unittest import HomeserverTestCase class TransactionStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main def test_get_set_transactions(self): """Tests that we can successfully get a non-existent entry for diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 48f1e9d84..7f1964eb6 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -149,7 +149,7 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_dir_helper = GetUserDirectoryTables(self.store) def _purge_and_rebuild_user_dir(self) -> None: @@ -415,7 +415,7 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): class UserDirectoryStoreTestCase(HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main # alice and bob are both in !room_id. bobby is not but shares # a homeserver with alice. diff --git a/tests/test_federation.py b/tests/test_federation.py index 2b9804aba..c39816de8 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -52,11 +52,13 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) )[0]["room_id"] - self.store = self.homeserver.get_datastore() + self.store = self.homeserver.get_datastores().main # Figure out what the most recent event is most_recent = self.get_success( - self.homeserver.get_datastore().get_latest_event_ids_in_room(self.room_id) + self.homeserver.get_datastores().main.get_latest_event_ids_in_room( + self.room_id + ) )[0] join_event = make_event_from_dict( @@ -185,7 +187,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Register a mock on the store so that the incoming update doesn't fail because # we don't share a room with the user. - store = self.homeserver.get_datastore() + store = self.homeserver.get_datastores().main store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"])) # Manually inject a fake device list update. We need this update to include at diff --git a/tests/test_mau.py b/tests/test_mau.py index 80ab40e25..46bd3075d 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -52,7 +52,7 @@ class TestMauLimit(unittest.HomeserverTestCase): return config def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main def test_simple_deny_mau(self): # Create and sync so that the MAU counts get updated diff --git a/tests/test_state.py b/tests/test_state.py index 76e0e8ca7..90800421f 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -162,7 +162,7 @@ class StateTestCase(unittest.TestCase): hs = Mock( spec_set=[ "config", - "get_datastore", + "get_datastores", "get_storage", "get_auth", "get_state_handler", @@ -173,7 +173,7 @@ class StateTestCase(unittest.TestCase): ] ) hs.config = default_config("tesths", True) - hs.get_datastore.return_value = self.store + hs.get_datastores.return_value = Mock(main=self.store) hs.get_state_handler.return_value = None hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index e9ec9e085..c654e36ee 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -85,7 +85,9 @@ async def create_event( **kwargs, ) -> Tuple[EventBase, EventContext]: if room_version is None: - room_version = await hs.get_datastore().get_room_version_id(kwargs["room_id"]) + room_version = await hs.get_datastores().main.get_room_version_id( + kwargs["room_id"] + ) builder = hs.get_event_builder_factory().for_room_version( KNOWN_ROOM_VERSIONS[room_version], kwargs diff --git a/tests/test_visibility.py b/tests/test_visibility.py index e0b08d67d..219b5660b 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -93,7 +93,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): events_to_filter.append(evt) # the erasey user gets erased - self.get_success(self.hs.get_datastore().mark_user_erased("@erased:local_hs")) + self.get_success( + self.hs.get_datastores().main.mark_user_erased("@erased:local_hs") + ) # ... and the filtering happens. filtered = self.get_success( diff --git a/tests/unittest.py b/tests/unittest.py index 7983c1e8b..0caa8e7a4 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -280,7 +280,7 @@ class HomeserverTestCase(TestCase): # We need a valid token ID to satisfy foreign key constraints. token_id = self.get_success( - self.hs.get_datastore().add_access_token_to_user( + self.hs.get_datastores().main.add_access_token_to_user( self.helper.auth_user_id, "some_fake_token", None, @@ -337,7 +337,7 @@ class HomeserverTestCase(TestCase): def wait_for_background_updates(self) -> None: """Block until all background database updates have completed.""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main while not self.get_success( store.db_pool.updates.has_completed_background_updates() ): @@ -504,7 +504,7 @@ class HomeserverTestCase(TestCase): self.get_success(stor.db_pool.updates.run_background_updates(False)) hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) - stor = hs.get_datastore() + stor = hs.get_datastores().main # Run the database background updates, when running against "master". if hs.__class__.__name__ == "TestHomeServer": @@ -722,14 +722,16 @@ class HomeserverTestCase(TestCase): Add the given event as an extremity to the room. """ self.get_success( - self.hs.get_datastore().db_pool.simple_insert( + self.hs.get_datastores().main.db_pool.simple_insert( table="event_forward_extremities", values={"room_id": room_id, "event_id": event_id}, desc="test_add_extremity", ) ) - self.hs.get_datastore().get_latest_event_ids_in_room.invalidate((room_id,)) + self.hs.get_datastores().main.get_latest_event_ids_in_room.invalidate( + (room_id,) + ) def attempt_wrong_password_login(self, username, password): """Attempts to login as the user with the given password, asserting @@ -775,7 +777,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase): verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version) self.get_success( - hs.get_datastore().store_server_verify_keys( + hs.get_datastores().main.store_server_verify_keys( from_server=self.OTHER_SERVER_NAME, ts_added_ms=clock.time_msec(), verify_keys=[ diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py index 9e1bebdc8..26cb71c64 100644 --- a/tests/util/test_retryutils.py +++ b/tests/util/test_retryutils.py @@ -24,7 +24,7 @@ from tests.unittest import HomeserverTestCase class RetryLimiterTestCase(HomeserverTestCase): def test_new_destination(self): """A happy-path case with a new destination and a successful operation""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) # advance the clock a bit before making the request @@ -38,7 +38,7 @@ class RetryLimiterTestCase(HomeserverTestCase): def test_limiter(self): """General test case which walks through the process of a failing request""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) diff --git a/tests/utils.py b/tests/utils.py index c06fc320f..ef99c72e0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -367,7 +367,7 @@ async def create_room(hs, room_id: str, creator_id: str): """Creates and persist a creation event for the given room""" persistence_store = hs.get_storage().persistence - store = hs.get_datastore() + store = hs.get_datastores().main event_builder_factory = hs.get_event_builder_factory() event_creation_handler = hs.get_event_creation_handler() From 5b2b36809fc3543ed0c9ec587398a09f2e176265 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 23 Feb 2022 12:35:53 +0000 Subject: [PATCH 50/84] Remove more references to `get_datastore` (#12067) These have snuck in since #12031 was started. Also a couple of other cleanups while we're in the area. --- changelog.d/12067.feature | 1 + synapse/handlers/account.py | 4 ++-- synapse/rest/client/account.py | 3 --- tests/rest/client/test_account.py | 4 +++- 4 files changed, 6 insertions(+), 6 deletions(-) create mode 100644 changelog.d/12067.feature diff --git a/changelog.d/12067.feature b/changelog.d/12067.feature new file mode 100644 index 000000000..dc1153c49 --- /dev/null +++ b/changelog.d/12067.feature @@ -0,0 +1 @@ +Implement experimental support for [MSC3720](https://github.com/matrix-org/matrix-doc/pull/3720) (account status endpoints). diff --git a/synapse/handlers/account.py b/synapse/handlers/account.py index f8cfe9f6d..d5badf635 100644 --- a/synapse/handlers/account.py +++ b/synapse/handlers/account.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: class AccountHandler: def __init__(self, hs: "HomeServer"): - self._store = hs.get_datastore() + self._main_store = hs.get_datastores().main self._is_mine = hs.is_mine self._federation_client = hs.get_federation_client() @@ -98,7 +98,7 @@ class AccountHandler: """ status = {"exists": False} - userinfo = await self._store.get_userinfo_by_id(user_id.to_string()) + userinfo = await self._main_store.get_userinfo_by_id(user_id.to_string()) if userinfo is not None: status = { diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index 4b217882e..5587cae98 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -904,9 +904,6 @@ class AccountStatusRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self._auth = hs.get_auth() - self._store = hs.get_datastore() - self._is_mine = hs.is_mine - self._federation_client = hs.get_federation_client() self._account_handler = hs.get_account_handler() async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index aa019c9a4..008d635b7 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -1119,7 +1119,9 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): """Tests that the account status endpoint correctly reports a deactivated user.""" user = self.register_user("someuser", "password") self.get_success( - self.hs.get_datastore().set_user_deactivated_status(user, deactivated=True) + self.hs.get_datastores().main.set_user_deactivated_status( + user, deactivated=True + ) ) self._test_status( From 64c73c6ac88a740ee480a0ad1f9afc8596bccfa4 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Wed, 23 Feb 2022 14:33:19 +0100 Subject: [PATCH 51/84] Add type hints to `tests/rest/client` (#12066) --- changelog.d/12066.misc | 1 + tests/rest/client/test_auth.py | 70 ++++++++------- tests/rest/client/test_capabilities.py | 30 ++++--- tests/rest/client/test_login.py | 120 ++++++++++++++----------- tests/rest/client/test_sync.py | 47 +++++----- 5 files changed, 149 insertions(+), 119 deletions(-) create mode 100644 changelog.d/12066.misc diff --git a/changelog.d/12066.misc b/changelog.d/12066.misc new file mode 100644 index 000000000..0360dbd61 --- /dev/null +++ b/changelog.d/12066.misc @@ -0,0 +1 @@ +Add type hints to `tests/rest/client`. diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 4a68d6657..9653f4583 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -13,17 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. from http import HTTPStatus -from typing import Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from twisted.internet.defer import succeed +from twisted.test.proto_helpers import MemoryReactor +from twisted.web.resource import Resource import synapse.rest.admin from synapse.api.constants import LoginType from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.rest.client import account, auth, devices, login, logout, register from synapse.rest.synapse.client import build_synapse_client_resource_tree +from synapse.server import HomeServer from synapse.storage.database import LoggingTransaction from synapse.types import JsonDict, UserID +from synapse.util import Clock from tests import unittest from tests.handlers.test_oidc import HAS_OIDC @@ -33,11 +37,11 @@ from tests.unittest import override_config, skip_unless class DummyRecaptchaChecker(UserInteractiveAuthChecker): - def __init__(self, hs): + def __init__(self, hs: HomeServer) -> None: super().__init__(hs) - self.recaptcha_attempts = [] + self.recaptcha_attempts: List[Tuple[dict, str]] = [] - def check_auth(self, authdict, clientip): + def check_auth(self, authdict: dict, clientip: str) -> Any: self.recaptcha_attempts.append((authdict, clientip)) return succeed(True) @@ -50,7 +54,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): ] hijack_auth = False - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() @@ -61,7 +65,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): hs = self.setup_test_homeserver(config=config) return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.recaptcha_checker = DummyRecaptchaChecker(hs) auth_handler = hs.get_auth_handler() auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker @@ -101,7 +105,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): self.assertEqual(len(attempts), 1) self.assertEqual(attempts[0][0]["response"], "a") - def test_fallback_captcha(self): + def test_fallback_captcha(self) -> None: """Ensure that fallback auth via a captcha works.""" # Returns a 401 as per the spec channel = self.register( @@ -132,7 +136,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): # We're given a registered user. self.assertEqual(channel.json_body["user_id"], "@user:test") - def test_complete_operation_unknown_session(self): + def test_complete_operation_unknown_session(self) -> None: """ Attempting to mark an invalid session as complete should error. """ @@ -165,7 +169,7 @@ class UIAuthTests(unittest.HomeserverTestCase): register.register_servlets, ] - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() # public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns @@ -182,12 +186,12 @@ class UIAuthTests(unittest.HomeserverTestCase): return config - def create_resource_dict(self): + def create_resource_dict(self) -> Dict[str, Resource]: resource_dict = super().create_resource_dict() resource_dict.update(build_synapse_client_resource_tree(self.hs)) return resource_dict - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_pass = "pass" self.user = self.register_user("test", self.user_pass) self.device_id = "dev1" @@ -229,7 +233,7 @@ class UIAuthTests(unittest.HomeserverTestCase): return channel - def test_ui_auth(self): + def test_ui_auth(self) -> None: """ Test user interactive authentication outside of registration. """ @@ -259,7 +263,7 @@ class UIAuthTests(unittest.HomeserverTestCase): }, ) - def test_grandfathered_identifier(self): + def test_grandfathered_identifier(self) -> None: """Check behaviour without "identifier" dict Synapse used to require clients to submit a "user" field for m.login.password @@ -286,7 +290,7 @@ class UIAuthTests(unittest.HomeserverTestCase): }, ) - def test_can_change_body(self): + def test_can_change_body(self) -> None: """ The client dict can be modified during the user interactive authentication session. @@ -325,7 +329,7 @@ class UIAuthTests(unittest.HomeserverTestCase): }, ) - def test_cannot_change_uri(self): + def test_cannot_change_uri(self) -> None: """ The initial requested URI cannot be modified during the user interactive authentication session. """ @@ -362,7 +366,7 @@ class UIAuthTests(unittest.HomeserverTestCase): ) @unittest.override_config({"ui_auth": {"session_timeout": "5s"}}) - def test_can_reuse_session(self): + def test_can_reuse_session(self) -> None: """ The session can be reused if configured. @@ -409,7 +413,7 @@ class UIAuthTests(unittest.HomeserverTestCase): @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) - def test_ui_auth_via_sso(self): + def test_ui_auth_via_sso(self) -> None: """Test a successful UI Auth flow via SSO This includes: @@ -452,7 +456,7 @@ class UIAuthTests(unittest.HomeserverTestCase): @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) - def test_does_not_offer_password_for_sso_user(self): + def test_does_not_offer_password_for_sso_user(self) -> None: login_resp = self.helper.login_via_oidc("username") user_tok = login_resp["access_token"] device_id = login_resp["device_id"] @@ -464,7 +468,7 @@ class UIAuthTests(unittest.HomeserverTestCase): flows = channel.json_body["flows"] self.assertEqual(flows, [{"stages": ["m.login.sso"]}]) - def test_does_not_offer_sso_for_password_user(self): + def test_does_not_offer_sso_for_password_user(self) -> None: channel = self.delete_device( self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED ) @@ -474,7 +478,7 @@ class UIAuthTests(unittest.HomeserverTestCase): @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) - def test_offers_both_flows_for_upgraded_user(self): + def test_offers_both_flows_for_upgraded_user(self) -> None: """A user that had a password and then logged in with SSO should get both flows""" login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) self.assertEqual(login_resp["user_id"], self.user) @@ -491,7 +495,7 @@ class UIAuthTests(unittest.HomeserverTestCase): @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) - def test_ui_auth_fails_for_incorrect_sso_user(self): + def test_ui_auth_fails_for_incorrect_sso_user(self) -> None: """If the user tries to authenticate with the wrong SSO user, they get an error""" # log the user in login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) @@ -534,7 +538,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): ] hijack_auth = False - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_pass = "pass" self.user = self.register_user("test", self.user_pass) @@ -548,7 +552,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): {"refresh_token": refresh_token}, ) - def is_access_token_valid(self, access_token) -> bool: + def is_access_token_valid(self, access_token: str) -> bool: """ Checks whether an access token is valid, returning whether it is or not. """ @@ -561,7 +565,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): return code == HTTPStatus.OK - def test_login_issue_refresh_token(self): + def test_login_issue_refresh_token(self) -> None: """ A login response should include a refresh_token only if asked. """ @@ -591,7 +595,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): self.assertIn("refresh_token", login_with_refresh.json_body) self.assertIn("expires_in_ms", login_with_refresh.json_body) - def test_register_issue_refresh_token(self): + def test_register_issue_refresh_token(self) -> None: """ A register response should include a refresh_token only if asked. """ @@ -627,7 +631,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): self.assertIn("refresh_token", register_with_refresh.json_body) self.assertIn("expires_in_ms", register_with_refresh.json_body) - def test_token_refresh(self): + def test_token_refresh(self) -> None: """ A refresh token can be used to issue a new access token. """ @@ -665,7 +669,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): ) @override_config({"refreshable_access_token_lifetime": "1m"}) - def test_refreshable_access_token_expiration(self): + def test_refreshable_access_token_expiration(self) -> None: """ The access token should have some time as specified in the config. """ @@ -722,7 +726,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "nonrefreshable_access_token_lifetime": "10m", } ) - def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(self): + def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens( + self, + ) -> None: """ Tests that the expiry times for refreshable and non-refreshable access tokens can be different. @@ -782,7 +788,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): @override_config( {"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"} ) - def test_refresh_token_expiry(self): + def test_refresh_token_expiry(self) -> None: """ The refresh token can be configured to have a limited lifetime. When that lifetime has ended, the refresh token can no longer be used to @@ -834,7 +840,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "session_lifetime": "3m", } ) - def test_ultimate_session_expiry(self): + def test_ultimate_session_expiry(self) -> None: """ The session can be configured to have an ultimate, limited lifetime. """ @@ -882,7 +888,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result ) - def test_refresh_token_invalidation(self): + def test_refresh_token_invalidation(self) -> None: """Refresh tokens are invalidated after first use of the next token. A refresh token is considered invalid if: @@ -987,7 +993,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result ) - def test_many_token_refresh(self): + def test_many_token_refresh(self) -> None: """ If a refresh is performed many times during a session, there shouldn't be extra 'cruft' built up over time. diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py index 989e80176..d1751e155 100644 --- a/tests/rest/client/test_capabilities.py +++ b/tests/rest/client/test_capabilities.py @@ -13,9 +13,13 @@ # limitations under the License. from http import HTTPStatus +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.rest.client import capabilities, login +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.unittest import override_config @@ -29,24 +33,24 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.url = b"/capabilities" hs = self.setup_test_homeserver() self.config = hs.config self.auth_handler = hs.get_auth_handler() return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.localpart = "user" self.password = "pass" self.user = self.register_user(self.localpart, self.password) - def test_check_auth_required(self): + def test_check_auth_required(self) -> None: channel = self.make_request("GET", self.url) self.assertEqual(channel.code, 401) - def test_get_room_version_capabilities(self): + def test_get_room_version_capabilities(self) -> None: access_token = self.login(self.localpart, self.password) channel = self.make_request("GET", self.url, access_token=access_token) @@ -61,7 +65,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): capabilities["m.room_versions"]["default"], ) - def test_get_change_password_capabilities_password_login(self): + def test_get_change_password_capabilities_password_login(self) -> None: access_token = self.login(self.localpart, self.password) channel = self.make_request("GET", self.url, access_token=access_token) @@ -71,7 +75,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertTrue(capabilities["m.change_password"]["enabled"]) @override_config({"password_config": {"localdb_enabled": False}}) - def test_get_change_password_capabilities_localdb_disabled(self): + def test_get_change_password_capabilities_localdb_disabled(self) -> None: access_token = self.get_success( self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None @@ -85,7 +89,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertFalse(capabilities["m.change_password"]["enabled"]) @override_config({"password_config": {"enabled": False}}) - def test_get_change_password_capabilities_password_disabled(self): + def test_get_change_password_capabilities_password_disabled(self) -> None: access_token = self.get_success( self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None @@ -98,7 +102,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertFalse(capabilities["m.change_password"]["enabled"]) - def test_get_change_users_attributes_capabilities(self): + def test_get_change_users_attributes_capabilities(self) -> None: """Test that server returns capabilities by default.""" access_token = self.login(self.localpart, self.password) @@ -112,7 +116,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertTrue(capabilities["m.3pid_changes"]["enabled"]) @override_config({"enable_set_displayname": False}) - def test_get_set_displayname_capabilities_displayname_disabled(self): + def test_get_set_displayname_capabilities_displayname_disabled(self) -> None: """Test if set displayname is disabled that the server responds it.""" access_token = self.login(self.localpart, self.password) @@ -123,7 +127,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertFalse(capabilities["m.set_displayname"]["enabled"]) @override_config({"enable_set_avatar_url": False}) - def test_get_set_avatar_url_capabilities_avatar_url_disabled(self): + def test_get_set_avatar_url_capabilities_avatar_url_disabled(self) -> None: """Test if set avatar_url is disabled that the server responds it.""" access_token = self.login(self.localpart, self.password) @@ -134,7 +138,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertFalse(capabilities["m.set_avatar_url"]["enabled"]) @override_config({"enable_3pid_changes": False}) - def test_get_change_3pid_capabilities_3pid_disabled(self): + def test_get_change_3pid_capabilities_3pid_disabled(self) -> None: """Test if change 3pid is disabled that the server responds it.""" access_token = self.login(self.localpart, self.password) @@ -145,7 +149,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertFalse(capabilities["m.3pid_changes"]["enabled"]) @override_config({"experimental_features": {"msc3244_enabled": False}}) - def test_get_does_not_include_msc3244_fields_when_disabled(self): + def test_get_does_not_include_msc3244_fields_when_disabled(self) -> None: access_token = self.get_success( self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None @@ -160,7 +164,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): "org.matrix.msc3244.room_capabilities", capabilities["m.room_versions"] ) - def test_get_does_include_msc3244_fields_when_enabled(self): + def test_get_does_include_msc3244_fields_when_enabled(self) -> None: access_token = self.get_success( self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 26d0d83e0..d48defda6 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -20,6 +20,7 @@ from urllib.parse import urlencode import pymacaroons +from twisted.test.proto_helpers import MemoryReactor from twisted.web.resource import Resource import synapse.rest.admin @@ -27,12 +28,15 @@ from synapse.appservice import ApplicationService from synapse.rest.client import devices, login, logout, register from synapse.rest.client.account import WhoamiRestServlet from synapse.rest.synapse.client import build_synapse_client_resource_tree +from synapse.server import HomeServer from synapse.types import create_requester +from synapse.util import Clock from tests import unittest from tests.handlers.test_oidc import HAS_OIDC from tests.handlers.test_saml import has_saml2 from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG +from tests.server import FakeChannel from tests.test_utils.html_parsers import TestHtmlParser from tests.unittest import HomeserverTestCase, override_config, skip_unless @@ -95,7 +99,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): lambda hs, http_server: WhoamiRestServlet(hs).register(http_server), ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.hs = self.setup_test_homeserver() self.hs.config.registration.enable_registration = True self.hs.config.registration.registrations_require_3pid = [] @@ -117,7 +121,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } } ) - def test_POST_ratelimiting_per_address(self): + def test_POST_ratelimiting_per_address(self) -> None: # Create different users so we're sure not to be bothered by the per-user # ratelimiter. for i in range(0, 6): @@ -165,7 +169,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } } ) - def test_POST_ratelimiting_per_account(self): + def test_POST_ratelimiting_per_account(self) -> None: self.register_user("kermit", "monkey") for i in range(0, 6): @@ -210,7 +214,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } } ) - def test_POST_ratelimiting_per_account_failed_attempts(self): + def test_POST_ratelimiting_per_account_failed_attempts(self) -> None: self.register_user("kermit", "monkey") for i in range(0, 6): @@ -243,7 +247,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"403", channel.result) @override_config({"session_lifetime": "24h"}) - def test_soft_logout(self): + def test_soft_logout(self) -> None: self.register_user("kermit", "monkey") # we shouldn't be able to make requests without an access token @@ -298,7 +302,9 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEquals(channel.json_body["soft_logout"], False) - def _delete_device(self, access_token, user_id, password, device_id): + def _delete_device( + self, access_token: str, user_id: str, password: str, device_id: str + ) -> None: """Perform the UI-Auth to delete a device""" channel = self.make_request( b"DELETE", "devices/" + device_id, access_token=access_token @@ -329,7 +335,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.code, 200, channel.result) @override_config({"session_lifetime": "24h"}) - def test_session_can_hard_logout_after_being_soft_logged_out(self): + def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None: self.register_user("kermit", "monkey") # log in as normal @@ -353,7 +359,9 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) @override_config({"session_lifetime": "24h"}) - def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(self): + def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out( + self, + ) -> None: self.register_user("kermit", "monkey") # log in as normal @@ -432,7 +440,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): d.update(build_synapse_client_resource_tree(self.hs)) return d - def test_get_login_flows(self): + def test_get_login_flows(self) -> None: """GET /login should return password and SSO flows""" channel = self.make_request("GET", "/_matrix/client/r0/login") self.assertEqual(channel.code, 200, channel.result) @@ -459,12 +467,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): ], ) - def test_multi_sso_redirect(self): + def test_multi_sso_redirect(self) -> None: """/login/sso/redirect should redirect to an identity picker""" # first hit the redirect url, which should redirect to our idp picker channel = self._make_sso_redirect_request(None) self.assertEqual(channel.code, 302, channel.result) - uri = channel.headers.getRawHeaders("Location")[0] + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + uri = location_headers[0] # hitting that picker should give us some HTML channel = self.make_request("GET", uri) @@ -487,7 +497,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"]) - def test_multi_sso_redirect_to_cas(self): + def test_multi_sso_redirect_to_cas(self) -> None: """If CAS is chosen, should redirect to the CAS server""" channel = self.make_request( @@ -514,7 +524,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): service_uri_params = urllib.parse.parse_qs(service_uri_query) self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL) - def test_multi_sso_redirect_to_saml(self): + def test_multi_sso_redirect_to_saml(self) -> None: """If SAML is chosen, should redirect to the SAML server""" channel = self.make_request( "GET", @@ -536,7 +546,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): relay_state_param = saml_uri_params["RelayState"][0] self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL) - def test_login_via_oidc(self): + def test_login_via_oidc(self) -> None: """If OIDC is chosen, should redirect to the OIDC auth endpoint""" # pick the default OIDC provider @@ -604,7 +614,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): self.assertEqual(chan.code, 200, chan.result) self.assertEqual(chan.json_body["user_id"], "@user1:test") - def test_multi_sso_redirect_to_unknown(self): + def test_multi_sso_redirect_to_unknown(self) -> None: """An unknown IdP should cause a 400""" channel = self.make_request( "GET", @@ -612,23 +622,25 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 400, channel.result) - def test_client_idp_redirect_to_unknown(self): + def test_client_idp_redirect_to_unknown(self) -> None: """If the client tries to pick an unknown IdP, return a 404""" channel = self._make_sso_redirect_request("xxx") self.assertEqual(channel.code, 404, channel.result) self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") - def test_client_idp_redirect_to_oidc(self): + def test_client_idp_redirect_to_oidc(self) -> None: """If the client pick a known IdP, redirect to it""" channel = self._make_sso_redirect_request("oidc") self.assertEqual(channel.code, 302, channel.result) - oidc_uri = channel.headers.getRawHeaders("Location")[0] + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + oidc_uri = location_headers[0] oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) # it should redirect us to the auth page of the OIDC server self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) - def _make_sso_redirect_request(self, idp_prov: Optional[str] = None): + def _make_sso_redirect_request(self, idp_prov: Optional[str] = None) -> FakeChannel: """Send a request to /_matrix/client/r0/login/sso/redirect ... possibly specifying an IDP provider @@ -659,7 +671,7 @@ class CASTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.base_url = "https://matrix.goodserver.com/" self.redirect_path = "_synapse/client/login/sso/redirect/confirm" @@ -675,7 +687,7 @@ class CASTestCase(unittest.HomeserverTestCase): cas_user_id = "username" self.user_id = "@%s:test" % cas_user_id - async def get_raw(uri, args): + async def get_raw(uri: str, args: Any) -> bytes: """Return an example response payload from a call to the `/proxyValidate` endpoint of a CAS server, copied from https://apereo.github.io/cas/5.0.x/protocol/CAS-Protocol-V2-Specification.html#26-proxyvalidate-cas-20 @@ -709,10 +721,10 @@ class CASTestCase(unittest.HomeserverTestCase): return self.hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.deactivate_account_handler = hs.get_deactivate_account_handler() - def test_cas_redirect_confirm(self): + def test_cas_redirect_confirm(self) -> None: """Tests that the SSO login flow serves a confirmation page before redirecting a user to the redirect URL. """ @@ -754,15 +766,15 @@ class CASTestCase(unittest.HomeserverTestCase): } } ) - def test_cas_redirect_whitelisted(self): + def test_cas_redirect_whitelisted(self) -> None: """Tests that the SSO login flow serves a redirect to a whitelisted url""" self._test_redirect("https://legit-site.com/") @override_config({"public_baseurl": "https://example.com"}) - def test_cas_redirect_login_fallback(self): + def test_cas_redirect_login_fallback(self) -> None: self._test_redirect("https://example.com/_matrix/static/client/login") - def _test_redirect(self, redirect_url): + def _test_redirect(self, redirect_url: str) -> None: """Tests that the SSO login flow serves a redirect for the given redirect URL.""" cas_ticket_url = ( "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" @@ -778,7 +790,7 @@ class CASTestCase(unittest.HomeserverTestCase): self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url) @override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}}) - def test_deactivated_user(self): + def test_deactivated_user(self) -> None: """Logging in as a deactivated account should error.""" redirect_url = "https://legit-site.com/" @@ -821,7 +833,7 @@ class JWTTestCase(unittest.HomeserverTestCase): "algorithm": jwt_algorithm, } - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() # If jwt_config has been defined (eg via @override_config), don't replace it. @@ -837,23 +849,23 @@ class JWTTestCase(unittest.HomeserverTestCase): return result.decode("ascii") return result - def jwt_login(self, *args): + def jwt_login(self, *args: Any) -> FakeChannel: params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} channel = self.make_request(b"POST", LOGIN_URL, params) return channel - def test_login_jwt_valid_registered(self): + def test_login_jwt_valid_registered(self) -> None: self.register_user("kermit", "monkey") channel = self.jwt_login({"sub": "kermit"}) self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") - def test_login_jwt_valid_unregistered(self): + def test_login_jwt_valid_unregistered(self) -> None: channel = self.jwt_login({"sub": "frog"}) self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@frog:test") - def test_login_jwt_invalid_signature(self): + def test_login_jwt_invalid_signature(self) -> None: channel = self.jwt_login({"sub": "frog"}, "notsecret") self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -862,7 +874,7 @@ class JWTTestCase(unittest.HomeserverTestCase): "JWT validation failed: Signature verification failed", ) - def test_login_jwt_expired(self): + def test_login_jwt_expired(self) -> None: channel = self.jwt_login({"sub": "frog", "exp": 864000}) self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -870,7 +882,7 @@ class JWTTestCase(unittest.HomeserverTestCase): channel.json_body["error"], "JWT validation failed: Signature has expired" ) - def test_login_jwt_not_before(self): + def test_login_jwt_not_before(self) -> None: now = int(time.time()) channel = self.jwt_login({"sub": "frog", "nbf": now + 3600}) self.assertEqual(channel.result["code"], b"403", channel.result) @@ -880,14 +892,14 @@ class JWTTestCase(unittest.HomeserverTestCase): "JWT validation failed: The token is not yet valid (nbf)", ) - def test_login_no_sub(self): + def test_login_no_sub(self) -> None: channel = self.jwt_login({"username": "root"}) self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Invalid JWT") @override_config({"jwt_config": {**base_config, "issuer": "test-issuer"}}) - def test_login_iss(self): + def test_login_iss(self) -> None: """Test validating the issuer claim.""" # A valid issuer. channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"}) @@ -911,14 +923,14 @@ class JWTTestCase(unittest.HomeserverTestCase): 'JWT validation failed: Token is missing the "iss" claim', ) - def test_login_iss_no_config(self): + def test_login_iss_no_config(self) -> None: """Test providing an issuer claim without requiring it in the configuration.""" channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") @override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}}) - def test_login_aud(self): + def test_login_aud(self) -> None: """Test validating the audience claim.""" # A valid audience. channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"}) @@ -942,7 +954,7 @@ class JWTTestCase(unittest.HomeserverTestCase): 'JWT validation failed: Token is missing the "aud" claim', ) - def test_login_aud_no_config(self): + def test_login_aud_no_config(self) -> None: """Test providing an audience without requiring it in the configuration.""" channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) self.assertEqual(channel.result["code"], b"403", channel.result) @@ -951,20 +963,20 @@ class JWTTestCase(unittest.HomeserverTestCase): channel.json_body["error"], "JWT validation failed: Invalid audience" ) - def test_login_default_sub(self): + def test_login_default_sub(self) -> None: """Test reading user ID from the default subject claim.""" channel = self.jwt_login({"sub": "kermit"}) self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") @override_config({"jwt_config": {**base_config, "subject_claim": "username"}}) - def test_login_custom_sub(self): + def test_login_custom_sub(self) -> None: """Test reading user ID from a custom subject claim.""" channel = self.jwt_login({"username": "frog"}) self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@frog:test") - def test_login_no_token(self): + def test_login_no_token(self) -> None: params = {"type": "org.matrix.login.jwt"} channel = self.make_request(b"POST", LOGIN_URL, params) self.assertEqual(channel.result["code"], b"403", channel.result) @@ -1026,7 +1038,7 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): ] ) - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["jwt_config"] = { "enabled": True, @@ -1042,17 +1054,17 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): return result.decode("ascii") return result - def jwt_login(self, *args): + def jwt_login(self, *args: Any) -> FakeChannel: params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} channel = self.make_request(b"POST", LOGIN_URL, params) return channel - def test_login_jwt_valid(self): + def test_login_jwt_valid(self) -> None: channel = self.jwt_login({"sub": "kermit"}) self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") - def test_login_jwt_invalid_signature(self): + def test_login_jwt_invalid_signature(self) -> None: channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey) self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -1071,7 +1083,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): register.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.hs = self.setup_test_homeserver() self.service = ApplicationService( @@ -1105,7 +1117,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.hs.get_datastores().main.services_cache.append(self.another_service) return self.hs - def test_login_appservice_user(self): + def test_login_appservice_user(self) -> None: """Test that an appservice user can use /login""" self.register_appservice_user(AS_USER, self.service.token) @@ -1119,7 +1131,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) - def test_login_appservice_user_bot(self): + def test_login_appservice_user_bot(self) -> None: """Test that the appservice bot can use /login""" self.register_appservice_user(AS_USER, self.service.token) @@ -1133,7 +1145,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) - def test_login_appservice_wrong_user(self): + def test_login_appservice_wrong_user(self) -> None: """Test that non-as users cannot login with the as token""" self.register_appservice_user(AS_USER, self.service.token) @@ -1147,7 +1159,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"403", channel.result) - def test_login_appservice_wrong_as(self): + def test_login_appservice_wrong_as(self) -> None: """Test that as users cannot login with wrong as token""" self.register_appservice_user(AS_USER, self.service.token) @@ -1161,7 +1173,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"403", channel.result) - def test_login_appservice_no_token(self): + def test_login_appservice_no_token(self) -> None: """Test that users must provide a token when using the appservice login method """ @@ -1182,7 +1194,7 @@ class UsernamePickerTestCase(HomeserverTestCase): servlets = [login.register_servlets] - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["public_baseurl"] = BASE_URL @@ -1202,7 +1214,7 @@ class UsernamePickerTestCase(HomeserverTestCase): d.update(build_synapse_client_resource_tree(self.hs)) return d - def test_username_picker(self): + def test_username_picker(self) -> None: """Test the happy path of a username picker flow.""" # do the start of the login flow diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index e06256136..69b4ef537 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -13,9 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +from typing import List, Optional from parameterized import parameterized +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.constants import ( EventContentFields, @@ -24,6 +27,9 @@ from synapse.api.constants import ( RelationTypes, ) from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest from tests.federation.transport.test_knocking import ( @@ -43,7 +49,7 @@ class FilterTestCase(unittest.HomeserverTestCase): sync.register_servlets, ] - def test_sync_argless(self): + def test_sync_argless(self) -> None: channel = self.make_request("GET", "/sync") self.assertEqual(channel.code, 200) @@ -58,7 +64,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): sync.register_servlets, ] - def test_sync_filter_labels(self): + def test_sync_filter_labels(self) -> None: """Test that we can filter by a label.""" sync_filter = json.dumps( { @@ -77,7 +83,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) - def test_sync_filter_not_labels(self): + def test_sync_filter_not_labels(self) -> None: """Test that we can filter by the absence of a label.""" sync_filter = json.dumps( { @@ -99,7 +105,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): events[2]["content"]["body"], "with two wrong labels", events[2] ) - def test_sync_filter_labels_not_labels(self): + def test_sync_filter_labels_not_labels(self) -> None: """Test that we can filter by both a label and the absence of another label.""" sync_filter = json.dumps( { @@ -118,7 +124,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): self.assertEqual(len(events), 1, [event["content"] for event in events]) self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) - def _test_sync_filter_labels(self, sync_filter): + def _test_sync_filter_labels(self, sync_filter: str) -> List[JsonDict]: user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test") @@ -194,7 +200,7 @@ class SyncTypingTests(unittest.HomeserverTestCase): user_id = True hijack_auth = False - def test_sync_backwards_typing(self): + def test_sync_backwards_typing(self) -> None: """ If the typing serial goes backwards and the typing handler is then reset (such as when the master restarts and sets the typing serial to 0), we @@ -298,7 +304,7 @@ class SyncKnockTestCase( knock.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.url = "/sync?since=%s" self.next_batch = "s0" @@ -336,7 +342,7 @@ class SyncKnockTestCase( ) @override_config({"experimental_features": {"msc2403_enabled": True}}) - def test_knock_room_state(self): + def test_knock_room_state(self) -> None: """Tests that /sync returns state from a room after knocking on it.""" # Knock on a room channel = self.make_request( @@ -383,7 +389,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): sync.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.url = "/sync?since=%s" self.next_batch = "s0" @@ -402,7 +408,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) @override_config({"experimental_features": {"msc2285_enabled": True}}) - def test_hidden_read_receipts(self): + def test_hidden_read_receipts(self) -> None: # Send a message as the first user res = self.helper.send(self.room_id, body="hello", tok=self.tok) @@ -441,8 +447,8 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): ] ) def test_read_receipt_with_empty_body( - self, name, user_agent: str, expected_status_code: int - ): + self, name: str, user_agent: str, expected_status_code: int + ) -> None: # Send a message as the first user res = self.helper.send(self.room_id, body="hello", tok=self.tok) @@ -455,11 +461,11 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, expected_status_code) - def _get_read_receipt(self): + def _get_read_receipt(self) -> Optional[JsonDict]: """Syncs and returns the read receipt.""" # Checks if event is a read receipt - def is_read_receipt(event): + def is_read_receipt(event: JsonDict) -> bool: return event["type"] == "m.receipt" # Sync @@ -477,7 +483,8 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): ephemeral_events = channel.json_body["rooms"]["join"][self.room_id][ "ephemeral" ]["events"] - return next(filter(is_read_receipt, ephemeral_events), None) + receipt_event = filter(is_read_receipt, ephemeral_events) + return next(receipt_event, None) class UnreadMessagesTestCase(unittest.HomeserverTestCase): @@ -490,7 +497,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): receipts.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.url = "/sync?since=%s" self.next_batch = "s0" @@ -533,7 +540,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): tok=self.tok, ) - def test_unread_counts(self): + def test_unread_counts(self) -> None: """Tests that /sync returns the right value for the unread count (MSC2654).""" # Check that our own messages don't increase the unread count. @@ -640,7 +647,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): ) self._check_unread_count(5) - def _check_unread_count(self, expected_count: int): + def _check_unread_count(self, expected_count: int) -> None: """Syncs and compares the unread count with the expected value.""" channel = self.make_request( @@ -669,7 +676,7 @@ class SyncCacheTestCase(unittest.HomeserverTestCase): sync.register_servlets, ] - def test_noop_sync_does_not_tightloop(self): + def test_noop_sync_does_not_tightloop(self) -> None: """If the sync times out, we shouldn't cache the result Essentially a regression test for #8518. @@ -720,7 +727,7 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase): devices.register_servlets, ] - def test_user_with_no_rooms_receives_self_device_list_updates(self): + def test_user_with_no_rooms_receives_self_device_list_updates(self) -> None: """Tests that a user with no rooms still receives their own device list updates""" device_id = "TESTDEVICE" From a711ae78a8f8ba406ff122035c8bf096fac9a26c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 23 Feb 2022 14:22:22 +0000 Subject: [PATCH 52/84] Add logging to `/sync` for debugging #11916 (#12068) --- changelog.d/12068.misc | 1 + synapse/handlers/sync.py | 9 +++++++++ 2 files changed, 10 insertions(+) create mode 100644 changelog.d/12068.misc diff --git a/changelog.d/12068.misc b/changelog.d/12068.misc new file mode 100644 index 000000000..72b211e4f --- /dev/null +++ b/changelog.d/12068.misc @@ -0,0 +1 @@ +Add some logging to `/sync` to try and track down #11916. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 98eaad331..0aa3052fd 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -697,6 +697,15 @@ class SyncHandler: else: # no events in this room - so presumably no state state = {} + + # (erikj) This should be rarely hit, but we've had some reports that + # we get more state down gappy syncs than we should, so let's add + # some logging. + logger.info( + "Failed to find any events in room %s at %s", + room_id, + stream_position.room_key, + ) return state async def compute_summary( From c56bfb08bc071368db23f3b1c593724eb4f205f0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 23 Feb 2022 17:49:04 -0500 Subject: [PATCH 53/84] Add documentation for missing worker types. (#11599) And clean-up the endpoints which should be routed to workers. --- changelog.d/11599.doc | 1 + docs/workers.md | 90 ++++++++++++++++++++++++++++++++++++------- 2 files changed, 77 insertions(+), 14 deletions(-) create mode 100644 changelog.d/11599.doc diff --git a/changelog.d/11599.doc b/changelog.d/11599.doc new file mode 100644 index 000000000..f07cfbef4 --- /dev/null +++ b/changelog.d/11599.doc @@ -0,0 +1 @@ +Document support for the `to_device`, `account_data`, `receipts`, and `presence` stream writers for workers. diff --git a/docs/workers.md b/docs/workers.md index dadde4d72..b82a6900a 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -178,8 +178,11 @@ recommend the use of `systemd` where available: for information on setting up ### `synapse.app.generic_worker` -This worker can handle API requests matching the following regular -expressions: +This worker can handle API requests matching the following regular expressions. +These endpoints can be routed to any worker. If a worker is set up to handle a +stream then, for maximum efficiency, additional endpoints should be routed to that +worker: refer to the [stream writers](#stream-writers) section below for further +information. # Sync requests ^/_matrix/client/(v2_alpha|r0|v3)/sync$ @@ -225,19 +228,23 @@ expressions: ^/_matrix/client/unstable/org.matrix.msc2946/rooms/.*/spaces$ ^/_matrix/client/(v1|unstable/org.matrix.msc2946)/rooms/.*/hierarchy$ ^/_matrix/client/unstable/im.nheko.summary/rooms/.*/summary$ - ^/_matrix/client/(api/v1|r0|v3|unstable)/account/3pid$ - ^/_matrix/client/(api/v1|r0|v3|unstable)/devices$ - ^/_matrix/client/(api/v1|r0|v3|unstable)/keys/query$ - ^/_matrix/client/(api/v1|r0|v3|unstable)/keys/changes$ + ^/_matrix/client/(r0|v3|unstable)/account/3pid$ + ^/_matrix/client/(r0|v3|unstable)/devices$ ^/_matrix/client/versions$ ^/_matrix/client/(api/v1|r0|v3|unstable)/voip/turnServer$ - ^/_matrix/client/(api/v1|r0|v3|unstable)/joined_groups$ - ^/_matrix/client/(api/v1|r0|v3|unstable)/publicised_groups$ - ^/_matrix/client/(api/v1|r0|v3|unstable)/publicised_groups/ + ^/_matrix/client/(r0|v3|unstable)/joined_groups$ + ^/_matrix/client/(r0|v3|unstable)/publicised_groups$ + ^/_matrix/client/(r0|v3|unstable)/publicised_groups/ ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/event/ ^/_matrix/client/(api/v1|r0|v3|unstable)/joined_rooms$ ^/_matrix/client/(api/v1|r0|v3|unstable)/search$ + # Encryption requests + ^/_matrix/client/(r0|v3|unstable)/keys/query$ + ^/_matrix/client/(r0|v3|unstable)/keys/changes$ + ^/_matrix/client/(r0|v3|unstable)/keys/claim$ + ^/_matrix/client/(r0|v3|unstable)/room_keys/ + # Registration/login requests ^/_matrix/client/(api/v1|r0|v3|unstable)/login$ ^/_matrix/client/(r0|v3|unstable)/register$ @@ -251,6 +258,20 @@ expressions: ^/_matrix/client/(api/v1|r0|v3|unstable)/join/ ^/_matrix/client/(api/v1|r0|v3|unstable)/profile/ + # Device requests + ^/_matrix/client/(r0|v3|unstable)/sendToDevice/ + + # Account data requests + ^/_matrix/client/(r0|v3|unstable)/.*/tags + ^/_matrix/client/(r0|v3|unstable)/.*/account_data + + # Receipts requests + ^/_matrix/client/(r0|v3|unstable)/rooms/.*/receipt + ^/_matrix/client/(r0|v3|unstable)/rooms/.*/read_markers + + # Presence requests + ^/_matrix/client/(api/v1|r0|v3|unstable)/presence/ + Additionally, the following REST endpoints can be handled for GET requests: @@ -330,12 +351,10 @@ Additionally, there is *experimental* support for moving writing of specific streams (such as events) off of the main process to a particular worker. (This is only supported with Redis-based replication.) -Currently supported streams are `events` and `typing`. - To enable this, the worker must have a HTTP replication listener configured, -have a `worker_name` and be listed in the `instance_map` config. For example to -move event persistence off to a dedicated worker, the shared configuration would -include: +have a `worker_name` and be listed in the `instance_map` config. The same worker +can handle multiple streams. For example, to move event persistence off to a +dedicated worker, the shared configuration would include: ```yaml instance_map: @@ -347,6 +366,12 @@ stream_writers: events: event_persister1 ``` +Some of the streams have associated endpoints which, for maximum efficiency, should +be routed to the workers handling that stream. See below for the currently supported +streams and the endpoints associated with them: + +##### The `events` stream + The `events` stream also experimentally supports having multiple writers, where work is sharded between them by room ID. Note that you *must* restart all worker instances when adding or removing event persisters. An example `stream_writers` @@ -359,6 +384,43 @@ stream_writers: - event_persister2 ``` +##### The `typing` stream + +The following endpoints should be routed directly to the workers configured as +stream writers for the `typing` stream: + + ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/typing + +##### The `to_device` stream + +The following endpoints should be routed directly to the workers configured as +stream writers for the `to_device` stream: + + ^/_matrix/client/(api/v1|r0|v3|unstable)/sendToDevice/ + +##### The `account_data` stream + +The following endpoints should be routed directly to the workers configured as +stream writers for the `account_data` stream: + + ^/_matrix/client/(api/v1|r0|v3|unstable)/.*/tags + ^/_matrix/client/(api/v1|r0|v3|unstable)/.*/account_data + +##### The `receipts` stream + +The following endpoints should be routed directly to the workers configured as +stream writers for the `receipts` stream: + + ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/receipt + ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/read_markers + +##### The `presence` stream + +The following endpoints should be routed directly to the workers configured as +stream writers for the `presence` stream: + + ^/_matrix/client/(api/v1|r0|v3|unstable)/presence/ + #### Background tasks There is also *experimental* support for moving background tasks to a separate From 41cf4c2cf6432336cc7477f130a2847449cff99a Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Thu, 24 Feb 2022 11:52:28 +0000 Subject: [PATCH 54/84] Fix non-strings in the `event_search` table (#12037) Don't attempt to add non-string `value`s to `event_search` and add a background update to clear out bad rows from `event_search` when using sqlite. Signed-off-by: Sean Quah --- changelog.d/12037.bugfix | 1 + synapse/storage/databases/main/events.py | 18 +-- synapse/storage/databases/main/search.py | 26 ++++ ...e_non_strings_from_event_search.sql.sqlite | 22 ++++ tests/storage/test_room_search.py | 117 +++++++++++++++++- 5 files changed, 173 insertions(+), 11 deletions(-) create mode 100644 changelog.d/12037.bugfix create mode 100644 synapse/storage/schema/main/delta/68/05_delete_non_strings_from_event_search.sql.sqlite diff --git a/changelog.d/12037.bugfix b/changelog.d/12037.bugfix new file mode 100644 index 000000000..9295cb4dc --- /dev/null +++ b/changelog.d/12037.bugfix @@ -0,0 +1 @@ +Properly fix a long-standing bug where wrong data could be inserted in the `event_search` table when using sqlite. This could block running `synapse_port_db` with an "argument of type 'int' is not iterable" error. This bug was partially fixed by a change in Synapse 1.44.0. diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index a1d7a9b41..e53e84054 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1473,10 +1473,10 @@ class PersistEventsStore: def _update_metadata_tables_txn( self, - txn, + txn: LoggingTransaction, *, - events_and_contexts, - all_events_and_contexts, + events_and_contexts: List[Tuple[EventBase, EventContext]], + all_events_and_contexts: List[Tuple[EventBase, EventContext]], inhibit_local_membership_updates: bool = False, ): """Update all the miscellaneous tables for new events @@ -1953,20 +1953,20 @@ class PersistEventsStore: txn, table="event_relations", keyvalues={"event_id": redacted_event_id} ) - def _store_room_topic_txn(self, txn, event): - if hasattr(event, "content") and "topic" in event.content: + def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase): + if isinstance(event.content.get("topic"), str): self.store_event_search_txn( txn, event, "content.topic", event.content["topic"] ) - def _store_room_name_txn(self, txn, event): - if hasattr(event, "content") and "name" in event.content: + def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase): + if isinstance(event.content.get("name"), str): self.store_event_search_txn( txn, event, "content.name", event.content["name"] ) - def _store_room_message_txn(self, txn, event): - if hasattr(event, "content") and "body" in event.content: + def _store_room_message_txn(self, txn: LoggingTransaction, event: EventBase): + if isinstance(event.content.get("body"), str): self.store_event_search_txn( txn, event, "content.body", event.content["body"] ) diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index acea300ed..e23b11907 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -115,6 +115,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" + EVENT_SEARCH_DELETE_NON_STRINGS = "event_search_sqlite_delete_non_strings" def __init__( self, @@ -147,6 +148,10 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search ) + self.db_pool.updates.register_background_update_handler( + self.EVENT_SEARCH_DELETE_NON_STRINGS, self._background_delete_non_strings + ) + async def _background_reindex_search(self, progress, batch_size): # we work through the events table from highest stream id to lowest target_min_stream_id = progress["target_min_stream_id_inclusive"] @@ -372,6 +377,27 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): return num_rows + async def _background_delete_non_strings( + self, progress: JsonDict, batch_size: int + ) -> int: + """Deletes rows with non-string `value`s from `event_search` if using sqlite. + + Prior to Synapse 1.44.0, malformed events received over federation could cause integers + to be inserted into the `event_search` table when using sqlite. + """ + + def delete_non_strings_txn(txn: LoggingTransaction) -> None: + txn.execute("DELETE FROM event_search WHERE typeof(value) != 'text'") + + await self.db_pool.runInteraction( + self.EVENT_SEARCH_DELETE_NON_STRINGS, delete_non_strings_txn + ) + + await self.db_pool.updates._end_background_update( + self.EVENT_SEARCH_DELETE_NON_STRINGS + ) + return 1 + class SearchStore(SearchBackgroundUpdateStore): def __init__( diff --git a/synapse/storage/schema/main/delta/68/05_delete_non_strings_from_event_search.sql.sqlite b/synapse/storage/schema/main/delta/68/05_delete_non_strings_from_event_search.sql.sqlite new file mode 100644 index 000000000..140df6526 --- /dev/null +++ b/synapse/storage/schema/main/delta/68/05_delete_non_strings_from_event_search.sql.sqlite @@ -0,0 +1,22 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- Delete rows with non-string `value`s from `event_search` if using sqlite. +-- +-- Prior to Synapse 1.44.0, malformed events received over federation could +-- cause integers to be inserted into the `event_search` table. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (6805, 'event_search_sqlite_delete_non_strings', '{}'); diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py index befaa0fce..d62e01726 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py @@ -13,13 +13,16 @@ # limitations under the License. import synapse.rest.admin +from synapse.api.constants import EventTypes +from synapse.api.errors import StoreError from synapse.rest.client import login, room from synapse.storage.engines import PostgresEngine -from tests.unittest import HomeserverTestCase +from tests.unittest import HomeserverTestCase, skip_unless +from tests.utils import USE_POSTGRES_FOR_TESTS -class NullByteInsertionTest(HomeserverTestCase): +class EventSearchInsertionTest(HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -72,3 +75,113 @@ class NullByteInsertionTest(HomeserverTestCase): ) if isinstance(store.database_engine, PostgresEngine): self.assertIn("alice", result.get("highlights")) + + def test_non_string(self): + """Test that non-string `value`s are not inserted into `event_search`. + + This is particularly important when using sqlite, since a sqlite column can hold + both strings and integers. When using Postgres, integers are automatically + converted to strings. + + Regression test for #11918. + """ + store = self.hs.get_datastores().main + + # Register a user and create a room + user_id = self.register_user("alice", "password") + access_token = self.login("alice", "password") + room_id = self.helper.create_room_as("alice", tok=access_token) + room_version = self.get_success(store.get_room_version(room_id)) + + # Construct a message with a numeric body to be received over federation + # The message can't be sent using the client API, since Synapse's event + # validation will reject it. + prev_event_ids = self.get_success(store.get_prev_events_for_room(room_id)) + prev_event = self.get_success(store.get_event(prev_event_ids[0])) + prev_state_map = self.get_success( + self.hs.get_storage().state.get_state_ids_for_event(prev_event_ids[0]) + ) + + event_dict = { + "type": EventTypes.Message, + "content": {"msgtype": "m.text", "body": 2}, + "room_id": room_id, + "sender": user_id, + "depth": prev_event.depth + 1, + "prev_events": prev_event_ids, + "origin_server_ts": self.clock.time_msec(), + } + builder = self.hs.get_event_builder_factory().for_room_version( + room_version, event_dict + ) + event = self.get_success( + builder.build( + prev_event_ids=prev_event_ids, + auth_event_ids=self.hs.get_event_auth_handler().compute_auth_events( + builder, + prev_state_map, + for_verification=False, + ), + depth=event_dict["depth"], + ) + ) + + # Receive the event + self.get_success( + self.hs.get_federation_event_handler().on_receive_pdu( + self.hs.hostname, event + ) + ) + + # The event should not have an entry in the `event_search` table + f = self.get_failure( + store.db_pool.simple_select_one_onecol( + "event_search", + {"room_id": room_id, "event_id": event.event_id}, + "event_id", + ), + StoreError, + ) + self.assertEqual(f.value.code, 404) + + @skip_unless(not USE_POSTGRES_FOR_TESTS, "requires sqlite") + def test_sqlite_non_string_deletion_background_update(self): + """Test the background update to delete bad rows from `event_search`.""" + store = self.hs.get_datastores().main + + # Populate `event_search` with dummy data + self.get_success( + store.db_pool.simple_insert_many( + "event_search", + keys=["event_id", "room_id", "key", "value"], + values=[ + ("event1", "room_id", "content.body", "hi"), + ("event2", "room_id", "content.body", "2"), + ("event3", "room_id", "content.body", 3), + ], + desc="populate_event_search", + ) + ) + + # Run the background update + store.db_pool.updates._all_done = False + self.get_success( + store.db_pool.simple_insert( + "background_updates", + { + "update_name": "event_search_sqlite_delete_non_strings", + "progress_json": "{}", + }, + ) + ) + self.wait_for_background_updates() + + # The non-string `value`s ought to be gone now. + values = self.get_success( + store.db_pool.simple_select_onecol( + "event_search", + {"room_id": "room_id"}, + "value", + ), + ) + self.assertCountEqual(values, ["hi", "2"]) From 2cc5ea933dbe65445e3711bb3f05022b007029ea Mon Sep 17 00:00:00 2001 From: reivilibre Date: Thu, 24 Feb 2022 17:55:45 +0000 Subject: [PATCH 55/84] Add support for MSC3202: sending one-time key counts and fallback key usage states to Application Services. (#11617) Co-authored-by: Erik Johnston --- changelog.d/11617.feature | 1 + synapse/appservice/__init__.py | 16 ++ synapse/appservice/api.py | 20 +- synapse/appservice/scheduler.py | 98 ++++++++- synapse/config/appservice.py | 13 +- synapse/config/experimental.py | 16 +- synapse/storage/databases/main/appservice.py | 33 ++- .../storage/databases/main/end_to_end_keys.py | 112 ++++++++++ tests/appservice/test_scheduler.py | 55 +++-- tests/handlers/test_appservice.py | 194 +++++++++++++++++- tests/storage/test_appservice.py | 8 +- 11 files changed, 528 insertions(+), 38 deletions(-) create mode 100644 changelog.d/11617.feature diff --git a/changelog.d/11617.feature b/changelog.d/11617.feature new file mode 100644 index 000000000..cf03f00e7 --- /dev/null +++ b/changelog.d/11617.feature @@ -0,0 +1 @@ +Add support for MSC3202: sending one-time key counts and fallback key usage states to Application Services. \ No newline at end of file diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index a340a8c9c..4d3f8e492 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -31,6 +31,14 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +# Type for the `device_one_time_key_counts` field in an appservice transaction +# user ID -> {device ID -> {algorithm -> count}} +TransactionOneTimeKeyCounts = Dict[str, Dict[str, Dict[str, int]]] + +# Type for the `device_unused_fallback_keys` field in an appservice transaction +# user ID -> {device ID -> [algorithm]} +TransactionUnusedFallbackKeys = Dict[str, Dict[str, List[str]]] + class ApplicationServiceState(Enum): DOWN = "down" @@ -72,6 +80,7 @@ class ApplicationService: rate_limited: bool = True, ip_range_whitelist: Optional[IPSet] = None, supports_ephemeral: bool = False, + msc3202_transaction_extensions: bool = False, ): self.token = token self.url = ( @@ -84,6 +93,7 @@ class ApplicationService: self.id = id self.ip_range_whitelist = ip_range_whitelist self.supports_ephemeral = supports_ephemeral + self.msc3202_transaction_extensions = msc3202_transaction_extensions if "|" in self.id: raise Exception("application service ID cannot contain '|' character") @@ -339,12 +349,16 @@ class AppServiceTransaction: events: List[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], + one_time_key_counts: TransactionOneTimeKeyCounts, + unused_fallback_keys: TransactionUnusedFallbackKeys, ): self.service = service self.id = id self.events = events self.ephemeral = ephemeral self.to_device_messages = to_device_messages + self.one_time_key_counts = one_time_key_counts + self.unused_fallback_keys = unused_fallback_keys async def send(self, as_api: "ApplicationServiceApi") -> bool: """Sends this transaction using the provided AS API interface. @@ -359,6 +373,8 @@ class AppServiceTransaction: events=self.events, ephemeral=self.ephemeral, to_device_messages=self.to_device_messages, + one_time_key_counts=self.one_time_key_counts, + unused_fallback_keys=self.unused_fallback_keys, txn_id=self.id, ) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 73be7ff3d..a0ea958af 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -19,6 +19,11 @@ from prometheus_client import Counter from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind from synapse.api.errors import CodeMessageException +from synapse.appservice import ( + ApplicationService, + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, +) from synapse.events import EventBase from synapse.events.utils import serialize_event from synapse.http.client import SimpleHttpClient @@ -26,7 +31,6 @@ from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util.caches.response_cache import ResponseCache if TYPE_CHECKING: - from synapse.appservice import ApplicationService from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -219,6 +223,8 @@ class ApplicationServiceApi(SimpleHttpClient): events: List[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], + one_time_key_counts: TransactionOneTimeKeyCounts, + unused_fallback_keys: TransactionUnusedFallbackKeys, txn_id: Optional[int] = None, ) -> bool: """ @@ -252,7 +258,7 @@ class ApplicationServiceApi(SimpleHttpClient): uri = service.url + ("/transactions/%s" % urllib.parse.quote(str(txn_id))) # Never send ephemeral events to appservices that do not support it - body: Dict[str, List[JsonDict]] = {"events": serialized_events} + body: JsonDict = {"events": serialized_events} if service.supports_ephemeral: body.update( { @@ -262,6 +268,16 @@ class ApplicationServiceApi(SimpleHttpClient): } ) + if service.msc3202_transaction_extensions: + if one_time_key_counts: + body[ + "org.matrix.msc3202.device_one_time_key_counts" + ] = one_time_key_counts + if unused_fallback_keys: + body[ + "org.matrix.msc3202.device_unused_fallback_keys" + ] = unused_fallback_keys + try: await self.put_json( uri=uri, diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index b4e602e88..72417151b 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -54,12 +54,19 @@ from typing import ( Callable, Collection, Dict, + Iterable, List, Optional, Set, + Tuple, ) -from synapse.appservice import ApplicationService, ApplicationServiceState +from synapse.appservice import ( + ApplicationService, + ApplicationServiceState, + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, +) from synapse.appservice.api import ApplicationServiceApi from synapse.events import EventBase from synapse.logging.context import run_in_background @@ -96,7 +103,7 @@ class ApplicationServiceScheduler: self.as_api = hs.get_application_service_api() self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api) - self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock) + self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock, hs) async def start(self) -> None: logger.info("Starting appservice scheduler") @@ -153,7 +160,9 @@ class _ServiceQueuer: appservice at a given time. """ - def __init__(self, txn_ctrl: "_TransactionController", clock: Clock): + def __init__( + self, txn_ctrl: "_TransactionController", clock: Clock, hs: "HomeServer" + ): # dict of {service_id: [events]} self.queued_events: Dict[str, List[EventBase]] = {} # dict of {service_id: [events]} @@ -165,6 +174,10 @@ class _ServiceQueuer: self.requests_in_flight: Set[str] = set() self.txn_ctrl = txn_ctrl self.clock = clock + self._msc3202_transaction_extensions_enabled: bool = ( + hs.config.experimental.msc3202_transaction_extensions + ) + self._store = hs.get_datastores().main def start_background_request(self, service: ApplicationService) -> None: # start a sender for this appservice if we don't already have one @@ -202,15 +215,84 @@ class _ServiceQueuer: if not events and not ephemeral and not to_device_messages_to_send: return + one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None + unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None + + if ( + self._msc3202_transaction_extensions_enabled + and service.msc3202_transaction_extensions + ): + # Compute the one-time key counts and fallback key usage states + # for the users which are mentioned in this transaction, + # as well as the appservice's sender. + ( + one_time_key_counts, + unused_fallback_keys, + ) = await self._compute_msc3202_otk_counts_and_fallback_keys( + service, events, ephemeral, to_device_messages_to_send + ) + try: await self.txn_ctrl.send( - service, events, ephemeral, to_device_messages_to_send + service, + events, + ephemeral, + to_device_messages_to_send, + one_time_key_counts, + unused_fallback_keys, ) except Exception: logger.exception("AS request failed") finally: self.requests_in_flight.discard(service.id) + async def _compute_msc3202_otk_counts_and_fallback_keys( + self, + service: ApplicationService, + events: Iterable[EventBase], + ephemerals: Iterable[JsonDict], + to_device_messages: Iterable[JsonDict], + ) -> Tuple[TransactionOneTimeKeyCounts, TransactionUnusedFallbackKeys]: + """ + Given a list of the events, ephemeral messages and to-device messages, + - first computes a list of application services users that may have + interesting updates to the one-time key counts or fallback key usage. + - then computes one-time key counts and fallback key usages for those users. + Given a list of application service users that are interesting, + compute one-time key counts and fallback key usages for the users. + """ + + # Set of 'interesting' users who may have updates + users: Set[str] = set() + + # The sender is always included + users.add(service.sender) + + # All AS users that would receive the PDUs or EDUs sent to these rooms + # are classed as 'interesting'. + rooms_of_interesting_users: Set[str] = set() + # PDUs + rooms_of_interesting_users.update(event.room_id for event in events) + # EDUs + rooms_of_interesting_users.update( + ephemeral["room_id"] for ephemeral in ephemerals + ) + + # Look up the AS users in those rooms + for room_id in rooms_of_interesting_users: + users.update( + await self._store.get_app_service_users_in_room(room_id, service) + ) + + # Add recipients of to-device messages. + # device_message["user_id"] is the ID of the recipient. + users.update(device_message["user_id"] for device_message in to_device_messages) + + # Compute and return the counts / fallback key usage states + otk_counts = await self._store.count_bulk_e2e_one_time_keys_for_as(users) + unused_fbks = await self._store.get_e2e_bulk_unused_fallback_key_types(users) + return otk_counts, unused_fbks + class _TransactionController: """Transaction manager. @@ -238,6 +320,8 @@ class _TransactionController: events: List[EventBase], ephemeral: Optional[List[JsonDict]] = None, to_device_messages: Optional[List[JsonDict]] = None, + one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None, + unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None, ) -> None: """ Create a transaction with the given data and send to the provided @@ -248,6 +332,10 @@ class _TransactionController: events: The persistent events to include in the transaction. ephemeral: The ephemeral events to include in the transaction. to_device_messages: The to-device messages to include in the transaction. + one_time_key_counts: Counts of remaining one-time keys for relevant + appservice devices in the transaction. + unused_fallback_keys: Lists of unused fallback keys for relevant + appservice devices in the transaction. """ try: txn = await self.store.create_appservice_txn( @@ -255,6 +343,8 @@ class _TransactionController: events=events, ephemeral=ephemeral or [], to_device_messages=to_device_messages or [], + one_time_key_counts=one_time_key_counts or {}, + unused_fallback_keys=unused_fallback_keys or {}, ) service_is_up = await self._is_service_up(service) if service_is_up: diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 7fad2e042..439bfe152 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -166,6 +166,16 @@ def _load_appservice( supports_ephemeral = as_info.get("de.sorunome.msc2409.push_ephemeral", False) + # Opt-in flag for the MSC3202-specific transactional behaviour. + # When enabled, appservice transactions contain the following information: + # - device One-Time Key counts + # - device unused fallback key usage states + msc3202_transaction_extensions = as_info.get("org.matrix.msc3202", False) + if not isinstance(msc3202_transaction_extensions, bool): + raise ValueError( + "The `org.matrix.msc3202` option should be true or false if specified." + ) + return ApplicationService( token=as_info["as_token"], hostname=hostname, @@ -174,8 +184,9 @@ def _load_appservice( hs_token=as_info["hs_token"], sender=user_id, id=as_info["id"], - supports_ephemeral=supports_ephemeral, protocols=protocols, rate_limited=rate_limited, ip_range_whitelist=ip_range_whitelist, + supports_ephemeral=supports_ephemeral, + msc3202_transaction_extensions=msc3202_transaction_extensions, ) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 772eb3501..41338b39d 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -47,11 +47,6 @@ class ExperimentalConfig(Config): # MSC3030 (Jump to date API endpoint) self.msc3030_enabled: bool = experimental.get("msc3030_enabled", False) - # The portion of MSC3202 which is related to device masquerading. - self.msc3202_device_masquerading_enabled: bool = experimental.get( - "msc3202_device_masquerading", False - ) - # MSC2409 (this setting only relates to optionally sending to-device messages). # Presence, typing and read receipt EDUs are already sent to application services that # have opted in to receive them. If enabled, this adds to-device messages to that list. @@ -59,6 +54,17 @@ class ExperimentalConfig(Config): "msc2409_to_device_messages_enabled", False ) + # The portion of MSC3202 which is related to device masquerading. + self.msc3202_device_masquerading_enabled: bool = experimental.get( + "msc3202_device_masquerading", False + ) + + # Portion of MSC3202 related to transaction extensions: + # sending one-time key counts and fallback key usage to application services. + self.msc3202_transaction_extensions: bool = experimental.get( + "msc3202_transaction_extensions", False + ) + # MSC3706 (server-side support for partial state in /send_join responses) self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False) diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 304814af5..069444655 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -20,14 +20,18 @@ from synapse.appservice import ( ApplicationService, ApplicationServiceState, AppServiceTransaction, + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, ) from synapse.config.appservice import load_appservices from synapse.events import EventBase -from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage._base import db_to_json from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.types import JsonDict from synapse.util import json_encoder +from synapse.util.caches.descriptors import _CacheContext, cached if TYPE_CHECKING: from synapse.server import HomeServer @@ -56,7 +60,7 @@ def _make_exclusive_regex( return exclusive_user_pattern -class ApplicationServiceWorkerStore(SQLBaseStore): +class ApplicationServiceWorkerStore(RoomMemberWorkerStore): def __init__( self, database: DatabasePool, @@ -124,6 +128,18 @@ class ApplicationServiceWorkerStore(SQLBaseStore): return service return None + @cached(iterable=True, cache_context=True) + async def get_app_service_users_in_room( + self, + room_id: str, + app_service: "ApplicationService", + cache_context: _CacheContext, + ) -> List[str]: + users_in_room = await self.get_users_in_room( + room_id, on_invalidate=cache_context.invalidate + ) + return list(filter(app_service.is_interested_in_user, users_in_room)) + class ApplicationServiceStore(ApplicationServiceWorkerStore): # This is currently empty due to there not being any AS storage functions @@ -199,6 +215,8 @@ class ApplicationServiceTransactionWorkerStore( events: List[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], + one_time_key_counts: TransactionOneTimeKeyCounts, + unused_fallback_keys: TransactionUnusedFallbackKeys, ) -> AppServiceTransaction: """Atomically creates a new transaction for this application service with the given list of events. Ephemeral events are NOT persisted to the @@ -209,6 +227,10 @@ class ApplicationServiceTransactionWorkerStore( events: A list of persistent events to put in the transaction. ephemeral: A list of ephemeral events to put in the transaction. to_device_messages: A list of to-device messages to put in the transaction. + one_time_key_counts: Counts of remaining one-time keys for relevant + appservice devices in the transaction. + unused_fallback_keys: Lists of unused fallback keys for relevant + appservice devices in the transaction. Returns: A new transaction. @@ -244,6 +266,8 @@ class ApplicationServiceTransactionWorkerStore( events=events, ephemeral=ephemeral, to_device_messages=to_device_messages, + one_time_key_counts=one_time_key_counts, + unused_fallback_keys=unused_fallback_keys, ) return await self.db_pool.runInteraction( @@ -335,12 +359,17 @@ class ApplicationServiceTransactionWorkerStore( events = await self.get_events_as_list(event_ids) + # TODO: to-device messages, one-time key counts and unused fallback keys + # are not yet populated for catch-up transactions. + # We likely want to populate those for reliability. return AppServiceTransaction( service=service, id=entry["txn_id"], events=events, ephemeral=[], to_device_messages=[], + one_time_key_counts={}, + unused_fallback_keys={}, ) def _get_last_txn(self, txn, service_id: Optional[str]) -> int: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 1f8447b50..9b293475c 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -29,6 +29,10 @@ import attr from canonicaljson import encode_canonical_json from synapse.api.constants import DeviceKeyAlgorithms +from synapse.appservice import ( + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, +) from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( @@ -439,6 +443,114 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker "count_e2e_one_time_keys", _count_e2e_one_time_keys ) + async def count_bulk_e2e_one_time_keys_for_as( + self, user_ids: Collection[str] + ) -> TransactionOneTimeKeyCounts: + """ + Counts, in bulk, the one-time keys for all the users specified. + Intended to be used by application services for populating OTK counts in + transactions. + + Return structure is of the shape: + user_id -> device_id -> algorithm -> count + Empty algorithm -> count dicts are created if needed to represent a + lack of unused one-time keys. + """ + + def _count_bulk_e2e_one_time_keys_txn( + txn: LoggingTransaction, + ) -> TransactionOneTimeKeyCounts: + user_in_where_clause, user_parameters = make_in_list_sql_clause( + self.database_engine, "user_id", user_ids + ) + sql = f""" + SELECT user_id, device_id, algorithm, COUNT(key_id) + FROM devices + LEFT JOIN e2e_one_time_keys_json USING (user_id, device_id) + WHERE {user_in_where_clause} + GROUP BY user_id, device_id, algorithm + """ + txn.execute(sql, user_parameters) + + result: TransactionOneTimeKeyCounts = {} + + for user_id, device_id, algorithm, count in txn: + # We deliberately construct empty dictionaries for + # users and devices without any unused one-time keys. + # We *could* omit these empty dicts if there have been no + # changes since the last transaction, but we currently don't + # do any change tracking! + device_count_by_algo = result.setdefault(user_id, {}).setdefault( + device_id, {} + ) + if algorithm is not None: + # algorithm will be None if this device has no keys. + device_count_by_algo[algorithm] = count + + return result + + return await self.db_pool.runInteraction( + "count_bulk_e2e_one_time_keys", _count_bulk_e2e_one_time_keys_txn + ) + + async def get_e2e_bulk_unused_fallback_key_types( + self, user_ids: Collection[str] + ) -> TransactionUnusedFallbackKeys: + """ + Finds, in bulk, the types of unused fallback keys for all the users specified. + Intended to be used by application services for populating unused fallback + keys in transactions. + + Return structure is of the shape: + user_id -> device_id -> algorithms + Empty lists are created for devices if there are no unused fallback + keys. This matches the response structure of MSC3202. + """ + if len(user_ids) == 0: + return {} + + def _get_bulk_e2e_unused_fallback_keys_txn( + txn: LoggingTransaction, + ) -> TransactionUnusedFallbackKeys: + user_in_where_clause, user_parameters = make_in_list_sql_clause( + self.database_engine, "devices.user_id", user_ids + ) + # We can't use USING here because we require the `.used` condition + # to be part of the JOIN condition so that we generate empty lists + # when all keys are used (as opposed to just when there are no keys at all). + sql = f""" + SELECT devices.user_id, devices.device_id, algorithm + FROM devices + LEFT JOIN e2e_fallback_keys_json AS fallback_keys + ON devices.user_id = fallback_keys.user_id + AND devices.device_id = fallback_keys.device_id + AND NOT fallback_keys.used + WHERE + {user_in_where_clause} + """ + txn.execute(sql, user_parameters) + + result: TransactionUnusedFallbackKeys = {} + + for user_id, device_id, algorithm in txn: + # We deliberately construct empty dictionaries and lists for + # users and devices without any unused fallback keys. + # We *could* omit these empty dicts if there have been no + # changes since the last transaction, but we currently don't + # do any change tracking! + device_unused_keys = result.setdefault(user_id, {}).setdefault( + device_id, [] + ) + if algorithm is not None: + # algorithm will be None if this device has no keys. + device_unused_keys.append(algorithm) + + return result + + return await self.db_pool.runInteraction( + "_get_bulk_e2e_unused_fallback_keys", _get_bulk_e2e_unused_fallback_keys_txn + ) + async def set_e2e_fallback_keys( self, user_id: str, device_id: str, fallback_keys: JsonDict ) -> None: diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index 8fb6687f8..b9dc4dfe1 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -68,6 +68,8 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): events=events, ephemeral=[], to_device_messages=[], # txn made and saved + one_time_key_counts={}, + unused_fallback_keys={}, ) self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made txn.complete.assert_called_once_with(self.store) # txn completed @@ -92,6 +94,8 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): events=events, ephemeral=[], to_device_messages=[], # txn made and saved + one_time_key_counts={}, + unused_fallback_keys={}, ) self.assertEquals(0, txn.send.call_count) # txn not sent though self.assertEquals(0, txn.complete.call_count) # or completed @@ -114,7 +118,12 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( - service=service, events=events, ephemeral=[], to_device_messages=[] + service=service, + events=events, + ephemeral=[], + to_device_messages=[], + one_time_key_counts={}, + unused_fallback_keys={}, ) self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made self.assertEquals(1, self.recoverer.recover.call_count) # and invoked @@ -216,7 +225,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): service = Mock(id=4) event = Mock() self.scheduler.enqueue_for_appservice(service, events=[event]) - self.txn_ctrl.send.assert_called_once_with(service, [event], [], []) + self.txn_ctrl.send.assert_called_once_with(service, [event], [], [], None, None) def test_send_single_event_with_queue(self): d = defer.Deferred() @@ -231,11 +240,13 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # (call enqueue_for_appservice multiple times deliberately) self.scheduler.enqueue_for_appservice(service, events=[event2]) self.scheduler.enqueue_for_appservice(service, events=[event3]) - self.txn_ctrl.send.assert_called_with(service, [event], [], []) + self.txn_ctrl.send.assert_called_with(service, [event], [], [], None, None) self.assertEquals(1, self.txn_ctrl.send.call_count) # Resolve the send event: expect the queued events to be sent d.callback(service) - self.txn_ctrl.send.assert_called_with(service, [event2, event3], [], []) + self.txn_ctrl.send.assert_called_with( + service, [event2, event3], [], [], None, None + ) self.assertEquals(2, self.txn_ctrl.send.call_count) def test_multiple_service_queues(self): @@ -261,15 +272,15 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # send events for different ASes and make sure they are sent self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event]) self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event2]) - self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], []) + self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], [], None, None) self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event]) self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event2]) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], []) + self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], [], None, None) # make sure callbacks for a service only send queued events for THAT # service srv_2_defer.callback(srv2) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], []) + self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], [], None, None) self.assertEquals(3, self.txn_ctrl.send.call_count) def test_send_large_txns(self): @@ -288,13 +299,19 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): self.scheduler.enqueue_for_appservice(service, [event], []) # Expect the first event to be sent immediately. - self.txn_ctrl.send.assert_called_with(service, [event_list[0]], [], []) + self.txn_ctrl.send.assert_called_with( + service, [event_list[0]], [], [], None, None + ) srv_1_defer.callback(service) # Then send the next 100 events - self.txn_ctrl.send.assert_called_with(service, event_list[1:101], [], []) + self.txn_ctrl.send.assert_called_with( + service, event_list[1:101], [], [], None, None + ) srv_2_defer.callback(service) # Then the final 99 events - self.txn_ctrl.send.assert_called_with(service, event_list[101:], [], []) + self.txn_ctrl.send.assert_called_with( + service, event_list[101:], [], [], None, None + ) self.assertEquals(3, self.txn_ctrl.send.call_count) def test_send_single_ephemeral_no_queue(self): @@ -302,14 +319,18 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): service = Mock(id=4, name="service") event_list = [Mock(name="event")] self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) - self.txn_ctrl.send.assert_called_once_with(service, [], event_list, []) + self.txn_ctrl.send.assert_called_once_with( + service, [], event_list, [], None, None + ) def test_send_multiple_ephemeral_no_queue(self): # Expect the event to be sent immediately. service = Mock(id=4, name="service") event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")] self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) - self.txn_ctrl.send.assert_called_once_with(service, [], event_list, []) + self.txn_ctrl.send.assert_called_once_with( + service, [], event_list, [], None, None + ) def test_send_single_ephemeral_with_queue(self): d = defer.Deferred() @@ -324,13 +345,13 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # Send more events: expect send() to NOT be called multiple times. self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2) self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3) - self.txn_ctrl.send.assert_called_with(service, [], event_list_1, []) + self.txn_ctrl.send.assert_called_with(service, [], event_list_1, [], None, None) self.assertEquals(1, self.txn_ctrl.send.call_count) # Resolve txn_ctrl.send d.callback(service) # Expect the queued events to be sent self.txn_ctrl.send.assert_called_with( - service, [], event_list_2 + event_list_3, [] + service, [], event_list_2 + event_list_3, [], None, None ) self.assertEquals(2, self.txn_ctrl.send.call_count) @@ -343,7 +364,9 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): second_chunk = [Mock(name="event%i" % (i + 101)) for i in range(50)] event_list = first_chunk + second_chunk self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) - self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk, []) + self.txn_ctrl.send.assert_called_once_with( + service, [], first_chunk, [], None, None + ) d.callback(service) - self.txn_ctrl.send.assert_called_with(service, [], second_chunk, []) + self.txn_ctrl.send.assert_called_with(service, [], second_chunk, [], None, None) self.assertEquals(2, self.txn_ctrl.send.call_count) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 9918ff680..6e0ec3796 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -16,17 +16,25 @@ from typing import Dict, Iterable, List, Optional from unittest.mock import Mock from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin import synapse.storage -from synapse.appservice import ApplicationService +from synapse.appservice import ( + ApplicationService, + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, +) from synapse.handlers.appservice import ApplicationServicesHandler -from synapse.rest.client import login, receipts, room, sendtodevice +from synapse.rest.client import login, receipts, register, room, sendtodevice +from synapse.server import HomeServer from synapse.types import RoomStreamToken +from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest from tests.test_utils import make_awaitable, simple_async_mock +from tests.unittest import override_config from tests.utils import MockClock @@ -428,7 +436,14 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): # # The uninterested application service should not have been notified at all. self.send_mock.assert_called_once() - service, _events, _ephemeral, to_device_messages = self.send_mock.call_args[0] + ( + service, + _events, + _ephemeral, + to_device_messages, + _otks, + _fbks, + ) = self.send_mock.call_args[0] # Assert that this was the same to-device message that local_user sent self.assertEqual(service, interested_appservice) @@ -540,7 +555,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): service_id_to_message_count: Dict[str, int] = {} for call in self.send_mock.call_args_list: - service, _events, _ephemeral, to_device_messages = call[0] + service, _events, _ephemeral, to_device_messages, _otks, _fbks = call[0] # Check that this was made to an interested service self.assertIn(service, interested_appservices) @@ -582,3 +597,174 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): self._services.append(appservice) return appservice + + +class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase): + # Argument indices for pulling out arguments from a `send_mock`. + ARG_OTK_COUNTS = 4 + ARG_FALLBACK_KEYS = 5 + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + room.register_servlets, + sendtodevice.register_servlets, + receipts.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + # Mock the ApplicationServiceScheduler's _TransactionController's send method so that + # we can track what's going out + self.send_mock = simple_async_mock() + hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] # We assign to a method. + + # Define an application service for the tests + self._service_token = "VERYSECRET" + self._service = ApplicationService( + self._service_token, + "as1.invalid", + "as1", + "@as.sender:test", + namespaces={ + "users": [ + {"regex": "@_as_.*:test", "exclusive": True}, + {"regex": "@as.sender:test", "exclusive": True}, + ] + }, + msc3202_transaction_extensions=True, + ) + self.hs.get_datastores().main.services_cache = [self._service] + + # Register some appservice users + self._sender_user, self._sender_device = self.register_appservice_user( + "as.sender", self._service_token + ) + self._namespaced_user, self._namespaced_device = self.register_appservice_user( + "_as_user1", self._service_token + ) + + # Register a real user as well. + self._real_user = self.register_user("real.user", "meow") + self._real_user_token = self.login("real.user", "meow") + + async def _add_otks_for_device( + self, user_id: str, device_id: str, otk_count: int + ) -> None: + """ + Add some dummy keys. It doesn't matter if they're not a real algorithm; + that should be opaque to the server anyway. + """ + await self.hs.get_datastores().main.add_e2e_one_time_keys( + user_id, + device_id, + self.clock.time_msec(), + [("algo", f"k{i}", "{}") for i in range(otk_count)], + ) + + async def _add_fallback_key_for_device( + self, user_id: str, device_id: str, used: bool + ) -> None: + """ + Adds a fake fallback key to a device, optionally marking it as used + right away. + """ + store = self.hs.get_datastores().main + await store.set_e2e_fallback_keys(user_id, device_id, {"algo:fk": "fall back!"}) + if used is True: + # Mark the key as used + await store.db_pool.simple_update_one( + table="e2e_fallback_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": "algo", + "key_id": "fk", + }, + updatevalues={"used": True}, + desc="_get_fallback_key_set_used", + ) + + def _set_up_devices_and_a_room(self) -> str: + """ + Helper to set up devices for all the users + and a room for the users to talk in. + """ + + async def preparation(): + await self._add_otks_for_device(self._sender_user, self._sender_device, 42) + await self._add_fallback_key_for_device( + self._sender_user, self._sender_device, used=True + ) + await self._add_otks_for_device( + self._namespaced_user, self._namespaced_device, 36 + ) + await self._add_fallback_key_for_device( + self._namespaced_user, self._namespaced_device, used=False + ) + + # Register a device for the real user, too, so that we can later ensure + # that we don't leak information to the AS about the non-AS user. + await self.hs.get_datastores().main.store_device( + self._real_user, "REALDEV", "UltraMatrix 3000" + ) + await self._add_otks_for_device(self._real_user, "REALDEV", 50) + + self.get_success(preparation()) + + room_id = self.helper.create_room_as( + self._real_user, is_public=True, tok=self._real_user_token + ) + self.helper.join( + room_id, + self._namespaced_user, + tok=self._service_token, + appservice_user_id=self._namespaced_user, + ) + + # Check it was called for sanity. (This was to send the join event to the AS.) + self.send_mock.assert_called() + self.send_mock.reset_mock() + + return room_id + + @override_config( + {"experimental_features": {"msc3202_transaction_extensions": True}} + ) + def test_application_services_receive_otk_counts_and_fallback_key_usages_with_pdus( + self, + ) -> None: + """ + Tests that: + - the AS receives one-time key counts and unused fallback keys for: + - the specified sender; and + - any user who is in receipt of the PDUs + """ + + room_id = self._set_up_devices_and_a_room() + + # Send a message into the AS's room + self.helper.send(room_id, "woof woof", tok=self._real_user_token) + + # Capture what was sent as an AS transaction. + self.send_mock.assert_called() + last_args, _last_kwargs = self.send_mock.call_args + otks: Optional[TransactionOneTimeKeyCounts] = last_args[self.ARG_OTK_COUNTS] + unused_fallbacks: Optional[TransactionUnusedFallbackKeys] = last_args[ + self.ARG_FALLBACK_KEYS + ] + + self.assertEqual( + otks, + { + "@as.sender:test": {self._sender_device: {"algo": 42}}, + "@_as_user1:test": {self._namespaced_device: {"algo": 36}}, + }, + ) + self.assertEqual( + unused_fallbacks, + { + "@as.sender:test": {self._sender_device: []}, + "@_as_user1:test": {self._namespaced_device: ["algo"]}, + }, + ) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 50703ccae..d2f654214 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -267,7 +267,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) txn = self.get_success( defer.ensureDeferred( - self.store.create_appservice_txn(service, events, [], []) + self.store.create_appservice_txn(service, events, [], [], {}, {}) ) ) self.assertEquals(txn.id, 1) @@ -283,7 +283,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): self.get_success(self._insert_txn(service.id, 9644, events)) self.get_success(self._insert_txn(service.id, 9645, events)) txn = self.get_success( - self.store.create_appservice_txn(service, events, [], []) + self.store.create_appservice_txn(service, events, [], [], {}, {}) ) self.assertEquals(txn.id, 9646) self.assertEquals(txn.events, events) @@ -296,7 +296,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) self.get_success(self._set_last_txn(service.id, 9643)) txn = self.get_success( - self.store.create_appservice_txn(service, events, [], []) + self.store.create_appservice_txn(service, events, [], [], {}, {}) ) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) @@ -320,7 +320,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events)) txn = self.get_success( - self.store.create_appservice_txn(service, events, [], []) + self.store.create_appservice_txn(service, events, [], [], {}, {}) ) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) From 54e74cc15f30585f5874780437614c0df6f639d9 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 24 Feb 2022 19:56:38 +0100 Subject: [PATCH 56/84] Add type hints to `tests/rest/client` (#12072) --- changelog.d/12072.misc | 1 + tests/rest/client/test_consent.py | 19 +++-- tests/rest/client/test_device_lists.py | 16 ++-- tests/rest/client/test_ephemeral_message.py | 19 +++-- tests/rest/client/test_identity.py | 13 +++- tests/rest/client/test_keys.py | 6 +- tests/rest/client/test_password_policy.py | 39 +++++----- tests/rest/client/test_power_levels.py | 47 +++++++----- tests/rest/client/test_presence.py | 15 ++-- tests/rest/client/test_room_batch.py | 2 +- tests/rest/client/utils.py | 85 +++++++++++++-------- 11 files changed, 160 insertions(+), 102 deletions(-) create mode 100644 changelog.d/12072.misc diff --git a/changelog.d/12072.misc b/changelog.d/12072.misc new file mode 100644 index 000000000..0360dbd61 --- /dev/null +++ b/changelog.d/12072.misc @@ -0,0 +1 @@ +Add type hints to `tests/rest/client`. diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py index fcdc56581..b1ca81a91 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py @@ -13,11 +13,16 @@ # limitations under the License. import os +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.urls import ConsentURIBuilder from synapse.rest.client import login, room from synapse.rest.consent import consent_resource +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.server import FakeSite, make_request @@ -32,7 +37,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): user_id = True hijack_auth = False - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["form_secret"] = "123abc" @@ -56,7 +61,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver(config=config) return hs - def test_render_public_consent(self): + def test_render_public_consent(self) -> None: """You can observe the terms form without specifying a user""" resource = consent_resource.ConsentResource(self.hs) channel = make_request( @@ -66,9 +71,9 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): "/consent?v=1", shorthand=False, ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) - def test_accept_consent(self): + def test_accept_consent(self) -> None: """ A user can use the consent form to accept the terms. """ @@ -92,7 +97,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): access_token=access_token, shorthand=False, ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) # Get the version from the body, and whether we've consented version, consented = channel.result["body"].decode("ascii").split(",") @@ -107,7 +112,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): access_token=access_token, shorthand=False, ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) # Fetch the consent page, to get the consent version -- it should have # changed @@ -119,7 +124,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): access_token=access_token, shorthand=False, ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) # Get the version from the body, and check that it's the version we # agreed to, and that we've consented to it. diff --git a/tests/rest/client/test_device_lists.py b/tests/rest/client/test_device_lists.py index 16070cf02..a8af4e243 100644 --- a/tests/rest/client/test_device_lists.py +++ b/tests/rest/client/test_device_lists.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus + from synapse.rest import admin, devices, room, sync from synapse.rest.client import account, login, register @@ -30,7 +32,7 @@ class DeviceListsTestCase(unittest.HomeserverTestCase): devices.register_servlets, ] - def test_receiving_local_device_list_changes(self): + def test_receiving_local_device_list_changes(self) -> None: """Tests that a local users that share a room receive each other's device list changes. """ @@ -84,7 +86,7 @@ class DeviceListsTestCase(unittest.HomeserverTestCase): }, access_token=alice_access_token, ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) # Check that bob's incremental sync contains the updated device list. # If not, the client would only receive the device list update on the @@ -97,7 +99,7 @@ class DeviceListsTestCase(unittest.HomeserverTestCase): ) self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body) - def test_not_receiving_local_device_list_changes(self): + def test_not_receiving_local_device_list_changes(self) -> None: """Tests a local users DO NOT receive device updates from each other if they do not share a room. """ @@ -119,7 +121,7 @@ class DeviceListsTestCase(unittest.HomeserverTestCase): "/sync", access_token=bob_access_token, ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) next_batch_token = channel.json_body["next_batch"] # ...and then an incremental sync. This should block until the sync stream is woken up, @@ -141,11 +143,13 @@ class DeviceListsTestCase(unittest.HomeserverTestCase): }, access_token=alice_access_token, ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) # Check that bob's incremental sync does not contain the updated device list. bob_sync_channel.await_result() - self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body) + self.assertEqual( + bob_sync_channel.code, HTTPStatus.OK, bob_sync_channel.json_body + ) changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get( "changed", [] diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py index 3d7aa8ec8..9fa1f82df 100644 --- a/tests/rest/client/test_ephemeral_message.py +++ b/tests/rest/client/test_ephemeral_message.py @@ -11,9 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventContentFields, EventTypes from synapse.rest import admin from synapse.rest.client import room +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest @@ -27,7 +34,7 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["enable_ephemeral_messages"] = True @@ -35,10 +42,10 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase): self.hs = self.setup_test_homeserver(config=config) return self.hs - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room_id = self.helper.create_room_as(self.user_id) - def test_message_expiry_no_delay(self): + def test_message_expiry_no_delay(self) -> None: """Tests that sending a message sent with a m.self_destruct_after field set to the past results in that event being deleted right away. """ @@ -61,7 +68,7 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase): event_content = self.get_event(self.room_id, event_id)["content"] self.assertFalse(bool(event_content), event_content) - def test_message_expiry_delay(self): + def test_message_expiry_delay(self) -> None: """Tests that sending a message with a m.self_destruct_after field set to the future results in that event not being deleted right away, but advancing the clock to after that expiry timestamp causes the event to be deleted. @@ -89,7 +96,9 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase): event_content = self.get_event(self.room_id, event_id)["content"] self.assertFalse(bool(event_content), event_content) - def get_event(self, room_id, event_id, expected_code=200): + def get_event( + self, room_id: str, event_id: str, expected_code: int = HTTPStatus.OK + ) -> JsonDict: url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) channel = self.make_request("GET", url) diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index becb4e8dc..299b9d21e 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -13,9 +13,14 @@ # limitations under the License. import json +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -28,7 +33,7 @@ class IdentityTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["enable_3pid_lookup"] = False @@ -36,14 +41,14 @@ class IdentityTestCase(unittest.HomeserverTestCase): return self.hs - def test_3pid_lookup_disabled(self): + def test_3pid_lookup_disabled(self) -> None: self.hs.config.registration.enable_3pid_lookup = False self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") channel = self.make_request(b"POST", "/createRoom", b"{}", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) room_id = channel.json_body["room_id"] params = { @@ -56,4 +61,4 @@ class IdentityTestCase(unittest.HomeserverTestCase): channel = self.make_request( b"POST", request_url, request_data, access_token=tok ) - self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result) diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py index d7fa635ea..bbc8e7424 100644 --- a/tests/rest/client/test_keys.py +++ b/tests/rest/client/test_keys.py @@ -28,7 +28,7 @@ class KeyQueryTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def test_rejects_device_id_ice_key_outside_of_list(self): + def test_rejects_device_id_ice_key_outside_of_list(self) -> None: self.register_user("alice", "wonderland") alice_token = self.login("alice", "wonderland") bob = self.register_user("bob", "uncle") @@ -49,7 +49,7 @@ class KeyQueryTestCase(unittest.HomeserverTestCase): channel.result, ) - def test_rejects_device_key_given_as_map_to_bool(self): + def test_rejects_device_key_given_as_map_to_bool(self) -> None: self.register_user("alice", "wonderland") alice_token = self.login("alice", "wonderland") bob = self.register_user("bob", "uncle") @@ -73,7 +73,7 @@ class KeyQueryTestCase(unittest.HomeserverTestCase): channel.result, ) - def test_requires_device_key(self): + def test_requires_device_key(self) -> None: """`device_keys` is required. We should complain if it's missing.""" self.register_user("alice", "wonderland") alice_token = self.login("alice", "wonderland") diff --git a/tests/rest/client/test_password_policy.py b/tests/rest/client/test_password_policy.py index 3cf587189..3a74d2e96 100644 --- a/tests/rest/client/test_password_policy.py +++ b/tests/rest/client/test_password_policy.py @@ -13,11 +13,16 @@ # limitations under the License. import json +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import LoginType from synapse.api.errors import Codes from synapse.rest import admin from synapse.rest.client import account, login, password_policy, register +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -46,7 +51,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): account.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.register_url = "/_matrix/client/r0/register" self.policy = { "enabled": True, @@ -65,12 +70,12 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver(config=config) return hs - def test_get_policy(self): + def test_get_policy(self) -> None: """Tests if the /password_policy endpoint returns the configured policy.""" channel = self.make_request("GET", "/_matrix/client/r0/password_policy") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual( channel.json_body, { @@ -83,70 +88,70 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): channel.result, ) - def test_password_too_short(self): + def test_password_too_short(self) -> None: request_data = json.dumps({"username": "kermit", "password": "shorty"}) channel = self.make_request("POST", self.register_url, request_data) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual( channel.json_body["errcode"], Codes.PASSWORD_TOO_SHORT, channel.result, ) - def test_password_no_digit(self): + def test_password_no_digit(self) -> None: request_data = json.dumps({"username": "kermit", "password": "longerpassword"}) channel = self.make_request("POST", self.register_url, request_data) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual( channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT, channel.result, ) - def test_password_no_symbol(self): + def test_password_no_symbol(self) -> None: request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"}) channel = self.make_request("POST", self.register_url, request_data) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual( channel.json_body["errcode"], Codes.PASSWORD_NO_SYMBOL, channel.result, ) - def test_password_no_uppercase(self): + def test_password_no_uppercase(self) -> None: request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"}) channel = self.make_request("POST", self.register_url, request_data) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual( channel.json_body["errcode"], Codes.PASSWORD_NO_UPPERCASE, channel.result, ) - def test_password_no_lowercase(self): + def test_password_no_lowercase(self) -> None: request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"}) channel = self.make_request("POST", self.register_url, request_data) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual( channel.json_body["errcode"], Codes.PASSWORD_NO_LOWERCASE, channel.result, ) - def test_password_compliant(self): + def test_password_compliant(self) -> None: request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"}) channel = self.make_request("POST", self.register_url, request_data) # Getting a 401 here means the password has passed validation and the server has # responded with a list of registration flows. - self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) - def test_password_change(self): + def test_password_change(self) -> None: """This doesn't test every possible use case, only that hitting /account/password triggers the password validation code. """ @@ -173,5 +178,5 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): access_token=tok, ) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT) diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py index c0de4c93a..27dcfc83d 100644 --- a/tests/rest/client/test_power_levels.py +++ b/tests/rest/client/test_power_levels.py @@ -11,11 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import Codes from synapse.events.utils import CANONICALJSON_MAX_INT, CANONICALJSON_MIN_INT from synapse.rest import admin from synapse.rest.client import login, room, sync +from synapse.server import HomeServer +from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -30,12 +35,12 @@ class PowerLevelsTestCase(HomeserverTestCase): sync.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() return self.setup_test_homeserver(config=config) - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # register a room admin, moderator and regular user self.admin_user_id = self.register_user("admin", "pass") self.admin_access_token = self.login("admin", "pass") @@ -88,7 +93,7 @@ class PowerLevelsTestCase(HomeserverTestCase): tok=self.admin_access_token, ) - def test_non_admins_cannot_enable_room_encryption(self): + def test_non_admins_cannot_enable_room_encryption(self) -> None: # have the mod try to enable room encryption self.helper.send_state( self.room_id, @@ -104,10 +109,10 @@ class PowerLevelsTestCase(HomeserverTestCase): "m.room.encryption", {"algorithm": "m.megolm.v1.aes-sha2"}, tok=self.user_access_token, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure ) - def test_non_admins_cannot_send_server_acl(self): + def test_non_admins_cannot_send_server_acl(self) -> None: # have the mod try to send a server ACL self.helper.send_state( self.room_id, @@ -118,7 +123,7 @@ class PowerLevelsTestCase(HomeserverTestCase): "deny": ["*.evil.com", "evil.com"], }, tok=self.mod_access_token, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure ) # have the user try to send a server ACL @@ -131,10 +136,10 @@ class PowerLevelsTestCase(HomeserverTestCase): "deny": ["*.evil.com", "evil.com"], }, tok=self.user_access_token, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure ) - def test_non_admins_cannot_tombstone_room(self): + def test_non_admins_cannot_tombstone_room(self) -> None: # Create another room that will serve as our "upgraded room" self.upgraded_room_id = self.helper.create_room_as( self.admin_user_id, tok=self.admin_access_token @@ -149,7 +154,7 @@ class PowerLevelsTestCase(HomeserverTestCase): "replacement_room": self.upgraded_room_id, }, tok=self.mod_access_token, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure ) # have the user try to send a tombstone event @@ -164,17 +169,17 @@ class PowerLevelsTestCase(HomeserverTestCase): expect_code=403, # expect failure ) - def test_admins_can_enable_room_encryption(self): + def test_admins_can_enable_room_encryption(self) -> None: # have the admin try to enable room encryption self.helper.send_state( self.room_id, "m.room.encryption", {"algorithm": "m.megolm.v1.aes-sha2"}, tok=self.admin_access_token, - expect_code=200, # expect success + expect_code=HTTPStatus.OK, # expect success ) - def test_admins_can_send_server_acl(self): + def test_admins_can_send_server_acl(self) -> None: # have the admin try to send a server ACL self.helper.send_state( self.room_id, @@ -185,10 +190,10 @@ class PowerLevelsTestCase(HomeserverTestCase): "deny": ["*.evil.com", "evil.com"], }, tok=self.admin_access_token, - expect_code=200, # expect success + expect_code=HTTPStatus.OK, # expect success ) - def test_admins_can_tombstone_room(self): + def test_admins_can_tombstone_room(self) -> None: # Create another room that will serve as our "upgraded room" self.upgraded_room_id = self.helper.create_room_as( self.admin_user_id, tok=self.admin_access_token @@ -203,10 +208,10 @@ class PowerLevelsTestCase(HomeserverTestCase): "replacement_room": self.upgraded_room_id, }, tok=self.admin_access_token, - expect_code=200, # expect success + expect_code=HTTPStatus.OK, # expect success ) - def test_cannot_set_string_power_levels(self): + def test_cannot_set_string_power_levels(self) -> None: room_power_levels = self.helper.get_state( self.room_id, "m.room.power_levels", @@ -221,7 +226,7 @@ class PowerLevelsTestCase(HomeserverTestCase): "m.room.power_levels", room_power_levels, tok=self.admin_access_token, - expect_code=400, # expect failure + expect_code=HTTPStatus.BAD_REQUEST, # expect failure ) self.assertEqual( @@ -230,7 +235,7 @@ class PowerLevelsTestCase(HomeserverTestCase): body, ) - def test_cannot_set_unsafe_large_power_levels(self): + def test_cannot_set_unsafe_large_power_levels(self) -> None: room_power_levels = self.helper.get_state( self.room_id, "m.room.power_levels", @@ -247,7 +252,7 @@ class PowerLevelsTestCase(HomeserverTestCase): "m.room.power_levels", room_power_levels, tok=self.admin_access_token, - expect_code=400, # expect failure + expect_code=HTTPStatus.BAD_REQUEST, # expect failure ) self.assertEqual( @@ -256,7 +261,7 @@ class PowerLevelsTestCase(HomeserverTestCase): body, ) - def test_cannot_set_unsafe_small_power_levels(self): + def test_cannot_set_unsafe_small_power_levels(self) -> None: room_power_levels = self.helper.get_state( self.room_id, "m.room.power_levels", @@ -273,7 +278,7 @@ class PowerLevelsTestCase(HomeserverTestCase): "m.room.power_levels", room_power_levels, tok=self.admin_access_token, - expect_code=400, # expect failure + expect_code=HTTPStatus.BAD_REQUEST, # expect failure ) self.assertEqual( diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py index 56fe1a3d0..0abe378fe 100644 --- a/tests/rest/client/test_presence.py +++ b/tests/rest/client/test_presence.py @@ -11,14 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from http import HTTPStatus from unittest.mock import Mock from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor from synapse.handlers.presence import PresenceHandler from synapse.rest.client import presence +from synapse.server import HomeServer from synapse.types import UserID +from synapse.util import Clock from tests import unittest @@ -31,7 +34,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): user = UserID.from_string(user_id) servlets = [presence.register_servlets] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: presence_handler = Mock(spec=PresenceHandler) presence_handler.set_state.return_value = defer.succeed(None) @@ -45,7 +48,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): return hs - def test_put_presence(self): + def test_put_presence(self) -> None: """ PUT to the status endpoint with use_presence enabled will call set_state on the presence handler. @@ -57,11 +60,11 @@ class PresenceTestCase(unittest.HomeserverTestCase): "PUT", "/presence/%s/status" % (self.user_id,), body ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1) @unittest.override_config({"use_presence": False}) - def test_put_presence_disabled(self): + def test_put_presence_disabled(self) -> None: """ PUT to the status endpoint with use_presence disabled will NOT call set_state on the presence handler. @@ -72,5 +75,5 @@ class PresenceTestCase(unittest.HomeserverTestCase): "PUT", "/presence/%s/status" % (self.user_id,), body ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0) diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py index e9f870403..44f333a0e 100644 --- a/tests/rest/client/test_room_batch.py +++ b/tests/rest/client/test_room_batch.py @@ -134,7 +134,7 @@ class RoomBatchTestCase(unittest.HomeserverTestCase): return room_id, event_id_a, event_id_b, event_id_c @unittest.override_config({"experimental_features": {"msc2716_enabled": True}}) - def test_same_state_groups_for_whole_historical_batch(self): + def test_same_state_groups_for_whole_historical_batch(self) -> None: """Make sure that when using the `/batch_send` endpoint to import a bunch of historical messages, it re-uses the same `state_group` across the whole batch. This is an easy optimization to make sure we're getting diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 2b3fdadff..46cd5f70a 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -19,6 +19,7 @@ import json import re import time import urllib.parse +from http import HTTPStatus from typing import ( Any, AnyStr, @@ -89,7 +90,7 @@ class RestHelper: is_public: Optional[bool] = None, room_version: Optional[str] = None, tok: Optional[str] = None, - expect_code: int = 200, + expect_code: int = HTTPStatus.OK, extra_content: Optional[Dict] = None, custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, ) -> Optional[str]: @@ -137,12 +138,19 @@ class RestHelper: assert channel.result["code"] == b"%d" % expect_code, channel.result self.auth_user_id = temp_id - if expect_code == 200: + if expect_code == HTTPStatus.OK: return channel.json_body["room_id"] else: return None - def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): + def invite( + self, + room: Optional[str] = None, + src: Optional[str] = None, + targ: Optional[str] = None, + expect_code: int = HTTPStatus.OK, + tok: Optional[str] = None, + ) -> None: self.change_membership( room=room, src=src, @@ -156,7 +164,7 @@ class RestHelper: self, room: str, user: Optional[str] = None, - expect_code: int = 200, + expect_code: int = HTTPStatus.OK, tok: Optional[str] = None, appservice_user_id: Optional[str] = None, ) -> None: @@ -170,7 +178,14 @@ class RestHelper: expect_code=expect_code, ) - def knock(self, room=None, user=None, reason=None, expect_code=200, tok=None): + def knock( + self, + room: Optional[str] = None, + user: Optional[str] = None, + reason: Optional[str] = None, + expect_code: int = HTTPStatus.OK, + tok: Optional[str] = None, + ) -> None: temp_id = self.auth_user_id self.auth_user_id = user path = "/knock/%s" % room @@ -199,7 +214,13 @@ class RestHelper: self.auth_user_id = temp_id - def leave(self, room=None, user=None, expect_code=200, tok=None): + def leave( + self, + room: Optional[str] = None, + user: Optional[str] = None, + expect_code: int = HTTPStatus.OK, + tok: Optional[str] = None, + ) -> None: self.change_membership( room=room, src=user, @@ -209,7 +230,7 @@ class RestHelper: expect_code=expect_code, ) - def ban(self, room: str, src: str, targ: str, **kwargs: object): + def ban(self, room: str, src: str, targ: str, **kwargs: object) -> None: """A convenience helper: `change_membership` with `membership` preset to "ban".""" self.change_membership( room=room, @@ -228,7 +249,7 @@ class RestHelper: extra_data: Optional[dict] = None, tok: Optional[str] = None, appservice_user_id: Optional[str] = None, - expect_code: int = 200, + expect_code: int = HTTPStatus.OK, expect_errcode: Optional[str] = None, ) -> None: """ @@ -294,13 +315,13 @@ class RestHelper: def send( self, - room_id, - body=None, - txn_id=None, - tok=None, - expect_code=200, + room_id: str, + body: Optional[str] = None, + txn_id: Optional[str] = None, + tok: Optional[str] = None, + expect_code: int = HTTPStatus.OK, custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, - ): + ) -> JsonDict: if body is None: body = "body_text_here" @@ -318,14 +339,14 @@ class RestHelper: def send_event( self, - room_id, - type, + room_id: str, + type: str, content: Optional[dict] = None, - txn_id=None, - tok=None, - expect_code=200, + txn_id: Optional[str] = None, + tok: Optional[str] = None, + expect_code: int = HTTPStatus.OK, custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, - ): + ) -> JsonDict: if txn_id is None: txn_id = "m%s" % (str(time.time())) @@ -358,10 +379,10 @@ class RestHelper: event_type: str, body: Optional[Dict[str, Any]], tok: str, - expect_code: int = 200, + expect_code: int = HTTPStatus.OK, state_key: str = "", method: str = "GET", - ) -> Dict: + ) -> JsonDict: """Read or write some state from a given room Args: @@ -410,9 +431,9 @@ class RestHelper: room_id: str, event_type: str, tok: str, - expect_code: int = 200, + expect_code: int = HTTPStatus.OK, state_key: str = "", - ): + ) -> JsonDict: """Gets some state from a room Args: @@ -438,9 +459,9 @@ class RestHelper: event_type: str, body: Dict[str, Any], tok: str, - expect_code: int = 200, + expect_code: int = HTTPStatus.OK, state_key: str = "", - ): + ) -> JsonDict: """Set some state in a room Args: @@ -467,8 +488,8 @@ class RestHelper: image_data: bytes, tok: str, filename: str = "test.png", - expect_code: int = 200, - ) -> dict: + expect_code: int = HTTPStatus.OK, + ) -> JsonDict: """Upload a piece of test media to the media repo Args: resource: The resource that will handle the upload request @@ -513,7 +534,7 @@ class RestHelper: channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url) # expect a confirmation page - assert channel.code == 200, channel.result + assert channel.code == HTTPStatus.OK, channel.result # fish the matrix login token out of the body of the confirmation page m = re.search( @@ -532,7 +553,7 @@ class RestHelper: "/login", content={"type": "m.login.token", "token": login_token}, ) - assert channel.code == 200 + assert channel.code == HTTPStatus.OK return channel.json_body def auth_via_oidc( @@ -641,7 +662,7 @@ class RestHelper: (expected_uri, resp_obj) = expected_requests.pop(0) assert uri == expected_uri resp = FakeResponse( - code=200, + code=HTTPStatus.OK, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"), ) @@ -739,7 +760,7 @@ class RestHelper: self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint ) # that should serve a confirmation page - assert channel.code == 200, channel.text_body + assert channel.code == HTTPStatus.OK, channel.text_body channel.extract_cookies(cookies) # parse the confirmation page to fish out the link. From f3fd8558cdb5d91d0e54ca35b55a3dba2610b215 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 25 Feb 2022 10:19:49 +0000 Subject: [PATCH 57/84] Minor typing fixes for `synapse/storage/persist_events.py` (#12069) Signed-off-by: Sean Quah --- changelog.d/12069.misc | 1 + synapse/storage/databases/main/events.py | 23 ++++++++++++---------- synapse/storage/persist_events.py | 25 ++++++++++++------------ 3 files changed, 26 insertions(+), 23 deletions(-) create mode 100644 changelog.d/12069.misc diff --git a/changelog.d/12069.misc b/changelog.d/12069.misc new file mode 100644 index 000000000..8374a6322 --- /dev/null +++ b/changelog.d/12069.misc @@ -0,0 +1 @@ +Minor typing fixes. diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index e53e84054..23fa089bc 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -130,7 +130,7 @@ class PersistEventsStore: *, current_state_for_room: Dict[str, StateMap[str]], state_delta_for_room: Dict[str, DeltaState], - new_forward_extremeties: Dict[str, List[str]], + new_forward_extremities: Dict[str, Set[str]], use_negative_stream_ordering: bool = False, inhibit_local_membership_updates: bool = False, ) -> None: @@ -143,7 +143,7 @@ class PersistEventsStore: the room based on forward extremities state_delta_for_room: Map from room_id to the delta to apply to room state - new_forward_extremities: Map from room_id to list of event IDs + new_forward_extremities: Map from room_id to set of event IDs that are the new forward extremities of the room. use_negative_stream_ordering: Whether to start stream_ordering on the negative side and decrement. This should be set as True @@ -193,7 +193,7 @@ class PersistEventsStore: events_and_contexts=events_and_contexts, inhibit_local_membership_updates=inhibit_local_membership_updates, state_delta_for_room=state_delta_for_room, - new_forward_extremeties=new_forward_extremeties, + new_forward_extremities=new_forward_extremities, ) persist_event_counter.inc(len(events_and_contexts)) @@ -220,7 +220,7 @@ class PersistEventsStore: for room_id, new_state in current_state_for_room.items(): self.store.get_current_state_ids.prefill((room_id,), new_state) - for room_id, latest_event_ids in new_forward_extremeties.items(): + for room_id, latest_event_ids in new_forward_extremities.items(): self.store.get_latest_event_ids_in_room.prefill( (room_id,), list(latest_event_ids) ) @@ -334,8 +334,8 @@ class PersistEventsStore: events_and_contexts: List[Tuple[EventBase, EventContext]], inhibit_local_membership_updates: bool = False, state_delta_for_room: Optional[Dict[str, DeltaState]] = None, - new_forward_extremeties: Optional[Dict[str, List[str]]] = None, - ): + new_forward_extremities: Optional[Dict[str, Set[str]]] = None, + ) -> None: """Insert some number of room events into the necessary database tables. Rejected events are only inserted into the events table, the events_json table, @@ -353,13 +353,13 @@ class PersistEventsStore: from the database. This is useful when retrying due to IntegrityError. state_delta_for_room: The current-state delta for each room. - new_forward_extremetie: The new forward extremities for each room. + new_forward_extremities: The new forward extremities for each room. For each room, a list of the event ids which are the forward extremities. """ state_delta_for_room = state_delta_for_room or {} - new_forward_extremeties = new_forward_extremeties or {} + new_forward_extremities = new_forward_extremities or {} all_events_and_contexts = events_and_contexts @@ -372,7 +372,7 @@ class PersistEventsStore: self._update_forward_extremities_txn( txn, - new_forward_extremities=new_forward_extremeties, + new_forward_extremities=new_forward_extremities, max_stream_order=max_stream_order, ) @@ -1158,7 +1158,10 @@ class PersistEventsStore: ) def _update_forward_extremities_txn( - self, txn, new_forward_extremities, max_stream_order + self, + txn: LoggingTransaction, + new_forward_extremities: Dict[str, Set[str]], + max_stream_order: int, ): for room_id in new_forward_extremities.keys(): self.db_pool.simple_delete_txn( diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 428d66a61..7d543fdbe 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -427,21 +427,21 @@ class EventsPersistenceStorage: # NB: Assumes that we are only persisting events for one room # at a time. - # map room_id->list[event_ids] giving the new forward + # map room_id->set[event_ids] giving the new forward # extremities in each room - new_forward_extremeties = {} + new_forward_extremities: Dict[str, Set[str]] = {} # map room_id->(type,state_key)->event_id tracking the full # state in each room after adding these events. # This is simply used to prefill the get_current_state_ids # cache - current_state_for_room = {} + current_state_for_room: Dict[str, StateMap[str]] = {} # map room_id->(to_delete, to_insert) where to_delete is a list # of type/state keys to remove from current state, and to_insert # is a map (type,key)->event_id giving the state delta in each # room - state_delta_for_room = {} + state_delta_for_room: Dict[str, DeltaState] = {} # Set of remote users which were in rooms the server has left. We # should check if we still share any rooms and if not we mark their @@ -460,14 +460,13 @@ class EventsPersistenceStorage: ) for room_id, ev_ctx_rm in events_by_room.items(): - latest_event_ids = ( + latest_event_ids = set( await self.main_store.get_latest_event_ids_in_room(room_id) ) new_latest_event_ids = await self._calculate_new_extremities( room_id, ev_ctx_rm, latest_event_ids ) - latest_event_ids = set(latest_event_ids) if new_latest_event_ids == latest_event_ids: # No change in extremities, so no change in state continue @@ -478,7 +477,7 @@ class EventsPersistenceStorage: # extremities, so we'll `continue` above and skip this bit.) assert new_latest_event_ids, "No forward extremities left!" - new_forward_extremeties[room_id] = new_latest_event_ids + new_forward_extremities[room_id] = new_latest_event_ids len_1 = ( len(latest_event_ids) == 1 @@ -533,7 +532,7 @@ class EventsPersistenceStorage: # extremities, so we'll `continue` above and skip this bit.) assert new_latest_event_ids, "No forward extremities left!" - new_forward_extremeties[room_id] = new_latest_event_ids + new_forward_extremities[room_id] = new_latest_event_ids # If either are not None then there has been a change, # and we need to work out the delta (or use that @@ -567,7 +566,7 @@ class EventsPersistenceStorage: ) if not is_still_joined: logger.info("Server no longer in room %s", room_id) - latest_event_ids = [] + latest_event_ids = set() current_state = {} delta.no_longer_in_room = True @@ -582,7 +581,7 @@ class EventsPersistenceStorage: chunk, current_state_for_room=current_state_for_room, state_delta_for_room=state_delta_for_room, - new_forward_extremeties=new_forward_extremeties, + new_forward_extremities=new_forward_extremities, use_negative_stream_ordering=backfilled, inhibit_local_membership_updates=backfilled, ) @@ -596,7 +595,7 @@ class EventsPersistenceStorage: room_id: str, event_contexts: List[Tuple[EventBase, EventContext]], latest_event_ids: Collection[str], - ): + ) -> Set[str]: """Calculates the new forward extremities for a room given events to persist. @@ -906,9 +905,9 @@ class EventsPersistenceStorage: # Ideally we'd figure out a way of still being able to drop old # dummy events that reference local events, but this is good enough # as a first cut. - events_to_check = [event] + events_to_check: Collection[EventBase] = [event] while events_to_check: - new_events = set() + new_events: Set[str] = set() for event_to_check in events_to_check: if self.is_mine_id(event_to_check.sender): if event_to_check.type != EventTypes.Dummy: From b43c3ef8e2306829074d847bed50575d5e7c7ea3 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 25 Feb 2022 10:20:40 +0000 Subject: [PATCH 58/84] Ensure that `get_datastores().main` is typed (#12070) Signed-off-by: Sean Quah --- changelog.d/12070.misc | 1 + synapse/storage/databases/__init__.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12070.misc diff --git a/changelog.d/12070.misc b/changelog.d/12070.misc new file mode 100644 index 000000000..d4bedc6b9 --- /dev/null +++ b/changelog.d/12070.misc @@ -0,0 +1 @@ +Remove legacy `HomeServer.get_datastore()`. diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py index cfe887b7f..ce3d1d4e9 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -24,6 +24,7 @@ from synapse.storage.prepare_database import prepare_database if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) @@ -44,7 +45,7 @@ class Databases(Generic[DataStoreT]): """ databases: List[DatabasePool] - main: DataStoreT + main: "DataStore" # FIXME: #11165: actually an instance of `main_store_class` state: StateGroupDataStore persist_events: Optional[PersistEventsStore] From ab3ef49059e465198754a3d818d1f3b21771f5ef Mon Sep 17 00:00:00 2001 From: lukasdenk <63459921+lukasdenk@users.noreply.github.com> Date: Mon, 28 Feb 2022 12:42:13 +0100 Subject: [PATCH 59/84] synctl: print warning if synctl_cache_factor is set in config (#11865) Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- changelog.d/11865.removal | 1 + synctl | 8 ++++++++ 2 files changed, 9 insertions(+) create mode 100644 changelog.d/11865.removal diff --git a/changelog.d/11865.removal b/changelog.d/11865.removal new file mode 100644 index 000000000..9fcabfc72 --- /dev/null +++ b/changelog.d/11865.removal @@ -0,0 +1 @@ +Deprecate using `synctl` with the config option `synctl_cache_factor` and print a warning if a user still uses this option. diff --git a/synctl b/synctl index 0e54f4847..1ab36949c 100755 --- a/synctl +++ b/synctl @@ -37,6 +37,13 @@ YELLOW = "\x1b[1;33m" RED = "\x1b[1;31m" NORMAL = "\x1b[m" +SYNCTL_CACHE_FACTOR_WARNING = """\ +Setting 'synctl_cache_factor' in the config is deprecated. Instead, please do +one of the following: + - Either set the environment variable 'SYNAPSE_CACHE_FACTOR' + - or set 'caches.global_factor' in the homeserver config. +--------------------------------------------------------------------------------""" + def pid_running(pid): try: @@ -228,6 +235,7 @@ def main(): start_stop_synapse = True if cache_factor: + write(SYNCTL_CACHE_FACTOR_WARNING) os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor) cache_factors = config.get("synctl_cache_factors", {}) From 02d708568b476f2f7716000b35c0adfa4cbd31b3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 28 Feb 2022 07:12:29 -0500 Subject: [PATCH 60/84] Replace assertEquals and friends with non-deprecated versions. (#12092) --- changelog.d/12092.misc | 1 + tests/api/test_auth.py | 36 +-- tests/api/test_filtering.py | 20 +- tests/api/test_ratelimiting.py | 28 +- tests/appservice/test_scheduler.py | 62 ++--- tests/crypto/test_event_signing.py | 8 +- tests/crypto/test_keyring.py | 6 +- tests/events/test_utils.py | 16 +- tests/federation/test_complexity.py | 4 +- tests/federation/test_federation_server.py | 12 +- tests/federation/transport/test_knocking.py | 16 +- tests/federation/transport/test_server.py | 4 +- tests/handlers/test_appservice.py | 18 +- tests/handlers/test_directory.py | 28 +- tests/handlers/test_presence.py | 78 +++--- tests/handlers/test_profile.py | 20 +- tests/handlers/test_receipts.py | 2 +- tests/handlers/test_register.py | 4 +- tests/handlers/test_sync.py | 6 +- tests/handlers/test_typing.py | 40 +-- tests/handlers/test_user_directory.py | 4 +- tests/http/federation/test_srv_resolver.py | 24 +- .../replication/slave/storage/test_events.py | 2 +- tests/rest/admin/test_room.py | 14 +- tests/rest/client/test_account.py | 22 +- tests/rest/client/test_events.py | 10 +- tests/rest/client/test_filter.py | 10 +- tests/rest/client/test_groups.py | 12 +- tests/rest/client/test_login.py | 78 +++--- tests/rest/client/test_profile.py | 2 +- tests/rest/client/test_register.py | 166 ++++++------ tests/rest/client/test_relations.py | 248 +++++++++--------- tests/rest/client/test_rooms.py | 192 +++++++------- tests/rest/client/test_shadow_banned.py | 20 +- tests/rest/client/test_shared_rooms.py | 24 +- tests/rest/client/test_sync.py | 16 +- tests/rest/client/test_third_party_rules.py | 14 +- tests/rest/client/test_typing.py | 18 +- tests/rest/client/test_upgrade_room.py | 18 +- tests/rest/media/v1/test_media_storage.py | 2 +- .../databases/main/test_events_worker.py | 12 +- tests/storage/test_appservice.py | 86 +++--- tests/storage/test_base.py | 6 +- tests/storage/test_directory.py | 2 +- tests/storage/test_event_push_actions.py | 2 +- tests/storage/test_main.py | 8 +- tests/storage/test_profile.py | 4 +- tests/storage/test_registration.py | 6 +- tests/storage/test_room.py | 4 +- tests/storage/test_room_search.py | 6 +- tests/storage/test_roommember.py | 2 +- tests/test_distributor.py | 2 +- tests/test_terms_auth.py | 6 +- tests/test_test_utils.py | 2 +- tests/test_types.py | 16 +- tests/unittest.py | 6 +- tests/util/caches/test_deferred_cache.py | 2 +- tests/util/caches/test_descriptors.py | 70 ++--- tests/util/test_expiring_cache.py | 40 +-- tests/util/test_logcontext.py | 2 +- tests/util/test_lrucache.py | 140 +++++----- tests/util/test_treecache.py | 48 ++-- 62 files changed, 888 insertions(+), 889 deletions(-) create mode 100644 changelog.d/12092.misc diff --git a/changelog.d/12092.misc b/changelog.d/12092.misc new file mode 100644 index 000000000..62653d6f8 --- /dev/null +++ b/changelog.d/12092.misc @@ -0,0 +1 @@ +User `assertEqual` instead of the deprecated `assertEquals` in test code. diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 686d17c0d..3e0578992 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -71,7 +71,7 @@ class AuthTestCase(unittest.HomeserverTestCase): request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = self.get_success(self.auth.get_user_by_req(request)) - self.assertEquals(requester.user.to_string(), self.test_user) + self.assertEqual(requester.user.to_string(), self.test_user) def test_get_user_by_req_user_bad_token(self): self.store.get_user_by_access_token = simple_async_mock(None) @@ -109,7 +109,7 @@ class AuthTestCase(unittest.HomeserverTestCase): request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = self.get_success(self.auth.get_user_by_req(request)) - self.assertEquals(requester.user.to_string(), self.test_user) + self.assertEqual(requester.user.to_string(), self.test_user) def test_get_user_by_req_appservice_valid_token_good_ip(self): from netaddr import IPSet @@ -128,7 +128,7 @@ class AuthTestCase(unittest.HomeserverTestCase): request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = self.get_success(self.auth.get_user_by_req(request)) - self.assertEquals(requester.user.to_string(), self.test_user) + self.assertEqual(requester.user.to_string(), self.test_user) def test_get_user_by_req_appservice_valid_token_bad_ip(self): from netaddr import IPSet @@ -195,7 +195,7 @@ class AuthTestCase(unittest.HomeserverTestCase): request.args[b"user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = self.get_success(self.auth.get_user_by_req(request)) - self.assertEquals( + self.assertEqual( requester.user.to_string(), masquerading_user_id.decode("utf8") ) @@ -242,10 +242,10 @@ class AuthTestCase(unittest.HomeserverTestCase): request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = self.get_success(self.auth.get_user_by_req(request)) - self.assertEquals( + self.assertEqual( requester.user.to_string(), masquerading_user_id.decode("utf8") ) - self.assertEquals(requester.device_id, masquerading_device_id.decode("utf8")) + self.assertEqual(requester.device_id, masquerading_device_id.decode("utf8")) @override_config({"experimental_features": {"msc3202_device_masquerading": True}}) def test_get_user_by_req_appservice_valid_token_invalid_device_id(self): @@ -275,8 +275,8 @@ class AuthTestCase(unittest.HomeserverTestCase): request.requestHeaders.getRawHeaders = mock_getRawHeaders() failure = self.get_failure(self.auth.get_user_by_req(request), AuthError) - self.assertEquals(failure.value.code, 400) - self.assertEquals(failure.value.errcode, Codes.EXCLUSIVE) + self.assertEqual(failure.value.code, 400) + self.assertEqual(failure.value.errcode, Codes.EXCLUSIVE) def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self): self.store.get_user_by_access_token = simple_async_mock( @@ -309,7 +309,7 @@ class AuthTestCase(unittest.HomeserverTestCase): request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() self.get_success(self.auth.get_user_by_req(request)) - self.assertEquals(self.store.insert_client_ip.call_count, 2) + self.assertEqual(self.store.insert_client_ip.call_count, 2) def test_get_user_from_macaroon(self): self.store.get_user_by_access_token = simple_async_mock( @@ -369,9 +369,9 @@ class AuthTestCase(unittest.HomeserverTestCase): self.store.get_monthly_active_count = simple_async_mock(lots_of_users) e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) - self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact) - self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - self.assertEquals(e.value.code, 403) + self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact) + self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + self.assertEqual(e.value.code, 403) # Ensure does not throw an error self.store.get_monthly_active_count = simple_async_mock(small_number_of_users) @@ -473,9 +473,9 @@ class AuthTestCase(unittest.HomeserverTestCase): self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled_message = "Reason for being disabled" e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) - self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact) - self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - self.assertEquals(e.value.code, 403) + self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact) + self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + self.assertEqual(e.value.code, 403) def test_hs_disabled_no_server_notices_user(self): """Check that 'hs_disabled_message' works correctly when there is no @@ -488,9 +488,9 @@ class AuthTestCase(unittest.HomeserverTestCase): self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled_message = "Reason for being disabled" e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) - self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact) - self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - self.assertEquals(e.value.code, 403) + self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact) + self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + self.assertEqual(e.value.code, 403) def test_server_notices_mxid_special_cased(self): self.auth_blocking._hs_disabled = True diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 973f0f7fa..2525018e9 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -364,7 +364,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) results = self.get_success(user_filter.filter_presence(events=events)) - self.assertEquals(events, results) + self.assertEqual(events, results) def test_filter_presence_no_match(self): user_filter_json = {"presence": {"types": ["m.*"]}} @@ -388,7 +388,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) results = self.get_success(user_filter.filter_presence(events=events)) - self.assertEquals([], results) + self.assertEqual([], results) def test_filter_room_state_match(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} @@ -407,7 +407,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) results = self.get_success(user_filter.filter_room_state(events=events)) - self.assertEquals(events, results) + self.assertEqual(events, results) def test_filter_room_state_no_match(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} @@ -428,7 +428,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) results = self.get_success(user_filter.filter_room_state(events)) - self.assertEquals([], results) + self.assertEqual([], results) def test_filter_rooms(self): definition = { @@ -444,7 +444,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): filtered_room_ids = list(Filter(self.hs, definition).filter_rooms(room_ids)) - self.assertEquals(filtered_room_ids, ["!allowed:example.com"]) + self.assertEqual(filtered_room_ids, ["!allowed:example.com"]) @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_filter_relations(self): @@ -486,7 +486,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): Filter(self.hs, definition)._check_event_relations(events) ) ) - self.assertEquals(filtered_events, events[1:]) + self.assertEqual(filtered_events, events[1:]) def test_add_filter(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} @@ -497,8 +497,8 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) ) - self.assertEquals(filter_id, 0) - self.assertEquals( + self.assertEqual(filter_id, 0) + self.assertEqual( user_filter_json, ( self.get_success( @@ -524,6 +524,6 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) ) - self.assertEquals(filter.get_filter_json(), user_filter_json) + self.assertEqual(filter.get_filter_json(), user_filter_json) - self.assertRegexpMatches(repr(filter), r"") + self.assertRegex(repr(filter), r"") diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 4ef754a18..483d5463a 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -14,19 +14,19 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter.can_do_action(None, key="test_id", _time_now_s=0) ) self.assertTrue(allowed) - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", _time_now_s=5) ) self.assertFalse(allowed) - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", _time_now_s=10) ) self.assertTrue(allowed) - self.assertEquals(20.0, time_allowed) + self.assertEqual(20.0, time_allowed) def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): appservice = ApplicationService( @@ -45,19 +45,19 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter.can_do_action(as_requester, _time_now_s=0) ) self.assertTrue(allowed) - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=5) ) self.assertFalse(allowed) - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=10) ) self.assertTrue(allowed) - self.assertEquals(20.0, time_allowed) + self.assertEqual(20.0, time_allowed) def test_allowed_appservice_via_can_requester_do_action(self): appservice = ApplicationService( @@ -76,19 +76,19 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter.can_do_action(as_requester, _time_now_s=0) ) self.assertTrue(allowed) - self.assertEquals(-1, time_allowed) + self.assertEqual(-1, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=5) ) self.assertTrue(allowed) - self.assertEquals(-1, time_allowed) + self.assertEqual(-1, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=10) ) self.assertTrue(allowed) - self.assertEquals(-1, time_allowed) + self.assertEqual(-1, time_allowed) def test_allowed_via_ratelimit(self): limiter = Ratelimiter( @@ -246,7 +246,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter.can_do_action(None, key="test_id", n_actions=3, _time_now_s=0) ) self.assertTrue(allowed) - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) # Test that, after doing these 3 actions, we can't do any more action without # waiting. @@ -254,7 +254,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter.can_do_action(None, key="test_id", n_actions=1, _time_now_s=0) ) self.assertFalse(allowed) - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) # Test that after waiting we can do only 1 action. allowed, time_allowed = self.get_success_or_raise( @@ -269,7 +269,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): self.assertTrue(allowed) # The time allowed is the current time because we could still repeat the action # once. - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", n_actions=2, _time_now_s=10) @@ -277,7 +277,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): self.assertFalse(allowed) # The time allowed doesn't change despite allowed being False because, while we # don't allow 2 actions, we could still do 1. - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) # Test that after waiting a bit more we can do 2 actions. allowed, time_allowed = self.get_success_or_raise( @@ -286,4 +286,4 @@ class TestRatelimiter(unittest.HomeserverTestCase): self.assertTrue(allowed) # The time allowed is the current time because we could still repeat the action # once. - self.assertEquals(20.0, time_allowed) + self.assertEqual(20.0, time_allowed) diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index b9dc4dfe1..1cbb05935 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -71,7 +71,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): one_time_key_counts={}, unused_fallback_keys={}, ) - self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made + self.assertEqual(0, len(self.txnctrl.recoverers)) # no recoverer made txn.complete.assert_called_once_with(self.store) # txn completed def test_single_service_down(self): @@ -97,8 +97,8 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): one_time_key_counts={}, unused_fallback_keys={}, ) - self.assertEquals(0, txn.send.call_count) # txn not sent though - self.assertEquals(0, txn.complete.call_count) # or completed + self.assertEqual(0, txn.send.call_count) # txn not sent though + self.assertEqual(0, txn.complete.call_count) # or completed def test_single_service_up_txn_not_sent(self): # Test: The AS is up and the txn is not sent. A Recoverer is made and @@ -125,10 +125,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): one_time_key_counts={}, unused_fallback_keys={}, ) - self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made - self.assertEquals(1, self.recoverer.recover.call_count) # and invoked - self.assertEquals(1, len(self.txnctrl.recoverers)) # and stored - self.assertEquals(0, txn.complete.call_count) # txn not completed + self.assertEqual(1, self.recoverer_fn.call_count) # recoverer made + self.assertEqual(1, self.recoverer.recover.call_count) # and invoked + self.assertEqual(1, len(self.txnctrl.recoverers)) # and stored + self.assertEqual(0, txn.complete.call_count) # txn not completed self.store.set_appservice_state.assert_called_once_with( service, ApplicationServiceState.DOWN # service marked as down ) @@ -161,17 +161,17 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.recoverer.recover() # shouldn't have called anything prior to waiting for exp backoff - self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count) + self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count) txn.send = simple_async_mock(True) txn.complete = simple_async_mock(None) # wait for exp backoff self.clock.advance_time(2) - self.assertEquals(1, txn.send.call_count) - self.assertEquals(1, txn.complete.call_count) + self.assertEqual(1, txn.send.call_count) + self.assertEqual(1, txn.complete.call_count) # 2 because it needs to get None to know there are no more txns - self.assertEquals(2, self.store.get_oldest_unsent_txn.call_count) + self.assertEqual(2, self.store.get_oldest_unsent_txn.call_count) self.callback.assert_called_once_with(self.recoverer) - self.assertEquals(self.recoverer.service, self.service) + self.assertEqual(self.recoverer.service, self.service) def test_recover_retry_txn(self): txn = Mock() @@ -187,26 +187,26 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.store.get_oldest_unsent_txn = Mock(side_effect=take_txn) self.recoverer.recover() - self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count) + self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count) txn.send = simple_async_mock(False) txn.complete = simple_async_mock(None) self.clock.advance_time(2) - self.assertEquals(1, txn.send.call_count) - self.assertEquals(0, txn.complete.call_count) - self.assertEquals(0, self.callback.call_count) + self.assertEqual(1, txn.send.call_count) + self.assertEqual(0, txn.complete.call_count) + self.assertEqual(0, self.callback.call_count) self.clock.advance_time(4) - self.assertEquals(2, txn.send.call_count) - self.assertEquals(0, txn.complete.call_count) - self.assertEquals(0, self.callback.call_count) + self.assertEqual(2, txn.send.call_count) + self.assertEqual(0, txn.complete.call_count) + self.assertEqual(0, self.callback.call_count) self.clock.advance_time(8) - self.assertEquals(3, txn.send.call_count) - self.assertEquals(0, txn.complete.call_count) - self.assertEquals(0, self.callback.call_count) + self.assertEqual(3, txn.send.call_count) + self.assertEqual(0, txn.complete.call_count) + self.assertEqual(0, self.callback.call_count) txn.send = simple_async_mock(True) # successfully send the txn pop_txn = True # returns the txn the first time, then no more. self.clock.advance_time(16) - self.assertEquals(1, txn.send.call_count) # new mock reset call count - self.assertEquals(1, txn.complete.call_count) + self.assertEqual(1, txn.send.call_count) # new mock reset call count + self.assertEqual(1, txn.complete.call_count) self.callback.assert_called_once_with(self.recoverer) @@ -241,13 +241,13 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): self.scheduler.enqueue_for_appservice(service, events=[event2]) self.scheduler.enqueue_for_appservice(service, events=[event3]) self.txn_ctrl.send.assert_called_with(service, [event], [], [], None, None) - self.assertEquals(1, self.txn_ctrl.send.call_count) + self.assertEqual(1, self.txn_ctrl.send.call_count) # Resolve the send event: expect the queued events to be sent d.callback(service) self.txn_ctrl.send.assert_called_with( service, [event2, event3], [], [], None, None ) - self.assertEquals(2, self.txn_ctrl.send.call_count) + self.assertEqual(2, self.txn_ctrl.send.call_count) def test_multiple_service_queues(self): # Tests that each service has its own queue, and that they don't block @@ -281,7 +281,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # service srv_2_defer.callback(srv2) self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], [], None, None) - self.assertEquals(3, self.txn_ctrl.send.call_count) + self.assertEqual(3, self.txn_ctrl.send.call_count) def test_send_large_txns(self): srv_1_defer = defer.Deferred() @@ -312,7 +312,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): self.txn_ctrl.send.assert_called_with( service, event_list[101:], [], [], None, None ) - self.assertEquals(3, self.txn_ctrl.send.call_count) + self.assertEqual(3, self.txn_ctrl.send.call_count) def test_send_single_ephemeral_no_queue(self): # Expect the event to be sent immediately. @@ -346,14 +346,14 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2) self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3) self.txn_ctrl.send.assert_called_with(service, [], event_list_1, [], None, None) - self.assertEquals(1, self.txn_ctrl.send.call_count) + self.assertEqual(1, self.txn_ctrl.send.call_count) # Resolve txn_ctrl.send d.callback(service) # Expect the queued events to be sent self.txn_ctrl.send.assert_called_with( service, [], event_list_2 + event_list_3, [], None, None ) - self.assertEquals(2, self.txn_ctrl.send.call_count) + self.assertEqual(2, self.txn_ctrl.send.call_count) def test_send_large_txns_ephemeral(self): d = defer.Deferred() @@ -369,4 +369,4 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): ) d.callback(service) self.txn_ctrl.send.assert_called_with(service, [], second_chunk, [], None, None) - self.assertEquals(2, self.txn_ctrl.send.call_count) + self.assertEqual(2, self.txn_ctrl.send.call_count) diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py index a72a0103d..694020fbe 100644 --- a/tests/crypto/test_event_signing.py +++ b/tests/crypto/test_event_signing.py @@ -63,14 +63,14 @@ class EventSigningTestCase(unittest.TestCase): self.assertTrue(hasattr(event, "hashes")) self.assertIn("sha256", event.hashes) - self.assertEquals( + self.assertEqual( event.hashes["sha256"], "6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI" ) self.assertTrue(hasattr(event, "signatures")) self.assertIn(HOSTNAME, event.signatures) self.assertIn(KEY_NAME, event.signatures["domain"]) - self.assertEquals( + self.assertEqual( event.signatures[HOSTNAME][KEY_NAME], "2Wptgo4CwmLo/Y8B8qinxApKaCkBG2fjTWB7AbP5Uy+" "aIbygsSdLOFzvdDjww8zUVKCmI02eP9xtyJxc/cLiBA", @@ -97,14 +97,14 @@ class EventSigningTestCase(unittest.TestCase): self.assertTrue(hasattr(event, "hashes")) self.assertIn("sha256", event.hashes) - self.assertEquals( + self.assertEqual( event.hashes["sha256"], "onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g" ) self.assertTrue(hasattr(event, "signatures")) self.assertIn(HOSTNAME, event.signatures) self.assertIn(KEY_NAME, event.signatures["domain"]) - self.assertEquals( + self.assertEqual( event.signatures[HOSTNAME][KEY_NAME], "Wm+VzmOUOz08Ds+0NTWb1d4CZrVsJSikkeRxh6aCcUw" "u6pNC78FunoD7KNWzqFn241eYHYMGCA5McEiVPdhzBA", diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 3a4d50271..d00ef24ca 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -76,7 +76,7 @@ class FakeRequest: @logcontext_clean class KeyringTestCase(unittest.HomeserverTestCase): def check_context(self, val, expected): - self.assertEquals(getattr(current_context(), "request", None), expected) + self.assertEqual(getattr(current_context(), "request", None), expected) return val def test_verify_json_objects_for_server_awaits_previous_requests(self): @@ -96,7 +96,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): async def first_lookup_fetch( server_name: str, key_ids: List[str], minimum_valid_until_ts: int ) -> Dict[str, FetchKeyResult]: - # self.assertEquals(current_context().request.id, "context_11") + # self.assertEqual(current_context().request.id, "context_11") self.assertEqual(server_name, "server10") self.assertEqual(key_ids, [get_key_id(key1)]) self.assertEqual(minimum_valid_until_ts, 0) @@ -137,7 +137,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): async def second_lookup_fetch( server_name: str, key_ids: List[str], minimum_valid_until_ts: int ) -> Dict[str, FetchKeyResult]: - # self.assertEquals(current_context().request.id, "context_12") + # self.assertEqual(current_context().request.id, "context_12") return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)} mock_fetcher.get_keys.reset_mock() diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index 1dea09e48..45e3395b3 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -395,7 +395,7 @@ class SerializeEventTestCase(unittest.TestCase): return serialize_event(ev, 1479807801915, only_event_fields=fields) def test_event_fields_works_with_keys(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent(sender="@alice:localhost", room_id="!foo:bar"), ["room_id"] ), @@ -403,7 +403,7 @@ class SerializeEventTestCase(unittest.TestCase): ) def test_event_fields_works_with_nested_keys(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent( sender="@alice:localhost", @@ -416,7 +416,7 @@ class SerializeEventTestCase(unittest.TestCase): ) def test_event_fields_works_with_dot_keys(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent( sender="@alice:localhost", @@ -429,7 +429,7 @@ class SerializeEventTestCase(unittest.TestCase): ) def test_event_fields_works_with_nested_dot_keys(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent( sender="@alice:localhost", @@ -445,7 +445,7 @@ class SerializeEventTestCase(unittest.TestCase): ) def test_event_fields_nops_with_unknown_keys(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent( sender="@alice:localhost", @@ -458,7 +458,7 @@ class SerializeEventTestCase(unittest.TestCase): ) def test_event_fields_nops_with_non_dict_keys(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent( sender="@alice:localhost", @@ -471,7 +471,7 @@ class SerializeEventTestCase(unittest.TestCase): ) def test_event_fields_nops_with_array_keys(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent( sender="@alice:localhost", @@ -484,7 +484,7 @@ class SerializeEventTestCase(unittest.TestCase): ) def test_event_fields_all_fields_if_empty(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent( type="foo", diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 9336181c9..9f1115dd2 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -50,7 +50,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): channel = self.make_signed_federation_request( "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) complexity = channel.json_body["v1"] self.assertTrue(complexity > 0, complexity) @@ -62,7 +62,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): channel = self.make_signed_federation_request( "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) complexity = channel.json_body["v1"] self.assertEqual(complexity, 1.23) diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index d084919ef..30e7e5093 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -59,7 +59,7 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase): "/_matrix/federation/v1/get_missing_events/%s" % (room_1,), query_content, ) - self.assertEquals(400, channel.code, channel.result) + self.assertEqual(400, channel.code, channel.result) self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON") @@ -125,7 +125,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase): channel = self.make_signed_federation_request( "GET", "/_matrix/federation/v1/state/%s" % (room_1,) ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertEqual( channel.json_body["room_version"], @@ -157,7 +157,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase): channel = self.make_signed_federation_request( "GET", "/_matrix/federation/v1/state/%s" % (room_1,) ) - self.assertEquals(403, channel.code, channel.result) + self.assertEqual(403, channel.code, channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -189,7 +189,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): f"/_matrix/federation/v1/make_join/{self._room_id}/{user_id}" f"?ver={DEFAULT_ROOM_VERSION}", ) - self.assertEquals(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, 200, channel.json_body) return channel.json_body def test_send_join(self): @@ -209,7 +209,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): f"/_matrix/federation/v2/send_join/{self._room_id}/x", content=join_event_dict, ) - self.assertEquals(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, 200, channel.json_body) # we should get complete room state back returned_state = [ @@ -266,7 +266,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true", content=join_event_dict, ) - self.assertEquals(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, 200, channel.json_body) # expect a reduced room state returned_state = [ diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index adf0535d9..648a01618 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -169,7 +169,7 @@ class KnockingStrippedStateEventHelperMixin(TestCase): self.assertIn(event_type, expected_room_state) # Check the state content matches - self.assertEquals( + self.assertEqual( expected_room_state[event_type]["content"], event["content"] ) @@ -256,7 +256,7 @@ class FederationKnockingTestCase( RoomVersions.V7.identifier, ), ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # Note: We don't expect the knock membership event to be sent over federation as # part of the stripped room state, as the knocking homeserver already has that @@ -266,11 +266,11 @@ class FederationKnockingTestCase( knock_event = channel.json_body["event"] # Check that the event has things we expect in it - self.assertEquals(knock_event["room_id"], room_id) - self.assertEquals(knock_event["sender"], fake_knocking_user_id) - self.assertEquals(knock_event["state_key"], fake_knocking_user_id) - self.assertEquals(knock_event["type"], EventTypes.Member) - self.assertEquals(knock_event["content"]["membership"], Membership.KNOCK) + self.assertEqual(knock_event["room_id"], room_id) + self.assertEqual(knock_event["sender"], fake_knocking_user_id) + self.assertEqual(knock_event["state_key"], fake_knocking_user_id) + self.assertEqual(knock_event["type"], EventTypes.Member) + self.assertEqual(knock_event["content"]["membership"], Membership.KNOCK) # Turn the event json dict into a proper event. # We won't sign it properly, but that's OK as we stub out event auth in `prepare` @@ -294,7 +294,7 @@ class FederationKnockingTestCase( % (room_id, signed_knock_event.event_id), signed_knock_event_json, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # Check that we got the stripped room state in return room_state_events = channel.json_body["knock_state_events"] diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py index eb62addda..ce49d094d 100644 --- a/tests/federation/transport/test_server.py +++ b/tests/federation/transport/test_server.py @@ -26,7 +26,7 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase): "GET", "/_matrix/federation/v1/publicRooms", ) - self.assertEquals(403, channel.code) + self.assertEqual(403, channel.code) @override_config({"allow_public_rooms_over_federation": True}) def test_open_public_room_list_over_federation(self): @@ -37,4 +37,4 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase): "GET", "/_matrix/federation/v1/publicRooms", ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 6e0ec3796..072e6bbcd 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -147,8 +147,8 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_as_api.query_alias.assert_called_once_with( interested_service, room_alias_str ) - self.assertEquals(result.room_id, room_id) - self.assertEquals(result.servers, servers) + self.assertEqual(result.room_id, room_id) + self.assertEqual(result.servers, servers) def test_get_3pe_protocols_no_appservices(self): self.mock_store.get_app_services.return_value = [] @@ -156,7 +156,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol")) ) self.mock_as_api.get_3pe_protocol.assert_not_called() - self.assertEquals(response, {}) + self.assertEqual(response, {}) def test_get_3pe_protocols_no_protocols(self): service = self._mkservice(False, []) @@ -165,7 +165,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): defer.ensureDeferred(self.handler.get_3pe_protocols()) ) self.mock_as_api.get_3pe_protocol.assert_not_called() - self.assertEquals(response, {}) + self.assertEqual(response, {}) def test_get_3pe_protocols_protocol_no_response(self): service = self._mkservice(False, ["my-protocol"]) @@ -177,7 +177,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_as_api.get_3pe_protocol.assert_called_once_with( service, "my-protocol" ) - self.assertEquals(response, {}) + self.assertEqual(response, {}) def test_get_3pe_protocols_select_one_protocol(self): service = self._mkservice(False, ["my-protocol"]) @@ -191,7 +191,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_as_api.get_3pe_protocol.assert_called_once_with( service, "my-protocol" ) - self.assertEquals( + self.assertEqual( response, {"my-protocol": {"x-protocol-data": 42, "instances": []}} ) @@ -207,7 +207,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_as_api.get_3pe_protocol.assert_called_once_with( service, "my-protocol" ) - self.assertEquals( + self.assertEqual( response, {"my-protocol": {"x-protocol-data": 42, "instances": []}} ) @@ -222,7 +222,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): defer.ensureDeferred(self.handler.get_3pe_protocols()) ) self.mock_as_api.get_3pe_protocol.assert_called() - self.assertEquals( + self.assertEqual( response, { "my-protocol": {"x-protocol-data": 42, "instances": []}, @@ -254,7 +254,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): defer.ensureDeferred(self.handler.get_3pe_protocols()) ) # It's expected that the second service's data doesn't appear in the response - self.assertEquals( + self.assertEqual( response, { "my-protocol": { diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 65ab107d0..6e403a87c 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -63,7 +63,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): result = self.get_success(self.handler.get_association(self.my_room)) - self.assertEquals({"room_id": "!8765qwer:test", "servers": ["test"]}, result) + self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result) def test_get_remote_association(self): self.mock_federation.make_query.return_value = make_awaitable( @@ -72,7 +72,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): result = self.get_success(self.handler.get_association(self.remote_room)) - self.assertEquals( + self.assertEqual( {"room_id": "!8765qwer:test", "servers": ["test", "remote"]}, result ) self.mock_federation.make_query.assert_called_with( @@ -94,7 +94,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.handler.on_directory_query({"room_alias": "#your-room:test"}) ) - self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response) + self.assertEqual({"room_id": "!8765asdf:test", "servers": ["test"]}, response) class TestCreateAlias(unittest.HomeserverTestCase): @@ -224,7 +224,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): create_requester(self.test_user), self.room_alias ) ) - self.assertEquals(self.room_id, result) + self.assertEqual(self.room_id, result) # Confirm the alias is gone. self.get_failure( @@ -243,7 +243,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): create_requester(self.admin_user), self.room_alias ) ) - self.assertEquals(self.room_id, result) + self.assertEqual(self.room_id, result) # Confirm the alias is gone. self.get_failure( @@ -269,7 +269,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): create_requester(self.test_user), self.room_alias ) ) - self.assertEquals(self.room_id, result) + self.assertEqual(self.room_id, result) # Confirm the alias is gone. self.get_failure( @@ -411,7 +411,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): b"directory/room/%23test%3Atest", {"room_id": room_id}, ) - self.assertEquals(403, channel.code, channel.result) + self.assertEqual(403, channel.code, channel.result) def test_allowed(self): room_id = self.helper.create_room_as(self.user_id) @@ -421,7 +421,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): b"directory/room/%23unofficial_test%3Atest", {"room_id": room_id}, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) def test_denied_during_creation(self): """A room alias that is not allowed should be rejected during creation.""" @@ -443,8 +443,8 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): "GET", b"directory/room/%23unofficial_test%3Atest", ) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(channel.json_body["room_id"], room_id) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(channel.json_body["room_id"], room_id) class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): @@ -572,7 +572,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): channel = self.make_request( "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}" ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.room_list_handler = hs.get_room_list_handler() self.directory_handler = hs.get_directory_handler() @@ -585,7 +585,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): # Room list is enabled so we should get some results channel = self.make_request("GET", b"publicRooms") - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["chunk"]) > 0) self.room_list_handler.enable_room_list_search = False @@ -593,7 +593,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): # Room list disabled so we should get no results channel = self.make_request("GET", b"publicRooms") - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["chunk"]) == 0) # Room list disabled so we shouldn't be allowed to publish rooms @@ -601,4 +601,4 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): channel = self.make_request( "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}" ) - self.assertEquals(403, channel.code, channel.result) + self.assertEqual(403, channel.code, channel.result) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 61d28603a..6ddec9ecf 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -61,11 +61,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): self.assertTrue(persist_and_notify) self.assertTrue(state.currently_active) - self.assertEquals(new_state.state, state.state) - self.assertEquals(new_state.status_msg, state.status_msg) - self.assertEquals(state.last_federation_update_ts, now) + self.assertEqual(new_state.state, state.state) + self.assertEqual(new_state.status_msg, state.status_msg) + self.assertEqual(state.last_federation_update_ts, now) - self.assertEquals(wheel_timer.insert.call_count, 3) + self.assertEqual(wheel_timer.insert.call_count, 3) wheel_timer.insert.assert_has_calls( [ call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER), @@ -104,11 +104,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): self.assertFalse(persist_and_notify) self.assertTrue(federation_ping) self.assertTrue(state.currently_active) - self.assertEquals(new_state.state, state.state) - self.assertEquals(new_state.status_msg, state.status_msg) - self.assertEquals(state.last_federation_update_ts, now) + self.assertEqual(new_state.state, state.state) + self.assertEqual(new_state.status_msg, state.status_msg) + self.assertEqual(state.last_federation_update_ts, now) - self.assertEquals(wheel_timer.insert.call_count, 3) + self.assertEqual(wheel_timer.insert.call_count, 3) wheel_timer.insert.assert_has_calls( [ call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER), @@ -149,11 +149,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): self.assertFalse(persist_and_notify) self.assertTrue(federation_ping) self.assertTrue(state.currently_active) - self.assertEquals(new_state.state, state.state) - self.assertEquals(new_state.status_msg, state.status_msg) - self.assertEquals(state.last_federation_update_ts, now) + self.assertEqual(new_state.state, state.state) + self.assertEqual(new_state.status_msg, state.status_msg) + self.assertEqual(state.last_federation_update_ts, now) - self.assertEquals(wheel_timer.insert.call_count, 3) + self.assertEqual(wheel_timer.insert.call_count, 3) wheel_timer.insert.assert_has_calls( [ call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER), @@ -191,11 +191,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): self.assertTrue(persist_and_notify) self.assertFalse(state.currently_active) - self.assertEquals(new_state.state, state.state) - self.assertEquals(new_state.status_msg, state.status_msg) - self.assertEquals(state.last_federation_update_ts, now) + self.assertEqual(new_state.state, state.state) + self.assertEqual(new_state.status_msg, state.status_msg) + self.assertEqual(state.last_federation_update_ts, now) - self.assertEquals(wheel_timer.insert.call_count, 2) + self.assertEqual(wheel_timer.insert.call_count, 2) wheel_timer.insert.assert_has_calls( [ call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER), @@ -227,10 +227,10 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): self.assertFalse(persist_and_notify) self.assertFalse(federation_ping) self.assertFalse(state.currently_active) - self.assertEquals(new_state.state, state.state) - self.assertEquals(new_state.status_msg, state.status_msg) + self.assertEqual(new_state.state, state.state) + self.assertEqual(new_state.status_msg, state.status_msg) - self.assertEquals(wheel_timer.insert.call_count, 1) + self.assertEqual(wheel_timer.insert.call_count, 1) wheel_timer.insert.assert_has_calls( [ call( @@ -259,10 +259,10 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): ) self.assertTrue(persist_and_notify) - self.assertEquals(new_state.state, state.state) - self.assertEquals(state.last_federation_update_ts, now) + self.assertEqual(new_state.state, state.state) + self.assertEqual(state.last_federation_update_ts, now) - self.assertEquals(wheel_timer.insert.call_count, 0) + self.assertEqual(wheel_timer.insert.call_count, 0) def test_online_to_idle(self): wheel_timer = Mock() @@ -281,12 +281,12 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): ) self.assertTrue(persist_and_notify) - self.assertEquals(new_state.state, state.state) - self.assertEquals(state.last_federation_update_ts, now) - self.assertEquals(new_state.state, state.state) - self.assertEquals(new_state.status_msg, state.status_msg) + self.assertEqual(new_state.state, state.state) + self.assertEqual(state.last_federation_update_ts, now) + self.assertEqual(new_state.state, state.state) + self.assertEqual(new_state.status_msg, state.status_msg) - self.assertEquals(wheel_timer.insert.call_count, 1) + self.assertEqual(wheel_timer.insert.call_count, 1) wheel_timer.insert.assert_has_calls( [ call( @@ -357,8 +357,8 @@ class PresenceTimeoutTestCase(unittest.TestCase): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) - self.assertEquals(new_state.state, PresenceState.UNAVAILABLE) - self.assertEquals(new_state.status_msg, status_msg) + self.assertEqual(new_state.state, PresenceState.UNAVAILABLE) + self.assertEqual(new_state.status_msg, status_msg) def test_busy_no_idle(self): """ @@ -380,8 +380,8 @@ class PresenceTimeoutTestCase(unittest.TestCase): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) - self.assertEquals(new_state.state, PresenceState.BUSY) - self.assertEquals(new_state.status_msg, status_msg) + self.assertEqual(new_state.state, PresenceState.BUSY) + self.assertEqual(new_state.status_msg, status_msg) def test_sync_timeout(self): user_id = "@foo:bar" @@ -399,8 +399,8 @@ class PresenceTimeoutTestCase(unittest.TestCase): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) - self.assertEquals(new_state.state, PresenceState.OFFLINE) - self.assertEquals(new_state.status_msg, status_msg) + self.assertEqual(new_state.state, PresenceState.OFFLINE) + self.assertEqual(new_state.status_msg, status_msg) def test_sync_online(self): user_id = "@foo:bar" @@ -420,8 +420,8 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) self.assertIsNotNone(new_state) - self.assertEquals(new_state.state, PresenceState.ONLINE) - self.assertEquals(new_state.status_msg, status_msg) + self.assertEqual(new_state.state, PresenceState.ONLINE) + self.assertEqual(new_state.status_msg, status_msg) def test_federation_ping(self): user_id = "@foo:bar" @@ -440,7 +440,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) - self.assertEquals(state, new_state) + self.assertEqual(state, new_state) def test_no_timeout(self): user_id = "@foo:bar" @@ -477,8 +477,8 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) self.assertIsNotNone(new_state) - self.assertEquals(new_state.state, PresenceState.OFFLINE) - self.assertEquals(new_state.status_msg, status_msg) + self.assertEqual(new_state.state, PresenceState.OFFLINE) + self.assertEqual(new_state.status_msg, status_msg) def test_last_active(self): user_id = "@foo:bar" @@ -497,7 +497,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) - self.assertEquals(state, new_state) + self.assertEqual(state, new_state) class PresenceHandlerTestCase(unittest.HomeserverTestCase): diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 69e299fc1..972cbac6e 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -65,7 +65,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): displayname = self.get_success(self.handler.get_displayname(self.frank)) - self.assertEquals("Frank", displayname) + self.assertEqual("Frank", displayname) def test_set_my_name(self): self.get_success( @@ -74,7 +74,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) ) - self.assertEquals( + self.assertEqual( ( self.get_success( self.store.get_profile_displayname(self.frank.localpart) @@ -90,7 +90,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) ) - self.assertEquals( + self.assertEqual( ( self.get_success( self.store.get_profile_displayname(self.frank.localpart) @@ -118,7 +118,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.store.set_profile_displayname(self.frank.localpart, "Frank") ) - self.assertEquals( + self.assertEqual( ( self.get_success( self.store.get_profile_displayname(self.frank.localpart) @@ -150,7 +150,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): displayname = self.get_success(self.handler.get_displayname(self.alice)) - self.assertEquals(displayname, "Alice") + self.assertEqual(displayname, "Alice") self.mock_federation.make_query.assert_called_with( destination="remote", query_type="profile", @@ -172,7 +172,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) ) - self.assertEquals({"displayname": "Caroline"}, response) + self.assertEqual({"displayname": "Caroline"}, response) def test_get_my_avatar(self): self.get_success( @@ -182,7 +182,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) avatar_url = self.get_success(self.handler.get_avatar_url(self.frank)) - self.assertEquals("http://my.server/me.png", avatar_url) + self.assertEqual("http://my.server/me.png", avatar_url) def test_set_my_avatar(self): self.get_success( @@ -193,7 +193,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) ) - self.assertEquals( + self.assertEqual( (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), "http://my.server/pic.gif", ) @@ -207,7 +207,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) ) - self.assertEquals( + self.assertEqual( (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), "http://my.server/me.png", ) @@ -235,7 +235,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) ) - self.assertEquals( + self.assertEqual( (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), "http://my.server/me.png", ) diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py index 5de89c873..5081b9757 100644 --- a/tests/handlers/test_receipts.py +++ b/tests/handlers/test_receipts.py @@ -314,4 +314,4 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ): """Tests that the _filter_out_hidden returns the expected output""" filtered_events = self.event_source.filter_out_hidden(events, "@me:server.org") - self.assertEquals(filtered_events, expected_output) + self.assertEqual(filtered_events, expected_output) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 51ee667ab..45fd30cf4 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -167,7 +167,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): result_user_id, result_token = self.get_success( self.get_or_create_user(requester, frank.localpart, "Frankie") ) - self.assertEquals(result_user_id, user_id) + self.assertEqual(result_user_id, user_id) self.assertIsInstance(result_token, str) self.assertGreater(len(result_token), 20) @@ -183,7 +183,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): result_user_id, result_token = self.get_success( self.get_or_create_user(requester, local_part, None) ) - self.assertEquals(result_user_id, user_id) + self.assertEqual(result_user_id, user_id) self.assertTrue(result_token is not None) @override_config({"limit_usage_by_mau": False}) diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 66b0bd4d1..3aedc0767 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -69,7 +69,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): self.sync_handler.wait_for_sync_for_user(requester, sync_config), ResourceLimitError, ) - self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.auth_blocking._hs_disabled = False @@ -80,7 +80,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): self.sync_handler.wait_for_sync_for_user(requester, sync_config), ResourceLimitError, ) - self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) def test_unknown_room_version(self): """ @@ -122,7 +122,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): b"{}", tok, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # The rooms should appear in the sync response. result = self.get_success( diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index e461e0359..f91a80b9f 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -156,7 +156,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): def test_started_typing_local(self): self.room_members = [U_APPLE, U_BANANA] - self.assertEquals(self.event_source.get_current_key(), 0) + self.assertEqual(self.event_source.get_current_key(), 0) self.get_success( self.handler.started_typing( @@ -169,13 +169,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) - self.assertEquals(self.event_source.get_current_key(), 1) + self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( self.event_source.get_new_events( user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False ) ) - self.assertEquals( + self.assertEqual( events[0], [ { @@ -220,7 +220,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): def test_started_typing_remote_recv(self): self.room_members = [U_APPLE, U_ONION] - self.assertEquals(self.event_source.get_current_key(), 0) + self.assertEqual(self.event_source.get_current_key(), 0) channel = self.make_request( "PUT", @@ -239,13 +239,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) - self.assertEquals(self.event_source.get_current_key(), 1) + self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( self.event_source.get_new_events( user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False ) ) - self.assertEquals( + self.assertEqual( events[0], [ { @@ -259,7 +259,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): def test_started_typing_remote_recv_not_in_room(self): self.room_members = [U_APPLE, U_ONION] - self.assertEquals(self.event_source.get_current_key(), 0) + self.assertEqual(self.event_source.get_current_key(), 0) channel = self.make_request( "PUT", @@ -278,7 +278,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_not_called() - self.assertEquals(self.event_source.get_current_key(), 0) + self.assertEqual(self.event_source.get_current_key(), 0) events = self.get_success( self.event_source.get_new_events( user=U_APPLE, @@ -288,8 +288,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): is_guest=False, ) ) - self.assertEquals(events[0], []) - self.assertEquals(events[1], 0) + self.assertEqual(events[0], []) + self.assertEqual(events[1], 0) @override_config({"send_federation": True}) def test_stopped_typing(self): @@ -302,7 +302,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.handler._member_typing_until[member] = 1002000 self.handler._room_typing[ROOM_ID] = {U_APPLE.to_string()} - self.assertEquals(self.event_source.get_current_key(), 0) + self.assertEqual(self.event_source.get_current_key(), 0) self.get_success( self.handler.stopped_typing( @@ -332,13 +332,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): try_trailing_slash_on_400=True, ) - self.assertEquals(self.event_source.get_current_key(), 1) + self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( self.event_source.get_new_events( user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False ) ) - self.assertEquals( + self.assertEqual( events[0], [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], ) @@ -346,7 +346,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): def test_typing_timeout(self): self.room_members = [U_APPLE, U_BANANA] - self.assertEquals(self.event_source.get_current_key(), 0) + self.assertEqual(self.event_source.get_current_key(), 0) self.get_success( self.handler.started_typing( @@ -360,7 +360,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) self.on_new_event.reset_mock() - self.assertEquals(self.event_source.get_current_key(), 1) + self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( self.event_source.get_new_events( user=U_APPLE, @@ -370,7 +370,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): is_guest=False, ) ) - self.assertEquals( + self.assertEqual( events[0], [ { @@ -385,7 +385,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 2, rooms=[ROOM_ID])]) - self.assertEquals(self.event_source.get_current_key(), 2) + self.assertEqual(self.event_source.get_current_key(), 2) events = self.get_success( self.event_source.get_new_events( user=U_APPLE, @@ -395,7 +395,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): is_guest=False, ) ) - self.assertEquals( + self.assertEqual( events[0], [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], ) @@ -414,7 +414,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 3, rooms=[ROOM_ID])]) self.on_new_event.reset_mock() - self.assertEquals(self.event_source.get_current_key(), 3) + self.assertEqual(self.event_source.get_current_key(), 3) events = self.get_success( self.event_source.get_new_events( user=U_APPLE, @@ -424,7 +424,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): is_guest=False, ) ) - self.assertEquals( + self.assertEqual( events[0], [ { diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index e159169e2..92012cd6f 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -1042,7 +1042,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): b'{"search_term":"user2"}', access_token=u1_token, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["results"]) > 0) # Disable user directory and check search returns nothing @@ -1053,5 +1053,5 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): b'{"search_term":"user2"}', access_token=u1_token, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["results"]) == 0) diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index c49be33b9..77ce8432a 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -65,9 +65,9 @@ class SrvResolverTestCase(unittest.TestCase): servers = self.successResultOf(test_d) - self.assertEquals(len(servers), 1) - self.assertEquals(servers, cache[service_name]) - self.assertEquals(servers[0].host, host_name) + self.assertEqual(len(servers), 1) + self.assertEqual(servers, cache[service_name]) + self.assertEqual(servers[0].host, host_name) @defer.inlineCallbacks def test_from_cache_expired_and_dns_fail(self): @@ -88,8 +88,8 @@ class SrvResolverTestCase(unittest.TestCase): dns_client_mock.lookupService.assert_called_once_with(service_name) - self.assertEquals(len(servers), 1) - self.assertEquals(servers, cache[service_name]) + self.assertEqual(len(servers), 1) + self.assertEqual(servers, cache[service_name]) @defer.inlineCallbacks def test_from_cache(self): @@ -114,8 +114,8 @@ class SrvResolverTestCase(unittest.TestCase): self.assertFalse(dns_client_mock.lookupService.called) - self.assertEquals(len(servers), 1) - self.assertEquals(servers, cache[service_name]) + self.assertEqual(len(servers), 1) + self.assertEqual(servers, cache[service_name]) @defer.inlineCallbacks def test_empty_cache(self): @@ -144,8 +144,8 @@ class SrvResolverTestCase(unittest.TestCase): servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) - self.assertEquals(len(servers), 0) - self.assertEquals(len(cache), 0) + self.assertEqual(len(servers), 0) + self.assertEqual(len(cache), 0) def test_disabled_service(self): """ @@ -201,6 +201,6 @@ class SrvResolverTestCase(unittest.TestCase): servers = self.successResultOf(resolve_d) - self.assertEquals(len(servers), 1) - self.assertEquals(servers, cache[service_name]) - self.assertEquals(servers[0].host, b"host") + self.assertEqual(len(servers), 1) + self.assertEqual(servers, cache[service_name]) + self.assertEqual(servers[0].host, b"host") diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index eca6a443a..17dc42fd3 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -59,7 +59,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): def setUp(self): # Patch up the equality operator for events so that we can check - # whether lists of events match using assertEquals + # whether lists of events match using assertEqual self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)] return super().setUp() diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 09c48e85c..95282f078 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1909,7 +1909,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) def test_join_private_room_if_not_member(self) -> None: @@ -1957,7 +1957,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok, ) - self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) # Join user to room. @@ -1980,7 +1980,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) def test_join_private_room_if_owner(self) -> None: @@ -2010,7 +2010,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) def test_context_as_non_admin(self) -> None: @@ -2044,7 +2044,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): % (room_id, events[midway]["event_id"]), access_token=tok, ) - self.assertEquals(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_context_as_admin(self) -> None: @@ -2074,8 +2074,8 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): % (room_id, events[midway]["event_id"]), access_token=self.admin_user_tok, ) - self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) - self.assertEquals( + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual( channel.json_body["event"]["event_id"], events[midway]["event_id"] ) diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index 008d635b7..6c4462e74 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -104,7 +104,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): client_secret = "foobar" session_id = self._request_token(email, client_secret) - self.assertEquals(len(self.email_attempts), 1) + self.assertEqual(len(self.email_attempts), 1) link = self._get_link_from_email() self._validate_token(link) @@ -143,7 +143,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): client_secret = "foobar" session_id = self._request_token(email, client_secret, ip) - self.assertEquals(len(self.email_attempts), 1) + self.assertEqual(len(self.email_attempts), 1) link = self._get_link_from_email() self._validate_token(link) @@ -193,7 +193,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): client_secret = "foobar" session_id = self._request_token(email_passwort_reset, client_secret) - self.assertEquals(len(self.email_attempts), 1) + self.assertEqual(len(self.email_attempts), 1) link = self._get_link_from_email() self._validate_token(link) @@ -230,7 +230,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): client_secret = "foobar" session_id = self._request_token(email, client_secret) - self.assertEquals(len(self.email_attempts), 1) + self.assertEqual(len(self.email_attempts), 1) # Attempt to reset password without clicking the link self._reset_password(new_password, session_id, client_secret, expected_code=401) @@ -322,7 +322,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): shorthand=False, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # Now POST to the same endpoint, mimicking the same behaviour as clicking the # password reset confirm button @@ -337,7 +337,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): shorthand=False, content_is_form=True, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) def _get_link_from_email(self): assert self.email_attempts, "No emails have been sent" @@ -376,7 +376,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): }, }, ) - self.assertEquals(expected_code, channel.code, channel.result) + self.assertEqual(expected_code, channel.code, channel.result) class DeactivateTestCase(unittest.HomeserverTestCase): @@ -676,7 +676,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): client_secret = "foobar" session_id = self._request_token(self.email, client_secret) - self.assertEquals(len(self.email_attempts), 1) + self.assertEqual(len(self.email_attempts), 1) link = self._get_link_from_email() self._validate_token(link) @@ -780,7 +780,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): client_secret = "foobar" session_id = self._request_token(self.email, client_secret) - self.assertEquals(len(self.email_attempts), 1) + self.assertEqual(len(self.email_attempts), 1) # Attempt to add email without clicking the link channel = self.make_request( @@ -981,7 +981,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): path = link.replace("https://example.com", "") channel = self.make_request("GET", path, shorthand=False) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) def _get_link_from_email(self): assert self.email_attempts, "No emails have been sent" @@ -1010,7 +1010,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): client_secret = "foobar" session_id = self._request_token(request_email, client_secret) - self.assertEquals(len(self.email_attempts) - previous_email_attempts, 1) + self.assertEqual(len(self.email_attempts) - previous_email_attempts, 1) link = self._get_link_from_email() self._validate_token(link) diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py index a90294003..145f24783 100644 --- a/tests/rest/client/test_events.py +++ b/tests/rest/client/test_events.py @@ -65,13 +65,13 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", "/events?access_token=%s" % ("invalid" + self.token,) ) - self.assertEquals(channel.code, 401, msg=channel.result) + self.assertEqual(channel.code, 401, msg=channel.result) # valid token, expect content channel = self.make_request( "GET", "/events?access_token=%s&timeout=0" % (self.token,) ) - self.assertEquals(channel.code, 200, msg=channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertTrue("chunk" in channel.json_body) self.assertTrue("start" in channel.json_body) self.assertTrue("end" in channel.json_body) @@ -89,10 +89,10 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", "/events?access_token=%s&timeout=0" % (self.token,) ) - self.assertEquals(channel.code, 200, msg=channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) # We may get a presence event for ourselves down - self.assertEquals( + self.assertEqual( 0, len( [ @@ -153,4 +153,4 @@ class GetEventsTestCase(unittest.HomeserverTestCase): "/events/" + event_id, access_token=self.token, ) - self.assertEquals(channel.code, 200, msg=channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py index a573cc3c2..5c31a5442 100644 --- a/tests/rest/client/test_filter.py +++ b/tests/rest/client/test_filter.py @@ -45,7 +45,7 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body, {"filter_id": "0"}) filter = self.store.get_user_filter(user_localpart="apple", filter_id=0) self.pump() - self.assertEquals(filter.result, self.EXAMPLE_FILTER) + self.assertEqual(filter.result, self.EXAMPLE_FILTER) def test_add_filter_for_other_user(self): channel = self.make_request( @@ -55,7 +55,7 @@ class FilterTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.result["code"], b"403") - self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) + self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) def test_add_filter_non_local_user(self): _is_mine = self.hs.is_mine @@ -68,7 +68,7 @@ class FilterTestCase(unittest.HomeserverTestCase): self.hs.is_mine = _is_mine self.assertEqual(channel.result["code"], b"403") - self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) + self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) def test_get_filter(self): filter_id = defer.ensureDeferred( @@ -83,7 +83,7 @@ class FilterTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.result["code"], b"200") - self.assertEquals(channel.json_body, self.EXAMPLE_FILTER) + self.assertEqual(channel.json_body, self.EXAMPLE_FILTER) def test_get_filter_non_existant(self): channel = self.make_request( @@ -91,7 +91,7 @@ class FilterTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.result["code"], b"404") - self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) # Currently invalid params do not have an appropriate errcode # in errors.py diff --git a/tests/rest/client/test_groups.py b/tests/rest/client/test_groups.py index ad0425ae6..c99f54cf4 100644 --- a/tests/rest/client/test_groups.py +++ b/tests/rest/client/test_groups.py @@ -30,8 +30,8 @@ class GroupsTestCase(unittest.HomeserverTestCase): # Alice creates a group channel = self.make_request("POST", "/create_group", {"localpart": "spqr"}) - self.assertEquals(channel.code, 200, msg=channel.text_body) - self.assertEquals(channel.json_body, {"group_id": group_id}) + self.assertEqual(channel.code, 200, msg=channel.text_body) + self.assertEqual(channel.json_body, {"group_id": group_id}) # Bob creates a private room room_id = self.helper.create_room_as(self.room_creator_user_id, is_public=False) @@ -45,12 +45,12 @@ class GroupsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", f"/groups/{group_id}/admin/rooms/{room_id}", {} ) - self.assertEquals(channel.code, 200, msg=channel.text_body) - self.assertEquals(channel.json_body, {}) + self.assertEqual(channel.code, 200, msg=channel.text_body) + self.assertEqual(channel.json_body, {}) # Alice now tries to retrieve the room list of the space. channel = self.make_request("GET", f"/groups/{group_id}/rooms") - self.assertEquals(channel.code, 200, msg=channel.text_body) - self.assertEquals( + self.assertEqual(channel.code, 200, msg=channel.text_body) + self.assertEqual( channel.json_body, {"chunk": [], "total_room_count_estimate": 0} ) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index d48defda6..090d2d0a2 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -136,10 +136,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.result["code"], b"429", channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. @@ -154,7 +154,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) @override_config( { @@ -181,10 +181,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.result["code"], b"429", channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. @@ -199,7 +199,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) @override_config( { @@ -226,10 +226,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.result["code"], b"429", channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.result["code"], b"403", channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. @@ -244,7 +244,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.result["code"], b"403", channel.result) @override_config({"session_lifetime": "24h"}) def test_soft_logout(self) -> None: @@ -252,8 +252,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # we shouldn't be able to make requests without an access token channel = self.make_request(b"GET", TEST_URL) - self.assertEquals(channel.result["code"], b"401", channel.result) - self.assertEquals(channel.json_body["errcode"], "M_MISSING_TOKEN") + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN") # log in as normal params = { @@ -263,22 +263,22 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEquals(channel.code, 200, channel.result) + self.assertEqual(channel.code, 200, channel.result) access_token = channel.json_body["access_token"] device_id = channel.json_body["device_id"] # we should now be able to make requests with the access token channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 200, channel.result) + self.assertEqual(channel.code, 200, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], True) + self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEqual(channel.json_body["soft_logout"], True) # # test behaviour after deleting the expired device @@ -290,17 +290,17 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # more requests with the expired token should still return a soft-logout self.reactor.advance(3600) channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], True) + self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEqual(channel.json_body["soft_logout"], True) # ... but if we delete that device, it will be a proper logout self._delete_device(access_token_2, "kermit", "monkey", device_id) channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], False) + self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEqual(channel.json_body["soft_logout"], False) def _delete_device( self, access_token: str, user_id: str, password: str, device_id: str @@ -309,7 +309,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request( b"DELETE", "devices/" + device_id, access_token=access_token ) - self.assertEquals(channel.code, 401, channel.result) + self.assertEqual(channel.code, 401, channel.result) # check it's a UI-Auth fail self.assertEqual( set(channel.json_body.keys()), @@ -332,7 +332,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): access_token=access_token, content={"auth": auth}, ) - self.assertEquals(channel.code, 200, channel.result) + self.assertEqual(channel.code, 200, channel.result) @override_config({"session_lifetime": "24h"}) def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None: @@ -343,20 +343,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # we should now be able to make requests with the access token channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 200, channel.result) + self.assertEqual(channel.code, 200, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], True) + self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEqual(channel.json_body["soft_logout"], True) # Now try to hard logout this session channel = self.make_request(b"POST", "/logout", access_token=access_token) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) @override_config({"session_lifetime": "24h"}) def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out( @@ -369,20 +369,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # we should now be able to make requests with the access token channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 200, channel.result) + self.assertEqual(channel.code, 200, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], True) + self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEqual(channel.json_body["soft_logout"], True) # Now try to hard log out all of the user's sessions channel = self.make_request(b"POST", "/logout/all", access_token=access_token) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) @skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC") @@ -1129,7 +1129,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) def test_login_appservice_user_bot(self) -> None: """Test that the appservice bot can use /login""" @@ -1143,7 +1143,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) def test_login_appservice_wrong_user(self) -> None: """Test that non-as users cannot login with the as token""" @@ -1157,7 +1157,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.result["code"], b"403", channel.result) def test_login_appservice_wrong_as(self) -> None: """Test that as users cannot login with wrong as token""" @@ -1171,7 +1171,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): b"POST", LOGIN_URL, params, access_token=self.another_service.token ) - self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.result["code"], b"403", channel.result) def test_login_appservice_no_token(self) -> None: """Test that users must provide a token when using the appservice @@ -1185,7 +1185,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) @skip_unless(HAS_OIDC, "requires OIDC") diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py index b9647d5bd..4239e1e61 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py @@ -80,7 +80,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): def test_get_displayname_other(self): res = self._get_displayname(self.other) - self.assertEquals(res, "Bob") + self.assertEqual(res, "Bob") def test_set_displayname_other(self): channel = self.make_request( diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 2835d86e5..4b95b8541 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -65,7 +65,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) det_data = {"user_id": user_id, "home_server": self.hs.hostname} self.assertDictContainsSubset(det_data, channel.json_body) @@ -87,7 +87,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - self.assertEquals(channel.result["code"], b"400", channel.result) + self.assertEqual(channel.result["code"], b"400", channel.result) def test_POST_appservice_registration_invalid(self): self.appservice = None # no application service exists @@ -98,21 +98,21 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) def test_POST_bad_password(self): request_data = json.dumps({"username": "kermit", "password": 666}) channel = self.make_request(b"POST", self.url, request_data) - self.assertEquals(channel.result["code"], b"400", channel.result) - self.assertEquals(channel.json_body["error"], "Invalid password") + self.assertEqual(channel.result["code"], b"400", channel.result) + self.assertEqual(channel.json_body["error"], "Invalid password") def test_POST_bad_username(self): request_data = json.dumps({"username": 777, "password": "monkey"}) channel = self.make_request(b"POST", self.url, request_data) - self.assertEquals(channel.result["code"], b"400", channel.result) - self.assertEquals(channel.json_body["error"], "Invalid username") + self.assertEqual(channel.result["code"], b"400", channel.result) + self.assertEqual(channel.json_body["error"], "Invalid username") def test_POST_user_valid(self): user_id = "@kermit:test" @@ -131,7 +131,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "home_server": self.hs.hostname, "device_id": device_id, } - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.assertDictContainsSubset(det_data, channel.json_body) @override_config({"enable_registration": False}) @@ -141,9 +141,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", self.url, request_data) - self.assertEquals(channel.result["code"], b"403", channel.result) - self.assertEquals(channel.json_body["error"], "Registration has been disabled") - self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["error"], "Registration has been disabled") + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") def test_POST_guest_registration(self): self.hs.config.key.macaroon_secret_key = "test" @@ -152,7 +152,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"} - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.assertDictContainsSubset(det_data, channel.json_body) def test_POST_disabled_guest_registration(self): @@ -160,8 +160,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.assertEquals(channel.result["code"], b"403", channel.result) - self.assertEquals(channel.json_body["error"], "Guest access is disabled") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["error"], "Guest access is disabled") @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) def test_POST_ratelimiting_guest(self): @@ -170,16 +170,16 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", url, b"{}") if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.result["code"], b"429", channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) def test_POST_ratelimiting(self): @@ -194,16 +194,16 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", self.url, request_data) if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.result["code"], b"429", channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) @override_config({"registration_requires_token": True}) def test_POST_registration_requires_token(self): @@ -231,7 +231,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): # Request without auth to get flows and session channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] # Synapse adds a dummy stage to differentiate flows where otherwise one # flow would be a subset of another flow. @@ -249,7 +249,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): } request_data = json.dumps(params) channel = self.make_request(b"POST", self.url, request_data) - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) completed = channel.json_body["completed"] self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) @@ -265,7 +265,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "home_server": self.hs.hostname, "device_id": device_id, } - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.assertDictContainsSubset(det_data, channel.json_body) # Check the `completed` counter has been incremented and pending is 0 @@ -276,8 +276,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): retcols=["pending", "completed"], ) ) - self.assertEquals(res["completed"], 1) - self.assertEquals(res["pending"], 0) + self.assertEqual(res["completed"], 1) + self.assertEqual(res["pending"], 0) @override_config({"registration_requires_token": True}) def test_POST_registration_token_invalid(self): @@ -295,23 +295,23 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "session": session, } channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEquals(channel.result["code"], b"401", channel.result) - self.assertEquals(channel.json_body["errcode"], Codes.MISSING_PARAM) - self.assertEquals(channel.json_body["completed"], []) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM) + self.assertEqual(channel.json_body["completed"], []) # Test with non-string (invalid) params["auth"]["token"] = 1234 channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEquals(channel.result["code"], b"401", channel.result) - self.assertEquals(channel.json_body["errcode"], Codes.INVALID_PARAM) - self.assertEquals(channel.json_body["completed"], []) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + self.assertEqual(channel.json_body["completed"], []) # Test with unknown token (invalid) params["auth"]["token"] = "1234" channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEquals(channel.result["code"], b"401", channel.result) - self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) - self.assertEquals(channel.json_body["completed"], []) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEqual(channel.json_body["completed"], []) @override_config({"registration_requires_token": True}) def test_POST_registration_token_limit_uses(self): @@ -354,7 +354,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): retcol="pending", ) ) - self.assertEquals(pending, 1) + self.assertEqual(pending, 1) # Check auth fails when using token with session2 params2["auth"] = { @@ -363,9 +363,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "session": session2, } channel = self.make_request(b"POST", self.url, json.dumps(params2)) - self.assertEquals(channel.result["code"], b"401", channel.result) - self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) - self.assertEquals(channel.json_body["completed"], []) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEqual(channel.json_body["completed"], []) # Complete registration with session1 params1["auth"]["type"] = LoginType.DUMMY @@ -378,14 +378,14 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): retcols=["pending", "completed"], ) ) - self.assertEquals(res["pending"], 0) - self.assertEquals(res["completed"], 1) + self.assertEqual(res["pending"], 0) + self.assertEqual(res["completed"], 1) # Check auth still fails when using token with session2 channel = self.make_request(b"POST", self.url, json.dumps(params2)) - self.assertEquals(channel.result["code"], b"401", channel.result) - self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) - self.assertEquals(channel.json_body["completed"], []) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEqual(channel.json_body["completed"], []) @override_config({"registration_requires_token": True}) def test_POST_registration_token_expiry(self): @@ -417,9 +417,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "session": session, } channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEquals(channel.result["code"], b"401", channel.result) - self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) - self.assertEquals(channel.json_body["completed"], []) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEqual(channel.json_body["completed"], []) # Update token so it expires tomorrow self.get_success( @@ -504,7 +504,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): retcol="result", ) ) - self.assertEquals(db_to_json(result2), token) + self.assertEqual(db_to_json(result2), token) # Delete both sessions (mimics expiry) self.get_success( @@ -519,7 +519,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): retcol="pending", ) ) - self.assertEquals(pending, 0) + self.assertEqual(pending, 0) @override_config({"registration_requires_token": True}) def test_POST_registration_token_session_expiry_deleted_token(self): @@ -572,7 +572,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_advertised_flows(self): channel = self.make_request(b"POST", self.url, b"{}") - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] # with the stock config, we only expect the dummy flow @@ -595,7 +595,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) def test_advertised_flows_captcha_and_terms_and_3pids(self): channel = self.make_request(b"POST", self.url, b"{}") - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] self.assertCountEqual( @@ -627,7 +627,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) def test_advertised_flows_no_msisdn_email_required(self): channel = self.make_request(b"POST", self.url, b"{}") - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] # with the stock config, we expect all four combinations of 3pid @@ -671,7 +671,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): b"register/email/requestToken", {"client_secret": "foobar", "email": email, "send_attempt": 1}, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertIsNotNone(channel.json_body.get("sid")) @@ -694,9 +694,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): b"register/email/requestToken", {"client_secret": "foobar", "email": "email@@email", "send_attempt": 1}, ) - self.assertEquals(400, channel.code, channel.result) + self.assertEqual(400, channel.code, channel.result) # Check error to ensure that we're not erroring due to a bug in the test. - self.assertEquals( + self.assertEqual( channel.json_body, {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, ) @@ -707,8 +707,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): b"register/email/requestToken", {"client_secret": "foobar", "email": "email", "send_attempt": 1}, ) - self.assertEquals(400, channel.code, channel.result) - self.assertEquals( + self.assertEqual(400, channel.code, channel.result) + self.assertEqual( channel.json_body, {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, ) @@ -720,8 +720,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): b"register/email/requestToken", {"client_secret": "foobar", "email": email, "send_attempt": 1}, ) - self.assertEquals(400, channel.code, channel.result) - self.assertEquals( + self.assertEqual(400, channel.code, channel.result) + self.assertEqual( channel.json_body, {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, ) @@ -745,7 +745,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): # Check that /available correctly ignores the username provided despite the # username being already registered. channel = self.make_request("GET", "register/available?username=" + username) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # Test that when starting a UIA registration flow the request doesn't fail because # of a conflicting username @@ -799,14 +799,14 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): # endpoint. channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEquals(channel.result["code"], b"403", channel.result) - self.assertEquals( + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual( channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) @@ -826,12 +826,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): params = {"user_id": user_id} request_data = json.dumps(params) channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) def test_manual_expire(self): user_id = self.register_user("kermit", "monkey") @@ -848,13 +848,13 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): } request_data = json.dumps(params) channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEquals(channel.result["code"], b"403", channel.result) - self.assertEquals( + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual( channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) @@ -873,18 +873,18 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): } request_data = json.dumps(params) channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # Try to log the user out channel = self.make_request(b"POST", "/logout", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # Log the user in again (allowed for expired accounts) tok = self.login("kermit", "monkey") # Try to log out all of the user's sessions channel = self.make_request(b"POST", "/logout/all", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): @@ -959,7 +959,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id)) url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token channel = self.make_request(b"GET", url) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # Check that we're getting HTML back. content_type = channel.headers.getRawHeaders(b"Content-Type") @@ -977,7 +977,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # Move 1 day forward. Try to renew with the same token again. url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token channel = self.make_request(b"GET", url) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # Check that we're getting HTML back. content_type = channel.headers.getRawHeaders(b"Content-Type") @@ -997,14 +997,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # succeed. self.reactor.advance(datetime.timedelta(days=3).total_seconds()) channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) def test_renewal_invalid_token(self): # Hit the renewal endpoint with an invalid token and check that it behaves as # expected, i.e. that it responds with 404 Not Found and the correct HTML. url = "/_matrix/client/unstable/account_validity/renew?token=123" channel = self.make_request(b"GET", url) - self.assertEquals(channel.result["code"], b"404", channel.result) + self.assertEqual(channel.result["code"], b"404", channel.result) # Check that we're getting HTML back. content_type = channel.headers.getRawHeaders(b"Content-Type") @@ -1028,7 +1028,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): "/_matrix/client/unstable/account_validity/send_mail", access_token=tok, ) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(len(self.email_attempts), 1) @@ -1103,7 +1103,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): "/_matrix/client/unstable/account_validity/send_mail", access_token=tok, ) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(len(self.email_attempts), 1) @@ -1183,8 +1183,8 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): b"GET", f"{self.url}?token={token}", ) - self.assertEquals(channel.result["code"], b"200", channel.result) - self.assertEquals(channel.json_body["valid"], True) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["valid"], True) def test_GET_token_invalid(self): token = "1234" @@ -1192,8 +1192,8 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): b"GET", f"{self.url}?token={token}", ) - self.assertEquals(channel.result["code"], b"200", channel.result) - self.assertEquals(channel.json_body["valid"], False) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["valid"], False) @override_config( {"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}} @@ -1208,10 +1208,10 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): ) if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.result["code"], b"429", channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) @@ -1219,4 +1219,4 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): b"GET", f"{self.url}?token={token}", ) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 5687dea48..8f7181103 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -69,7 +69,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) event_id = channel.json_body["event_id"] @@ -78,7 +78,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "/rooms/%s/event/%s" % (self.room, event_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assert_dict( { @@ -103,7 +103,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): parent_id="foo", content={"body": "foo", "msgtype": "m.text"}, ) - self.assertEquals(400, channel.code, channel.json_body) + self.assertEqual(400, channel.code, channel.json_body) # Unless that event is referenced from another event! self.get_success( @@ -123,7 +123,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): parent_id="foo", content={"body": "foo", "msgtype": "m.text"}, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) def test_deny_invalid_room(self): """Test that we deny relations on non-existant events""" @@ -136,15 +136,15 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", parent_id=parent_id, key="A" ) - self.assertEquals(400, channel.code, channel.json_body) + self.assertEqual(400, channel.code, channel.json_body) def test_deny_double_react(self): """Test that we deny relations on membership events""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(400, channel.code, channel.json_body) + self.assertEqual(400, channel.code, channel.json_body) def test_deny_forked_thread(self): """It is invalid to start a thread off a thread.""" @@ -154,7 +154,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): content={"msgtype": "m.text", "body": "foo"}, parent_id=self.parent_id, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) parent_id = channel.json_body["event_id"] channel = self._send_relation( @@ -163,16 +163,16 @@ class RelationsTestCase(unittest.HomeserverTestCase): content={"msgtype": "m.text", "body": "foo"}, parent_id=parent_id, ) - self.assertEquals(400, channel.code, channel.json_body) + self.assertEqual(400, channel.code, channel.json_body) def test_basic_paginate_relations(self): """Tests that calling pagination API correctly the latest relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) first_annotation_id = channel.json_body["event_id"] channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) second_annotation_id = channel.json_body["event_id"] channel = self.make_request( @@ -180,11 +180,11 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # We expect to get back a single pagination result, which is the latest # full relation event we sent above. - self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) self.assert_dict( { "event_id": second_annotation_id, @@ -195,7 +195,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) # We also expect to get the original event (the id of which is self.parent_id) - self.assertEquals( + self.assertEqual( channel.json_body["original_event"]["event_id"], self.parent_id ) @@ -212,11 +212,11 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # We expect to get back a single pagination result, which is the earliest # full relation event we sent above. - self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) self.assert_dict( { "event_id": first_annotation_id, @@ -245,7 +245,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx) ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) expected_event_ids.append(channel.json_body["event_id"]) prev_token = "" @@ -260,12 +260,12 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) next_batch = channel.json_body.get("next_batch") - self.assertNotEquals(prev_token, next_batch) + self.assertNotEqual(prev_token, next_batch) prev_token = next_batch if not prev_token: @@ -273,7 +273,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): # We paginated backwards, so reverse found_event_ids.reverse() - self.assertEquals(found_event_ids, expected_event_ids) + self.assertEqual(found_event_ids, expected_event_ids) # Reset and try again, but convert the tokens to the legacy format. prev_token = "" @@ -288,12 +288,12 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) next_batch = channel.json_body.get("next_batch") - self.assertNotEquals(prev_token, next_batch) + self.assertNotEqual(prev_token, next_batch) prev_token = next_batch if not prev_token: @@ -301,12 +301,12 @@ class RelationsTestCase(unittest.HomeserverTestCase): # We paginated backwards, so reverse found_event_ids.reverse() - self.assertEquals(found_event_ids, expected_event_ids) + self.assertEqual(found_event_ids, expected_event_ids) def test_pagination_from_sync_and_messages(self): """Pagination tokens from /sync and /messages can be used to paginate /relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) annotation_id = channel.json_body["event_id"] # Send an event after the relation events. self.helper.send(self.room, body="Latest event", tok=self.user_token) @@ -319,7 +319,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", f"/sync?filter={filter}", access_token=self.user_token ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] sync_prev_batch = room_timeline["prev_batch"] self.assertIsNotNone(sync_prev_batch) @@ -335,7 +335,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/rooms/{self.room}/messages?dir=b&limit=1", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) messages_end = channel.json_body["end"] self.assertIsNotNone(messages_end) # Ensure the relation event is not in the chunk returned from /messages. @@ -355,7 +355,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # The relation should be in the returned chunk. self.assertIn( @@ -386,7 +386,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): key=key, access_token=access_tokens[idx], ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) idx += 1 idx %= len(access_tokens) @@ -404,7 +404,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, self.parent_id, from_token), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) @@ -419,13 +419,13 @@ class RelationsTestCase(unittest.HomeserverTestCase): next_batch = channel.json_body.get("next_batch") - self.assertNotEquals(prev_token, next_batch) + self.assertNotEqual(prev_token, next_batch) prev_token = next_batch if not prev_token: break - self.assertEquals(sent_groups, found_groups) + self.assertEqual(sent_groups, found_groups) def test_aggregation_pagination_within_group(self): """Test that we can paginate within an annotation group.""" @@ -449,14 +449,14 @@ class RelationsTestCase(unittest.HomeserverTestCase): key="👍", access_token=access_tokens[idx], ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) expected_event_ids.append(channel.json_body["event_id"]) idx += 1 # Also send a different type of reaction so that we test we don't see it channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) prev_token = "" found_event_ids: List[str] = [] @@ -473,7 +473,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/m.reaction/{encoded_key}?limit=1{from_token}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) @@ -481,7 +481,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): next_batch = channel.json_body.get("next_batch") - self.assertNotEquals(prev_token, next_batch) + self.assertNotEqual(prev_token, next_batch) prev_token = next_batch if not prev_token: @@ -489,7 +489,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): # We paginated backwards, so reverse found_event_ids.reverse() - self.assertEquals(found_event_ids, expected_event_ids) + self.assertEqual(found_event_ids, expected_event_ids) # Reset and try again, but convert the tokens to the legacy format. prev_token = "" @@ -506,7 +506,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/m.reaction/{encoded_key}?limit=1{from_token}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) @@ -514,7 +514,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): next_batch = channel.json_body.get("next_batch") - self.assertNotEquals(prev_token, next_batch) + self.assertNotEqual(prev_token, next_batch) prev_token = next_batch if not prev_token: @@ -522,21 +522,21 @@ class RelationsTestCase(unittest.HomeserverTestCase): # We paginated backwards, so reverse found_event_ids.reverse() - self.assertEquals(found_event_ids, expected_event_ids) + self.assertEqual(found_event_ids, expected_event_ids) def test_aggregation(self): """Test that annotations get correctly aggregated.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self.make_request( "GET", @@ -544,9 +544,9 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, self.parent_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) - self.assertEquals( + self.assertEqual( channel.json_body, { "chunk": [ @@ -560,13 +560,13 @@ class RelationsTestCase(unittest.HomeserverTestCase): """Test that annotations get correctly aggregated after a redaction.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) to_redact_event_id = channel.json_body["event_id"] channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Now lets redact one of the 'a' reactions channel = self.make_request( @@ -575,7 +575,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): access_token=self.user_token, content={}, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self.make_request( "GET", @@ -583,9 +583,9 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, self.parent_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) - self.assertEquals( + self.assertEqual( channel.json_body, {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, ) @@ -599,7 +599,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, self.parent_id, RelationTypes.REPLACE), access_token=self.user_token, ) - self.assertEquals(400, channel.code, channel.json_body) + self.assertEqual(400, channel.code, channel.json_body) @unittest.override_config( {"experimental_features": {"msc3440_enabled": True, "msc3666_enabled": True}} @@ -615,29 +615,29 @@ class RelationsTestCase(unittest.HomeserverTestCase): """ # Setup by sending a variety of relations. channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) reply_1 = channel.json_body["event_id"] channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) reply_2 = channel.json_body["event_id"] channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) thread_2 = channel.json_body["event_id"] def assert_bundle(event_json: JsonDict) -> None: @@ -655,7 +655,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) # Check the values of each field. - self.assertEquals( + self.assertEqual( { "chunk": [ {"type": "m.reaction", "key": "a", "count": 2}, @@ -665,12 +665,12 @@ class RelationsTestCase(unittest.HomeserverTestCase): relations_dict[RelationTypes.ANNOTATION], ) - self.assertEquals( + self.assertEqual( {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, relations_dict[RelationTypes.REFERENCE], ) - self.assertEquals( + self.assertEqual( 2, relations_dict[RelationTypes.THREAD].get("count"), ) @@ -701,7 +701,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/rooms/{self.room}/event/{self.parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) assert_bundle(channel.json_body) # Request the room messages. @@ -710,7 +710,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/rooms/{self.room}/messages?dir=b", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) # Request the room context. @@ -719,12 +719,12 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/rooms/{self.room}/context/{self.parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) assert_bundle(channel.json_body["event"]) # Request sync. channel = self.make_request("GET", "/sync", access_token=self.user_token) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] self.assertTrue(room_timeline["limited"]) assert_bundle(self._find_event_in_chunk(room_timeline["events"])) @@ -737,7 +737,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): content={"search_categories": {"room_events": {"search_term": "Hi"}}}, access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) chunk = [ result["result"] for result in channel.json_body["search_categories"]["room_events"][ @@ -751,42 +751,42 @@ class RelationsTestCase(unittest.HomeserverTestCase): when directly requested. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) annotation_id = channel.json_body["event_id"] # Annotate the annotation. channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self.make_request( "GET", f"/rooms/{self.room}/event/{annotation_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertIsNone(channel.json_body["unsigned"].get("m.relations")) def test_aggregation_get_event_for_thread(self): """Test that threads get bundled aggregations included when directly requested.""" channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) thread_id = channel.json_body["event_id"] # Annotate the annotation. channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self.make_request( "GET", f"/rooms/{self.room}/event/{thread_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) - self.assertEquals( + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual( channel.json_body["unsigned"].get("m.relations"), { RelationTypes.ANNOTATION: { @@ -801,11 +801,11 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(len(channel.json_body["chunk"]), 1) thread_message = channel.json_body["chunk"][0] - self.assertEquals( + self.assertEqual( thread_message["unsigned"].get("m.relations"), { RelationTypes.ANNOTATION: { @@ -905,7 +905,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/_matrix/client/unstable/rooms/{room2}/relations/{parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(channel.json_body["chunk"], []) # And when fetching aggregations. @@ -914,7 +914,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/_matrix/client/unstable/rooms/{room2}/aggregations/{parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(channel.json_body["chunk"], []) # And for bundled aggregations. @@ -923,7 +923,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/rooms/{room2}/event/{parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertNotIn("m.relations", channel.json_body["unsigned"]) @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) @@ -936,7 +936,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "m.room.message", content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) edit_event_id = channel.json_body["event_id"] @@ -958,8 +958,8 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/rooms/{self.room}/event/{self.parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) - self.assertEquals(channel.json_body["content"], new_body) + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual(channel.json_body["content"], new_body) assert_bundle(channel.json_body) # Request the room messages. @@ -968,7 +968,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/rooms/{self.room}/messages?dir=b", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) # Request the room context. @@ -977,7 +977,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/rooms/{self.room}/context/{self.parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) assert_bundle(channel.json_body["event"]) # Request sync, but limit the timeline so it becomes limited (and includes @@ -988,7 +988,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", f"/sync?filter={filter}", access_token=self.user_token ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] self.assertTrue(room_timeline["limited"]) assert_bundle(self._find_event_in_chunk(room_timeline["events"])) @@ -1001,7 +1001,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): content={"search_categories": {"room_events": {"search_term": "Hi"}}}, access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) chunk = [ result["result"] for result in channel.json_body["search_categories"]["room_events"][ @@ -1024,7 +1024,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "m.new_content": {"msgtype": "m.text", "body": "First edit"}, }, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) new_body = {"msgtype": "m.text", "body": "I've been edited!"} channel = self._send_relation( @@ -1032,7 +1032,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "m.room.message", content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) edit_event_id = channel.json_body["event_id"] @@ -1045,16 +1045,16 @@ class RelationsTestCase(unittest.HomeserverTestCase): "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"}, }, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self.make_request( "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) - self.assertEquals(channel.json_body["content"], new_body) + self.assertEqual(channel.json_body["content"], new_body) relations_dict = channel.json_body["unsigned"].get("m.relations") self.assertIn(RelationTypes.REPLACE, relations_dict) @@ -1076,7 +1076,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "m.room.message", content={"msgtype": "m.text", "body": "A reply!"}, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) reply = channel.json_body["event_id"] new_body = {"msgtype": "m.text", "body": "I've been edited!"} @@ -1086,7 +1086,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, parent_id=reply, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) edit_event_id = channel.json_body["event_id"] @@ -1095,7 +1095,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "/rooms/%s/event/%s" % (self.room, reply), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # We expect to see the new body in the dict, as well as the reference # metadata sill intact. @@ -1133,7 +1133,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "m.room.message", content={"msgtype": "m.text", "body": "A threaded reply!"}, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) threaded_event_id = channel.json_body["event_id"] new_body = {"msgtype": "m.text", "body": "I've been edited!"} @@ -1143,7 +1143,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, parent_id=threaded_event_id, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Fetch the thread root, to get the bundled aggregation for the thread. channel = self.make_request( @@ -1151,7 +1151,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/rooms/{self.room}/event/{self.parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # We expect that the edit message appears in the thread summary in the # unsigned relations section. @@ -1161,9 +1161,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): thread_summary = relations_dict[RelationTypes.THREAD] self.assertIn("latest_event", thread_summary) latest_event_in_thread = thread_summary["latest_event"] - self.assertEquals( - latest_event_in_thread["content"]["body"], "I've been edited!" - ) + self.assertEqual(latest_event_in_thread["content"]["body"], "I've been edited!") def test_edit_edit(self): """Test that an edit cannot be edited.""" @@ -1177,7 +1175,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "m.new_content": new_body, }, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) edit_event_id = channel.json_body["event_id"] # Edit the edit event. @@ -1191,7 +1189,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): }, parent_id=edit_event_id, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Request the original event. channel = self.make_request( @@ -1199,9 +1197,9 @@ class RelationsTestCase(unittest.HomeserverTestCase): "/rooms/%s/event/%s" % (self.room, self.parent_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # The edit to the edit should be ignored. - self.assertEquals(channel.json_body["content"], new_body) + self.assertEqual(channel.json_body["content"], new_body) # The relations information should not include the edit to the edit. relations_dict = channel.json_body["unsigned"].get("m.relations") @@ -1234,7 +1232,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "m.new_content": {"msgtype": "m.text", "body": "First edit"}, }, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Check the relation is returned channel = self.make_request( @@ -1243,10 +1241,10 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, original_event_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertIn("chunk", channel.json_body) - self.assertEquals(len(channel.json_body["chunk"]), 1) + self.assertEqual(len(channel.json_body["chunk"]), 1) # Redact the original event channel = self.make_request( @@ -1256,7 +1254,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): access_token=self.user_token, content="{}", ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Try to check for remaining m.replace relations channel = self.make_request( @@ -1265,11 +1263,11 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, original_event_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Check that no relations are returned self.assertIn("chunk", channel.json_body) - self.assertEquals(channel.json_body["chunk"], []) + self.assertEqual(channel.json_body["chunk"], []) def test_aggregations_redaction_prevents_access_to_aggregations(self): """Test that annotations of an event are redacted when the original event @@ -1283,7 +1281,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", key="👍", parent_id=original_event_id ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Redact the original channel = self.make_request( @@ -1297,7 +1295,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): access_token=self.user_token, content="{}", ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Check that aggregations returns zero channel = self.make_request( @@ -1306,15 +1304,15 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, original_event_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertIn("chunk", channel.json_body) - self.assertEquals(channel.json_body["chunk"], []) + self.assertEqual(channel.json_body["chunk"], []) def test_unknown_relations(self): """Unknown relations should be accepted.""" channel = self._send_relation("m.relation.test", "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) event_id = channel.json_body["event_id"] channel = self.make_request( @@ -1323,18 +1321,18 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, self.parent_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # We expect to get back a single pagination result, which is the full # relation event we sent above. - self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) self.assert_dict( {"event_id": event_id, "sender": self.user_id, "type": "m.room.test"}, channel.json_body["chunk"][0], ) # We also expect to get the original event (the id of which is self.parent_id) - self.assertEquals( + self.assertEqual( channel.json_body["original_event"]["event_id"], self.parent_id ) @@ -1344,7 +1342,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "/rooms/%s/event/%s" % (self.room, self.parent_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertNotIn("m.relations", channel.json_body["unsigned"]) # But unknown relations can be directly queried. @@ -1354,8 +1352,8 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, self.parent_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) - self.assertEquals(channel.json_body["chunk"], []) + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual(channel.json_body["chunk"], []) def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: """ @@ -1422,15 +1420,15 @@ class RelationsTestCase(unittest.HomeserverTestCase): def test_background_update(self): """Test the event_arbitrary_relations background update.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) annotation_event_id_good = channel.json_body["event_id"] channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) annotation_event_id_bad = channel.json_body["event_id"] channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) thread_event_id = channel.json_body["event_id"] # Clean-up the table as if the inserts did not happen during event creation. @@ -1450,8 +1448,8 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) - self.assertEquals( + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual( [ev["event_id"] for ev in channel.json_body["chunk"]], [annotation_event_id_good], ) @@ -1475,7 +1473,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertCountEqual( [ev["event_id"] for ev in channel.json_body["chunk"]], [annotation_event_id_good, thread_event_id], diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 1afd96b8f..e0b11e726 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -95,7 +95,7 @@ class RoomPermissionsTestCase(RoomBase): channel = self.make_request( "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}' ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # set topic for public room channel = self.make_request( @@ -103,7 +103,7 @@ class RoomPermissionsTestCase(RoomBase): ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"), b'{"topic":"Public Room Topic"}', ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # auth as user_id now self.helper.auth_user_id = self.user_id @@ -125,28 +125,28 @@ class RoomPermissionsTestCase(RoomBase): "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), msg_content, ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # send message in created room not joined (no state), expect 403 channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # send message in created room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # send message in created room and joined, expect 200 self.helper.join(room=self.created_rmid, user=self.user_id) channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # send message in created room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) def test_topic_perms(self): topic_content = b'{"topic":"My Topic Name"}' @@ -156,28 +156,28 @@ class RoomPermissionsTestCase(RoomBase): channel = self.make_request( "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) channel = self.make_request( "GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room not joined, expect 403 channel = self.make_request("PUT", topic_path, topic_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", topic_path) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # set topic in created PRIVATE room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) channel = self.make_request("PUT", topic_path, topic_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # get topic in created PRIVATE room and invited, expect 403 channel = self.make_request("GET", topic_path) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room and joined, expect 200 self.helper.join(room=self.created_rmid, user=self.user_id) @@ -185,25 +185,25 @@ class RoomPermissionsTestCase(RoomBase): # Only room ops can set topic by default self.helper.auth_user_id = self.rmcreator_id channel = self.make_request("PUT", topic_path, topic_content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.helper.auth_user_id = self.user_id channel = self.make_request("GET", topic_path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body) # set/get topic in created PRIVATE room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) channel = self.make_request("PUT", topic_path, topic_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", topic_path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # get topic in PUBLIC room, not joined, expect 403 channel = self.make_request( "GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # set topic in PUBLIC room, not joined, expect 403 channel = self.make_request( @@ -211,7 +211,7 @@ class RoomPermissionsTestCase(RoomBase): "/rooms/%s/state/m.room.topic" % self.created_public_rmid, topic_content, ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) def _test_get_membership( self, room=None, members: Iterable = frozenset(), expect_code=None @@ -219,7 +219,7 @@ class RoomPermissionsTestCase(RoomBase): for member in members: path = "/rooms/%s/state/m.room.member/%s" % (room, member) channel = self.make_request("GET", path) - self.assertEquals(expect_code, channel.code) + self.assertEqual(expect_code, channel.code) def test_membership_basic_room_perms(self): # === room does not exist === @@ -478,16 +478,16 @@ class RoomsMemberListTestCase(RoomBase): def test_get_member_list(self): room_id = self.helper.create_room_as(self.user_id) channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) def test_get_member_list_no_room(self): channel = self.make_request("GET", "/rooms/roomdoesnotexist/members") - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission(self): room_id = self.helper.create_room_as("@some_other_guy:red") channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission_with_at_token(self): """ @@ -498,7 +498,7 @@ class RoomsMemberListTestCase(RoomBase): # first sync to get an at token channel = self.make_request("GET", "/sync") - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) sync_token = channel.json_body["next_batch"] # check that permission is denied for @sid1:red to get the @@ -507,7 +507,7 @@ class RoomsMemberListTestCase(RoomBase): "GET", f"/rooms/{room_id}/members?at={sync_token}", ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission_former_member(self): """ @@ -520,14 +520,14 @@ class RoomsMemberListTestCase(RoomBase): # check that the user can see the member list to start with channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # ban the user self.helper.change_membership(room_id, "@alice:red", self.user_id, "ban") # check the user can no longer see the member list channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission_former_member_with_at_token(self): """ @@ -541,14 +541,14 @@ class RoomsMemberListTestCase(RoomBase): # sync to get an at token channel = self.make_request("GET", "/sync") - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) sync_token = channel.json_body["next_batch"] # check that the user can see the member list to start with channel = self.make_request( "GET", "/rooms/%s/members?at=%s" % (room_id, sync_token) ) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # ban the user (Note: the user is actually allowed to see this event and # state so that they know they're banned!) @@ -560,14 +560,14 @@ class RoomsMemberListTestCase(RoomBase): # now, with the original user, sync again to get a new at token channel = self.make_request("GET", "/sync") - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) sync_token = channel.json_body["next_batch"] # check the user can no longer see the updated member list channel = self.make_request( "GET", "/rooms/%s/members?at=%s" % (room_id, sync_token) ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) def test_get_member_list_mixed_memberships(self): room_creator = "@some_other_guy:red" @@ -576,17 +576,17 @@ class RoomsMemberListTestCase(RoomBase): self.helper.invite(room=room_id, src=room_creator, targ=self.user_id) # can't see list if you're just invited. channel = self.make_request("GET", room_path) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) self.helper.join(room=room_id, user=self.user_id) # can see list now joined channel = self.make_request("GET", room_path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.helper.leave(room=room_id, user=self.user_id) # can see old list once left channel = self.make_request("GET", room_path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) class RoomsCreateTestCase(RoomBase): @@ -598,19 +598,19 @@ class RoomsCreateTestCase(RoomBase): # POST with no config keys, expect new room id channel = self.make_request("POST", "/createRoom", "{}") - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) def test_post_room_visibility_key(self): # POST with visibility config key, expect new room id channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}') - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_custom_key(self): # POST with custom config keys, expect new room id channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}') - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_known_and_unknown_keys(self): @@ -618,16 +618,16 @@ class RoomsCreateTestCase(RoomBase): channel = self.make_request( "POST", "/createRoom", b'{"visibility":"private","custom":"things"}' ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_invalid_content(self): # POST with invalid content / paths, expect 400 channel = self.make_request("POST", "/createRoom", b'{"visibili') - self.assertEquals(400, channel.code) + self.assertEqual(400, channel.code) channel = self.make_request("POST", "/createRoom", b'["hello"]') - self.assertEquals(400, channel.code) + self.assertEqual(400, channel.code) def test_post_room_invitees_invalid_mxid(self): # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088 @@ -635,7 +635,7 @@ class RoomsCreateTestCase(RoomBase): channel = self.make_request( "POST", "/createRoom", b'{"invite":["@alice:example.com "]}' ) - self.assertEquals(400, channel.code) + self.assertEqual(400, channel.code) @unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}}) def test_post_room_invitees_ratelimit(self): @@ -694,9 +694,9 @@ class RoomsCreateTestCase(RoomBase): "/createRoom", {}, ) - self.assertEquals(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, 200, channel.json_body) - self.assertEquals(join_mock.call_count, 0) + self.assertEqual(join_mock.call_count, 0) class RoomTopicTestCase(RoomBase): @@ -712,54 +712,54 @@ class RoomTopicTestCase(RoomBase): def test_invalid_puts(self): # missing keys or invalid json channel = self.make_request("PUT", self.path, "{}") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", self.path, '{"_name":"bo"}') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", self.path, '{"nao') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request( "PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]' ) - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", self.path, "text only") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", self.path, "") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) # valid key, wrong type content = '{"topic":["Topic name"]}' channel = self.make_request("PUT", self.path, content) - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) def test_rooms_topic(self): # nothing should be there channel = self.make_request("GET", self.path) - self.assertEquals(404, channel.code, msg=channel.result["body"]) + self.assertEqual(404, channel.code, msg=channel.result["body"]) # valid put content = '{"topic":"Topic name"}' channel = self.make_request("PUT", self.path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # valid get channel = self.make_request("GET", self.path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) def test_rooms_topic_with_extra_keys(self): # valid put with extra keys content = '{"topic":"Seasons","subtopic":"Summer"}' channel = self.make_request("PUT", self.path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # valid get channel = self.make_request("GET", self.path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) @@ -775,22 +775,22 @@ class RoomMemberStateTestCase(RoomBase): path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) # missing keys or invalid json channel = self.make_request("PUT", path, "{}") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, '{"_name":"bo"}') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, '{"nao') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, "text only") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, "") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) # valid keys, wrong types content = '{"membership":["%s","%s","%s"]}' % ( @@ -799,7 +799,7 @@ class RoomMemberStateTestCase(RoomBase): Membership.LEAVE, ) channel = self.make_request("PUT", path, content.encode("ascii")) - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) def test_rooms_members_self(self): path = "/rooms/%s/state/m.room.member/%s" % ( @@ -810,13 +810,13 @@ class RoomMemberStateTestCase(RoomBase): # valid join message (NOOP since we made the room) content = '{"membership":"%s"}' % Membership.JOIN channel = self.make_request("PUT", path, content.encode("ascii")) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", path, None) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) expected_response = {"membership": Membership.JOIN} - self.assertEquals(expected_response, channel.json_body) + self.assertEqual(expected_response, channel.json_body) def test_rooms_members_other(self): self.other_id = "@zzsid1:red" @@ -828,11 +828,11 @@ class RoomMemberStateTestCase(RoomBase): # valid invite message content = '{"membership":"%s"}' % Membership.INVITE channel = self.make_request("PUT", path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", path, None) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - self.assertEquals(json.loads(content), channel.json_body) + self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(json.loads(content), channel.json_body) def test_rooms_members_other_custom_keys(self): self.other_id = "@zzsid1:red" @@ -847,11 +847,11 @@ class RoomMemberStateTestCase(RoomBase): "Join us!", ) channel = self.make_request("PUT", path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", path, None) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - self.assertEquals(json.loads(content), channel.json_body) + self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(json.loads(content), channel.json_body) class RoomInviteRatelimitTestCase(RoomBase): @@ -937,7 +937,7 @@ class RoomJoinTestCase(RoomBase): False, ), ) - self.assertEquals( + self.assertEqual( callback_mock.call_args, expected_call_args, callback_mock.call_args, @@ -955,7 +955,7 @@ class RoomJoinTestCase(RoomBase): True, ), ) - self.assertEquals( + self.assertEqual( callback_mock.call_args, expected_call_args, callback_mock.call_args, @@ -1013,7 +1013,7 @@ class RoomJoinRatelimitTestCase(RoomBase): # Update the display name for the user. path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id channel = self.make_request("PUT", path, {"displayname": "John Doe"}) - self.assertEquals(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, 200, channel.json_body) # Check that all the rooms have been sent a profile update into. for room_id in room_ids: @@ -1023,10 +1023,10 @@ class RoomJoinRatelimitTestCase(RoomBase): ) channel = self.make_request("GET", path) - self.assertEquals(channel.code, 200) + self.assertEqual(channel.code, 200) self.assertIn("displayname", channel.json_body) - self.assertEquals(channel.json_body["displayname"], "John Doe") + self.assertEqual(channel.json_body["displayname"], "John Doe") @unittest.override_config( {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} @@ -1047,7 +1047,7 @@ class RoomJoinRatelimitTestCase(RoomBase): # if all of these requests ended up joining the user to a room. for _ in range(4): channel = self.make_request("POST", path % room_id, {}) - self.assertEquals(channel.code, 200) + self.assertEqual(channel.code, 200) @unittest.override_config( { @@ -1078,40 +1078,40 @@ class RoomMessagesTestCase(RoomBase): path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) # missing keys or invalid json channel = self.make_request("PUT", path, b"{}") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, b'{"_name":"bo"}') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, b'{"nao') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, b"text only") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, b"") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) def test_rooms_messages_sent(self): path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) content = b'{"body":"test","msgtype":{"type":"a"}}' channel = self.make_request("PUT", path, content) - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) # custom message types content = b'{"body":"test","msgtype":"test.custom.text"}' channel = self.make_request("PUT", path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # m.text message type path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id)) content = b'{"body":"test2","msgtype":"m.text"}' channel = self.make_request("PUT", path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) class RoomInitialSyncTestCase(RoomBase): @@ -1125,10 +1125,10 @@ class RoomInitialSyncTestCase(RoomBase): def test_initial_sync(self): channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) - self.assertEquals(self.room_id, channel.json_body["room_id"]) - self.assertEquals("join", channel.json_body["membership"]) + self.assertEqual(self.room_id, channel.json_body["room_id"]) + self.assertEqual("join", channel.json_body["membership"]) # Room state is easier to assert on if we unpack it into a dict state = {} @@ -1152,7 +1152,7 @@ class RoomInitialSyncTestCase(RoomBase): e["content"]["user_id"]: e for e in channel.json_body["presence"] } self.assertTrue(self.user_id in presence_by_user) - self.assertEquals("m.presence", presence_by_user[self.user_id]["type"]) + self.assertEqual("m.presence", presence_by_user[self.user_id]["type"]) class RoomMessageListTestCase(RoomBase): @@ -1168,9 +1168,9 @@ class RoomMessageListTestCase(RoomBase): channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) self.assertTrue("start" in channel.json_body) - self.assertEquals(token, channel.json_body["start"]) + self.assertEqual(token, channel.json_body["start"]) self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) @@ -1179,9 +1179,9 @@ class RoomMessageListTestCase(RoomBase): channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) self.assertTrue("start" in channel.json_body) - self.assertEquals(token, channel.json_body["start"]) + self.assertEqual(token, channel.json_body["start"]) self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) @@ -2614,7 +2614,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): }, access_token=self.tok, ) - self.assertEquals(channel.code, 200) + self.assertEqual(channel.code, 200) # Check that the callback was called with the right params. mock.assert_called_with(self.user_id, "email", email_to_invite, self.room_id) @@ -2636,7 +2636,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): }, access_token=self.tok, ) - self.assertEquals(channel.code, 403) + self.assertEqual(channel.code, 403) # Also check that it stopped before calling _make_and_store_3pid_invite. make_invite_mock.assert_called_once() diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index 7d0e66b53..2634c98dd 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -96,7 +96,7 @@ class RoomTestCase(_ShadowBannedBase): {"id_server": "test", "medium": "email", "address": "test@test.test"}, access_token=self.banned_access_token, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # This should have raised an error earlier, but double check this wasn't called. identity_handler.lookup_3pid.assert_not_called() @@ -110,7 +110,7 @@ class RoomTestCase(_ShadowBannedBase): {"visibility": "public", "invite": [self.other_user_id]}, access_token=self.banned_access_token, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) room_id = channel.json_body["room_id"] # But the user wasn't actually invited. @@ -165,7 +165,7 @@ class RoomTestCase(_ShadowBannedBase): {"new_version": "6"}, access_token=self.banned_access_token, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # A new room_id should be returned. self.assertIn("replacement_room", channel.json_body) @@ -190,11 +190,11 @@ class RoomTestCase(_ShadowBannedBase): {"typing": True, "timeout": 30000}, access_token=self.banned_access_token, ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) # There should be no typing events. event_source = self.hs.get_event_sources().sources.typing - self.assertEquals(event_source.get_current_key(), 0) + self.assertEqual(event_source.get_current_key(), 0) # The other user can join and send typing events. self.helper.join(room_id, self.other_user_id, tok=self.other_access_token) @@ -205,10 +205,10 @@ class RoomTestCase(_ShadowBannedBase): {"typing": True, "timeout": 30000}, access_token=self.other_access_token, ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) # These appear in the room. - self.assertEquals(event_source.get_current_key(), 1) + self.assertEqual(event_source.get_current_key(), 1) events = self.get_success( event_source.get_new_events( user=UserID.from_string(self.other_user_id), @@ -218,7 +218,7 @@ class RoomTestCase(_ShadowBannedBase): is_guest=False, ) ) - self.assertEquals( + self.assertEqual( events[0], [ { @@ -257,7 +257,7 @@ class ProfileTestCase(_ShadowBannedBase): {"displayname": new_display_name}, access_token=self.banned_access_token, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertEqual(channel.json_body, {}) # The user's display name should be updated. @@ -299,7 +299,7 @@ class ProfileTestCase(_ShadowBannedBase): {"membership": "join", "displayname": new_display_name}, access_token=self.banned_access_token, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertIn("event_id", channel.json_body) # The display name in the room should not be changed. diff --git a/tests/rest/client/test_shared_rooms.py b/tests/rest/client/test_shared_rooms.py index c42c8aff6..294f46fb9 100644 --- a/tests/rest/client/test_shared_rooms.py +++ b/tests/rest/client/test_shared_rooms.py @@ -91,9 +91,9 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): # Check shared rooms from user1's perspective. # We should see the one room in common channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 1) - self.assertEquals(channel.json_body["joined"][0], room_id_one) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 1) + self.assertEqual(channel.json_body["joined"][0], room_id_one) # Create another room and invite user2 to it room_id_two = self.helper.create_room_as( @@ -104,8 +104,8 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): # Check shared rooms again. We should now see both rooms. channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 2) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 2) for room_id_id in channel.json_body["joined"]: self.assertIn(room_id_id, [room_id_one, room_id_two]) @@ -125,18 +125,18 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): # Assert user directory is not empty channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 1) - self.assertEquals(channel.json_body["joined"][0], room) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 1) + self.assertEqual(channel.json_body["joined"][0], room) self.helper.leave(room, user=u1, tok=u1_token) # Check user1's view of shared rooms with user2 channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 0) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 0) # Check user2's view of shared rooms with user1 channel = self._get_shared_rooms(u2_token, u1) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 0) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 0) diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index 69b4ef537..435101395 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -237,10 +237,10 @@ class SyncTypingTests(unittest.HomeserverTestCase): typing_url % (room, other_user_id, other_access_token), b'{"typing": true, "timeout": 30000}', ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) channel = self.make_request("GET", "/sync?access_token=%s" % (access_token,)) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) next_batch = channel.json_body["next_batch"] # Stop typing. @@ -249,7 +249,7 @@ class SyncTypingTests(unittest.HomeserverTestCase): typing_url % (room, other_user_id, other_access_token), b'{"typing": false}', ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) # Start typing. channel = self.make_request( @@ -257,11 +257,11 @@ class SyncTypingTests(unittest.HomeserverTestCase): typing_url % (room, other_user_id, other_access_token), b'{"typing": true, "timeout": 30000}', ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) # Should return immediately channel = self.make_request("GET", sync_url % (access_token, next_batch)) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) next_batch = channel.json_body["next_batch"] # Reset typing serial back to 0, as if the master had. @@ -273,7 +273,7 @@ class SyncTypingTests(unittest.HomeserverTestCase): self.helper.send(room, body="There!", tok=other_access_token) channel = self.make_request("GET", sync_url % (access_token, next_batch)) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) next_batch = channel.json_body["next_batch"] # This should time out! But it does not, because our stream token is @@ -281,7 +281,7 @@ class SyncTypingTests(unittest.HomeserverTestCase): # already seen) is new, since it's got a token above our new, now-reset # stream token. channel = self.make_request("GET", sync_url % (access_token, next_batch)) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) next_batch = channel.json_body["next_batch"] # Clear the typing information, so that it doesn't think everything is @@ -351,7 +351,7 @@ class SyncKnockTestCase( b"{}", self.knocker_tok, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # We expect to see the knock event in the stripped room state later self.expected_room_state[EventTypes.Member] = { diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index ac6b86ff6..9cca9edd3 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -139,7 +139,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): {}, access_token=self.tok, ) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) callback.assert_called_once() @@ -157,7 +157,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): {}, access_token=self.tok, ) - self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.result["code"], b"403", channel.result) def test_third_party_rules_workaround_synapse_errors_pass_through(self): """ @@ -193,7 +193,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): access_token=self.tok, ) # Check the error code - self.assertEquals(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.result["code"], b"429", channel.result) # Check the JSON body has had the `nasty` key injected self.assertEqual( channel.json_body, @@ -329,10 +329,10 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): self.hs.get_module_api().create_and_send_event_into_room(event_dict) ) - self.assertEquals(event.sender, self.user_id) - self.assertEquals(event.room_id, self.room_id) - self.assertEquals(event.type, "m.room.message") - self.assertEquals(event.content, content) + self.assertEqual(event.sender, self.user_id) + self.assertEqual(event.room_id, self.room_id) + self.assertEqual(event.type, "m.room.message") + self.assertEqual(event.content, content) @unittest.override_config( { diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py index de312cb63..8b2da88e8 100644 --- a/tests/rest/client/test_typing.py +++ b/tests/rest/client/test_typing.py @@ -72,9 +72,9 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": true, "timeout": 30000}', ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) - self.assertEquals(self.event_source.get_current_key(), 1) + self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( self.event_source.get_new_events( user=UserID.from_string(self.user_id), @@ -84,7 +84,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): is_guest=False, ) ) - self.assertEquals( + self.assertEqual( events[0], [ { @@ -101,7 +101,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": false}', ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) def test_typing_timeout(self): channel = self.make_request( @@ -109,19 +109,19 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": true, "timeout": 30000}', ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) - self.assertEquals(self.event_source.get_current_key(), 1) + self.assertEqual(self.event_source.get_current_key(), 1) self.reactor.advance(36) - self.assertEquals(self.event_source.get_current_key(), 2) + self.assertEqual(self.event_source.get_current_key(), 2) channel = self.make_request( "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": true, "timeout": 30000}', ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) - self.assertEquals(self.event_source.get_current_key(), 3) + self.assertEqual(self.event_source.get_current_key(), 3) diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py index 7f79336ab..658c21b2a 100644 --- a/tests/rest/client/test_upgrade_room.py +++ b/tests/rest/client/test_upgrade_room.py @@ -65,7 +65,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): Upgrading a room should work fine. """ channel = self._upgrade_room() - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertIn("replacement_room", channel.json_body) def test_not_in_room(self): @@ -77,7 +77,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): roomless_token = self.login(roomless, "pass") channel = self._upgrade_room(roomless_token) - self.assertEquals(403, channel.code, channel.result) + self.assertEqual(403, channel.code, channel.result) def test_power_levels(self): """ @@ -85,7 +85,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): """ # The other user doesn't have the proper power level. channel = self._upgrade_room(self.other_token) - self.assertEquals(403, channel.code, channel.result) + self.assertEqual(403, channel.code, channel.result) # Increase the power levels so that this user can upgrade. power_levels = self.helper.get_state( @@ -103,7 +103,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): # The upgrade should succeed! channel = self._upgrade_room(self.other_token) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) def test_power_levels_user_default(self): """ @@ -111,7 +111,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): """ # The other user doesn't have the proper power level. channel = self._upgrade_room(self.other_token) - self.assertEquals(403, channel.code, channel.result) + self.assertEqual(403, channel.code, channel.result) # Increase the power levels so that this user can upgrade. power_levels = self.helper.get_state( @@ -129,7 +129,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): # The upgrade should succeed! channel = self._upgrade_room(self.other_token) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) def test_power_levels_tombstone(self): """ @@ -137,7 +137,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): """ # The other user doesn't have the proper power level. channel = self._upgrade_room(self.other_token) - self.assertEquals(403, channel.code, channel.result) + self.assertEqual(403, channel.code, channel.result) # Increase the power levels so that this user can upgrade. power_levels = self.helper.get_state( @@ -155,7 +155,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): # The upgrade should succeed! channel = self._upgrade_room(self.other_token) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) power_levels = self.helper.get_state( self.room_id, @@ -197,7 +197,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): # Upgrade the room! channel = self._upgrade_room(room_id=space_id) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertIn("replacement_room", channel.json_body) new_space_id = channel.json_body["replacement_room"] diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 6878ccddb..cba9be17c 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -94,7 +94,7 @@ class MediaStorageTests(unittest.HomeserverTestCase): self.assertTrue(os.path.exists(local_path)) # Asserts the file is under the expected local cache directory - self.assertEquals( + self.assertEqual( os.path.commonprefix([self.primary_base_path, local_path]), self.primary_base_path, ) diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 59def6e59..1f6a9eb07 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -88,18 +88,18 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): res = self.get_success( self.store.have_seen_events("room1", ["event10", "event19"]) ) - self.assertEquals(res, {"event10"}) + self.assertEqual(res, {"event10"}) # that should result in a single db query - self.assertEquals(ctx.get_resource_usage().db_txn_count, 1) + self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) # a second lookup of the same events should cause no queries with LoggingContext(name="test") as ctx: res = self.get_success( self.store.have_seen_events("room1", ["event10", "event19"]) ) - self.assertEquals(res, {"event10"}) - self.assertEquals(ctx.get_resource_usage().db_txn_count, 0) + self.assertEqual(res, {"event10"}) + self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) def test_query_via_event_cache(self): # fetch an event into the event cache @@ -108,8 +108,8 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): # looking it up should now cause no db hits with LoggingContext(name="test") as ctx: res = self.get_success(self.store.have_seen_events("room1", ["event10"])) - self.assertEquals(res, {"event10"}) - self.assertEquals(ctx.get_resource_usage().db_txn_count, 0) + self.assertEqual(res, {"event10"}) + self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) class EventCacheTestCase(unittest.HomeserverTestCase): diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index d2f654214..ee599f433 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -88,21 +88,21 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): def test_retrieve_unknown_service_token(self) -> None: service = self.store.get_app_service_by_token("invalid_token") - self.assertEquals(service, None) + self.assertEqual(service, None) def test_retrieval_of_service(self) -> None: stored_service = self.store.get_app_service_by_token(self.as_token) assert stored_service is not None - self.assertEquals(stored_service.token, self.as_token) - self.assertEquals(stored_service.id, self.as_id) - self.assertEquals(stored_service.url, self.as_url) - self.assertEquals(stored_service.namespaces[ApplicationService.NS_ALIASES], []) - self.assertEquals(stored_service.namespaces[ApplicationService.NS_ROOMS], []) - self.assertEquals(stored_service.namespaces[ApplicationService.NS_USERS], []) + self.assertEqual(stored_service.token, self.as_token) + self.assertEqual(stored_service.id, self.as_id) + self.assertEqual(stored_service.url, self.as_url) + self.assertEqual(stored_service.namespaces[ApplicationService.NS_ALIASES], []) + self.assertEqual(stored_service.namespaces[ApplicationService.NS_ROOMS], []) + self.assertEqual(stored_service.namespaces[ApplicationService.NS_USERS], []) def test_retrieval_of_all_services(self) -> None: services = self.store.get_app_services() - self.assertEquals(len(services), 3) + self.assertEqual(len(services), 3) class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): @@ -182,7 +182,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): ) -> None: service = Mock(id="999") state = self.get_success(self.store.get_appservice_state(service)) - self.assertEquals(None, state) + self.assertEqual(None, state) def test_get_appservice_state_up( self, @@ -194,7 +194,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): state = self.get_success( defer.ensureDeferred(self.store.get_appservice_state(service)) ) - self.assertEquals(ApplicationServiceState.UP, state) + self.assertEqual(ApplicationServiceState.UP, state) def test_get_appservice_state_down( self, @@ -210,7 +210,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): ) service = Mock(id=self.as_list[1]["id"]) state = self.get_success(self.store.get_appservice_state(service)) - self.assertEquals(ApplicationServiceState.DOWN, state) + self.assertEqual(ApplicationServiceState.DOWN, state) def test_get_appservices_by_state_none( self, @@ -218,7 +218,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): services = self.get_success( self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) - self.assertEquals(0, len(services)) + self.assertEqual(0, len(services)) def test_set_appservices_state_down( self, @@ -235,7 +235,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): (ApplicationServiceState.DOWN.value,), ) ) - self.assertEquals(service.id, rows[0][0]) + self.assertEqual(service.id, rows[0][0]) def test_set_appservices_state_multiple_up( self, @@ -258,7 +258,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): (ApplicationServiceState.UP.value,), ) ) - self.assertEquals(service.id, rows[0][0]) + self.assertEqual(service.id, rows[0][0]) def test_create_appservice_txn_first( self, @@ -270,9 +270,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): self.store.create_appservice_txn(service, events, [], [], {}, {}) ) ) - self.assertEquals(txn.id, 1) - self.assertEquals(txn.events, events) - self.assertEquals(txn.service, service) + self.assertEqual(txn.id, 1) + self.assertEqual(txn.events, events) + self.assertEqual(txn.service, service) def test_create_appservice_txn_older_last_txn( self, @@ -285,9 +285,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): txn = self.get_success( self.store.create_appservice_txn(service, events, [], [], {}, {}) ) - self.assertEquals(txn.id, 9646) - self.assertEquals(txn.events, events) - self.assertEquals(txn.service, service) + self.assertEqual(txn.id, 9646) + self.assertEqual(txn.events, events) + self.assertEqual(txn.service, service) def test_create_appservice_txn_up_to_date_last_txn( self, @@ -298,9 +298,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): txn = self.get_success( self.store.create_appservice_txn(service, events, [], [], {}, {}) ) - self.assertEquals(txn.id, 9644) - self.assertEquals(txn.events, events) - self.assertEquals(txn.service, service) + self.assertEqual(txn.id, 9644) + self.assertEqual(txn.events, events) + self.assertEqual(txn.service, service) def test_create_appservice_txn_up_fuzzing( self, @@ -322,9 +322,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): txn = self.get_success( self.store.create_appservice_txn(service, events, [], [], {}, {}) ) - self.assertEquals(txn.id, 9644) - self.assertEquals(txn.events, events) - self.assertEquals(txn.service, service) + self.assertEqual(txn.id, 9644) + self.assertEqual(txn.events, events) + self.assertEqual(txn.service, service) def test_complete_appservice_txn_first_txn( self, @@ -346,8 +346,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): (service.id,), ) ) - self.assertEquals(1, len(res)) - self.assertEquals(txn_id, res[0][0]) + self.assertEqual(1, len(res)) + self.assertEqual(txn_id, res[0][0]) res = self.get_success( self.db_pool.runQuery( @@ -357,7 +357,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): (txn_id,), ) ) - self.assertEquals(0, len(res)) + self.assertEqual(0, len(res)) def test_complete_appservice_txn_existing_in_state_table( self, @@ -379,9 +379,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): (service.id,), ) ) - self.assertEquals(1, len(res)) - self.assertEquals(txn_id, res[0][0]) - self.assertEquals(ApplicationServiceState.UP.value, res[0][1]) + self.assertEqual(1, len(res)) + self.assertEqual(txn_id, res[0][0]) + self.assertEqual(ApplicationServiceState.UP.value, res[0][1]) res = self.get_success( self.db_pool.runQuery( @@ -391,7 +391,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): (txn_id,), ) ) - self.assertEquals(0, len(res)) + self.assertEqual(0, len(res)) def test_get_oldest_unsent_txn_none( self, @@ -399,7 +399,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): service = Mock(id=self.as_list[0]["id"]) txn = self.get_success(self.store.get_oldest_unsent_txn(service)) - self.assertEquals(None, txn) + self.assertEqual(None, txn) def test_get_oldest_unsent_txn(self) -> None: service = Mock(id=self.as_list[0]["id"]) @@ -416,9 +416,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): self.get_success(self._insert_txn(service.id, 12, other_events)) txn = self.get_success(self.store.get_oldest_unsent_txn(service)) - self.assertEquals(service, txn.service) - self.assertEquals(10, txn.id) - self.assertEquals(events, txn.events) + self.assertEqual(service, txn.service) + self.assertEqual(10, txn.id) + self.assertEqual(events, txn.events) def test_get_appservices_by_state_single( self, @@ -433,8 +433,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): services = self.get_success( self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) - self.assertEquals(1, len(services)) - self.assertEquals(self.as_list[0]["id"], services[0].id) + self.assertEqual(1, len(services)) + self.assertEqual(self.as_list[0]["id"], services[0].id) def test_get_appservices_by_state_multiple( self, @@ -455,8 +455,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): services = self.get_success( self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) - self.assertEquals(2, len(services)) - self.assertEquals( + self.assertEqual(2, len(services)) + self.assertEqual( {self.as_list[2]["id"], self.as_list[0]["id"]}, {services[0].id, services[1].id}, ) @@ -476,12 +476,12 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): value = self.get_success( self.store.get_type_stream_id_for_appservice(self.service, "read_receipt") ) - self.assertEquals(value, 0) + self.assertEqual(value, 0) value = self.get_success( self.store.get_type_stream_id_for_appservice(self.service, "presence") ) - self.assertEquals(value, 0) + self.assertEqual(value, 0) def test_get_type_stream_id_for_appservice_invalid_type(self) -> None: self.get_failure( diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 3e4f0579c..a8ffb52c0 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -103,7 +103,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) ) - self.assertEquals("Value", value) + self.assertEqual("Value", value) self.mock_txn.execute.assert_called_with( "SELECT retcol FROM tablename WHERE keycol = ?", ["TheKey"] ) @@ -121,7 +121,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) ) - self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret) + self.assertEqual({"colA": 1, "colB": 2, "colC": 3}, ret) self.mock_txn.execute.assert_called_with( "SELECT colA, colB, colC FROM tablename WHERE keycol = ?", ["TheKey"] ) @@ -154,7 +154,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) ) - self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret) + self.assertEqual([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret) self.mock_txn.execute.assert_called_with( "SELECT colA FROM tablename WHERE keycol = ?", ["A set"] ) diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index 7b72a9242..20bf3ca17 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -31,7 +31,7 @@ class DirectoryStoreTestCase(HomeserverTestCase): ) ) - self.assertEquals( + self.assertEqual( ["#my-room:test"], (self.get_success(self.store.get_aliases_for_room(self.room.to_string()))), ) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index c9e3b9fa7..0f9add484 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -57,7 +57,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 ) ) - self.assertEquals( + self.assertEqual( counts, NotifCounts( notify_count=noitf_count, diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py index 4ca212fd1..5806cb0e4 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py @@ -38,12 +38,12 @@ class DataStoreTestCase(unittest.HomeserverTestCase): self.store.get_users_paginate(0, 10, name="bc", guests=False) ) - self.assertEquals(1, total) - self.assertEquals(self.displayname, users.pop()["displayname"]) + self.assertEqual(1, total) + self.assertEqual(self.displayname, users.pop()["displayname"]) users, total = self.get_success( self.store.get_users_paginate(0, 10, name="BC", guests=False) ) - self.assertEquals(1, total) - self.assertEquals(self.displayname, users.pop()["displayname"]) + self.assertEqual(1, total) + self.assertEqual(self.displayname, users.pop()["displayname"]) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index b6f99af2f..a019d06e0 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -33,7 +33,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase): self.store.set_profile_displayname(self.u_frank.localpart, "Frank") ) - self.assertEquals( + self.assertEqual( "Frank", ( self.get_success( @@ -60,7 +60,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase): ) ) - self.assertEquals( + self.assertEqual( "http://my.site/here", ( self.get_success( diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 1fa495f77..a49ac1525 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -30,7 +30,7 @@ class RegistrationStoreTestCase(HomeserverTestCase): def test_register(self): self.get_success(self.store.register_user(self.user_id, self.pwhash)) - self.assertEquals( + self.assertEqual( { # TODO(paul): Surely this field should be 'user_id', not 'name' "name": self.user_id, @@ -131,7 +131,7 @@ class RegistrationStoreTestCase(HomeserverTestCase): ), ThreepidValidationError, ) - self.assertEquals(e.value.msg, "Unknown session_id", e) + self.assertEqual(e.value.msg, "Unknown session_id", e) # Set the config setting to true. self.store._ignore_unknown_session_error = True @@ -146,4 +146,4 @@ class RegistrationStoreTestCase(HomeserverTestCase): ), ThreepidValidationError, ) - self.assertEquals(e.value.msg, "Validation token not found or has expired", e) + self.assertEqual(e.value.msg, "Validation token not found or has expired", e) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 42bfca2a8..5b011e18c 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -104,7 +104,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase): self.store.get_current_state(room_id=self.room.to_string()) ) - self.assertEquals(1, len(state)) + self.assertEqual(1, len(state)) self.assertObjectHasAttributes( {"type": "m.room.name", "room_id": self.room.to_string(), "name": name}, state[0], @@ -121,7 +121,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase): self.store.get_current_state(room_id=self.room.to_string()) ) - self.assertEquals(1, len(state)) + self.assertEqual(1, len(state)) self.assertObjectHasAttributes( {"type": "m.room.topic", "room_id": self.room.to_string(), "topic": topic}, state[0], diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py index d62e01726..8dfc1e1db 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py @@ -53,7 +53,7 @@ class EventSearchInsertionTest(HomeserverTestCase): result = self.get_success( store.search_msgs([room_id], "hi bob", ["content.body"]) ) - self.assertEquals(result.get("count"), 1) + self.assertEqual(result.get("count"), 1) if isinstance(store.database_engine, PostgresEngine): self.assertIn("hi", result.get("highlights")) self.assertIn("bob", result.get("highlights")) @@ -62,14 +62,14 @@ class EventSearchInsertionTest(HomeserverTestCase): result = self.get_success( store.search_msgs([room_id], "another", ["content.body"]) ) - self.assertEquals(result.get("count"), 1) + self.assertEqual(result.get("count"), 1) if isinstance(store.database_engine, PostgresEngine): self.assertIn("another", result.get("highlights")) # Check that search works for a search term that overlaps with the message # containing a null byte and an unrelated message. result = self.get_success(store.search_msgs([room_id], "hi", ["content.body"])) - self.assertEquals(result.get("count"), 2) + self.assertEqual(result.get("count"), 2) result = self.get_success( store.search_msgs([room_id], "hi alice", ["content.body"]) ) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 7028f0dfb..b8f09a8ee 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -55,7 +55,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): ) ) - self.assertEquals([self.room], [m.room_id for m in rooms_for_user]) + self.assertEqual([self.room], [m.room_id for m in rooms_for_user]) def test_count_known_servers(self): """ diff --git a/tests/test_distributor.py b/tests/test_distributor.py index f8341041e..31546ea52 100644 --- a/tests/test_distributor.py +++ b/tests/test_distributor.py @@ -48,7 +48,7 @@ class DistributorTestCase(unittest.TestCase): observers[0].assert_called_once_with("Go") observers[1].assert_called_once_with("Go") - self.assertEquals(mock_logger.warning.call_count, 1) + self.assertEqual(mock_logger.warning.call_count, 1) self.assertIsInstance(mock_logger.warning.call_args[0][0], str) def test_signal_prereg(self): diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index 67dcf567c..37fada5c5 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -54,7 +54,7 @@ class TermsTestCase(unittest.HomeserverTestCase): request_data = json.dumps({"username": "kermit", "password": "monkey"}) channel = self.make_request(b"POST", self.url, request_data) - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) self.assertTrue(channel.json_body is not None) self.assertIsInstance(channel.json_body["session"], str) @@ -99,7 +99,7 @@ class TermsTestCase(unittest.HomeserverTestCase): # We don't bother checking that the response is correct - we'll leave that to # other tests. We just want to make sure we're on the right path. - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) # Finish the UI auth for terms request_data = json.dumps( @@ -117,7 +117,7 @@ class TermsTestCase(unittest.HomeserverTestCase): # We're interested in getting a response that looks like a successful # registration, not so much that the details are exactly what we want. - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.assertTrue(channel.json_body is not None) self.assertIsInstance(channel.json_body["user_id"], str) diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py index f2ef1c605..d04bcae0f 100644 --- a/tests/test_test_utils.py +++ b/tests/test_test_utils.py @@ -25,7 +25,7 @@ class MockClockTestCase(unittest.TestCase): self.clock.advance_time(20) - self.assertEquals(20, self.clock.time() - start_time) + self.assertEqual(20, self.clock.time() - start_time) def test_later(self): invoked = [0, 0] diff --git a/tests/test_types.py b/tests/test_types.py index 0d0c00d97..80888a744 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -22,9 +22,9 @@ class UserIDTestCase(unittest.HomeserverTestCase): def test_parse(self): user = UserID.from_string("@1234abcd:test") - self.assertEquals("1234abcd", user.localpart) - self.assertEquals("test", user.domain) - self.assertEquals(True, self.hs.is_mine(user)) + self.assertEqual("1234abcd", user.localpart) + self.assertEqual("test", user.domain) + self.assertEqual(True, self.hs.is_mine(user)) def test_pase_empty(self): with self.assertRaises(SynapseError): @@ -33,7 +33,7 @@ class UserIDTestCase(unittest.HomeserverTestCase): def test_build(self): user = UserID("5678efgh", "my.domain") - self.assertEquals(user.to_string(), "@5678efgh:my.domain") + self.assertEqual(user.to_string(), "@5678efgh:my.domain") def test_compare(self): userA = UserID.from_string("@userA:my.domain") @@ -48,14 +48,14 @@ class RoomAliasTestCase(unittest.HomeserverTestCase): def test_parse(self): room = RoomAlias.from_string("#channel:test") - self.assertEquals("channel", room.localpart) - self.assertEquals("test", room.domain) - self.assertEquals(True, self.hs.is_mine(room)) + self.assertEqual("channel", room.localpart) + self.assertEqual("test", room.domain) + self.assertEqual(True, self.hs.is_mine(room)) def test_build(self): room = RoomAlias("channel", "my.domain") - self.assertEquals(room.to_string(), "#channel:my.domain") + self.assertEqual(room.to_string(), "#channel:my.domain") def test_validate(self): id_string = "#test:domain,test" diff --git a/tests/unittest.py b/tests/unittest.py index 0caa8e7a4..326895f4c 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -152,12 +152,12 @@ class TestCase(unittest.TestCase): def assertObjectHasAttributes(self, attrs, obj): """Asserts that the given object has each of the attributes given, and - that the value of each matches according to assertEquals.""" + that the value of each matches according to assertEqual.""" for key in attrs.keys(): if not hasattr(obj, key): raise AssertionError("Expected obj to have a '.%s'" % key) try: - self.assertEquals(attrs[key], getattr(obj, key)) + self.assertEqual(attrs[key], getattr(obj, key)) except AssertionError as e: raise (type(e))(f"Assert error for '.{key}':") from e @@ -169,7 +169,7 @@ class TestCase(unittest.TestCase): actual (dict): The test result. Extra keys will not be checked. """ for key in required: - self.assertEquals( + self.assertEqual( required[key], actual[key], msg="%s mismatch. %s" % (key, actual) ) diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py index c613ce3f1..02b99b466 100644 --- a/tests/util/caches/test_deferred_cache.py +++ b/tests/util/caches/test_deferred_cache.py @@ -31,7 +31,7 @@ class DeferredCacheTestCase(TestCase): cache = DeferredCache("test") cache.prefill("foo", 123) - self.assertEquals(self.successResultOf(cache.get("foo")), 123) + self.assertEqual(self.successResultOf(cache.get("foo")), 123) def test_hit_deferred(self): cache = DeferredCache("test") diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index ced3efd93..b92d3f0c1 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -434,8 +434,8 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): a = A() - self.assertEquals((yield a.func("foo")), "foo") - self.assertEquals((yield a.func("bar")), "bar") + self.assertEqual((yield a.func("foo")), "foo") + self.assertEqual((yield a.func("bar")), "bar") @defer.inlineCallbacks def test_hit(self): @@ -450,10 +450,10 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): a = A() yield a.func("foo") - self.assertEquals(callcount[0], 1) + self.assertEqual(callcount[0], 1) - self.assertEquals((yield a.func("foo")), "foo") - self.assertEquals(callcount[0], 1) + self.assertEqual((yield a.func("foo")), "foo") + self.assertEqual(callcount[0], 1) @defer.inlineCallbacks def test_invalidate(self): @@ -468,13 +468,13 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): a = A() yield a.func("foo") - self.assertEquals(callcount[0], 1) + self.assertEqual(callcount[0], 1) a.func.invalidate(("foo",)) yield a.func("foo") - self.assertEquals(callcount[0], 2) + self.assertEqual(callcount[0], 2) def test_invalidate_missing(self): class A: @@ -499,7 +499,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): for k in range(0, 12): yield a.func(k) - self.assertEquals(callcount[0], 12) + self.assertEqual(callcount[0], 12) # There must have been at least 2 evictions, meaning if we calculate # all 12 values again, we must get called at least 2 more times @@ -525,8 +525,8 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): a.func.prefill(("foo",), 456) - self.assertEquals(a.func("foo").result, 456) - self.assertEquals(callcount[0], 0) + self.assertEqual(a.func("foo").result, 456) + self.assertEqual(callcount[0], 0) @defer.inlineCallbacks def test_invalidate_context(self): @@ -547,19 +547,19 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): a = A() yield a.func2("foo") - self.assertEquals(callcount[0], 1) - self.assertEquals(callcount2[0], 1) + self.assertEqual(callcount[0], 1) + self.assertEqual(callcount2[0], 1) a.func.invalidate(("foo",)) yield a.func("foo") - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 1) + self.assertEqual(callcount[0], 2) + self.assertEqual(callcount2[0], 1) yield a.func2("foo") - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 2) + self.assertEqual(callcount[0], 2) + self.assertEqual(callcount2[0], 2) @defer.inlineCallbacks def test_eviction_context(self): @@ -581,22 +581,22 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): yield a.func2("foo") yield a.func2("foo2") - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 2) + self.assertEqual(callcount[0], 2) + self.assertEqual(callcount2[0], 2) yield a.func2("foo") - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 2) + self.assertEqual(callcount[0], 2) + self.assertEqual(callcount2[0], 2) yield a.func("foo3") - self.assertEquals(callcount[0], 3) - self.assertEquals(callcount2[0], 2) + self.assertEqual(callcount[0], 3) + self.assertEqual(callcount2[0], 2) yield a.func2("foo") - self.assertEquals(callcount[0], 4) - self.assertEquals(callcount2[0], 3) + self.assertEqual(callcount[0], 4) + self.assertEqual(callcount2[0], 3) @defer.inlineCallbacks def test_double_get(self): @@ -619,30 +619,30 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): yield a.func2("foo") - self.assertEquals(callcount[0], 1) - self.assertEquals(callcount2[0], 1) + self.assertEqual(callcount[0], 1) + self.assertEqual(callcount2[0], 1) a.func2.invalidate(("foo",)) - self.assertEquals(a.func2.cache.cache.del_multi.call_count, 1) + self.assertEqual(a.func2.cache.cache.del_multi.call_count, 1) yield a.func2("foo") a.func2.invalidate(("foo",)) - self.assertEquals(a.func2.cache.cache.del_multi.call_count, 2) + self.assertEqual(a.func2.cache.cache.del_multi.call_count, 2) - self.assertEquals(callcount[0], 1) - self.assertEquals(callcount2[0], 2) + self.assertEqual(callcount[0], 1) + self.assertEqual(callcount2[0], 2) a.func.invalidate(("foo",)) - self.assertEquals(a.func2.cache.cache.del_multi.call_count, 3) + self.assertEqual(a.func2.cache.cache.del_multi.call_count, 3) yield a.func("foo") - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 2) + self.assertEqual(callcount[0], 2) + self.assertEqual(callcount2[0], 2) yield a.func2("foo") - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 3) + self.assertEqual(callcount[0], 2) + self.assertEqual(callcount2[0], 3) class CachedListDescriptorTestCase(unittest.TestCase): diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py index e6e13ba06..7f60aae5b 100644 --- a/tests/util/test_expiring_cache.py +++ b/tests/util/test_expiring_cache.py @@ -26,8 +26,8 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase): cache = ExpiringCache("test", clock, max_len=1) cache["key"] = "value" - self.assertEquals(cache.get("key"), "value") - self.assertEquals(cache["key"], "value") + self.assertEqual(cache.get("key"), "value") + self.assertEqual(cache["key"], "value") def test_eviction(self): clock = MockClock() @@ -35,13 +35,13 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase): cache["key"] = "value" cache["key2"] = "value2" - self.assertEquals(cache.get("key"), "value") - self.assertEquals(cache.get("key2"), "value2") + self.assertEqual(cache.get("key"), "value") + self.assertEqual(cache.get("key2"), "value2") cache["key3"] = "value3" - self.assertEquals(cache.get("key"), None) - self.assertEquals(cache.get("key2"), "value2") - self.assertEquals(cache.get("key3"), "value3") + self.assertEqual(cache.get("key"), None) + self.assertEqual(cache.get("key2"), "value2") + self.assertEqual(cache.get("key3"), "value3") def test_iterable_eviction(self): clock = MockClock() @@ -51,15 +51,15 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase): cache["key2"] = [2, 3] cache["key3"] = [4, 5] - self.assertEquals(cache.get("key"), [1]) - self.assertEquals(cache.get("key2"), [2, 3]) - self.assertEquals(cache.get("key3"), [4, 5]) + self.assertEqual(cache.get("key"), [1]) + self.assertEqual(cache.get("key2"), [2, 3]) + self.assertEqual(cache.get("key3"), [4, 5]) cache["key4"] = [6, 7] - self.assertEquals(cache.get("key"), None) - self.assertEquals(cache.get("key2"), None) - self.assertEquals(cache.get("key3"), [4, 5]) - self.assertEquals(cache.get("key4"), [6, 7]) + self.assertEqual(cache.get("key"), None) + self.assertEqual(cache.get("key2"), None) + self.assertEqual(cache.get("key3"), [4, 5]) + self.assertEqual(cache.get("key4"), [6, 7]) def test_time_eviction(self): clock = MockClock() @@ -69,13 +69,13 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase): clock.advance_time(0.5) cache["key2"] = 2 - self.assertEquals(cache.get("key"), 1) - self.assertEquals(cache.get("key2"), 2) + self.assertEqual(cache.get("key"), 1) + self.assertEqual(cache.get("key2"), 2) clock.advance_time(0.9) - self.assertEquals(cache.get("key"), None) - self.assertEquals(cache.get("key2"), 2) + self.assertEqual(cache.get("key"), None) + self.assertEqual(cache.get("key2"), 2) clock.advance_time(1) - self.assertEquals(cache.get("key"), None) - self.assertEquals(cache.get("key2"), None) + self.assertEqual(cache.get("key"), None) + self.assertEqual(cache.get("key2"), None) diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 621b0f9fc..2ad321e18 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -17,7 +17,7 @@ from .. import unittest class LoggingContextTestCase(unittest.TestCase): def _check_test_key(self, value): - self.assertEquals(current_context().name, value) + self.assertEqual(current_context().name, value) def test_with_context(self): with LoggingContext("test"): diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 291644eb7..321fc1776 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -27,37 +27,37 @@ class LruCacheTestCase(unittest.HomeserverTestCase): def test_get_set(self): cache = LruCache(1) cache["key"] = "value" - self.assertEquals(cache.get("key"), "value") - self.assertEquals(cache["key"], "value") + self.assertEqual(cache.get("key"), "value") + self.assertEqual(cache["key"], "value") def test_eviction(self): cache = LruCache(2) cache[1] = 1 cache[2] = 2 - self.assertEquals(cache.get(1), 1) - self.assertEquals(cache.get(2), 2) + self.assertEqual(cache.get(1), 1) + self.assertEqual(cache.get(2), 2) cache[3] = 3 - self.assertEquals(cache.get(1), None) - self.assertEquals(cache.get(2), 2) - self.assertEquals(cache.get(3), 3) + self.assertEqual(cache.get(1), None) + self.assertEqual(cache.get(2), 2) + self.assertEqual(cache.get(3), 3) def test_setdefault(self): cache = LruCache(1) - self.assertEquals(cache.setdefault("key", 1), 1) - self.assertEquals(cache.get("key"), 1) - self.assertEquals(cache.setdefault("key", 2), 1) - self.assertEquals(cache.get("key"), 1) + self.assertEqual(cache.setdefault("key", 1), 1) + self.assertEqual(cache.get("key"), 1) + self.assertEqual(cache.setdefault("key", 2), 1) + self.assertEqual(cache.get("key"), 1) cache["key"] = 2 # Make sure overriding works. - self.assertEquals(cache.get("key"), 2) + self.assertEqual(cache.get("key"), 2) def test_pop(self): cache = LruCache(1) cache["key"] = 1 - self.assertEquals(cache.pop("key"), 1) - self.assertEquals(cache.pop("key"), None) + self.assertEqual(cache.pop("key"), 1) + self.assertEqual(cache.pop("key"), None) def test_del_multi(self): cache = LruCache(4, cache_type=TreeCache) @@ -66,23 +66,23 @@ class LruCacheTestCase(unittest.HomeserverTestCase): cache[("vehicles", "car")] = "vroom" cache[("vehicles", "train")] = "chuff" - self.assertEquals(len(cache), 4) + self.assertEqual(len(cache), 4) - self.assertEquals(cache.get(("animal", "cat")), "mew") - self.assertEquals(cache.get(("vehicles", "car")), "vroom") + self.assertEqual(cache.get(("animal", "cat")), "mew") + self.assertEqual(cache.get(("vehicles", "car")), "vroom") cache.del_multi(("animal",)) - self.assertEquals(len(cache), 2) - self.assertEquals(cache.get(("animal", "cat")), None) - self.assertEquals(cache.get(("animal", "dog")), None) - self.assertEquals(cache.get(("vehicles", "car")), "vroom") - self.assertEquals(cache.get(("vehicles", "train")), "chuff") + self.assertEqual(len(cache), 2) + self.assertEqual(cache.get(("animal", "cat")), None) + self.assertEqual(cache.get(("animal", "dog")), None) + self.assertEqual(cache.get(("vehicles", "car")), "vroom") + self.assertEqual(cache.get(("vehicles", "train")), "chuff") # Man from del_multi say "Yes". def test_clear(self): cache = LruCache(1) cache["key"] = 1 cache.clear() - self.assertEquals(len(cache), 0) + self.assertEqual(len(cache), 0) @override_config({"caches": {"per_cache_factors": {"mycache": 10}}}) def test_special_size(self): @@ -105,10 +105,10 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): self.assertFalse(m.called) cache.set("key", "value2") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) cache.set("key", "value") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) def test_multi_get(self): m = Mock() @@ -124,10 +124,10 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): self.assertFalse(m.called) cache.set("key", "value2") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) cache.set("key", "value") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) def test_set(self): m = Mock() @@ -140,10 +140,10 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): self.assertFalse(m.called) cache.set("key", "value2") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) cache.set("key", "value") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) def test_pop(self): m = Mock() @@ -153,13 +153,13 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): self.assertFalse(m.called) cache.pop("key") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) cache.set("key", "value") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) cache.pop("key") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) def test_del_multi(self): m1 = Mock() @@ -173,17 +173,17 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): cache.set(("b", "1"), "value", callbacks=[m3]) cache.set(("b", "2"), "value", callbacks=[m4]) - self.assertEquals(m1.call_count, 0) - self.assertEquals(m2.call_count, 0) - self.assertEquals(m3.call_count, 0) - self.assertEquals(m4.call_count, 0) + self.assertEqual(m1.call_count, 0) + self.assertEqual(m2.call_count, 0) + self.assertEqual(m3.call_count, 0) + self.assertEqual(m4.call_count, 0) cache.del_multi(("a",)) - self.assertEquals(m1.call_count, 1) - self.assertEquals(m2.call_count, 1) - self.assertEquals(m3.call_count, 0) - self.assertEquals(m4.call_count, 0) + self.assertEqual(m1.call_count, 1) + self.assertEqual(m2.call_count, 1) + self.assertEqual(m3.call_count, 0) + self.assertEqual(m4.call_count, 0) def test_clear(self): m1 = Mock() @@ -193,13 +193,13 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): cache.set("key1", "value", callbacks=[m1]) cache.set("key2", "value", callbacks=[m2]) - self.assertEquals(m1.call_count, 0) - self.assertEquals(m2.call_count, 0) + self.assertEqual(m1.call_count, 0) + self.assertEqual(m2.call_count, 0) cache.clear() - self.assertEquals(m1.call_count, 1) - self.assertEquals(m2.call_count, 1) + self.assertEqual(m1.call_count, 1) + self.assertEqual(m2.call_count, 1) def test_eviction(self): m1 = Mock(name="m1") @@ -210,33 +210,33 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): cache.set("key1", "value", callbacks=[m1]) cache.set("key2", "value", callbacks=[m2]) - self.assertEquals(m1.call_count, 0) - self.assertEquals(m2.call_count, 0) - self.assertEquals(m3.call_count, 0) + self.assertEqual(m1.call_count, 0) + self.assertEqual(m2.call_count, 0) + self.assertEqual(m3.call_count, 0) cache.set("key3", "value", callbacks=[m3]) - self.assertEquals(m1.call_count, 1) - self.assertEquals(m2.call_count, 0) - self.assertEquals(m3.call_count, 0) + self.assertEqual(m1.call_count, 1) + self.assertEqual(m2.call_count, 0) + self.assertEqual(m3.call_count, 0) cache.set("key3", "value") - self.assertEquals(m1.call_count, 1) - self.assertEquals(m2.call_count, 0) - self.assertEquals(m3.call_count, 0) + self.assertEqual(m1.call_count, 1) + self.assertEqual(m2.call_count, 0) + self.assertEqual(m3.call_count, 0) cache.get("key2") - self.assertEquals(m1.call_count, 1) - self.assertEquals(m2.call_count, 0) - self.assertEquals(m3.call_count, 0) + self.assertEqual(m1.call_count, 1) + self.assertEqual(m2.call_count, 0) + self.assertEqual(m3.call_count, 0) cache.set("key1", "value", callbacks=[m1]) - self.assertEquals(m1.call_count, 1) - self.assertEquals(m2.call_count, 0) - self.assertEquals(m3.call_count, 1) + self.assertEqual(m1.call_count, 1) + self.assertEqual(m2.call_count, 0) + self.assertEqual(m3.call_count, 1) class LruCacheSizedTestCase(unittest.HomeserverTestCase): @@ -247,20 +247,20 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase): cache["key3"] = [3] cache["key4"] = [4] - self.assertEquals(cache["key1"], [0]) - self.assertEquals(cache["key2"], [1, 2]) - self.assertEquals(cache["key3"], [3]) - self.assertEquals(cache["key4"], [4]) - self.assertEquals(len(cache), 5) + self.assertEqual(cache["key1"], [0]) + self.assertEqual(cache["key2"], [1, 2]) + self.assertEqual(cache["key3"], [3]) + self.assertEqual(cache["key4"], [4]) + self.assertEqual(len(cache), 5) cache["key5"] = [5, 6] - self.assertEquals(len(cache), 4) - self.assertEquals(cache.get("key1"), None) - self.assertEquals(cache.get("key2"), None) - self.assertEquals(cache["key3"], [3]) - self.assertEquals(cache["key4"], [4]) - self.assertEquals(cache["key5"], [5, 6]) + self.assertEqual(len(cache), 4) + self.assertEqual(cache.get("key1"), None) + self.assertEqual(cache.get("key2"), None) + self.assertEqual(cache["key3"], [3]) + self.assertEqual(cache["key4"], [4]) + self.assertEqual(cache["key5"], [5, 6]) def test_zero_size_drop_from_cache(self) -> None: """Test that `drop_from_cache` works correctly with 0-sized entries.""" diff --git a/tests/util/test_treecache.py b/tests/util/test_treecache.py index 606637205..567cb1846 100644 --- a/tests/util/test_treecache.py +++ b/tests/util/test_treecache.py @@ -23,61 +23,61 @@ class TreeCacheTestCase(unittest.TestCase): cache = TreeCache() cache[("a",)] = "A" cache[("b",)] = "B" - self.assertEquals(cache.get(("a",)), "A") - self.assertEquals(cache.get(("b",)), "B") - self.assertEquals(len(cache), 2) + self.assertEqual(cache.get(("a",)), "A") + self.assertEqual(cache.get(("b",)), "B") + self.assertEqual(len(cache), 2) def test_pop_onelevel(self): cache = TreeCache() cache[("a",)] = "A" cache[("b",)] = "B" - self.assertEquals(cache.pop(("a",)), "A") - self.assertEquals(cache.pop(("a",)), None) - self.assertEquals(cache.get(("b",)), "B") - self.assertEquals(len(cache), 1) + self.assertEqual(cache.pop(("a",)), "A") + self.assertEqual(cache.pop(("a",)), None) + self.assertEqual(cache.get(("b",)), "B") + self.assertEqual(len(cache), 1) def test_get_set_twolevel(self): cache = TreeCache() cache[("a", "a")] = "AA" cache[("a", "b")] = "AB" cache[("b", "a")] = "BA" - self.assertEquals(cache.get(("a", "a")), "AA") - self.assertEquals(cache.get(("a", "b")), "AB") - self.assertEquals(cache.get(("b", "a")), "BA") - self.assertEquals(len(cache), 3) + self.assertEqual(cache.get(("a", "a")), "AA") + self.assertEqual(cache.get(("a", "b")), "AB") + self.assertEqual(cache.get(("b", "a")), "BA") + self.assertEqual(len(cache), 3) def test_pop_twolevel(self): cache = TreeCache() cache[("a", "a")] = "AA" cache[("a", "b")] = "AB" cache[("b", "a")] = "BA" - self.assertEquals(cache.pop(("a", "a")), "AA") - self.assertEquals(cache.get(("a", "a")), None) - self.assertEquals(cache.get(("a", "b")), "AB") - self.assertEquals(cache.pop(("b", "a")), "BA") - self.assertEquals(cache.pop(("b", "a")), None) - self.assertEquals(len(cache), 1) + self.assertEqual(cache.pop(("a", "a")), "AA") + self.assertEqual(cache.get(("a", "a")), None) + self.assertEqual(cache.get(("a", "b")), "AB") + self.assertEqual(cache.pop(("b", "a")), "BA") + self.assertEqual(cache.pop(("b", "a")), None) + self.assertEqual(len(cache), 1) def test_pop_mixedlevel(self): cache = TreeCache() cache[("a", "a")] = "AA" cache[("a", "b")] = "AB" cache[("b", "a")] = "BA" - self.assertEquals(cache.get(("a", "a")), "AA") + self.assertEqual(cache.get(("a", "a")), "AA") popped = cache.pop(("a",)) - self.assertEquals(cache.get(("a", "a")), None) - self.assertEquals(cache.get(("a", "b")), None) - self.assertEquals(cache.get(("b", "a")), "BA") - self.assertEquals(len(cache), 1) + self.assertEqual(cache.get(("a", "a")), None) + self.assertEqual(cache.get(("a", "b")), None) + self.assertEqual(cache.get(("b", "a")), "BA") + self.assertEqual(len(cache), 1) - self.assertEquals({"AA", "AB"}, set(iterate_tree_cache_entry(popped))) + self.assertEqual({"AA", "AB"}, set(iterate_tree_cache_entry(popped))) def test_clear(self): cache = TreeCache() cache[("a",)] = "A" cache[("b",)] = "B" cache.clear() - self.assertEquals(len(cache), 0) + self.assertEqual(len(cache), 0) def test_contains(self): cache = TreeCache() From 9e83521af860cb33a7459dbe74188ce5ef39f446 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 28 Feb 2022 07:52:44 -0500 Subject: [PATCH 61/84] Properly failover for unknown endpoints from Conduit/Dendrite. (#12077) Before this fix, a legitimate 404 from a federation endpoint (e.g. due to an unknown room) would be treated as an unknown endpoint. This could cause unnecessary federation traffic. --- changelog.d/12077.bugfix | 1 + synapse/federation/federation_client.py | 22 +++++++++++++--------- 2 files changed, 14 insertions(+), 9 deletions(-) create mode 100644 changelog.d/12077.bugfix diff --git a/changelog.d/12077.bugfix b/changelog.d/12077.bugfix new file mode 100644 index 000000000..1bce82082 --- /dev/null +++ b/changelog.d/12077.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where Synapse would make additional failing requests over federation for missing data. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 2121e92e3..a4bae3c4c 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -615,11 +615,15 @@ class FederationClient(FederationBase): synapse_error = e.to_synapse_error() # There is no good way to detect an "unknown" endpoint. # - # Dendrite returns a 404 (with no body); synapse returns a 400 + # Dendrite returns a 404 (with a body of "404 page not found"); + # Conduit returns a 404 (with no body); and Synapse returns a 400 # with M_UNRECOGNISED. - return e.code == 404 or ( - e.code == 400 and synapse_error.errcode == Codes.UNRECOGNIZED - ) + # + # This needs to be rather specific as some endpoints truly do return 404 + # errors. + return ( + e.code == 404 and (not e.response or e.response == b"404 page not found") + ) or (e.code == 400 and synapse_error.errcode == Codes.UNRECOGNIZED) async def _try_destination_list( self, @@ -1002,7 +1006,7 @@ class FederationClient(FederationBase): ) except HttpResponseException as e: # If an error is received that is due to an unrecognised endpoint, - # fallback to the v1 endpoint. Otherwise consider it a legitmate error + # fallback to the v1 endpoint. Otherwise, consider it a legitimate error # and raise. if not self._is_unknown_endpoint(e): raise @@ -1071,7 +1075,7 @@ class FederationClient(FederationBase): except HttpResponseException as e: # If an error is received that is due to an unrecognised endpoint, # fallback to the v1 endpoint if the room uses old-style event IDs. - # Otherwise consider it a legitmate error and raise. + # Otherwise, consider it a legitimate error and raise. err = e.to_synapse_error() if self._is_unknown_endpoint(e, err): if room_version.event_format != EventFormatVersions.V1: @@ -1132,7 +1136,7 @@ class FederationClient(FederationBase): ) except HttpResponseException as e: # If an error is received that is due to an unrecognised endpoint, - # fallback to the v1 endpoint. Otherwise consider it a legitmate error + # fallback to the v1 endpoint. Otherwise, consider it a legitimate error # and raise. if not self._is_unknown_endpoint(e): raise @@ -1458,8 +1462,8 @@ class FederationClient(FederationBase): ) except HttpResponseException as e: # If an error is received that is due to an unrecognised endpoint, - # fallback to the unstable endpoint. Otherwise consider it a - # legitmate error and raise. + # fallback to the unstable endpoint. Otherwise, consider it a + # legitimate error and raise. if not self._is_unknown_endpoint(e): raise From 5565f454e1b323b637dd418549f70fadac0f44b4 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 28 Feb 2022 14:10:36 +0000 Subject: [PATCH 62/84] Actually fix bad debug logging rejecting device list & signing key transactions (#12098) --- changelog.d/12098.bugfix | 1 + .../federation/transport/server/federation.py | 2 +- tests/federation/transport/test_server.py | 20 ++++++++++++++++++- 3 files changed, 21 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12098.bugfix diff --git a/changelog.d/12098.bugfix b/changelog.d/12098.bugfix new file mode 100644 index 000000000..6b696692e --- /dev/null +++ b/changelog.d/12098.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.51.0rc1 where incoming federation transactions containing at least one EDU would be dropped if debug logging was enabled for `synapse.8631_debug`. \ No newline at end of file diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 9cc9a7339..23ce34305 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -110,7 +110,7 @@ class FederationSendServlet(BaseFederationServerServlet): if issue_8631_logger.isEnabledFor(logging.DEBUG): DEVICE_UPDATE_EDUS = ["m.device_list_update", "m.signing_key_update"] device_list_updates = [ - edu.content + edu.get("content", {}) for edu in transaction_data.get("edus", []) if edu.get("edu_type") in DEVICE_UPDATE_EDUS ] diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py index ce49d094d..5f001c33b 100644 --- a/tests/federation/transport/test_server.py +++ b/tests/federation/transport/test_server.py @@ -13,7 +13,7 @@ # limitations under the License. from tests import unittest -from tests.unittest import override_config +from tests.unittest import DEBUG, override_config class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase): @@ -38,3 +38,21 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase): "/_matrix/federation/v1/publicRooms", ) self.assertEqual(200, channel.code) + + @DEBUG + def test_edu_debugging_doesnt_explode(self): + """Sanity check incoming federation succeeds with `synapse.debug_8631` enabled. + + Remove this when we strip out issue_8631_logger. + """ + channel = self.make_signed_federation_request( + "PUT", + "/_matrix/federation/v1/send/txn_id_1234/", + content={ + "edus": [ + {"edu_type": "m.device_list_update", "content": {"foo": "bar"}} + ], + "pdus": [], + }, + ) + self.assertEqual(200, channel.code) From 6c0b44a3d73f73dc5913f081418347645dc84d6f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 28 Feb 2022 17:40:24 +0000 Subject: [PATCH 63/84] Fix `PushRuleEvaluator` and `Filter` to work on frozendicts (#12100) * Fix `PushRuleEvaluator` to work on frozendicts frozendicts do not (necessarily) inherit from dict, so this needs to handle them correctly. * Fix event filtering for frozen events Looks like this one was introduced by #11194. --- changelog.d/12100.bugfix | 1 + synapse/api/filtering.py | 5 +++-- synapse/push/push_rule_evaluator.py | 8 ++++---- tests/api/test_filtering.py | 10 ++++++++++ tests/push/test_push_rule_evaluator.py | 9 +++++++++ 5 files changed, 27 insertions(+), 6 deletions(-) create mode 100644 changelog.d/12100.bugfix diff --git a/changelog.d/12100.bugfix b/changelog.d/12100.bugfix new file mode 100644 index 000000000..181095ad9 --- /dev/null +++ b/changelog.d/12100.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug which could cause push notifications to malfunction if `use_frozen_dicts` was set in the configuration. diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index fe4cc2e8e..cb532d723 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -22,6 +22,7 @@ from typing import ( Dict, Iterable, List, + Mapping, Optional, Set, TypeVar, @@ -361,10 +362,10 @@ class Filter: return self._check_fields(field_matchers) else: content = event.get("content") - # Content is assumed to be a dict below, so ensure it is. This should + # Content is assumed to be a mapping below, so ensure it is. This should # always be true for events, but account_data has been allowed to # have non-dict content. - if not isinstance(content, dict): + if not isinstance(content, Mapping): content = {} sender = event.get("sender", None) diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 659a53805..f617c759e 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -15,12 +15,12 @@ import logging import re -from typing import Any, Dict, List, Optional, Pattern, Tuple, Union +from typing import Any, Dict, List, Mapping, Optional, Pattern, Tuple, Union from matrix_common.regex import glob_to_regex, to_word_pattern from synapse.events import EventBase -from synapse.types import JsonDict, UserID +from synapse.types import UserID from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) @@ -223,7 +223,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: def _flatten_dict( - d: Union[EventBase, JsonDict], + d: Union[EventBase, Mapping[str, Any]], prefix: Optional[List[str]] = None, result: Optional[Dict[str, str]] = None, ) -> Dict[str, str]: @@ -234,7 +234,7 @@ def _flatten_dict( for key, value in d.items(): if isinstance(value, str): result[".".join(prefix + [key])] = value.lower() - elif isinstance(value, dict): + elif isinstance(value, Mapping): _flatten_dict(value, prefix=(prefix + [key]), result=result) return result diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 2525018e9..8c3354ce3 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -18,6 +18,7 @@ from unittest.mock import patch import jsonschema +from frozendict import frozendict from synapse.api.constants import EventContentFields from synapse.api.errors import SynapseError @@ -327,6 +328,15 @@ class FilteringTestCase(unittest.HomeserverTestCase): self.assertFalse(Filter(self.hs, definition)._check(event)) + # check it works with frozendicts too + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content=frozendict({EventContentFields.LABELS: ["#fun"]}), + ) + self.assertTrue(Filter(self.hs, definition)._check(event)) + def test_filter_not_labels(self): definition = {"org.matrix.not_labels": ["#fun"]} event = MockEvent( diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index a52e89e40..3849beb9d 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -14,6 +14,8 @@ from typing import Any, Dict +import frozendict + from synapse.api.room_versions import RoomVersions from synapse.events import FrozenEvent from synapse.push import push_rule_evaluator @@ -191,6 +193,13 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): "pattern should only match at the start/end of the value", ) + # it should work on frozendicts too + self._assert_matches( + condition, + frozendict.frozendict({"value": "FoobaZ"}), + "patterns should match on frozendicts", + ) + # wildcards should match condition = { "kind": "event_match", From 1901cb1d4a8b7d9af64493fbd336e9aa2561c20c Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 28 Feb 2022 18:47:37 +0100 Subject: [PATCH 64/84] Add type hints to `tests/rest/client` (#12084) --- changelog.d/12084.misc | 1 + mypy.ini | 3 +- tests/rest/client/test_profile.py | 78 +++++++++++++---------- tests/rest/client/test_push_rule_attrs.py | 26 ++++---- tests/rest/client/test_redactions.py | 26 +++++--- tests/rest/client/test_relations.py | 58 +++++++++-------- tests/rest/client/test_retention.py | 41 ++++++++---- tests/rest/client/test_sendtodevice.py | 8 +-- tests/rest/client/test_shadow_banned.py | 22 ++++--- tests/rest/client/test_shared_rooms.py | 20 +++--- tests/rest/client/test_upgrade_room.py | 17 +++-- tests/rest/client/utils.py | 36 +++++++---- 12 files changed, 198 insertions(+), 138 deletions(-) create mode 100644 changelog.d/12084.misc diff --git a/changelog.d/12084.misc b/changelog.d/12084.misc new file mode 100644 index 000000000..0360dbd61 --- /dev/null +++ b/changelog.d/12084.misc @@ -0,0 +1 @@ +Add type hints to `tests/rest/client`. diff --git a/mypy.ini b/mypy.ini index 610660b9b..bd75905c8 100644 --- a/mypy.ini +++ b/mypy.ini @@ -84,7 +84,6 @@ exclude = (?x) |tests/rest/client/test_third_party_rules.py |tests/rest/client/test_transactions.py |tests/rest/client/test_typing.py - |tests/rest/client/utils.py |tests/rest/key/v2/test_remote_key_resource.py |tests/rest/media/v1/test_base.py |tests/rest/media/v1/test_media_storage.py @@ -253,7 +252,7 @@ disallow_untyped_defs = True [mypy-tests.rest.admin.*] disallow_untyped_defs = True -[mypy-tests.rest.client.test_directory] +[mypy-tests.rest.client.*] disallow_untyped_defs = True [mypy-tests.federation.transport.test_client] diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py index 4239e1e61..77c3ced42 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py @@ -13,12 +13,16 @@ # limitations under the License. """Tests REST events for /profile paths.""" -from typing import Any, Dict +from typing import Any, Dict, Optional + +from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import Codes from synapse.rest import admin from synapse.rest.client import login, profile, room +from synapse.server import HomeServer from synapse.types import UserID +from synapse.util import Clock from tests import unittest @@ -32,20 +36,20 @@ class ProfileTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.hs = self.setup_test_homeserver() return self.hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.owner = self.register_user("owner", "pass") self.owner_tok = self.login("owner", "pass") self.other = self.register_user("other", "pass", displayname="Bob") - def test_get_displayname(self): + def test_get_displayname(self) -> None: res = self._get_displayname() self.assertEqual(res, "owner") - def test_set_displayname(self): + def test_set_displayname(self) -> None: channel = self.make_request( "PUT", "/profile/%s/displayname" % (self.owner,), @@ -57,7 +61,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): res = self._get_displayname() self.assertEqual(res, "test") - def test_set_displayname_noauth(self): + def test_set_displayname_noauth(self) -> None: channel = self.make_request( "PUT", "/profile/%s/displayname" % (self.owner,), @@ -65,7 +69,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 401, channel.result) - def test_set_displayname_too_long(self): + def test_set_displayname_too_long(self) -> None: """Attempts to set a stupid displayname should get a 400""" channel = self.make_request( "PUT", @@ -78,11 +82,11 @@ class ProfileTestCase(unittest.HomeserverTestCase): res = self._get_displayname() self.assertEqual(res, "owner") - def test_get_displayname_other(self): + def test_get_displayname_other(self) -> None: res = self._get_displayname(self.other) self.assertEqual(res, "Bob") - def test_set_displayname_other(self): + def test_set_displayname_other(self) -> None: channel = self.make_request( "PUT", "/profile/%s/displayname" % (self.other,), @@ -91,11 +95,11 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 400, channel.result) - def test_get_avatar_url(self): + def test_get_avatar_url(self) -> None: res = self._get_avatar_url() self.assertIsNone(res) - def test_set_avatar_url(self): + def test_set_avatar_url(self) -> None: channel = self.make_request( "PUT", "/profile/%s/avatar_url" % (self.owner,), @@ -107,7 +111,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): res = self._get_avatar_url() self.assertEqual(res, "http://my.server/pic.gif") - def test_set_avatar_url_noauth(self): + def test_set_avatar_url_noauth(self) -> None: channel = self.make_request( "PUT", "/profile/%s/avatar_url" % (self.owner,), @@ -115,7 +119,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 401, channel.result) - def test_set_avatar_url_too_long(self): + def test_set_avatar_url_too_long(self) -> None: """Attempts to set a stupid avatar_url should get a 400""" channel = self.make_request( "PUT", @@ -128,11 +132,11 @@ class ProfileTestCase(unittest.HomeserverTestCase): res = self._get_avatar_url() self.assertIsNone(res) - def test_get_avatar_url_other(self): + def test_get_avatar_url_other(self) -> None: res = self._get_avatar_url(self.other) self.assertIsNone(res) - def test_set_avatar_url_other(self): + def test_set_avatar_url_other(self) -> None: channel = self.make_request( "PUT", "/profile/%s/avatar_url" % (self.other,), @@ -141,14 +145,14 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 400, channel.result) - def _get_displayname(self, name=None): + def _get_displayname(self, name: Optional[str] = None) -> str: channel = self.make_request( "GET", "/profile/%s/displayname" % (name or self.owner,) ) self.assertEqual(channel.code, 200, channel.result) return channel.json_body["displayname"] - def _get_avatar_url(self, name=None): + def _get_avatar_url(self, name: Optional[str] = None) -> str: channel = self.make_request( "GET", "/profile/%s/avatar_url" % (name or self.owner,) ) @@ -156,7 +160,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): return channel.json_body.get("avatar_url") @unittest.override_config({"max_avatar_size": 50}) - def test_avatar_size_limit_global(self): + def test_avatar_size_limit_global(self) -> None: """Tests that the maximum size limit for avatars is enforced when updating a global profile. """ @@ -187,7 +191,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) @unittest.override_config({"max_avatar_size": 50}) - def test_avatar_size_limit_per_room(self): + def test_avatar_size_limit_per_room(self) -> None: """Tests that the maximum size limit for avatars is enforced when updating a per-room profile. """ @@ -220,7 +224,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]}) - def test_avatar_allowed_mime_type_global(self): + def test_avatar_allowed_mime_type_global(self) -> None: """Tests that the MIME type whitelist for avatars is enforced when updating a global profile. """ @@ -251,7 +255,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]}) - def test_avatar_allowed_mime_type_per_room(self): + def test_avatar_allowed_mime_type_per_room(self) -> None: """Tests that the MIME type whitelist for avatars is enforced when updating a per-room profile. """ @@ -283,7 +287,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 200, channel.result) - def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]): + def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]) -> None: """Stores metadata about files in the database. Args: @@ -316,8 +320,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): - + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["require_auth_for_profile_requests"] = True config["limit_profile_requests_to_users_who_share_rooms"] = True @@ -325,7 +328,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): return self.hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # User owning the requested profile. self.owner = self.register_user("owner", "pass") self.owner_tok = self.login("owner", "pass") @@ -337,22 +340,24 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): self.room_id = self.helper.create_room_as(self.owner, tok=self.owner_tok) - def test_no_auth(self): + def test_no_auth(self) -> None: self.try_fetch_profile(401) - def test_not_in_shared_room(self): + def test_not_in_shared_room(self) -> None: self.ensure_requester_left_room() self.try_fetch_profile(403, access_token=self.requester_tok) - def test_in_shared_room(self): + def test_in_shared_room(self) -> None: self.ensure_requester_left_room() self.helper.join(room=self.room_id, user=self.requester, tok=self.requester_tok) self.try_fetch_profile(200, self.requester_tok) - def try_fetch_profile(self, expected_code, access_token=None): + def try_fetch_profile( + self, expected_code: int, access_token: Optional[str] = None + ) -> None: self.request_profile(expected_code, access_token=access_token) self.request_profile( @@ -363,13 +368,18 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): expected_code, url_suffix="/avatar_url", access_token=access_token ) - def request_profile(self, expected_code, url_suffix="", access_token=None): + def request_profile( + self, + expected_code: int, + url_suffix: str = "", + access_token: Optional[str] = None, + ) -> None: channel = self.make_request( "GET", self.profile_url + url_suffix, access_token=access_token ) self.assertEqual(channel.code, expected_code, channel.result) - def ensure_requester_left_room(self): + def ensure_requester_left_room(self) -> None: try: self.helper.leave( room=self.room_id, user=self.requester, tok=self.requester_tok @@ -389,7 +399,7 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): profile.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["require_auth_for_profile_requests"] = True config["limit_profile_requests_to_users_who_share_rooms"] = True @@ -397,12 +407,12 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): return self.hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # User requesting the profile. self.requester = self.register_user("requester", "pass") self.requester_tok = self.login("requester", "pass") - def test_can_lookup_own_profile(self): + def test_can_lookup_own_profile(self) -> None: """Tests that a user can lookup their own profile without having to be in a room if 'require_auth_for_profile_requests' is set to true in the server's config. """ diff --git a/tests/rest/client/test_push_rule_attrs.py b/tests/rest/client/test_push_rule_attrs.py index d0ce91ccd..4f875b928 100644 --- a/tests/rest/client/test_push_rule_attrs.py +++ b/tests/rest/client/test_push_rule_attrs.py @@ -27,7 +27,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): ] hijack_auth = False - def test_enabled_on_creation(self): + def test_enabled_on_creation(self) -> None: """ Tests the GET and PUT of push rules' `enabled` endpoints. Tests that a rule is enabled upon creation, even though a rule with that @@ -56,7 +56,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["enabled"], True) - def test_enabled_on_recreation(self): + def test_enabled_on_recreation(self) -> None: """ Tests the GET and PUT of push rules' `enabled` endpoints. Tests that a rule is enabled upon creation, even if a rule with that @@ -113,7 +113,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["enabled"], True) - def test_enabled_disable(self): + def test_enabled_disable(self) -> None: """ Tests the GET and PUT of push rules' `enabled` endpoints. Tests that a rule is disabled and enabled when we ask for it. @@ -166,7 +166,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["enabled"], True) - def test_enabled_404_when_get_non_existent(self): + def test_enabled_404_when_get_non_existent(self) -> None: """ Tests that `enabled` gives 404 when the rule doesn't exist. """ @@ -212,7 +212,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_enabled_404_when_get_non_existent_server_rule(self): + def test_enabled_404_when_get_non_existent_server_rule(self) -> None: """ Tests that `enabled` gives 404 when the server-default rule doesn't exist. """ @@ -226,7 +226,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_enabled_404_when_put_non_existent_rule(self): + def test_enabled_404_when_put_non_existent_rule(self) -> None: """ Tests that `enabled` gives 404 when we put to a rule that doesn't exist. """ @@ -243,7 +243,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_enabled_404_when_put_non_existent_server_rule(self): + def test_enabled_404_when_put_non_existent_server_rule(self) -> None: """ Tests that `enabled` gives 404 when we put to a server-default rule that doesn't exist. """ @@ -260,7 +260,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_actions_get(self): + def test_actions_get(self) -> None: """ Tests that `actions` gives you what you expect on a fresh rule. """ @@ -289,7 +289,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): channel.json_body["actions"], ["notify", {"set_tweak": "highlight"}] ) - def test_actions_put(self): + def test_actions_put(self) -> None: """ Tests that PUT on actions updates the value you'd get from GET. """ @@ -325,7 +325,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["actions"], ["dont_notify"]) - def test_actions_404_when_get_non_existent(self): + def test_actions_404_when_get_non_existent(self) -> None: """ Tests that `actions` gives 404 when the rule doesn't exist. """ @@ -365,7 +365,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_actions_404_when_get_non_existent_server_rule(self): + def test_actions_404_when_get_non_existent_server_rule(self) -> None: """ Tests that `actions` gives 404 when the server-default rule doesn't exist. """ @@ -379,7 +379,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_actions_404_when_put_non_existent_rule(self): + def test_actions_404_when_put_non_existent_rule(self) -> None: """ Tests that `actions` gives 404 when putting to a rule that doesn't exist. """ @@ -396,7 +396,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_actions_404_when_put_non_existent_server_rule(self): + def test_actions_404_when_put_non_existent_server_rule(self) -> None: """ Tests that `actions` gives 404 when putting to a server-default rule that doesn't exist. """ diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index 433d715f6..7401b5e0c 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -11,9 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List + +from twisted.test.proto_helpers import MemoryReactor from synapse.rest import admin from synapse.rest.client import login, room, sync +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -28,7 +34,7 @@ class RedactionsTestCase(HomeserverTestCase): sync.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["rc_message"] = {"per_second": 0.2, "burst_count": 10} @@ -36,7 +42,7 @@ class RedactionsTestCase(HomeserverTestCase): return self.setup_test_homeserver(config=config) - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # register a couple of users self.mod_user_id = self.register_user("user1", "pass") self.mod_access_token = self.login("user1", "pass") @@ -60,7 +66,9 @@ class RedactionsTestCase(HomeserverTestCase): room=self.room_id, user=self.other_user_id, tok=self.other_access_token ) - def _redact_event(self, access_token, room_id, event_id, expect_code=200): + def _redact_event( + self, access_token: str, room_id: str, event_id: str, expect_code: int = 200 + ) -> JsonDict: """Helper function to send a redaction event. Returns the json body. @@ -71,13 +79,13 @@ class RedactionsTestCase(HomeserverTestCase): self.assertEqual(int(channel.result["code"]), expect_code) return channel.json_body - def _sync_room_timeline(self, access_token, room_id): + def _sync_room_timeline(self, access_token: str, room_id: str) -> List[JsonDict]: channel = self.make_request("GET", "sync", access_token=self.mod_access_token) self.assertEqual(channel.result["code"], b"200") room_sync = channel.json_body["rooms"]["join"][room_id] return room_sync["timeline"]["events"] - def test_redact_event_as_moderator(self): + def test_redact_event_as_moderator(self) -> None: # as a regular user, send a message to redact b = self.helper.send(room_id=self.room_id, tok=self.other_access_token) msg_id = b["event_id"] @@ -98,7 +106,7 @@ class RedactionsTestCase(HomeserverTestCase): self.assertEqual(timeline[-2]["unsigned"]["redacted_by"], redaction_id) self.assertEqual(timeline[-2]["content"], {}) - def test_redact_event_as_normal(self): + def test_redact_event_as_normal(self) -> None: # as a regular user, send a message to redact b = self.helper.send(room_id=self.room_id, tok=self.other_access_token) normal_msg_id = b["event_id"] @@ -133,7 +141,7 @@ class RedactionsTestCase(HomeserverTestCase): self.assertEqual(timeline[-3]["unsigned"]["redacted_by"], redaction_id) self.assertEqual(timeline[-3]["content"], {}) - def test_redact_nonexistent_event(self): + def test_redact_nonexistent_event(self) -> None: # control case: an existing event b = self.helper.send(room_id=self.room_id, tok=self.other_access_token) msg_id = b["event_id"] @@ -158,7 +166,7 @@ class RedactionsTestCase(HomeserverTestCase): self.assertEqual(timeline[-2]["unsigned"]["redacted_by"], redaction_id) self.assertEqual(timeline[-2]["content"], {}) - def test_redact_create_event(self): + def test_redact_create_event(self) -> None: # control case: an existing event b = self.helper.send(room_id=self.room_id, tok=self.mod_access_token) msg_id = b["event_id"] @@ -178,7 +186,7 @@ class RedactionsTestCase(HomeserverTestCase): self.other_access_token, self.room_id, create_event_id, expect_code=403 ) - def test_redact_event_as_moderator_ratelimit(self): + def test_redact_event_as_moderator_ratelimit(self) -> None: """Tests that the correct ratelimiting is applied to redactions""" message_ids = [] diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 8f7181103..c8db45719 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -18,11 +18,15 @@ import urllib.parse from typing import Dict, List, Optional, Tuple from unittest.mock import patch +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin from synapse.rest.client import login, register, relations, room, sync +from synapse.server import HomeServer from synapse.storage.relations import RelationPaginationToken from synapse.types import JsonDict, StreamToken +from synapse.util import Clock from tests import unittest from tests.server import FakeChannel @@ -52,7 +56,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): return config - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.user_id, self.user_token = self._create_user("alice") @@ -63,7 +67,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): res = self.helper.send(self.room, body="Hi!", tok=self.user_token) self.parent_id = res["event_id"] - def test_send_relation(self): + def test_send_relation(self) -> None: """Tests that sending a relation using the new /send_relation works creates the right shape of event. """ @@ -95,7 +99,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel.json_body, ) - def test_deny_invalid_event(self): + def test_deny_invalid_event(self) -> None: """Test that we deny relations on non-existant events""" channel = self._send_relation( RelationTypes.ANNOTATION, @@ -125,7 +129,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(200, channel.code, channel.json_body) - def test_deny_invalid_room(self): + def test_deny_invalid_room(self) -> None: """Test that we deny relations on non-existant events""" # Create another room and send a message in it. room2 = self.helper.create_room_as(self.user_id, tok=self.user_token) @@ -138,7 +142,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(400, channel.code, channel.json_body) - def test_deny_double_react(self): + def test_deny_double_react(self) -> None: """Test that we deny relations on membership events""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") self.assertEqual(200, channel.code, channel.json_body) @@ -146,7 +150,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEqual(400, channel.code, channel.json_body) - def test_deny_forked_thread(self): + def test_deny_forked_thread(self) -> None: """It is invalid to start a thread off a thread.""" channel = self._send_relation( RelationTypes.THREAD, @@ -165,7 +169,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(400, channel.code, channel.json_body) - def test_basic_paginate_relations(self): + def test_basic_paginate_relations(self) -> None: """Tests that calling pagination API correctly the latest relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEqual(200, channel.code, channel.json_body) @@ -235,7 +239,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ).to_string(self.store) ) - def test_repeated_paginate_relations(self): + def test_repeated_paginate_relations(self) -> None: """Test that if we paginate using a limit and tokens then we get the expected events. """ @@ -303,7 +307,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): found_event_ids.reverse() self.assertEqual(found_event_ids, expected_event_ids) - def test_pagination_from_sync_and_messages(self): + def test_pagination_from_sync_and_messages(self) -> None: """Pagination tokens from /sync and /messages can be used to paginate /relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A") self.assertEqual(200, channel.code, channel.json_body) @@ -362,7 +366,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] ) - def test_aggregation_pagination_groups(self): + def test_aggregation_pagination_groups(self) -> None: """Test that we can paginate annotation groups correctly.""" # We need to create ten separate users to send each reaction. @@ -427,7 +431,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEqual(sent_groups, found_groups) - def test_aggregation_pagination_within_group(self): + def test_aggregation_pagination_within_group(self) -> None: """Test that we can paginate within an annotation group.""" # We need to create ten separate users to send each reaction. @@ -524,7 +528,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): found_event_ids.reverse() self.assertEqual(found_event_ids, expected_event_ids) - def test_aggregation(self): + def test_aggregation(self) -> None: """Test that annotations get correctly aggregated.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") @@ -556,7 +560,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): }, ) - def test_aggregation_redactions(self): + def test_aggregation_redactions(self) -> None: """Test that annotations get correctly aggregated after a redaction.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") @@ -590,7 +594,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, ) - def test_aggregation_must_be_annotation(self): + def test_aggregation_must_be_annotation(self) -> None: """Test that aggregations must be annotations.""" channel = self.make_request( @@ -604,7 +608,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): @unittest.override_config( {"experimental_features": {"msc3440_enabled": True, "msc3666_enabled": True}} ) - def test_bundled_aggregations(self): + def test_bundled_aggregations(self) -> None: """ Test that annotations, references, and threads get correctly bundled. @@ -746,7 +750,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ] assert_bundle(self._find_event_in_chunk(chunk)) - def test_aggregation_get_event_for_annotation(self): + def test_aggregation_get_event_for_annotation(self) -> None: """Test that annotations do not get bundled aggregations included when directly requested. """ @@ -768,7 +772,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, channel.json_body) self.assertIsNone(channel.json_body["unsigned"].get("m.relations")) - def test_aggregation_get_event_for_thread(self): + def test_aggregation_get_event_for_thread(self) -> None: """Test that threads get bundled aggregations included when directly requested.""" channel = self._send_relation(RelationTypes.THREAD, "m.room.test") self.assertEqual(200, channel.code, channel.json_body) @@ -815,7 +819,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) - def test_ignore_invalid_room(self): + def test_ignore_invalid_room(self) -> None: """Test that we ignore invalid relations over federation.""" # Create another room and send a message in it. room2 = self.helper.create_room_as(self.user_id, tok=self.user_token) @@ -927,7 +931,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertNotIn("m.relations", channel.json_body["unsigned"]) @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) - def test_edit(self): + def test_edit(self) -> None: """Test that a simple edit works.""" new_body = {"msgtype": "m.text", "body": "I've been edited!"} @@ -1010,7 +1014,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ] assert_bundle(self._find_event_in_chunk(chunk)) - def test_multi_edit(self): + def test_multi_edit(self) -> None: """Test that multiple edits, including attempts by people who shouldn't be allowed, are correctly handled. """ @@ -1067,7 +1071,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) - def test_edit_reply(self): + def test_edit_reply(self) -> None: """Test that editing a reply works.""" # Create a reply to edit. @@ -1124,7 +1128,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) - def test_edit_thread(self): + def test_edit_thread(self) -> None: """Test that editing a thread works.""" # Create a thread and edit the last event. @@ -1163,7 +1167,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): latest_event_in_thread = thread_summary["latest_event"] self.assertEqual(latest_event_in_thread["content"]["body"], "I've been edited!") - def test_edit_edit(self): + def test_edit_edit(self) -> None: """Test that an edit cannot be edited.""" new_body = {"msgtype": "m.text", "body": "Initial edit"} channel = self._send_relation( @@ -1213,7 +1217,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) - def test_relations_redaction_redacts_edits(self): + def test_relations_redaction_redacts_edits(self) -> None: """Test that edits of an event are redacted when the original event is redacted. """ @@ -1269,7 +1273,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertIn("chunk", channel.json_body) self.assertEqual(channel.json_body["chunk"], []) - def test_aggregations_redaction_prevents_access_to_aggregations(self): + def test_aggregations_redaction_prevents_access_to_aggregations(self) -> None: """Test that annotations of an event are redacted when the original event is redacted. """ @@ -1309,7 +1313,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertIn("chunk", channel.json_body) self.assertEqual(channel.json_body["chunk"], []) - def test_unknown_relations(self): + def test_unknown_relations(self) -> None: """Unknown relations should be accepted.""" channel = self._send_relation("m.relation.test", "m.room.test") self.assertEqual(200, channel.code, channel.json_body) @@ -1417,7 +1421,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): return user_id, access_token - def test_background_update(self): + def test_background_update(self) -> None: """Test the event_arbitrary_relations background update.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") self.assertEqual(200, channel.code, channel.json_body) diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index c41a1c14a..f3bf8d093 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -13,9 +13,14 @@ # limitations under the License. from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventTypes from synapse.rest import admin from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from synapse.visibility import filter_events_for_client from tests import unittest @@ -31,7 +36,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["retention"] = { "enabled": True, @@ -47,7 +52,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): return self.hs - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = self.register_user("user", "password") self.token = self.login("user", "password") @@ -55,7 +60,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): self.serializer = self.hs.get_event_client_serializer() self.clock = self.hs.get_clock() - def test_retention_event_purged_with_state_event(self): + def test_retention_event_purged_with_state_event(self) -> None: """Tests that expired events are correctly purged when the room's retention policy is defined by a state event. """ @@ -72,7 +77,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): self._test_retention_event_purged(room_id, one_day_ms * 1.5) - def test_retention_event_purged_with_state_event_outside_allowed(self): + def test_retention_event_purged_with_state_event_outside_allowed(self) -> None: """Tests that the server configuration can override the policy for a room when running the purge jobs. """ @@ -102,7 +107,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): # instead of the one specified in the room's policy. self._test_retention_event_purged(room_id, one_day_ms * 0.5) - def test_retention_event_purged_without_state_event(self): + def test_retention_event_purged_without_state_event(self) -> None: """Tests that expired events are correctly purged when the room's retention policy is defined by the server's configuration's default retention policy. """ @@ -110,7 +115,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): self._test_retention_event_purged(room_id, one_day_ms * 2) - def test_visibility(self): + def test_visibility(self) -> None: """Tests that synapse.visibility.filter_events_for_client correctly filters out outdated events """ @@ -152,7 +157,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): # That event should be the second, not outdated event. self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events) - def _test_retention_event_purged(self, room_id: str, increment: float): + def _test_retention_event_purged(self, room_id: str, increment: float) -> None: """Run the following test scenario to test the message retention policy support: 1. Send event 1 @@ -186,6 +191,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): resp = self.helper.send(room_id=room_id, body="1", tok=self.token) expired_event_id = resp.get("event_id") + assert expired_event_id is not None # Check that we can retrieve the event. expired_event = self.get_event(expired_event_id) @@ -201,6 +207,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): resp = self.helper.send(room_id=room_id, body="2", tok=self.token) valid_event_id = resp.get("event_id") + assert valid_event_id is not None # Advance the time again. Now our first event should have expired but our second # one should still be kept. @@ -218,7 +225,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): # has been purged. self.get_event(room_id, create_event.event_id) - def get_event(self, event_id, expect_none=False): + def get_event(self, event_id: str, expect_none: bool = False) -> JsonDict: event = self.get_success(self.store.get_event(event_id, allow_none=True)) if expect_none: @@ -240,7 +247,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["retention"] = { "enabled": True, @@ -254,11 +261,11 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): ) return self.hs - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = self.register_user("user", "password") self.token = self.login("user", "password") - def test_no_default_policy(self): + def test_no_default_policy(self) -> None: """Tests that an event doesn't get expired if there is neither a default retention policy nor a policy specific to the room. """ @@ -266,7 +273,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): self._test_retention(room_id) - def test_state_policy(self): + def test_state_policy(self) -> None: """Tests that an event gets correctly expired if there is no default retention policy but there's a policy specific to the room. """ @@ -283,12 +290,15 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): self._test_retention(room_id, expected_code_for_first_event=404) - def _test_retention(self, room_id, expected_code_for_first_event=200): + def _test_retention( + self, room_id: str, expected_code_for_first_event: int = 200 + ) -> None: # Send a first event to the room. This is the event we'll want to be purged at the # end of the test. resp = self.helper.send(room_id=room_id, body="1", tok=self.token) first_event_id = resp.get("event_id") + assert first_event_id is not None # Check that we can retrieve the event. expired_event = self.get_event(room_id, first_event_id) @@ -304,6 +314,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): resp = self.helper.send(room_id=room_id, body="2", tok=self.token) second_event_id = resp.get("event_id") + assert second_event_id is not None # Advance the time by another month. self.reactor.advance(one_day_ms * 30 / 1000) @@ -322,7 +333,9 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): second_event = self.get_event(room_id, second_event_id) self.assertEqual(second_event.get("content", {}).get("body"), "2", second_event) - def get_event(self, room_id, event_id, expected_code=200): + def get_event( + self, room_id: str, event_id: str, expected_code: int = 200 + ) -> JsonDict: url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) channel = self.make_request("GET", url, access_token=self.token) diff --git a/tests/rest/client/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py index e2ed14457..c3942889e 100644 --- a/tests/rest/client/test_sendtodevice.py +++ b/tests/rest/client/test_sendtodevice.py @@ -26,7 +26,7 @@ class SendToDeviceTestCase(HomeserverTestCase): sync.register_servlets, ] - def test_user_to_user(self): + def test_user_to_user(self) -> None: """A to-device message from one user to another should get delivered""" user1 = self.register_user("u1", "pass") @@ -73,7 +73,7 @@ class SendToDeviceTestCase(HomeserverTestCase): self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), []) @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) - def test_local_room_key_request(self): + def test_local_room_key_request(self) -> None: """m.room_key_request has special-casing; test from local user""" user1 = self.register_user("u1", "pass") user1_tok = self.login("u1", "pass", "d1") @@ -128,7 +128,7 @@ class SendToDeviceTestCase(HomeserverTestCase): ) @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) - def test_remote_room_key_request(self): + def test_remote_room_key_request(self) -> None: """m.room_key_request has special-casing; test from remote user""" user2 = self.register_user("u2", "pass") user2_tok = self.login("u2", "pass", "d2") @@ -199,7 +199,7 @@ class SendToDeviceTestCase(HomeserverTestCase): }, ) - def test_limited_sync(self): + def test_limited_sync(self) -> None: """If a limited sync for to-devices happens the next /sync should respond immediately.""" self.register_user("u1", "pass") diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index 2634c98dd..ae5ada3be 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -14,6 +14,8 @@ from unittest.mock import Mock, patch +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.constants import EventTypes from synapse.rest.client import ( @@ -23,13 +25,15 @@ from synapse.rest.client import ( room, room_upgrade_rest_servlet, ) +from synapse.server import HomeServer from synapse.types import UserID +from synapse.util import Clock from tests import unittest class _ShadowBannedBase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # Create two users, one of which is shadow-banned. self.banned_user_id = self.register_user("banned", "test") self.banned_access_token = self.login("banned", "test") @@ -55,7 +59,7 @@ class RoomTestCase(_ShadowBannedBase): room_upgrade_rest_servlet.register_servlets, ] - def test_invite(self): + def test_invite(self) -> None: """Invites from shadow-banned users don't actually get sent.""" # The create works fine. @@ -77,7 +81,7 @@ class RoomTestCase(_ShadowBannedBase): ) self.assertEqual(invited_rooms, []) - def test_invite_3pid(self): + def test_invite_3pid(self) -> None: """Ensure that a 3PID invite does not attempt to contact the identity server.""" identity_handler = self.hs.get_identity_handler() identity_handler.lookup_3pid = Mock( @@ -101,7 +105,7 @@ class RoomTestCase(_ShadowBannedBase): # This should have raised an error earlier, but double check this wasn't called. identity_handler.lookup_3pid.assert_not_called() - def test_create_room(self): + def test_create_room(self) -> None: """Invitations during a room creation should be discarded, but the room still gets created.""" # The room creation is successful. channel = self.make_request( @@ -126,7 +130,7 @@ class RoomTestCase(_ShadowBannedBase): users = self.get_success(self.store.get_users_in_room(room_id)) self.assertCountEqual(users, ["@banned:test", "@otheruser:test"]) - def test_message(self): + def test_message(self) -> None: """Messages from shadow-banned users don't actually get sent.""" room_id = self.helper.create_room_as( @@ -151,7 +155,7 @@ class RoomTestCase(_ShadowBannedBase): ) self.assertNotIn(event_id, latest_events) - def test_upgrade(self): + def test_upgrade(self) -> None: """A room upgrade should fail, but look like it succeeded.""" # The create works fine. @@ -177,7 +181,7 @@ class RoomTestCase(_ShadowBannedBase): # The summary should be empty since the room doesn't exist. self.assertEqual(summary, {}) - def test_typing(self): + def test_typing(self) -> None: """Typing notifications should not be propagated into the room.""" # The create works fine. room_id = self.helper.create_room_as( @@ -240,7 +244,7 @@ class ProfileTestCase(_ShadowBannedBase): room.register_servlets, ] - def test_displayname(self): + def test_displayname(self) -> None: """Profile changes should succeed, but don't end up in a room.""" original_display_name = "banned" new_display_name = "new name" @@ -281,7 +285,7 @@ class ProfileTestCase(_ShadowBannedBase): event.content, {"membership": "join", "displayname": original_display_name} ) - def test_room_displayname(self): + def test_room_displayname(self) -> None: """Changes to state events for a room should be processed, but not end up in the room.""" original_display_name = "banned" new_display_name = "new name" diff --git a/tests/rest/client/test_shared_rooms.py b/tests/rest/client/test_shared_rooms.py index 294f46fb9..3818b7b14 100644 --- a/tests/rest/client/test_shared_rooms.py +++ b/tests/rest/client/test_shared_rooms.py @@ -11,8 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.rest.client import login, room, shared_rooms +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.server import FakeChannel @@ -30,16 +34,16 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): shared_rooms.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["update_user_directory"] = True return self.setup_test_homeserver(config=config) - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.handler = hs.get_user_directory_handler() - def _get_shared_rooms(self, token, other_user) -> FakeChannel: + def _get_shared_rooms(self, token: str, other_user: str) -> FakeChannel: return self.make_request( "GET", "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s" @@ -47,14 +51,14 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): access_token=token, ) - def test_shared_room_list_public(self): + def test_shared_room_list_public(self) -> None: """ A room should show up in the shared list of rooms between two users if it is public. """ self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=True) - def test_shared_room_list_private(self): + def test_shared_room_list_private(self) -> None: """ A room should show up in the shared list of rooms between two users if it is private. @@ -63,7 +67,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): room_one_is_public=False, room_two_is_public=False ) - def test_shared_room_list_mixed(self): + def test_shared_room_list_mixed(self) -> None: """ The shared room list between two users should contain both public and private rooms. @@ -72,7 +76,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): def _check_shared_rooms_with( self, room_one_is_public: bool, room_two_is_public: bool - ): + ) -> None: """Checks that shared public or private rooms between two users appear in their shared room lists """ @@ -109,7 +113,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): for room_id_id in channel.json_body["joined"]: self.assertIn(room_id_id, [room_id_one, room_id_two]) - def test_shared_room_list_after_leave(self): + def test_shared_room_list_after_leave(self) -> None: """ A room should no longer be considered shared if the other user has left it. diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py index 658c21b2a..b7d0f42da 100644 --- a/tests/rest/client/test_upgrade_room.py +++ b/tests/rest/client/test_upgrade_room.py @@ -13,11 +13,14 @@ # limitations under the License. from typing import Optional +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventContentFields, EventTypes, RoomTypes from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.rest import admin from synapse.rest.client import login, room, room_upgrade_rest_servlet from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.server import FakeChannel @@ -31,7 +34,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): room_upgrade_rest_servlet.register_servlets, ] - def prepare(self, reactor, clock, hs: "HomeServer"): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.creator = self.register_user("creator", "pass") @@ -60,7 +63,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): access_token=token or self.creator_token, ) - def test_upgrade(self): + def test_upgrade(self) -> None: """ Upgrading a room should work fine. """ @@ -68,7 +71,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, channel.result) self.assertIn("replacement_room", channel.json_body) - def test_not_in_room(self): + def test_not_in_room(self) -> None: """ Upgrading a room should work fine. """ @@ -79,7 +82,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): channel = self._upgrade_room(roomless_token) self.assertEqual(403, channel.code, channel.result) - def test_power_levels(self): + def test_power_levels(self) -> None: """ Another user can upgrade the room if their power level is increased. """ @@ -105,7 +108,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): channel = self._upgrade_room(self.other_token) self.assertEqual(200, channel.code, channel.result) - def test_power_levels_user_default(self): + def test_power_levels_user_default(self) -> None: """ Another user can upgrade the room if the default power level for users is increased. """ @@ -131,7 +134,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): channel = self._upgrade_room(self.other_token) self.assertEqual(200, channel.code, channel.result) - def test_power_levels_tombstone(self): + def test_power_levels_tombstone(self) -> None: """ Another user can upgrade the room if they can send the tombstone event. """ @@ -164,7 +167,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): ) self.assertNotIn(self.other, power_levels["users"]) - def test_space(self): + def test_space(self) -> None: """Test upgrading a space.""" # Create a space. diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 46cd5f70a..28663826f 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -41,6 +41,7 @@ from twisted.web.resource import Resource from twisted.web.server import Site from synapse.api.constants import Membership +from synapse.server import HomeServer from synapse.types import JsonDict from tests.server import FakeChannel, FakeSite, make_request @@ -48,15 +49,15 @@ from tests.test_utils import FakeResponse from tests.test_utils.html_parsers import TestHtmlParser -@attr.s +@attr.s(auto_attribs=True) class RestHelper: """Contains extra helper functions to quickly and clearly perform a given REST action, which isn't the focus of the test. """ - hs = attr.ib() - site = attr.ib(type=Site) - auth_user_id = attr.ib() + hs: HomeServer + site: Site + auth_user_id: Optional[str] @overload def create_room_as( @@ -145,7 +146,7 @@ class RestHelper: def invite( self, - room: Optional[str] = None, + room: str, src: Optional[str] = None, targ: Optional[str] = None, expect_code: int = HTTPStatus.OK, @@ -216,7 +217,7 @@ class RestHelper: def leave( self, - room: Optional[str] = None, + room: str, user: Optional[str] = None, expect_code: int = HTTPStatus.OK, tok: Optional[str] = None, @@ -230,14 +231,22 @@ class RestHelper: expect_code=expect_code, ) - def ban(self, room: str, src: str, targ: str, **kwargs: object) -> None: + def ban( + self, + room: str, + src: str, + targ: str, + expect_code: int = HTTPStatus.OK, + tok: Optional[str] = None, + ) -> None: """A convenience helper: `change_membership` with `membership` preset to "ban".""" self.change_membership( room=room, src=src, targ=targ, + tok=tok, membership=Membership.BAN, - **kwargs, + expect_code=expect_code, ) def change_membership( @@ -378,7 +387,7 @@ class RestHelper: room_id: str, event_type: str, body: Optional[Dict[str, Any]], - tok: str, + tok: Optional[str], expect_code: int = HTTPStatus.OK, state_key: str = "", method: str = "GET", @@ -458,7 +467,7 @@ class RestHelper: room_id: str, event_type: str, body: Dict[str, Any], - tok: str, + tok: Optional[str], expect_code: int = HTTPStatus.OK, state_key: str = "", ) -> JsonDict: @@ -658,7 +667,12 @@ class RestHelper: (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict), ] - async def mock_req(method: str, uri: str, data=None, headers=None): + async def mock_req( + method: str, + uri: str, + data: Optional[dict] = None, + headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, + ): (expected_uri, resp_obj) = expected_requests.pop(0) assert uri == expected_uri resp = FakeResponse( From 1866fb39d7ffc86d7374a9aed916f70a91ec65fa Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 28 Feb 2022 13:29:09 -0500 Subject: [PATCH 65/84] Move experimental support for MSC3440 to /versions. (#12099) Instead of being part of /capabilities, this matches a change to MSC3440 to properly use these endpoints. --- changelog.d/12099.misc | 1 + synapse/rest/client/capabilities.py | 3 --- synapse/rest/client/versions.py | 2 ++ 3 files changed, 3 insertions(+), 3 deletions(-) create mode 100644 changelog.d/12099.misc diff --git a/changelog.d/12099.misc b/changelog.d/12099.misc new file mode 100644 index 000000000..0553825db --- /dev/null +++ b/changelog.d/12099.misc @@ -0,0 +1 @@ +Move experimental support for [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440) to /versions. diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py index b80fdd371..4237071c6 100644 --- a/synapse/rest/client/capabilities.py +++ b/synapse/rest/client/capabilities.py @@ -72,9 +72,6 @@ class CapabilitiesRestServlet(RestServlet): "org.matrix.msc3244.room_capabilities" ] = MSC3244_CAPABILITIES - if self.config.experimental.msc3440_enabled: - response["capabilities"]["io.element.thread"] = {"enabled": True} - if self.config.experimental.msc3720_enabled: response["capabilities"]["org.matrix.msc3720.account_status"] = { "enabled": True, diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 00f29344a..2e5d0e4e2 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -99,6 +99,8 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc2716": self.config.experimental.msc2716_enabled, # Adds support for jump to date endpoints (/timestamp_to_event) as per MSC3030 "org.matrix.msc3030": self.config.experimental.msc3030_enabled, + # Adds support for thread relations, per MSC3440. + "org.matrix.msc3440": self.config.experimental.msc3440_enabled, }, }, ) From 7754af24ab163a3666bc04c7df409e59ace0d763 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 28 Feb 2022 13:33:00 -0500 Subject: [PATCH 66/84] Remove the unstable `/spaces` endpoint. (#12073) ...and various code supporting it. The /spaces endpoint was from an old version of MSC2946 and included both a Client-Server and Server-Server API. Note that the unstable /hierarchy endpoint (from the final version of MSC2946) is not yet removed. --- changelog.d/12073.removal | 1 + docs/workers.md | 2 - synapse/federation/federation_client.py | 226 ++---------- synapse/federation/transport/client.py | 33 -- .../federation/transport/server/federation.py | 76 ----- synapse/handlers/room_summary.py | 323 +----------------- synapse/rest/client/room.py | 68 ---- tests/handlers/test_room_summary.py | 119 +------ 8 files changed, 46 insertions(+), 802 deletions(-) create mode 100644 changelog.d/12073.removal diff --git a/changelog.d/12073.removal b/changelog.d/12073.removal new file mode 100644 index 000000000..1f3979271 --- /dev/null +++ b/changelog.d/12073.removal @@ -0,0 +1 @@ +Remove the unstable `/spaces` endpoint from [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/docs/workers.md b/docs/workers.md index b82a6900a..b0f8599ef 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -212,7 +212,6 @@ information. ^/_matrix/federation/v1/user/devices/ ^/_matrix/federation/v1/get_groups_publicised$ ^/_matrix/key/v2/query - ^/_matrix/federation/unstable/org.matrix.msc2946/spaces/ ^/_matrix/federation/(v1|unstable/org.matrix.msc2946)/hierarchy/ # Inbound federation transaction request @@ -225,7 +224,6 @@ information. ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/context/.*$ ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/members$ ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/state$ - ^/_matrix/client/unstable/org.matrix.msc2946/rooms/.*/spaces$ ^/_matrix/client/(v1|unstable/org.matrix.msc2946)/rooms/.*/hierarchy$ ^/_matrix/client/unstable/im.nheko.summary/rooms/.*/summary$ ^/_matrix/client/(r0|v3|unstable)/account/3pid$ diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index a4bae3c4c..64e595e74 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1362,61 +1362,6 @@ class FederationClient(FederationBase): # server doesn't give it to us. return None - async def get_space_summary( - self, - destinations: Iterable[str], - room_id: str, - suggested_only: bool, - max_rooms_per_space: Optional[int], - exclude_rooms: List[str], - ) -> "FederationSpaceSummaryResult": - """ - Call other servers to get a summary of the given space - - - Args: - destinations: The remote servers. We will try them in turn, omitting any - that have been blacklisted. - - room_id: ID of the space to be queried - - suggested_only: If true, ask the remote server to only return children - with the "suggested" flag set - - max_rooms_per_space: A limit on the number of children to return for each - space - - exclude_rooms: A list of room IDs to tell the remote server to skip - - Returns: - a parsed FederationSpaceSummaryResult - - Raises: - SynapseError if we were unable to get a valid summary from any of the - remote servers - """ - - async def send_request(destination: str) -> FederationSpaceSummaryResult: - res = await self.transport_layer.get_space_summary( - destination=destination, - room_id=room_id, - suggested_only=suggested_only, - max_rooms_per_space=max_rooms_per_space, - exclude_rooms=exclude_rooms, - ) - - try: - return FederationSpaceSummaryResult.from_json_dict(res) - except ValueError as e: - raise InvalidResponseError(str(e)) - - return await self._try_destination_list( - "fetch space summary", - destinations, - send_request, - failover_on_unknown_endpoint=True, - ) - async def get_room_hierarchy( self, destinations: Iterable[str], @@ -1488,10 +1433,8 @@ class FederationClient(FederationBase): if any(not isinstance(e, dict) for e in children_state): raise InvalidResponseError("Invalid event in 'children_state' list") try: - [ - FederationSpaceSummaryEventResult.from_json_dict(e) - for e in children_state - ] + for child_state in children_state: + _validate_hierarchy_event(child_state) except ValueError as e: raise InvalidResponseError(str(e)) @@ -1513,62 +1456,12 @@ class FederationClient(FederationBase): return room, children_state, children, inaccessible_children - try: - result = await self._try_destination_list( - "fetch room hierarchy", - destinations, - send_request, - failover_on_unknown_endpoint=True, - ) - except SynapseError as e: - # If an unexpected error occurred, re-raise it. - if e.code != 502: - raise - - logger.debug( - "Couldn't fetch room hierarchy, falling back to the spaces API" - ) - - # Fallback to the old federation API and translate the results if - # no servers implement the new API. - # - # The algorithm below is a bit inefficient as it only attempts to - # parse information for the requested room, but the legacy API may - # return additional layers. - legacy_result = await self.get_space_summary( - destinations, - room_id, - suggested_only, - max_rooms_per_space=None, - exclude_rooms=[], - ) - - # Find the requested room in the response (and remove it). - for _i, room in enumerate(legacy_result.rooms): - if room.get("room_id") == room_id: - break - else: - # The requested room was not returned, nothing we can do. - raise - requested_room = legacy_result.rooms.pop(_i) - - # Find any children events of the requested room. - children_events = [] - children_room_ids = set() - for event in legacy_result.events: - if event.room_id == room_id: - children_events.append(event.data) - children_room_ids.add(event.state_key) - - # Find the children rooms. - children = [] - for room in legacy_result.rooms: - if room.get("room_id") in children_room_ids: - children.append(room) - - # It isn't clear from the response whether some of the rooms are - # not accessible. - result = (requested_room, children_events, children, ()) + result = await self._try_destination_list( + "fetch room hierarchy", + destinations, + send_request, + failover_on_unknown_endpoint=True, + ) # Cache the result to avoid fetching data over federation every time. self._get_room_hierarchy_cache[(room_id, suggested_only)] = result @@ -1710,89 +1603,34 @@ class TimestampToEventResponse: return cls(event_id, origin_server_ts, d) -@attr.s(frozen=True, slots=True, auto_attribs=True) -class FederationSpaceSummaryEventResult: - """Represents a single event in the result of a successful get_space_summary call. +def _validate_hierarchy_event(d: JsonDict) -> None: + """Validate an event within the result of a /hierarchy request - It's essentially just a serialised event object, but we do a bit of parsing and - validation in `from_json_dict` and store some of the validated properties in - object attributes. + Args: + d: json object to be parsed + + Raises: + ValueError if d is not a valid event """ - event_type: str - room_id: str - state_key: str - via: Sequence[str] + event_type = d.get("type") + if not isinstance(event_type, str): + raise ValueError("Invalid event: 'event_type' must be a str") - # the raw data, including the above keys - data: JsonDict + room_id = d.get("room_id") + if not isinstance(room_id, str): + raise ValueError("Invalid event: 'room_id' must be a str") - @classmethod - def from_json_dict(cls, d: JsonDict) -> "FederationSpaceSummaryEventResult": - """Parse an event within the result of a /spaces/ request + state_key = d.get("state_key") + if not isinstance(state_key, str): + raise ValueError("Invalid event: 'state_key' must be a str") - Args: - d: json object to be parsed + content = d.get("content") + if not isinstance(content, dict): + raise ValueError("Invalid event: 'content' must be a dict") - Raises: - ValueError if d is not a valid event - """ - - event_type = d.get("type") - if not isinstance(event_type, str): - raise ValueError("Invalid event: 'event_type' must be a str") - - room_id = d.get("room_id") - if not isinstance(room_id, str): - raise ValueError("Invalid event: 'room_id' must be a str") - - state_key = d.get("state_key") - if not isinstance(state_key, str): - raise ValueError("Invalid event: 'state_key' must be a str") - - content = d.get("content") - if not isinstance(content, dict): - raise ValueError("Invalid event: 'content' must be a dict") - - via = content.get("via") - if not isinstance(via, Sequence): - raise ValueError("Invalid event: 'via' must be a list") - if any(not isinstance(v, str) for v in via): - raise ValueError("Invalid event: 'via' must be a list of strings") - - return cls(event_type, room_id, state_key, via, d) - - -@attr.s(frozen=True, slots=True, auto_attribs=True) -class FederationSpaceSummaryResult: - """Represents the data returned by a successful get_space_summary call.""" - - rooms: List[JsonDict] - events: Sequence[FederationSpaceSummaryEventResult] - - @classmethod - def from_json_dict(cls, d: JsonDict) -> "FederationSpaceSummaryResult": - """Parse the result of a /spaces/ request - - Args: - d: json object to be parsed - - Raises: - ValueError if d is not a valid /spaces/ response - """ - rooms = d.get("rooms") - if not isinstance(rooms, List): - raise ValueError("'rooms' must be a list") - if any(not isinstance(r, dict) for r in rooms): - raise ValueError("Invalid room in 'rooms' list") - - events = d.get("events") - if not isinstance(events, Sequence): - raise ValueError("'events' must be a list") - if any(not isinstance(e, dict) for e in events): - raise ValueError("Invalid event in 'events' list") - parsed_events = [ - FederationSpaceSummaryEventResult.from_json_dict(e) for e in events - ] - - return cls(rooms, parsed_events) + via = content.get("via") + if not isinstance(via, Sequence): + raise ValueError("Invalid event: 'via' must be a list") + if any(not isinstance(v, str) for v in via): + raise ValueError("Invalid event: 'via' must be a list of strings") diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 69998de52..de6e5f44f 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -1179,39 +1179,6 @@ class TransportLayerClient: return await self.client.get_json(destination=destination, path=path) - async def get_space_summary( - self, - destination: str, - room_id: str, - suggested_only: bool, - max_rooms_per_space: Optional[int], - exclude_rooms: List[str], - ) -> JsonDict: - """ - Args: - destination: The remote server - room_id: The room ID to ask about. - suggested_only: if True, only suggested rooms will be returned - max_rooms_per_space: an optional limit to the number of children to be - returned per space - exclude_rooms: a list of any rooms we can skip - """ - # TODO When switching to the stable endpoint, use GET instead of POST. - path = _create_path( - FEDERATION_UNSTABLE_PREFIX, "/org.matrix.msc2946/spaces/%s", room_id - ) - - params = { - "suggested_only": suggested_only, - "exclude_rooms": exclude_rooms, - } - if max_rooms_per_space is not None: - params["max_rooms_per_space"] = max_rooms_per_space - - return await self.client.post_json( - destination=destination, path=path, data=params - ) - async def get_room_hierarchy( self, destination: str, room_id: str, suggested_only: bool ) -> JsonDict: diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 23ce34305..aed3d5069 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -624,81 +624,6 @@ class FederationVersionServlet(BaseFederationServlet): ) -class FederationSpaceSummaryServlet(BaseFederationServlet): - PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" - PATH = "/spaces/(?P[^/]*)" - - def __init__( - self, - hs: "HomeServer", - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_room_summary_handler() - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Mapping[bytes, Sequence[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - suggested_only = parse_boolean_from_args(query, "suggested_only", default=False) - - max_rooms_per_space = parse_integer_from_args(query, "max_rooms_per_space") - if max_rooms_per_space is not None and max_rooms_per_space < 0: - raise SynapseError( - 400, - "Value for 'max_rooms_per_space' must be a non-negative integer", - Codes.BAD_JSON, - ) - - exclude_rooms = parse_strings_from_args(query, "exclude_rooms", default=[]) - - return 200, await self.handler.federation_space_summary( - origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms - ) - - # TODO When switching to the stable endpoint, remove the POST handler. - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Mapping[bytes, Sequence[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - suggested_only = content.get("suggested_only", False) - if not isinstance(suggested_only, bool): - raise SynapseError( - 400, "'suggested_only' must be a boolean", Codes.BAD_JSON - ) - - exclude_rooms = content.get("exclude_rooms", []) - if not isinstance(exclude_rooms, list) or any( - not isinstance(x, str) for x in exclude_rooms - ): - raise SynapseError(400, "bad value for 'exclude_rooms'", Codes.BAD_JSON) - - max_rooms_per_space = content.get("max_rooms_per_space") - if max_rooms_per_space is not None: - if not isinstance(max_rooms_per_space, int): - raise SynapseError( - 400, "bad value for 'max_rooms_per_space'", Codes.BAD_JSON - ) - if max_rooms_per_space < 0: - raise SynapseError( - 400, - "Value for 'max_rooms_per_space' must be a non-negative integer", - Codes.BAD_JSON, - ) - - return 200, await self.handler.federation_space_summary( - origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms - ) - - class FederationRoomHierarchyServlet(BaseFederationServlet): PATH = "/hierarchy/(?P[^/]*)" @@ -826,7 +751,6 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( On3pidBindServlet, FederationVersionServlet, RoomComplexityServlet, - FederationSpaceSummaryServlet, FederationRoomHierarchyServlet, FederationRoomHierarchyUnstableServlet, FederationV1SendKnockServlet, diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index 2e61d1cbe..55c2cbdba 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -15,7 +15,6 @@ import itertools import logging import re -from collections import deque from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Set, Tuple import attr @@ -107,153 +106,6 @@ class RoomSummaryHandler: "get_room_hierarchy", ) - async def get_space_summary( - self, - requester: str, - room_id: str, - suggested_only: bool = False, - max_rooms_per_space: Optional[int] = None, - ) -> JsonDict: - """ - Implementation of the space summary C-S API - - Args: - requester: user id of the user making this request - - room_id: room id to start the summary at - - suggested_only: whether we should only return children with the "suggested" - flag set. - - max_rooms_per_space: an optional limit on the number of child rooms we will - return. This does not apply to the root room (ie, room_id), and - is overridden by MAX_ROOMS_PER_SPACE. - - Returns: - summary dict to return - """ - # First of all, check that the room is accessible. - if not await self._is_local_room_accessible(room_id, requester): - raise AuthError( - 403, - "User %s not in room %s, and room previews are disabled" - % (requester, room_id), - ) - - # the queue of rooms to process - room_queue = deque((_RoomQueueEntry(room_id, ()),)) - - # rooms we have already processed - processed_rooms: Set[str] = set() - - # events we have already processed. We don't necessarily have their event ids, - # so instead we key on (room id, state key) - processed_events: Set[Tuple[str, str]] = set() - - rooms_result: List[JsonDict] = [] - events_result: List[JsonDict] = [] - - if max_rooms_per_space is None or max_rooms_per_space > MAX_ROOMS_PER_SPACE: - max_rooms_per_space = MAX_ROOMS_PER_SPACE - - while room_queue and len(rooms_result) < MAX_ROOMS: - queue_entry = room_queue.popleft() - room_id = queue_entry.room_id - if room_id in processed_rooms: - # already done this room - continue - - logger.debug("Processing room %s", room_id) - - is_in_room = await self._store.is_host_joined(room_id, self._server_name) - - # The client-specified max_rooms_per_space limit doesn't apply to the - # room_id specified in the request, so we ignore it if this is the - # first room we are processing. - max_children = max_rooms_per_space if processed_rooms else MAX_ROOMS - - if is_in_room: - room_entry = await self._summarize_local_room( - requester, None, room_id, suggested_only, max_children - ) - - events: Sequence[JsonDict] = [] - if room_entry: - rooms_result.append(room_entry.room) - events = room_entry.children_state_events - - logger.debug( - "Query of local room %s returned events %s", - room_id, - ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], - ) - else: - fed_rooms = await self._summarize_remote_room( - queue_entry, - suggested_only, - max_children, - exclude_rooms=processed_rooms, - ) - - # The results over federation might include rooms that the we, - # as the requesting server, are allowed to see, but the requesting - # user is not permitted see. - # - # Filter the returned results to only what is accessible to the user. - events = [] - for room_entry in fed_rooms: - room = room_entry.room - fed_room_id = room_entry.room_id - - # The user can see the room, include it! - if await self._is_remote_room_accessible( - requester, fed_room_id, room - ): - # Before returning to the client, remove the allowed_room_ids - # and allowed_spaces keys. - room.pop("allowed_room_ids", None) - room.pop("allowed_spaces", None) # historical - - rooms_result.append(room) - events.extend(room_entry.children_state_events) - - # All rooms returned don't need visiting again (even if the user - # didn't have access to them). - processed_rooms.add(fed_room_id) - - logger.debug( - "Query of %s returned rooms %s, events %s", - room_id, - [room_entry.room.get("room_id") for room_entry in fed_rooms], - ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], - ) - - # the room we queried may or may not have been returned, but don't process - # it again, anyway. - processed_rooms.add(room_id) - - # XXX: is it ok that we blindly iterate through any events returned by - # a remote server, whether or not they actually link to any rooms in our - # tree? - for ev in events: - # remote servers might return events we have already processed - # (eg, Dendrite returns inward pointers as well as outward ones), so - # we need to filter them out, to avoid returning duplicate links to the - # client. - ev_key = (ev["room_id"], ev["state_key"]) - if ev_key in processed_events: - continue - events_result.append(ev) - - # add the child to the queue. we have already validated - # that the vias are a list of server names. - room_queue.append( - _RoomQueueEntry(ev["state_key"], ev["content"]["via"]) - ) - processed_events.add(ev_key) - - return {"rooms": rooms_result, "events": events_result} - async def get_room_hierarchy( self, requester: Requester, @@ -398,8 +250,6 @@ class RoomSummaryHandler: None, room_id, suggested_only, - # Do not limit the maximum children. - max_children=None, ) # Otherwise, attempt to use information for federation. @@ -488,74 +338,6 @@ class RoomSummaryHandler: return result - async def federation_space_summary( - self, - origin: str, - room_id: str, - suggested_only: bool, - max_rooms_per_space: Optional[int], - exclude_rooms: Iterable[str], - ) -> JsonDict: - """ - Implementation of the space summary Federation API - - Args: - origin: The server requesting the spaces summary. - - room_id: room id to start the summary at - - suggested_only: whether we should only return children with the "suggested" - flag set. - - max_rooms_per_space: an optional limit on the number of child rooms we will - return. Unlike the C-S API, this applies to the root room (room_id). - It is clipped to MAX_ROOMS_PER_SPACE. - - exclude_rooms: a list of rooms to skip over (presumably because the - calling server has already seen them). - - Returns: - summary dict to return - """ - # the queue of rooms to process - room_queue = deque((room_id,)) - - # the set of rooms that we should not walk further. Initialise it with the - # excluded-rooms list; we will add other rooms as we process them so that - # we do not loop. - processed_rooms: Set[str] = set(exclude_rooms) - - rooms_result: List[JsonDict] = [] - events_result: List[JsonDict] = [] - - # Set a limit on the number of rooms to return. - if max_rooms_per_space is None or max_rooms_per_space > MAX_ROOMS_PER_SPACE: - max_rooms_per_space = MAX_ROOMS_PER_SPACE - - while room_queue and len(rooms_result) < MAX_ROOMS: - room_id = room_queue.popleft() - if room_id in processed_rooms: - # already done this room - continue - - room_entry = await self._summarize_local_room( - None, origin, room_id, suggested_only, max_rooms_per_space - ) - - processed_rooms.add(room_id) - - if room_entry: - rooms_result.append(room_entry.room) - events_result.extend(room_entry.children_state_events) - - # add any children to the queue - room_queue.extend( - edge_event["state_key"] - for edge_event in room_entry.children_state_events - ) - - return {"rooms": rooms_result, "events": events_result} - async def get_federation_hierarchy( self, origin: str, @@ -579,7 +361,7 @@ class RoomSummaryHandler: The JSON hierarchy dictionary. """ root_room_entry = await self._summarize_local_room( - None, origin, requested_room_id, suggested_only, max_children=None + None, origin, requested_room_id, suggested_only ) if root_room_entry is None: # Room is inaccessible to the requesting server. @@ -600,7 +382,7 @@ class RoomSummaryHandler: continue room_entry = await self._summarize_local_room( - None, origin, room_id, suggested_only, max_children=0 + None, origin, room_id, suggested_only, include_children=False ) # If the room is accessible, include it in the results. # @@ -626,7 +408,7 @@ class RoomSummaryHandler: origin: Optional[str], room_id: str, suggested_only: bool, - max_children: Optional[int], + include_children: bool = True, ) -> Optional["_RoomEntry"]: """ Generate a room entry and a list of event entries for a given room. @@ -641,9 +423,8 @@ class RoomSummaryHandler: room_id: The room ID to summarize. suggested_only: True if only suggested children should be returned. Otherwise, all children are returned. - max_children: - The maximum number of children rooms to include. A value of None - means no limit. + include_children: + Whether to include the events of any children. Returns: A room entry if the room should be returned. None, otherwise. @@ -653,9 +434,8 @@ class RoomSummaryHandler: room_entry = await self._build_room_entry(room_id, for_federation=bool(origin)) - # If the room is not a space or the children don't matter, return just - # the room information. - if room_entry.get("room_type") != RoomTypes.SPACE or max_children == 0: + # If the room is not a space return just the room information. + if room_entry.get("room_type") != RoomTypes.SPACE or not include_children: return _RoomEntry(room_id, room_entry) # Otherwise, look for child rooms/spaces. @@ -665,14 +445,6 @@ class RoomSummaryHandler: # we only care about suggested children child_events = filter(_is_suggested_child_event, child_events) - # TODO max_children is legacy code for the /spaces endpoint. - if max_children is not None: - child_iter: Iterable[EventBase] = itertools.islice( - child_events, max_children - ) - else: - child_iter = child_events - stripped_events: List[JsonDict] = [ { "type": e.type, @@ -682,80 +454,10 @@ class RoomSummaryHandler: "sender": e.sender, "origin_server_ts": e.origin_server_ts, } - for e in child_iter + for e in child_events ] return _RoomEntry(room_id, room_entry, stripped_events) - async def _summarize_remote_room( - self, - room: "_RoomQueueEntry", - suggested_only: bool, - max_children: Optional[int], - exclude_rooms: Iterable[str], - ) -> Iterable["_RoomEntry"]: - """ - Request room entries and a list of event entries for a given room by querying a remote server. - - Args: - room: The room to summarize. - suggested_only: True if only suggested children should be returned. - Otherwise, all children are returned. - max_children: - The maximum number of children rooms to include. This is capped - to a server-set limit. - exclude_rooms: - Rooms IDs which do not need to be summarized. - - Returns: - An iterable of room entries. - """ - room_id = room.room_id - logger.info("Requesting summary for %s via %s", room_id, room.via) - - # we need to make the exclusion list json-serialisable - exclude_rooms = list(exclude_rooms) - - via = itertools.islice(room.via, MAX_SERVERS_PER_SPACE) - try: - res = await self._federation_client.get_space_summary( - via, - room_id, - suggested_only=suggested_only, - max_rooms_per_space=max_children, - exclude_rooms=exclude_rooms, - ) - except Exception as e: - logger.warning( - "Unable to get summary of %s via federation: %s", - room_id, - e, - exc_info=logger.isEnabledFor(logging.DEBUG), - ) - return () - - # Group the events by their room. - children_by_room: Dict[str, List[JsonDict]] = {} - for ev in res.events: - if ev.event_type == EventTypes.SpaceChild: - children_by_room.setdefault(ev.room_id, []).append(ev.data) - - # Generate the final results. - results = [] - for fed_room in res.rooms: - fed_room_id = fed_room.get("room_id") - if not fed_room_id or not isinstance(fed_room_id, str): - continue - - results.append( - _RoomEntry( - fed_room_id, - fed_room, - children_by_room.get(fed_room_id, []), - ) - ) - - return results - async def _summarize_remote_room_hierarchy( self, room: "_RoomQueueEntry", suggested_only: bool ) -> Tuple[Optional["_RoomEntry"], Dict[str, JsonDict], Set[str]]: @@ -958,9 +660,8 @@ class RoomSummaryHandler: ): return True - # Check if the user is a member of any of the allowed spaces - # from the response. - allowed_rooms = room.get("allowed_room_ids") or room.get("allowed_spaces") + # Check if the user is a member of any of the allowed rooms from the response. + allowed_rooms = room.get("allowed_room_ids") if allowed_rooms and isinstance(allowed_rooms, list): if await self._event_auth_handler.is_user_in_rooms( allowed_rooms, requester @@ -1028,8 +729,6 @@ class RoomSummaryHandler: ) if allowed_rooms: entry["allowed_room_ids"] = allowed_rooms - # TODO Remove this key once the API is stable. - entry["allowed_spaces"] = allowed_rooms # Filter out Nones – rather omit the field altogether room_entry = {k: v for k, v in entry.items() if v is not None} @@ -1094,7 +793,7 @@ class RoomSummaryHandler: room_id, # Suggested-only doesn't matter since no children are requested. suggested_only=False, - max_children=0, + include_children=False, ) if not room_entry: diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 5ccfe5a92..8a06ab8c5 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -1141,73 +1141,6 @@ class TimestampLookupRestServlet(RestServlet): } -class RoomSpaceSummaryRestServlet(RestServlet): - PATTERNS = ( - re.compile( - "^/_matrix/client/unstable/org.matrix.msc2946" - "/rooms/(?P[^/]*)/spaces$" - ), - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self._auth = hs.get_auth() - self._room_summary_handler = hs.get_room_summary_handler() - - async def on_GET( - self, request: SynapseRequest, room_id: str - ) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request, allow_guest=True) - - max_rooms_per_space = parse_integer(request, "max_rooms_per_space") - if max_rooms_per_space is not None and max_rooms_per_space < 0: - raise SynapseError( - 400, - "Value for 'max_rooms_per_space' must be a non-negative integer", - Codes.BAD_JSON, - ) - - return 200, await self._room_summary_handler.get_space_summary( - requester.user.to_string(), - room_id, - suggested_only=parse_boolean(request, "suggested_only", default=False), - max_rooms_per_space=max_rooms_per_space, - ) - - # TODO When switching to the stable endpoint, remove the POST handler. - async def on_POST( - self, request: SynapseRequest, room_id: str - ) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request, allow_guest=True) - content = parse_json_object_from_request(request) - - suggested_only = content.get("suggested_only", False) - if not isinstance(suggested_only, bool): - raise SynapseError( - 400, "'suggested_only' must be a boolean", Codes.BAD_JSON - ) - - max_rooms_per_space = content.get("max_rooms_per_space") - if max_rooms_per_space is not None: - if not isinstance(max_rooms_per_space, int): - raise SynapseError( - 400, "'max_rooms_per_space' must be an integer", Codes.BAD_JSON - ) - if max_rooms_per_space < 0: - raise SynapseError( - 400, - "Value for 'max_rooms_per_space' must be a non-negative integer", - Codes.BAD_JSON, - ) - - return 200, await self._room_summary_handler.get_space_summary( - requester.user.to_string(), - room_id, - suggested_only=suggested_only, - max_rooms_per_space=max_rooms_per_space, - ) - - class RoomHierarchyRestServlet(RestServlet): PATTERNS = ( re.compile( @@ -1301,7 +1234,6 @@ def register_servlets( RoomRedactEventRestServlet(hs).register(http_server) RoomTypingRestServlet(hs).register(http_server) RoomEventContextServlet(hs).register(http_server) - RoomSpaceSummaryRestServlet(hs).register(http_server) RoomHierarchyRestServlet(hs).register(http_server) if hs.config.experimental.msc3266_enabled: RoomSummaryRestServlet(hs).register(http_server) diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py index 51b22d299..b33ff94a3 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py @@ -157,35 +157,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): state_key=room_id, ) - def _assert_rooms( - self, result: JsonDict, rooms_and_children: Iterable[Tuple[str, Iterable[str]]] - ) -> None: - """ - Assert that the expected room IDs and events are in the response. - - Args: - result: The result from the API call. - rooms_and_children: An iterable of tuples where each tuple is: - The expected room ID. - The expected IDs of any children rooms. - """ - room_ids = [] - children_ids = [] - for room_id, children in rooms_and_children: - room_ids.append(room_id) - if children: - children_ids.extend([(room_id, child_id) for child_id in children]) - self.assertCountEqual( - [room.get("room_id") for room in result["rooms"]], room_ids - ) - self.assertCountEqual( - [ - (event.get("room_id"), event.get("state_key")) - for event in result["events"] - ], - children_ids, - ) - def _assert_hierarchy( self, result: JsonDict, rooms_and_children: Iterable[Tuple[str, Iterable[str]]] ) -> None: @@ -251,11 +222,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): def test_simple_space(self): """Test a simple space with a single room.""" - result = self.get_success(self.handler.get_space_summary(self.user, self.space)) # The result should have the space and the room in it, along with a link # from space -> room. expected = [(self.space, [self.room]), (self.room, ())] - self._assert_rooms(result, expected) result = self.get_success( self.handler.get_room_hierarchy(create_requester(self.user), self.space) @@ -271,12 +240,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self._add_child(self.space, room, self.token) rooms.append(room) - result = self.get_success(self.handler.get_space_summary(self.user, self.space)) - # The spaces result should have the space and the first 50 rooms in it, - # along with the links from space -> room for those 50 rooms. - expected = [(self.space, rooms[:50])] + [(room, []) for room in rooms[:49]] - self._assert_rooms(result, expected) - # The result should have the space and the rooms in it, along with the links # from space -> room. expected = [(self.space, rooms)] + [(room, []) for room in rooms] @@ -300,10 +263,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): token2 = self.login("user2", "pass") # The user can see the space since it is publicly joinable. - result = self.get_success(self.handler.get_space_summary(user2, self.space)) expected = [(self.space, [self.room]), (self.room, ())] - self._assert_rooms(result, expected) - result = self.get_success( self.handler.get_room_hierarchy(create_requester(user2), self.space) ) @@ -316,7 +276,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): body={"join_rule": JoinRules.INVITE}, tok=self.token, ) - self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) self.get_failure( self.handler.get_room_hierarchy(create_requester(user2), self.space), AuthError, @@ -329,9 +288,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): body={"history_visibility": HistoryVisibility.WORLD_READABLE}, tok=self.token, ) - result = self.get_success(self.handler.get_space_summary(user2, self.space)) - self._assert_rooms(result, expected) - result = self.get_success( self.handler.get_room_hierarchy(create_requester(user2), self.space) ) @@ -344,7 +300,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): body={"history_visibility": HistoryVisibility.JOINED}, tok=self.token, ) - self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) self.get_failure( self.handler.get_room_hierarchy(create_requester(user2), self.space), AuthError, @@ -353,19 +308,12 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # Join the space and results should be returned. self.helper.invite(self.space, targ=user2, tok=self.token) self.helper.join(self.space, user2, tok=token2) - result = self.get_success(self.handler.get_space_summary(user2, self.space)) - self._assert_rooms(result, expected) - result = self.get_success( self.handler.get_room_hierarchy(create_requester(user2), self.space) ) self._assert_hierarchy(result, expected) # Attempting to view an unknown room returns the same error. - self.get_failure( - self.handler.get_space_summary(user2, "#not-a-space:" + self.hs.hostname), - AuthError, - ) self.get_failure( self.handler.get_room_hierarchy( create_requester(user2), "#not-a-space:" + self.hs.hostname @@ -496,7 +444,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # Join the space. self.helper.join(self.space, user2, tok=token2) - result = self.get_success(self.handler.get_space_summary(user2, self.space)) expected = [ ( self.space, @@ -520,7 +467,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): (world_readable_room, ()), (joined_room, ()), ] - self._assert_rooms(result, expected) result = self.get_success( self.handler.get_room_hierarchy(create_requester(user2), self.space) @@ -554,8 +500,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self._add_child(subspace, self.room, token=self.token) self._add_child(subspace, room2, self.token) - result = self.get_success(self.handler.get_space_summary(self.user, self.space)) - # The result should include each room a single time and each link. expected = [ (self.space, [self.room, room2, subspace]), @@ -563,7 +507,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): (subspace, [subroom, self.room, room2]), (subroom, ()), ] - self._assert_rooms(result, expected) result = self.get_success( self.handler.get_room_hierarchy(create_requester(self.user), self.space) @@ -728,10 +671,8 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ) ) - result = self.get_success(self.handler.get_space_summary(self.user, self.space)) # The result should have only the space, along with a link from space -> room. expected = [(self.space, [self.room])] - self._assert_rooms(result, expected) result = self.get_success( self.handler.get_room_hierarchy(create_requester(self.user), self.space) @@ -775,41 +716,18 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): "world_readable": True, } - async def summarize_remote_room( - _self, room, suggested_only, max_children, exclude_rooms - ): - return [ - requested_room_entry, - _RoomEntry( - subroom, - { - "room_id": subroom, - "world_readable": True, - }, - ), - ] - async def summarize_remote_room_hierarchy(_self, room, suggested_only): return requested_room_entry, {subroom: child_room}, set() # Add a room to the space which is on another server. self._add_child(self.space, subspace, self.token) - with mock.patch( - "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room", - new=summarize_remote_room, - ): - result = self.get_success( - self.handler.get_space_summary(self.user, self.space) - ) - expected = [ (self.space, [self.room, subspace]), (self.room, ()), (subspace, [subroom]), (subroom, ()), ] - self._assert_rooms(result, expected) with mock.patch( "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy", @@ -881,7 +799,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): "room_id": restricted_room, "world_readable": False, "join_rules": JoinRules.RESTRICTED, - "allowed_spaces": [], + "allowed_room_ids": [], }, ), ( @@ -890,7 +808,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): "room_id": restricted_accessible_room, "world_readable": False, "join_rules": JoinRules.RESTRICTED, - "allowed_spaces": [self.room], + "allowed_room_ids": [self.room], }, ), ( @@ -929,30 +847,12 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ], ) - async def summarize_remote_room( - _self, room, suggested_only, max_children, exclude_rooms - ): - return [subspace_room_entry] + [ - # A copy is made of the room data since the allowed_spaces key - # is removed. - _RoomEntry(child_room[0], dict(child_room[1])) - for child_room in children_rooms - ] - async def summarize_remote_room_hierarchy(_self, room, suggested_only): return subspace_room_entry, dict(children_rooms), set() # Add a room to the space which is on another server. self._add_child(self.space, subspace, self.token) - with mock.patch( - "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room", - new=summarize_remote_room, - ): - result = self.get_success( - self.handler.get_space_summary(self.user, self.space) - ) - expected = [ (self.space, [self.room, subspace]), (self.room, ()), @@ -976,7 +876,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): (world_readable_room, ()), (joined_room, ()), ] - self._assert_rooms(result, expected) with mock.patch( "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy", @@ -1010,31 +909,17 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): }, ) - async def summarize_remote_room( - _self, room, suggested_only, max_children, exclude_rooms - ): - return [fed_room_entry] - async def summarize_remote_room_hierarchy(_self, room, suggested_only): return fed_room_entry, {}, set() # Add a room to the space which is on another server. self._add_child(self.space, fed_room, self.token) - with mock.patch( - "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room", - new=summarize_remote_room, - ): - result = self.get_success( - self.handler.get_space_summary(self.user, self.space) - ) - expected = [ (self.space, [self.room, fed_room]), (self.room, ()), (fed_room, ()), ] - self._assert_rooms(result, expected) with mock.patch( "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy", From 952efd0bca967bc2fcabe5c3f1f58e14ddc41686 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 28 Feb 2022 19:59:00 +0100 Subject: [PATCH 67/84] Add type hints to `tests/rest/client` (#12094) * Add type hints to `tests/rest/client` * update `mypy.ini` * newsfile * add `test_register.py` --- changelog.d/12094.misc | 1 + mypy.ini | 3 - tests/rest/client/test_events.py | 20 +++--- tests/rest/client/test_groups.py | 2 +- tests/rest/client/test_register.py | 110 +++++++++++++++-------------- 5 files changed, 72 insertions(+), 64 deletions(-) create mode 100644 changelog.d/12094.misc diff --git a/changelog.d/12094.misc b/changelog.d/12094.misc new file mode 100644 index 000000000..0360dbd61 --- /dev/null +++ b/changelog.d/12094.misc @@ -0,0 +1 @@ +Add type hints to `tests/rest/client`. diff --git a/mypy.ini b/mypy.ini index bd75905c8..38ff78760 100644 --- a/mypy.ini +++ b/mypy.ini @@ -75,10 +75,7 @@ exclude = (?x) |tests/push/test_presentable_names.py |tests/push/test_push_rule_evaluator.py |tests/rest/client/test_account.py - |tests/rest/client/test_events.py |tests/rest/client/test_filter.py - |tests/rest/client/test_groups.py - |tests/rest/client/test_register.py |tests/rest/client/test_report_event.py |tests/rest/client/test_rooms.py |tests/rest/client/test_third_party_rules.py diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py index 145f24783..1b1392fa2 100644 --- a/tests/rest/client/test_events.py +++ b/tests/rest/client/test_events.py @@ -16,8 +16,12 @@ from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.rest.client import events, login, room +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -32,7 +36,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["enable_registration_captcha"] = False @@ -41,11 +45,11 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver(config=config) - hs.get_federation_handler = Mock() + hs.get_federation_handler = Mock() # type: ignore[assignment] return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # register an account self.user_id = self.register_user("sid1", "pass") @@ -55,7 +59,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): self.other_user = self.register_user("other2", "pass") self.other_token = self.login(self.other_user, "pass") - def test_stream_basic_permissions(self): + def test_stream_basic_permissions(self) -> None: # invalid token, expect 401 # note: this is in violation of the original v1 spec, which expected # 403. However, since the v1 spec no longer exists and the v1 @@ -76,7 +80,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): self.assertTrue("start" in channel.json_body) self.assertTrue("end" in channel.json_body) - def test_stream_room_permissions(self): + def test_stream_room_permissions(self) -> None: room_id = self.helper.create_room_as(self.other_user, tok=self.other_token) self.helper.send(room_id, tok=self.other_token) @@ -111,7 +115,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): # left to room (expect no content for room) - def TODO_test_stream_items(self): + def TODO_test_stream_items(self) -> None: # new user, no content # join room, expect 1 item (join) @@ -136,7 +140,7 @@ class GetEventsTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, hs, reactor, clock): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # register an account self.user_id = self.register_user("sid1", "pass") @@ -144,7 +148,7 @@ class GetEventsTestCase(unittest.HomeserverTestCase): self.room_id = self.helper.create_room_as(self.user_id, tok=self.token) - def test_get_event_via_events(self): + def test_get_event_via_events(self) -> None: resp = self.helper.send(self.room_id, tok=self.token) event_id = resp["event_id"] diff --git a/tests/rest/client/test_groups.py b/tests/rest/client/test_groups.py index c99f54cf4..e067cf825 100644 --- a/tests/rest/client/test_groups.py +++ b/tests/rest/client/test_groups.py @@ -25,7 +25,7 @@ class GroupsTestCase(unittest.HomeserverTestCase): servlets = [room.register_servlets, groups.register_servlets] @override_config({"enable_group_creation": True}) - def test_rooms_limited_by_visibility(self): + def test_rooms_limited_by_visibility(self) -> None: group_id = "+spqr:test" # Alice creates a group diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 4b95b8541..9aebf1735 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -16,15 +16,21 @@ import datetime import json import os +from typing import Any, Dict, List, Tuple import pkg_resources +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType from synapse.api.errors import Codes from synapse.appservice import ApplicationService from synapse.rest.client import account, account_validity, login, logout, register, sync +from synapse.server import HomeServer from synapse.storage._base import db_to_json +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest from tests.unittest import override_config @@ -39,12 +45,12 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ] url = b"/_matrix/client/r0/register" - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["allow_guest_access"] = True return config - def test_POST_appservice_registration_valid(self): + def test_POST_appservice_registration_valid(self) -> None: user_id = "@as_user_kermit:test" as_token = "i_am_an_app_service" @@ -69,7 +75,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): det_data = {"user_id": user_id, "home_server": self.hs.hostname} self.assertDictContainsSubset(det_data, channel.json_body) - def test_POST_appservice_registration_no_type(self): + def test_POST_appservice_registration_no_type(self) -> None: as_token = "i_am_an_app_service" appservice = ApplicationService( @@ -89,7 +95,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"400", channel.result) - def test_POST_appservice_registration_invalid(self): + def test_POST_appservice_registration_invalid(self) -> None: self.appservice = None # no application service exists request_data = json.dumps( {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE} @@ -100,21 +106,21 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"401", channel.result) - def test_POST_bad_password(self): + def test_POST_bad_password(self) -> None: request_data = json.dumps({"username": "kermit", "password": 666}) channel = self.make_request(b"POST", self.url, request_data) self.assertEqual(channel.result["code"], b"400", channel.result) self.assertEqual(channel.json_body["error"], "Invalid password") - def test_POST_bad_username(self): + def test_POST_bad_username(self) -> None: request_data = json.dumps({"username": 777, "password": "monkey"}) channel = self.make_request(b"POST", self.url, request_data) self.assertEqual(channel.result["code"], b"400", channel.result) self.assertEqual(channel.json_body["error"], "Invalid username") - def test_POST_user_valid(self): + def test_POST_user_valid(self) -> None: user_id = "@kermit:test" device_id = "frogfone" params = { @@ -135,7 +141,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertDictContainsSubset(det_data, channel.json_body) @override_config({"enable_registration": False}) - def test_POST_disabled_registration(self): + def test_POST_disabled_registration(self) -> None: request_data = json.dumps({"username": "kermit", "password": "monkey"}) self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) @@ -145,7 +151,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["error"], "Registration has been disabled") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - def test_POST_guest_registration(self): + def test_POST_guest_registration(self) -> None: self.hs.config.key.macaroon_secret_key = "test" self.hs.config.registration.allow_guest_access = True @@ -155,7 +161,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"200", channel.result) self.assertDictContainsSubset(det_data, channel.json_body) - def test_POST_disabled_guest_registration(self): + def test_POST_disabled_guest_registration(self) -> None: self.hs.config.registration.allow_guest_access = False channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") @@ -164,7 +170,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["error"], "Guest access is disabled") @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) - def test_POST_ratelimiting_guest(self): + def test_POST_ratelimiting_guest(self) -> None: for i in range(0, 6): url = self.url + b"?kind=guest" channel = self.make_request(b"POST", url, b"{}") @@ -182,7 +188,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"200", channel.result) @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) - def test_POST_ratelimiting(self): + def test_POST_ratelimiting(self) -> None: for i in range(0, 6): params = { "username": "kermit" + str(i), @@ -206,7 +212,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"200", channel.result) @override_config({"registration_requires_token": True}) - def test_POST_registration_requires_token(self): + def test_POST_registration_requires_token(self) -> None: username = "kermit" device_id = "frogfone" token = "abcd" @@ -223,7 +229,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): }, ) ) - params = { + params: JsonDict = { "username": username, "password": "monkey", "device_id": device_id, @@ -280,8 +286,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(res["pending"], 0) @override_config({"registration_requires_token": True}) - def test_POST_registration_token_invalid(self): - params = { + def test_POST_registration_token_invalid(self) -> None: + params: JsonDict = { "username": "kermit", "password": "monkey", } @@ -314,7 +320,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["completed"], []) @override_config({"registration_requires_token": True}) - def test_POST_registration_token_limit_uses(self): + def test_POST_registration_token_limit_uses(self) -> None: token = "abcd" store = self.hs.get_datastores().main # Create token that can be used once @@ -330,8 +336,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): }, ) ) - params1 = {"username": "bert", "password": "monkey"} - params2 = {"username": "ernie", "password": "monkey"} + params1: JsonDict = {"username": "bert", "password": "monkey"} + params2: JsonDict = {"username": "ernie", "password": "monkey"} # Do 2 requests without auth to get two session IDs channel1 = self.make_request(b"POST", self.url, json.dumps(params1)) session1 = channel1.json_body["session"] @@ -388,7 +394,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["completed"], []) @override_config({"registration_requires_token": True}) - def test_POST_registration_token_expiry(self): + def test_POST_registration_token_expiry(self) -> None: token = "abcd" now = self.hs.get_clock().time_msec() store = self.hs.get_datastores().main @@ -405,7 +411,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): }, ) ) - params = {"username": "kermit", "password": "monkey"} + params: JsonDict = {"username": "kermit", "password": "monkey"} # Request without auth to get session channel = self.make_request(b"POST", self.url, json.dumps(params)) session = channel.json_body["session"] @@ -436,7 +442,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) @override_config({"registration_requires_token": True}) - def test_POST_registration_token_session_expiry(self): + def test_POST_registration_token_session_expiry(self) -> None: """Test `pending` is decremented when an uncompleted session expires.""" token = "abcd" store = self.hs.get_datastores().main @@ -454,8 +460,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) # Do 2 requests without auth to get two session IDs - params1 = {"username": "bert", "password": "monkey"} - params2 = {"username": "ernie", "password": "monkey"} + params1: JsonDict = {"username": "bert", "password": "monkey"} + params2: JsonDict = {"username": "ernie", "password": "monkey"} channel1 = self.make_request(b"POST", self.url, json.dumps(params1)) session1 = channel1.json_body["session"] channel2 = self.make_request(b"POST", self.url, json.dumps(params2)) @@ -522,7 +528,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(pending, 0) @override_config({"registration_requires_token": True}) - def test_POST_registration_token_session_expiry_deleted_token(self): + def test_POST_registration_token_session_expiry_deleted_token(self) -> None: """Test session expiry doesn't break when the token is deleted. 1. Start but don't complete UIA with a registration token @@ -545,7 +551,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) # Do request without auth to get a session ID - params = {"username": "kermit", "password": "monkey"} + params: JsonDict = {"username": "kermit", "password": "monkey"} channel = self.make_request(b"POST", self.url, json.dumps(params)) session = channel.json_body["session"] @@ -570,7 +576,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec()) ) - def test_advertised_flows(self): + def test_advertised_flows(self) -> None: channel = self.make_request(b"POST", self.url, b"{}") self.assertEqual(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] @@ -593,7 +599,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): }, } ) - def test_advertised_flows_captcha_and_terms_and_3pids(self): + def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None: channel = self.make_request(b"POST", self.url, b"{}") self.assertEqual(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] @@ -625,7 +631,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): }, } ) - def test_advertised_flows_no_msisdn_email_required(self): + def test_advertised_flows_no_msisdn_email_required(self) -> None: channel = self.make_request(b"POST", self.url, b"{}") self.assertEqual(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] @@ -646,7 +652,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): }, } ) - def test_request_token_existing_email_inhibit_error(self): + def test_request_token_existing_email_inhibit_error(self) -> None: """Test that requesting a token via this endpoint doesn't leak existing associations if configured that way. """ @@ -685,7 +691,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): }, } ) - def test_reject_invalid_email(self): + def test_reject_invalid_email(self) -> None: """Check that bad emails are rejected""" # Test for email with multiple @ @@ -731,7 +737,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "inhibit_user_in_use_error": True, } ) - def test_inhibit_user_in_use_error(self): + def test_inhibit_user_in_use_error(self) -> None: """Tests that the 'inhibit_user_in_use_error' configuration flag behaves correctly. """ @@ -779,7 +785,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): account_validity.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() # Test for account expiring after a week. config["enable_registration"] = True @@ -791,7 +797,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): return self.hs - def test_validity_period(self): + def test_validity_period(self) -> None: self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") @@ -810,7 +816,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) - def test_manual_renewal(self): + def test_manual_renewal(self) -> None: user_id = self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") @@ -833,7 +839,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"GET", "/sync", access_token=tok) self.assertEqual(channel.result["code"], b"200", channel.result) - def test_manual_expire(self): + def test_manual_expire(self) -> None: user_id = self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") @@ -858,7 +864,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) - def test_logging_out_expired_user(self): + def test_logging_out_expired_user(self) -> None: user_id = self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") @@ -898,7 +904,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): account.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() # Test for account expiring after a week and renewal emails being sent 2 @@ -935,17 +941,17 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.hs = self.setup_test_homeserver(config=config) - async def sendmail(*args, **kwargs): + async def sendmail(*args: Any, **kwargs: Any) -> None: self.email_attempts.append((args, kwargs)) - self.email_attempts = [] + self.email_attempts: List[Tuple[Any, Any]] = [] self.hs.get_send_email_handler()._sendmail = sendmail self.store = self.hs.get_datastores().main return self.hs - def test_renewal_email(self): + def test_renewal_email(self) -> None: self.email_attempts = [] (user_id, tok) = self.create_user() @@ -999,7 +1005,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"GET", "/sync", access_token=tok) self.assertEqual(channel.result["code"], b"200", channel.result) - def test_renewal_invalid_token(self): + def test_renewal_invalid_token(self) -> None: # Hit the renewal endpoint with an invalid token and check that it behaves as # expected, i.e. that it responds with 404 Not Found and the correct HTML. url = "/_matrix/client/unstable/account_validity/renew?token=123" @@ -1019,7 +1025,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): channel.result["body"], expected_html.encode("utf8"), channel.result ) - def test_manual_email_send(self): + def test_manual_email_send(self) -> None: self.email_attempts = [] (user_id, tok) = self.create_user() @@ -1032,7 +1038,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.assertEqual(len(self.email_attempts), 1) - def test_deactivated_user(self): + def test_deactivated_user(self) -> None: self.email_attempts = [] (user_id, tok) = self.create_user() @@ -1056,7 +1062,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.assertEqual(len(self.email_attempts), 0) - def create_user(self): + def create_user(self) -> Tuple[str, str]: user_id = self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") # We need to manually add an email address otherwise the handler will do @@ -1073,7 +1079,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): ) return user_id, tok - def test_manual_email_send_expired_account(self): + def test_manual_email_send_expired_account(self) -> None: user_id = self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") @@ -1112,7 +1118,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.validity_period = 10 self.max_delta = self.validity_period * 10.0 / 100.0 @@ -1135,7 +1141,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): return self.hs - def test_background_job(self): + def test_background_job(self) -> None: """ Tests the same thing as test_background_job, except that it sets the startup_job_max_delta parameter and checks that the expiration date is within the @@ -1158,12 +1164,12 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): servlets = [register.register_servlets] url = "/_matrix/client/v1/register/m.login.registration_token/validity" - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["registration_requires_token"] = True return config - def test_GET_token_valid(self): + def test_GET_token_valid(self) -> None: token = "abcd" store = self.hs.get_datastores().main self.get_success( @@ -1186,7 +1192,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["valid"], True) - def test_GET_token_invalid(self): + def test_GET_token_invalid(self) -> None: token = "1234" channel = self.make_request( b"GET", @@ -1198,7 +1204,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): @override_config( {"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}} ) - def test_GET_ratelimiting(self): + def test_GET_ratelimiting(self) -> None: token = "1234" for i in range(0, 6): From 9d11fee8f223787c04c6574b8a30967e2b73cc35 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 1 Mar 2022 09:34:30 +0000 Subject: [PATCH 68/84] Improve exception handling for concurrent execution (#12109) * fix incorrect unwrapFirstError import this was being imported from the wrong place * Refactor `concurrently_execute` to use `yieldable_gather_results` * Improve exception handling in `yieldable_gather_results` Try to avoid swallowing so many stack traces. * mark unwrapFirstError deprecated * changelog --- changelog.d/12109.misc | 1 + synapse/handlers/message.py | 4 +- synapse/util/__init__.py | 4 +- synapse/util/async_helpers.py | 54 +++++++++------ tests/util/test_async_helpers.py | 115 ++++++++++++++++++++++++++++++- 5 files changed, 151 insertions(+), 27 deletions(-) create mode 100644 changelog.d/12109.misc diff --git a/changelog.d/12109.misc b/changelog.d/12109.misc new file mode 100644 index 000000000..3295e49f4 --- /dev/null +++ b/changelog.d/12109.misc @@ -0,0 +1 @@ +Improve exception handling for concurrent execution. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index a9c964cd7..ce1fa3c78 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -55,8 +55,8 @@ from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester -from synapse.util import json_decoder, json_encoder, log_failure -from synapse.util.async_helpers import Linearizer, gather_results, unwrapFirstError +from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError +from synapse.util.async_helpers import Linearizer, gather_results from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import measure_func from synapse.visibility import filter_events_for_client diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 511f52534..58b4220ff 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -81,7 +81,9 @@ json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json) def unwrapFirstError(failure: Failure) -> Failure: - # defer.gatherResults and DeferredLists wrap failures. + # Deprecated: you probably just want to catch defer.FirstError and reraise + # the subFailure's value, which will do a better job of preserving stacktraces. + # (actually, you probably want to use yieldable_gather_results anyway) failure.trap(defer.FirstError) return failure.value.subFailure # type: ignore[union-attr] # Issue in Twisted's annotations diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 3f7299aff..a83296a22 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -29,6 +29,7 @@ from typing import ( Hashable, Iterable, Iterator, + List, Optional, Set, Tuple, @@ -51,7 +52,7 @@ from synapse.logging.context import ( make_deferred_yieldable, run_in_background, ) -from synapse.util import Clock, unwrapFirstError +from synapse.util import Clock logger = logging.getLogger(__name__) @@ -193,9 +194,9 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]): T = TypeVar("T") -def concurrently_execute( +async def concurrently_execute( func: Callable[[T], Any], args: Iterable[T], limit: int -) -> defer.Deferred: +) -> None: """Executes the function with each argument concurrently while limiting the number of concurrent executions. @@ -221,20 +222,14 @@ def concurrently_execute( # We use `itertools.islice` to handle the case where the number of args is # less than the limit, avoiding needlessly spawning unnecessary background # tasks. - return make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background(_concurrently_execute_inner, value) - for value in itertools.islice(it, limit) - ], - consumeErrors=True, - ) - ).addErrback(unwrapFirstError) + await yieldable_gather_results( + _concurrently_execute_inner, (value for value in itertools.islice(it, limit)) + ) -def yieldable_gather_results( - func: Callable, iter: Iterable, *args: Any, **kwargs: Any -) -> defer.Deferred: +async def yieldable_gather_results( + func: Callable[..., Awaitable[T]], iter: Iterable, *args: Any, **kwargs: Any +) -> List[T]: """Executes the function with each argument concurrently. Args: @@ -245,15 +240,30 @@ def yieldable_gather_results( **kwargs: Keyword arguments to be passed to each call to func Returns - Deferred[list]: Resolved when all functions have been invoked, or errors if - one of the function calls fails. + A list containing the results of the function """ - return make_deferred_yieldable( - defer.gatherResults( - [run_in_background(func, item, *args, **kwargs) for item in iter], - consumeErrors=True, + try: + return await make_deferred_yieldable( + defer.gatherResults( + [run_in_background(func, item, *args, **kwargs) for item in iter], + consumeErrors=True, + ) ) - ).addErrback(unwrapFirstError) + except defer.FirstError as dfe: + # unwrap the error from defer.gatherResults. + + # The raised exception's traceback only includes func() etc if + # the 'await' happens before the exception is thrown - ie if the failure + # happens *asynchronously* - otherwise Twisted throws away the traceback as it + # could be large. + # + # We could maybe reconstruct a fake traceback from Failure.frames. Or maybe + # we could throw Twisted into the fires of Mordor. + + # suppress exception chaining, because the FirstError doesn't tell us anything + # very interesting. + assert isinstance(dfe.subFailure.value, BaseException) + raise dfe.subFailure.value from None T1 = TypeVar("T1") diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py index ab89cab81..cce8d595f 100644 --- a/tests/util/test_async_helpers.py +++ b/tests/util/test_async_helpers.py @@ -11,9 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import traceback + from twisted.internet import defer -from twisted.internet.defer import CancelledError, Deferred +from twisted.internet.defer import CancelledError, Deferred, ensureDeferred from twisted.internet.task import Clock +from twisted.python.failure import Failure from synapse.logging.context import ( SENTINEL_CONTEXT, @@ -21,7 +24,11 @@ from synapse.logging.context import ( PreserveLoggingContext, current_context, ) -from synapse.util.async_helpers import ObservableDeferred, timeout_deferred +from synapse.util.async_helpers import ( + ObservableDeferred, + concurrently_execute, + timeout_deferred, +) from tests.unittest import TestCase @@ -171,3 +178,107 @@ class TimeoutDeferredTest(TestCase): ) self.failureResultOf(timing_out_d, defer.TimeoutError) self.assertIs(current_context(), context_one) + + +class _TestException(Exception): + pass + + +class ConcurrentlyExecuteTest(TestCase): + def test_limits_runners(self): + """If we have more tasks than runners, we should get the limit of runners""" + started = 0 + waiters = [] + processed = [] + + async def callback(v): + # when we first enter, bump the start count + nonlocal started + started += 1 + + # record the fact we got an item + processed.append(v) + + # wait for the goahead before returning + d2 = Deferred() + waiters.append(d2) + await d2 + + # set it going + d2 = ensureDeferred(concurrently_execute(callback, [1, 2, 3, 4, 5], 3)) + + # check we got exactly 3 processes + self.assertEqual(started, 3) + self.assertEqual(len(waiters), 3) + + # let one finish + waiters.pop().callback(0) + + # ... which should start another + self.assertEqual(started, 4) + self.assertEqual(len(waiters), 3) + + # we still shouldn't be done + self.assertNoResult(d2) + + # finish the job + while waiters: + waiters.pop().callback(0) + + # check everything got done + self.assertEqual(started, 5) + self.assertCountEqual(processed, [1, 2, 3, 4, 5]) + self.successResultOf(d2) + + def test_preserves_stacktraces(self): + """Test that the stacktrace from an exception thrown in the callback is preserved""" + d1 = Deferred() + + async def callback(v): + # alas, this doesn't work at all without an await here + await d1 + raise _TestException("bah") + + async def caller(): + try: + await concurrently_execute(callback, [1], 2) + except _TestException as e: + tb = traceback.extract_tb(e.__traceback__) + # we expect to see "caller", "concurrently_execute" and "callback". + self.assertEqual(tb[0].name, "caller") + self.assertEqual(tb[1].name, "concurrently_execute") + self.assertEqual(tb[-1].name, "callback") + else: + self.fail("No exception thrown") + + d2 = ensureDeferred(caller()) + d1.callback(0) + self.successResultOf(d2) + + def test_preserves_stacktraces_on_preformed_failure(self): + """Test that the stacktrace on a Failure returned by the callback is preserved""" + d1 = Deferred() + f = Failure(_TestException("bah")) + + async def callback(v): + # alas, this doesn't work at all without an await here + await d1 + await defer.fail(f) + + async def caller(): + try: + await concurrently_execute(callback, [1], 2) + except _TestException as e: + tb = traceback.extract_tb(e.__traceback__) + # we expect to see "caller", "concurrently_execute", "callback", + # and some magic from inside ensureDeferred that happens when .fail + # is called. + self.assertEqual(tb[0].name, "caller") + self.assertEqual(tb[1].name, "concurrently_execute") + self.assertEqual(tb[-2].name, "callback") + else: + self.fail("No exception thrown") + + d2 = ensureDeferred(caller()) + d1.callback(0) + self.successResultOf(d2) From 5458eb8551be676fea7ff21e2b0d3c3762c871a7 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 1 Mar 2022 09:51:38 +0000 Subject: [PATCH 69/84] Fix 'Unhandled error in Deferred' (#12089) * Fix 'Unhandled error in Deferred' Fixes a CRITICAL "Unhandled error in Deferred" log message which happened when a function wrapped with `@cachedList` failed * Minor optimisation to cachedListDescriptor we can avoid re-using `missing`, which saves looking up entries in `deferreds_map`, and means we don't need to copy it. * Improve type annotation on CachedListDescriptor --- changelog.d/12089.bugfix | 1 + synapse/util/caches/descriptors.py | 62 +++++++++++++-------------- tests/util/caches/test_descriptors.py | 10 ++--- 3 files changed, 37 insertions(+), 36 deletions(-) create mode 100644 changelog.d/12089.bugfix diff --git a/changelog.d/12089.bugfix b/changelog.d/12089.bugfix new file mode 100644 index 000000000..27172c482 --- /dev/null +++ b/changelog.d/12089.bugfix @@ -0,0 +1 @@ +Fix occasional 'Unhandled error in Deferred' error message. diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index df4fb156c..1cdead02f 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -18,6 +18,7 @@ import inspect import logging from typing import ( Any, + Awaitable, Callable, Dict, Generic, @@ -346,15 +347,15 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): """Wraps an existing cache to support bulk fetching of keys. Given an iterable of keys it looks in the cache to find any hits, then passes - the tuple of missing keys to the wrapped function. + the set of missing keys to the wrapped function. - Once wrapped, the function returns a Deferred which resolves to the list - of results. + Once wrapped, the function returns a Deferred which resolves to a Dict mapping from + input key to output value. """ def __init__( self, - orig: Callable[..., Any], + orig: Callable[..., Awaitable[Dict]], cached_method_name: str, list_name: str, num_args: Optional[int] = None, @@ -385,13 +386,13 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): def __get__( self, obj: Optional[Any], objtype: Optional[Type] = None - ) -> Callable[..., Any]: + ) -> Callable[..., "defer.Deferred[Dict[Hashable, Any]]"]: cached_method = getattr(obj, self.cached_method_name) cache: DeferredCache[CacheKey, Any] = cached_method.cache num_args = cached_method.num_args @functools.wraps(self.orig) - def wrapped(*args: Any, **kwargs: Any) -> Any: + def wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Dict]": # If we're passed a cache_context then we'll want to call its # invalidate() whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) @@ -444,39 +445,38 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): deferred: "defer.Deferred[Any]" = defer.Deferred() deferreds_map[arg] = deferred key = arg_to_cache_key(arg) - cache.set(key, deferred, callback=invalidate_callback) + cached_defers.append( + cache.set(key, deferred, callback=invalidate_callback) + ) def complete_all(res: Dict[Hashable, Any]) -> None: - # the wrapped function has completed. It returns a - # a dict. We can now resolve the observable deferreds in - # the cache and update our own result map. - for e in missing: + # the wrapped function has completed. It returns a dict. + # We can now update our own result map, and then resolve the + # observable deferreds in the cache. + for e, d1 in deferreds_map.items(): val = res.get(e, None) - deferreds_map[e].callback(val) + # make sure we update the results map before running the + # deferreds, because as soon as we run the last deferred, the + # gatherResults() below will complete and return the result + # dict to our caller. results[e] = val + d1.callback(val) - def errback(f: Failure) -> Failure: - # the wrapped function has failed. Invalidate any cache - # entries we're supposed to be populating, and fail - # their deferreds. - for e in missing: - key = arg_to_cache_key(e) - cache.invalidate(key) - deferreds_map[e].errback(f) - - # return the failure, to propagate to our caller. - return f + def errback_all(f: Failure) -> None: + # the wrapped function has failed. Propagate the failure into + # the cache, which will invalidate the entry, and cause the + # relevant cached_deferreds to fail, which will propagate the + # failure to our caller. + for d1 in deferreds_map.values(): + d1.errback(f) args_to_call = dict(arg_dict) - # copy the missing set before sending it to the callee, to guard against - # modification. - args_to_call[self.list_name] = tuple(missing) + args_to_call[self.list_name] = missing - cached_defers.append( - defer.maybeDeferred( - preserve_fn(self.orig), **args_to_call - ).addCallbacks(complete_all, errback) - ) + # dispatch the call, and attach the two handlers + defer.maybeDeferred( + preserve_fn(self.orig), **args_to_call + ).addCallbacks(complete_all, errback_all) if cached_defers: d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks( diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index b92d3f0c1..19741ffcd 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -673,14 +673,14 @@ class CachedListDescriptorTestCase(unittest.TestCase): self.assertEqual(current_context(), SENTINEL_CONTEXT) r = yield d1 self.assertEqual(current_context(), c1) - obj.mock.assert_called_once_with((10, 20), 2) + obj.mock.assert_called_once_with({10, 20}, 2) self.assertEqual(r, {10: "fish", 20: "chips"}) obj.mock.reset_mock() # a call with different params should call the mock again obj.mock.return_value = {30: "peas"} r = yield obj.list_fn([20, 30], 2) - obj.mock.assert_called_once_with((30,), 2) + obj.mock.assert_called_once_with({30}, 2) self.assertEqual(r, {20: "chips", 30: "peas"}) obj.mock.reset_mock() @@ -701,7 +701,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): obj.mock.return_value = {40: "gravy"} iterable = (x for x in [10, 40, 40]) r = yield obj.list_fn(iterable, 2) - obj.mock.assert_called_once_with((40,), 2) + obj.mock.assert_called_once_with({40}, 2) self.assertEqual(r, {10: "fish", 40: "gravy"}) def test_concurrent_lookups(self): @@ -729,7 +729,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): d3 = obj.list_fn([10]) # the mock should have been called exactly once - obj.mock.assert_called_once_with((10,)) + obj.mock.assert_called_once_with({10}) obj.mock.reset_mock() # ... and none of the calls should yet be complete @@ -771,7 +771,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): # cache miss obj.mock.return_value = {10: "fish", 20: "chips"} r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0) - obj.mock.assert_called_once_with((10, 20), 2) + obj.mock.assert_called_once_with({10, 20}, 2) self.assertEqual(r1, {10: "fish", 20: "chips"}) obj.mock.reset_mock() From 4ccc2d09aae71da0be725ac177a9d4aced9a53c9 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 1 Mar 2022 12:35:32 +0000 Subject: [PATCH 70/84] Advertise Python 3.10 support in setup.py (#12111) --- changelog.d/12111.misc | 1 + setup.py | 1 + 2 files changed, 2 insertions(+) create mode 100644 changelog.d/12111.misc diff --git a/changelog.d/12111.misc b/changelog.d/12111.misc new file mode 100644 index 000000000..be84789c9 --- /dev/null +++ b/changelog.d/12111.misc @@ -0,0 +1 @@ +Advertise support for Python 3.10 in packaging files. \ No newline at end of file diff --git a/setup.py b/setup.py index c80cb6f20..26f465034 100755 --- a/setup.py +++ b/setup.py @@ -165,6 +165,7 @@ setup( "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", ], scripts=["synctl"] + glob.glob("scripts/*"), cmdclass={"test": TestCommand}, From e2e1d90a5e4030616a3de242cde26c0cfff4a6b5 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 1 Mar 2022 12:49:54 +0000 Subject: [PATCH 71/84] Faster joins: persist to database (#12012) When we get a partial_state response from send_join, store information in the database about it: * store a record about the room as a whole having partial state, and stash the list of member servers too. * flag the join event itself as having partial state * also, for any new events whose prev-events are partial-stated, note that they will *also* be partial-stated. We don't yet make any attempt to interpret this data, so API calls (and a bunch of other things) are just going to get incorrect data. --- changelog.d/12012.misc | 1 + synapse/events/snapshot.py | 9 +++ synapse/handlers/federation.py | 11 ++- synapse/handlers/federation_event.py | 13 +++- synapse/handlers/message.py | 2 + synapse/state/__init__.py | 31 +++++++- synapse/storage/databases/main/events.py | 25 +++++++ .../storage/databases/main/events_worker.py | 28 ++++++++ synapse/storage/databases/main/room.py | 37 ++++++++++ .../main/delta/68/04partial_state_rooms.sql | 41 +++++++++++ .../68/05partial_state_rooms_triggers.py | 72 +++++++++++++++++++ tests/test_state.py | 59 ++++++++------- 12 files changed, 297 insertions(+), 32 deletions(-) create mode 100644 changelog.d/12012.misc create mode 100644 synapse/storage/schema/main/delta/68/04partial_state_rooms.sql create mode 100644 synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py diff --git a/changelog.d/12012.misc b/changelog.d/12012.misc new file mode 100644 index 000000000..a473f41e7 --- /dev/null +++ b/changelog.d/12012.misc @@ -0,0 +1 @@ +Preparation for faster-room-join work: persist information on which events and rooms have partial state to the database. diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 5833fee25..46042b2bf 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -101,6 +101,9 @@ class EventContext: As with _current_state_ids, this is a private attribute. It should be accessed via get_prev_state_ids. + + partial_state: if True, we may be storing this event with a temporary, + incomplete state. """ rejected: Union[bool, str] = False @@ -113,12 +116,15 @@ class EventContext: _current_state_ids: Optional[StateMap[str]] = None _prev_state_ids: Optional[StateMap[str]] = None + partial_state: bool = False + @staticmethod def with_state( state_group: Optional[int], state_group_before_event: Optional[int], current_state_ids: Optional[StateMap[str]], prev_state_ids: Optional[StateMap[str]], + partial_state: bool, prev_group: Optional[int] = None, delta_ids: Optional[StateMap[str]] = None, ) -> "EventContext": @@ -129,6 +135,7 @@ class EventContext: state_group_before_event=state_group_before_event, prev_group=prev_group, delta_ids=delta_ids, + partial_state=partial_state, ) @staticmethod @@ -170,6 +177,7 @@ class EventContext: "prev_group": self.prev_group, "delta_ids": _encode_state_dict(self.delta_ids), "app_service_id": self.app_service.id if self.app_service else None, + "partial_state": self.partial_state, } @staticmethod @@ -196,6 +204,7 @@ class EventContext: prev_group=input["prev_group"], delta_ids=_decode_state_dict(input["delta_ids"]), rejected=input["rejected"], + partial_state=input.get("partial_state", False), ) app_service_id = input["app_service_id"] diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index c055c26ec..eb03a5acc 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -519,8 +519,17 @@ class FederationHandler: state_events=state, ) + if ret.partial_state: + await self.store.store_partial_state_room(room_id, ret.servers_in_room) + max_stream_id = await self._federation_event_handler.process_remote_join( - origin, room_id, auth_chain, state, event, room_version_obj + origin, + room_id, + auth_chain, + state, + event, + room_version_obj, + partial_state=ret.partial_state, ) # We wait here until this instance has seen the events come down diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 09d0de1ea..4bd87709f 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -397,6 +397,7 @@ class FederationEventHandler: state: List[EventBase], event: EventBase, room_version: RoomVersion, + partial_state: bool, ) -> int: """Persists the events returned by a send_join @@ -412,6 +413,7 @@ class FederationEventHandler: event room_version: The room version we expect this room to have, and will raise if it doesn't match the version in the create event. + partial_state: True if the state omits non-critical membership events Returns: The stream ID after which all events have been persisted. @@ -453,10 +455,14 @@ class FederationEventHandler: ) # and now persist the join event itself. - logger.info("Peristing join-via-remote %s", event) + logger.info( + "Peristing join-via-remote %s (partial_state: %s)", event, partial_state + ) with nested_logging_context(suffix=event.event_id): context = await self._state_handler.compute_event_context( - event, old_state=state + event, + old_state=state, + partial_state=partial_state, ) context = await self._check_event_auth(origin, event, context) @@ -698,6 +704,8 @@ class FederationEventHandler: try: state = await self._resolve_state_at_missing_prevs(origin, event) + # TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does + # not return partial state await self._process_received_pdu( origin, event, state=state, backfilled=backfilled ) @@ -1791,6 +1799,7 @@ class FederationEventHandler: prev_state_ids=prev_state_ids, prev_group=prev_group, delta_ids=state_updates, + partial_state=context.partial_state, ) async def _run_push_actions_and_persist_event( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index ce1fa3c78..61cb133ef 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -992,6 +992,8 @@ class EventCreationHandler: and full_state_ids_at_event and builder.internal_metadata.is_historical() ): + # TODO(faster_joins): figure out how this works, and make sure that the + # old state is complete. old_state = await self.store.get_events_as_list(full_state_ids_at_event) context = await self.state.compute_event_context(event, old_state=old_state) else: diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index fcc24ad12..6babd5963 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -258,7 +258,10 @@ class StateHandler: return await self.store.get_joined_hosts(room_id, entry) async def compute_event_context( - self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None + self, + event: EventBase, + old_state: Optional[Iterable[EventBase]] = None, + partial_state: bool = False, ) -> EventContext: """Build an EventContext structure for a non-outlier event. @@ -273,6 +276,8 @@ class StateHandler: calculated from existing events. This is normally only specified when receiving an event from federation where we don't have the prev events for, e.g. when backfilling. + partial_state: True if `old_state` is partial and omits non-critical + membership events Returns: The event context. """ @@ -295,8 +300,28 @@ class StateHandler: else: # otherwise, we'll need to resolve the state across the prev_events. - logger.debug("calling resolve_state_groups from compute_event_context") + # partial_state should not be set explicitly in this case: + # we work it out dynamically + assert not partial_state + + # if any of the prev-events have partial state, so do we. + # (This is slightly racy - the prev-events might get fixed up before we use + # their states - but I don't think that really matters; it just means we + # might redundantly recalculate the state for this event later.) + prev_event_ids = event.prev_event_ids() + incomplete_prev_events = await self.store.get_partial_state_events( + prev_event_ids + ) + if any(incomplete_prev_events.values()): + logger.debug( + "New/incoming event %s refers to prev_events %s with partial state", + event.event_id, + [k for (k, v) in incomplete_prev_events.items() if v], + ) + partial_state = True + + logger.debug("calling resolve_state_groups from compute_event_context") entry = await self.resolve_state_groups_for_events( event.room_id, event.prev_event_ids() ) @@ -342,6 +367,7 @@ class StateHandler: prev_state_ids=state_ids_before_event, prev_group=state_group_before_event_prev_group, delta_ids=deltas_to_state_group_before_event, + partial_state=partial_state, ) # @@ -373,6 +399,7 @@ class StateHandler: prev_state_ids=state_ids_before_event, prev_group=state_group_before_event, delta_ids=delta_ids, + partial_state=partial_state, ) @measure_func() diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 23fa089bc..ca2a9ba9d 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2145,6 +2145,14 @@ class PersistEventsStore: state_groups = {} for event, context in events_and_contexts: if event.internal_metadata.is_outlier(): + # double-check that we don't have any events that claim to be outliers + # *and* have partial state (which is meaningless: we should have no + # state at all for an outlier) + if context.partial_state: + raise ValueError( + "Outlier event %s claims to have partial state", event.event_id + ) + continue # if the event was rejected, just give it the same state as its @@ -2155,6 +2163,23 @@ class PersistEventsStore: state_groups[event.event_id] = context.state_group + # if we have partial state for these events, record the fact. (This happens + # here rather than in _store_event_txn because it also needs to happen when + # we de-outlier an event.) + self.db_pool.simple_insert_many_txn( + txn, + table="partial_state_events", + keys=("room_id", "event_id"), + values=[ + ( + event.room_id, + event.event_id, + ) + for event, ctx in events_and_contexts + if ctx.partial_state + ], + ) + self.db_pool.simple_upsert_many_txn( txn, table="event_to_state_groups", diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 2a255d103..26784f755 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1953,3 +1953,31 @@ class EventsWorkerStore(SQLBaseStore): "get_event_id_for_timestamp_txn", get_event_id_for_timestamp_txn, ) + + @cachedList("is_partial_state_event", list_name="event_ids") + async def get_partial_state_events( + self, event_ids: Collection[str] + ) -> Dict[str, bool]: + """Checks which of the given events have partial state""" + result = await self.db_pool.simple_select_many_batch( + table="partial_state_events", + column="event_id", + iterable=event_ids, + retcols=["event_id"], + desc="get_partial_state_events", + ) + # convert the result to a dict, to make @cachedList work + partial = {r["event_id"] for r in result} + return {e_id: e_id in partial for e_id in event_ids} + + @cached() + async def is_partial_state_event(self, event_id: str) -> bool: + """Checks if the given event has partial state""" + result = await self.db_pool.simple_select_one_onecol( + table="partial_state_events", + keyvalues={"event_id": event_id}, + retcol="1", + allow_none=True, + desc="is_partial_state_event", + ) + return result is not None diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 0416df64c..94068940b 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -20,6 +20,7 @@ from typing import ( TYPE_CHECKING, Any, Awaitable, + Collection, Dict, List, Optional, @@ -1543,6 +1544,42 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): lock=False, ) + async def store_partial_state_room( + self, + room_id: str, + servers: Collection[str], + ) -> None: + """Mark the given room as containing events with partial state + + Args: + room_id: the ID of the room + servers: other servers known to be in the room + """ + await self.db_pool.runInteraction( + "store_partial_state_room", + self._store_partial_state_room_txn, + room_id, + servers, + ) + + @staticmethod + def _store_partial_state_room_txn( + txn: LoggingTransaction, room_id: str, servers: Collection[str] + ) -> None: + DatabasePool.simple_insert_txn( + txn, + table="partial_state_rooms", + values={ + "room_id": room_id, + }, + ) + DatabasePool.simple_insert_many_txn( + txn, + table="partial_state_rooms_servers", + keys=("room_id", "server_name"), + values=((room_id, s) for s in servers), + ) + async def maybe_store_room_on_outlier_membership( self, room_id: str, room_version: RoomVersion ) -> None: diff --git a/synapse/storage/schema/main/delta/68/04partial_state_rooms.sql b/synapse/storage/schema/main/delta/68/04partial_state_rooms.sql new file mode 100644 index 000000000..815c0cc39 --- /dev/null +++ b/synapse/storage/schema/main/delta/68/04partial_state_rooms.sql @@ -0,0 +1,41 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- rooms which we have done a partial-state-style join to +CREATE TABLE IF NOT EXISTS partial_state_rooms ( + room_id TEXT PRIMARY KEY, + FOREIGN KEY(room_id) REFERENCES rooms(room_id) +); + +-- a list of remote servers we believe are in the room +CREATE TABLE IF NOT EXISTS partial_state_rooms_servers ( + room_id TEXT NOT NULL REFERENCES partial_state_rooms(room_id), + server_name TEXT NOT NULL, + UNIQUE(room_id, server_name) +); + +-- a list of events with partial state. We can't store this in the `events` table +-- itself, because `events` is meant to be append-only. +CREATE TABLE IF NOT EXISTS partial_state_events ( + -- the room_id is denormalised for efficient indexing (the canonical source is `events`) + room_id TEXT NOT NULL REFERENCES partial_state_rooms(room_id), + event_id TEXT NOT NULL REFERENCES events(event_id), + UNIQUE(event_id) +); + +CREATE INDEX IF NOT EXISTS partial_state_events_room_id_idx + ON partial_state_events (room_id); + + diff --git a/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py b/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py new file mode 100644 index 000000000..a2ec4fc26 --- /dev/null +++ b/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py @@ -0,0 +1,72 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This migration adds triggers to the partial_state_events tables to enforce uniqueness + +Triggers cannot be expressed in .sql files, so we have to use a separate file. +""" +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine +from synapse.storage.types import Cursor + + +def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs): + # complain if the room_id in partial_state_events doesn't match + # that in `events`. We already have a fk constraint which ensures that the event + # exists in `events`, so all we have to do is raise if there is a row with a + # matching stream_ordering but not a matching room_id. + if isinstance(database_engine, Sqlite3Engine): + cur.execute( + """ + CREATE TRIGGER IF NOT EXISTS partial_state_events_bad_room_id + BEFORE INSERT ON partial_state_events + FOR EACH ROW + BEGIN + SELECT RAISE(ABORT, 'Incorrect room_id in partial_state_events') + WHERE EXISTS ( + SELECT 1 FROM events + WHERE events.event_id = NEW.event_id + AND events.room_id != NEW.room_id + ); + END; + """ + ) + elif isinstance(database_engine, PostgresEngine): + cur.execute( + """ + CREATE OR REPLACE FUNCTION check_partial_state_events() RETURNS trigger AS $BODY$ + BEGIN + IF EXISTS ( + SELECT 1 FROM events + WHERE events.event_id = NEW.event_id + AND events.room_id != NEW.room_id + ) THEN + RAISE EXCEPTION 'Incorrect room_id in partial_state_events'; + END IF; + RETURN NEW; + END; + $BODY$ LANGUAGE plpgsql; + """ + ) + + cur.execute( + """ + CREATE TRIGGER check_partial_state_events BEFORE INSERT OR UPDATE ON partial_state_events + FOR EACH ROW + EXECUTE PROCEDURE check_partial_state_events() + """ + ) + else: + raise NotImplementedError("Unknown database engine") diff --git a/tests/test_state.py b/tests/test_state.py index 90800421f..e4baa6913 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import Collection, Dict, List, Optional from unittest.mock import Mock from twisted.internet import defer @@ -70,7 +70,7 @@ def create_event( return event -class StateGroupStore: +class _DummyStore: def __init__(self): self._event_to_state_group = {} self._group_to_state = {} @@ -105,6 +105,11 @@ class StateGroupStore: if e_id in self._event_id_to_event } + async def get_partial_state_events( + self, event_ids: Collection[str] + ) -> Dict[str, bool]: + return {e: False for e in event_ids} + async def get_state_group_delta(self, name): return None, None @@ -157,8 +162,8 @@ class Graph: class StateTestCase(unittest.TestCase): def setUp(self): - self.store = StateGroupStore() - storage = Mock(main=self.store, state=self.store) + self.dummy_store = _DummyStore() + storage = Mock(main=self.dummy_store, state=self.dummy_store) hs = Mock( spec_set=[ "config", @@ -173,7 +178,7 @@ class StateTestCase(unittest.TestCase): ] ) hs.config = default_config("tesths", True) - hs.get_datastores.return_value = Mock(main=self.store) + hs.get_datastores.return_value = Mock(main=self.dummy_store) hs.get_state_handler.return_value = None hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) @@ -198,7 +203,7 @@ class StateTestCase(unittest.TestCase): edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]}, ) - self.store.register_events(graph.walk()) + self.dummy_store.register_events(graph.walk()) context_store: dict[str, EventContext] = {} @@ -206,7 +211,7 @@ class StateTestCase(unittest.TestCase): context = yield defer.ensureDeferred( self.state.compute_event_context(event) ) - self.store.register_event_context(event, context) + self.dummy_store.register_event_context(event, context) context_store[event.event_id] = context ctx_c = context_store["C"] @@ -242,7 +247,7 @@ class StateTestCase(unittest.TestCase): edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]}, ) - self.store.register_events(graph.walk()) + self.dummy_store.register_events(graph.walk()) context_store = {} @@ -250,7 +255,7 @@ class StateTestCase(unittest.TestCase): context = yield defer.ensureDeferred( self.state.compute_event_context(event) ) - self.store.register_event_context(event, context) + self.dummy_store.register_event_context(event, context) context_store[event.event_id] = context # C ends up winning the resolution between B and C @@ -300,7 +305,7 @@ class StateTestCase(unittest.TestCase): edges={"A": ["START"], "B": ["A"], "C": ["B"], "D": ["B"], "E": ["C", "D"]}, ) - self.store.register_events(graph.walk()) + self.dummy_store.register_events(graph.walk()) context_store = {} @@ -308,7 +313,7 @@ class StateTestCase(unittest.TestCase): context = yield defer.ensureDeferred( self.state.compute_event_context(event) ) - self.store.register_event_context(event, context) + self.dummy_store.register_event_context(event, context) context_store[event.event_id] = context # C ends up winning the resolution between C and D because bans win over other @@ -375,7 +380,7 @@ class StateTestCase(unittest.TestCase): self._add_depths(nodes, edges) graph = Graph(nodes, edges) - self.store.register_events(graph.walk()) + self.dummy_store.register_events(graph.walk()) context_store = {} @@ -383,7 +388,7 @@ class StateTestCase(unittest.TestCase): context = yield defer.ensureDeferred( self.state.compute_event_context(event) ) - self.store.register_event_context(event, context) + self.dummy_store.register_event_context(event, context) context_store[event.event_id] = context # B ends up winning the resolution between B and C because power levels @@ -476,7 +481,7 @@ class StateTestCase(unittest.TestCase): ] group_name = yield defer.ensureDeferred( - self.store.store_state_group( + self.dummy_store.store_state_group( prev_event_id, event.room_id, None, @@ -484,7 +489,7 @@ class StateTestCase(unittest.TestCase): {(e.type, e.state_key): e.event_id for e in old_state}, ) ) - self.store.register_event_id_state_group(prev_event_id, group_name) + self.dummy_store.register_event_id_state_group(prev_event_id, group_name) context = yield defer.ensureDeferred(self.state.compute_event_context(event)) @@ -510,7 +515,7 @@ class StateTestCase(unittest.TestCase): ] group_name = yield defer.ensureDeferred( - self.store.store_state_group( + self.dummy_store.store_state_group( prev_event_id, event.room_id, None, @@ -518,7 +523,7 @@ class StateTestCase(unittest.TestCase): {(e.type, e.state_key): e.event_id for e in old_state}, ) ) - self.store.register_event_id_state_group(prev_event_id, group_name) + self.dummy_store.register_event_id_state_group(prev_event_id, group_name) context = yield defer.ensureDeferred(self.state.compute_event_context(event)) @@ -554,8 +559,8 @@ class StateTestCase(unittest.TestCase): create_event(type="test4", state_key=""), ] - self.store.register_events(old_state_1) - self.store.register_events(old_state_2) + self.dummy_store.register_events(old_state_1) + self.dummy_store.register_events(old_state_2) context = yield self._get_context( event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 @@ -594,10 +599,10 @@ class StateTestCase(unittest.TestCase): create_event(type="test4", state_key=""), ] - store = StateGroupStore() + store = _DummyStore() store.register_events(old_state_1) store.register_events(old_state_2) - self.store.get_events = store.get_events + self.dummy_store.get_events = store.get_events context = yield self._get_context( event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 @@ -649,10 +654,10 @@ class StateTestCase(unittest.TestCase): create_event(type="test1", state_key="1", depth=2), ] - store = StateGroupStore() + store = _DummyStore() store.register_events(old_state_1) store.register_events(old_state_2) - self.store.get_events = store.get_events + self.dummy_store.get_events = store.get_events context = yield self._get_context( event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 @@ -695,7 +700,7 @@ class StateTestCase(unittest.TestCase): self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2 ): sg1 = yield defer.ensureDeferred( - self.store.store_state_group( + self.dummy_store.store_state_group( prev_event_id_1, event.room_id, None, @@ -703,10 +708,10 @@ class StateTestCase(unittest.TestCase): {(e.type, e.state_key): e.event_id for e in old_state_1}, ) ) - self.store.register_event_id_state_group(prev_event_id_1, sg1) + self.dummy_store.register_event_id_state_group(prev_event_id_1, sg1) sg2 = yield defer.ensureDeferred( - self.store.store_state_group( + self.dummy_store.store_state_group( prev_event_id_2, event.room_id, None, @@ -714,7 +719,7 @@ class StateTestCase(unittest.TestCase): {(e.type, e.state_key): e.event_id for e in old_state_2}, ) ) - self.store.register_event_id_state_group(prev_event_id_2, sg2) + self.dummy_store.register_event_id_state_group(prev_event_id_2, sg2) result = yield defer.ensureDeferred(self.state.compute_event_context(event)) return result From c893632319f9bcd76d105573008e8cb0ec2fe7ce Mon Sep 17 00:00:00 2001 From: reivilibre Date: Tue, 1 Mar 2022 13:41:57 +0000 Subject: [PATCH 72/84] Order in-flight state group queries in biggest-first order (#11610) Co-authored-by: Patrick Cloke --- changelog.d/11610.misc | 1 + synapse/storage/databases/state/store.py | 30 +++++- tests/storage/databases/test_state_store.py | 104 +++++++++++++++++++- 3 files changed, 131 insertions(+), 4 deletions(-) create mode 100644 changelog.d/11610.misc diff --git a/changelog.d/11610.misc b/changelog.d/11610.misc new file mode 100644 index 000000000..3af049b96 --- /dev/null +++ b/changelog.d/11610.misc @@ -0,0 +1 @@ +Deduplicate in-flight requests in `_get_state_for_groups`. diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index b8016f679..dadf3d1e3 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -25,6 +25,7 @@ from typing import ( ) import attr +from sortedcontainers import SortedDict from twisted.internet import defer @@ -72,6 +73,24 @@ class _GetStateGroupDelta: return len(self.delta_ids) if self.delta_ids else 0 +def state_filter_rough_priority_comparator( + state_filter: StateFilter, +) -> Tuple[int, int]: + """ + Returns a comparable value that roughly indicates the relative size of this + state filter compared to others. + 'Larger' state filters should sort first when using ascending order, so + this is essentially the opposite of 'size'. + It should be treated as a rough guide only and should not be interpreted to + have any particular meaning. The representation may also change + + The current implementation returns a tuple of the form: + * -1 for include_others, 0 otherwise + * -(number of entries in state_filter.types) + """ + return -int(state_filter.include_others), -len(state_filter.types) + + class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): """A data store for fetching/storing state groups.""" @@ -127,7 +146,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): # Current ongoing get_state_for_groups in-flight requests # {group ID -> {StateFilter -> ObservableDeferred}} self._state_group_inflight_requests: Dict[ - int, Dict[StateFilter, AbstractObservableDeferred[StateMap[str]]] + int, SortedDict[StateFilter, AbstractObservableDeferred[StateMap[str]]] ] = {} def get_max_state_group_txn(txn: Cursor) -> int: @@ -279,7 +298,10 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): # The list of ongoing requests which will help narrow the current request. reusable_requests = [] - for (request_state_filter, request_deferred) in inflight_requests.items(): + + # Iterate over existing requests in roughly biggest-first order. + for request_state_filter in inflight_requests: + request_deferred = inflight_requests[request_state_filter] new_state_filter_left_over = state_filter_left_over.approx_difference( request_state_filter ) @@ -358,7 +380,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): observable_deferred = ObservableDeferred(request_deferred, consumeErrors=True) # Insert the ObservableDeferred into the cache - group_request_dict = self._state_group_inflight_requests.setdefault(group, {}) + group_request_dict = self._state_group_inflight_requests.setdefault( + group, SortedDict(state_filter_rough_priority_comparator) + ) group_request_dict[db_state_filter] = observable_deferred return await make_deferred_yieldable(observable_deferred.observe()) diff --git a/tests/storage/databases/test_state_store.py b/tests/storage/databases/test_state_store.py index 076b66080..2b484c95a 100644 --- a/tests/storage/databases/test_state_store.py +++ b/tests/storage/databases/test_state_store.py @@ -15,11 +15,16 @@ import typing from typing import Dict, List, Sequence, Tuple from unittest.mock import patch +from parameterized import parameterized + from twisted.internet.defer import Deferred, ensureDeferred from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes -from synapse.storage.databases.state.store import MAX_INFLIGHT_REQUESTS_PER_GROUP +from synapse.storage.databases.state.store import ( + MAX_INFLIGHT_REQUESTS_PER_GROUP, + state_filter_rough_priority_comparator, +) from synapse.storage.state import StateFilter from synapse.types import StateMap from synapse.util import Clock @@ -350,3 +355,100 @@ class StateGroupInflightCachingTestCase(HomeserverTestCase): self._complete_request_fake(groups, sf, d) self.assertTrue(reqs[CAP_COUNT].called) self.assertTrue(reqs[CAP_COUNT + 1].called) + + @parameterized.expand([(False,), (True,)]) + def test_ordering_of_request_reuse(self, reverse: bool) -> None: + """ + Tests that 'larger' in-flight requests are ordered first. + + This is mostly a design decision in order to prevent a request from + hanging on to multiple queries when it would have been sufficient to + hang on to only one bigger query. + + The 'size' of a state filter is a rough heuristic. + + - requests two pieces of state, one 'larger' than the other, but each + spawning a query + - requests a third piece of state + - completes the larger of the first two queries + - checks that the third request gets completed (and doesn't needlessly + wait for the other query) + + Parameters: + reverse: whether to reverse the order of the initial requests, to ensure + that the effect doesn't depend on the order of request submission. + """ + + # We add in an extra state type to make sure that both requests spawn + # queries which are not optimised out. + state_filters = [ + StateFilter.freeze( + {"state.type": {"A"}, "other.state.type": {"a"}}, include_others=False + ), + StateFilter.freeze( + { + "state.type": None, + "other.state.type": {"b"}, + # The current rough size comparator uses the number of state types + # as an indicator of size. + # To influence it to make this state filter bigger than the previous one, + # we add another dummy state type. + "extra.state.type": {"c"}, + }, + include_others=False, + ), + ] + + if reverse: + # For fairness, we perform one test run with the list reversed. + state_filters.reverse() + smallest_state_filter_idx = 1 + biggest_state_filter_idx = 0 + else: + smallest_state_filter_idx = 0 + biggest_state_filter_idx = 1 + + # This assertion is for our own sanity more than anything else. + self.assertLess( + state_filter_rough_priority_comparator( + state_filters[biggest_state_filter_idx] + ), + state_filter_rough_priority_comparator( + state_filters[smallest_state_filter_idx] + ), + "Test invalid: bigger state filter is not actually bigger.", + ) + + # Spawn the initial two requests + for state_filter in state_filters: + ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, + state_filter, + ) + ) + + # Spawn a third request + req = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, + StateFilter.freeze( + { + "state.type": {"A"}, + }, + include_others=False, + ), + ) + ) + self.pump(by=0.1) + + self.assertFalse(req.called) + + # Complete the largest request's query to make sure that the final request + # only waits for that one (and doesn't needlessly wait for both queries) + self._complete_request_fake( + *self.get_state_group_calls[biggest_state_filter_idx] + ) + + # That should have been sufficient to complete the third request + self.assertTrue(req.called) From 91bc15c772d22fbe814170ab2e0fdbfa50f9c372 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Tue, 1 Mar 2022 13:51:03 +0000 Subject: [PATCH 73/84] Add `stop_cancellation` utility function (#12106) --- changelog.d/12106.misc | 1 + synapse/util/async_helpers.py | 19 ++++++++++++++ tests/util/test_async_helpers.py | 45 ++++++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+) create mode 100644 changelog.d/12106.misc diff --git a/changelog.d/12106.misc b/changelog.d/12106.misc new file mode 100644 index 000000000..d918e9e3b --- /dev/null +++ b/changelog.d/12106.misc @@ -0,0 +1 @@ +Add `stop_cancellation` utility function to stop `Deferred`s from being cancelled. diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index a83296a22..81320b897 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -665,3 +665,22 @@ def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]: return value return DoneAwaitable(value) + + +def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]": + """Prevent a `Deferred` from being cancelled by wrapping it in another `Deferred`. + + Args: + deferred: The `Deferred` to protect against cancellation. Must not follow the + Synapse logcontext rules. + + Returns: + A new `Deferred`, which will contain the result of the original `Deferred`, + but will not propagate cancellation through to the original. When cancelled, + the new `Deferred` will fail with a `CancelledError` and will not follow the + Synapse logcontext rules. `make_deferred_yieldable` should be used to wrap + the new `Deferred`. + """ + new_deferred: defer.Deferred[T] = defer.Deferred() + deferred.chainDeferred(new_deferred) + return new_deferred diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py index cce8d595f..362014f4c 100644 --- a/tests/util/test_async_helpers.py +++ b/tests/util/test_async_helpers.py @@ -27,6 +27,7 @@ from synapse.logging.context import ( from synapse.util.async_helpers import ( ObservableDeferred, concurrently_execute, + stop_cancellation, timeout_deferred, ) @@ -282,3 +283,47 @@ class ConcurrentlyExecuteTest(TestCase): d2 = ensureDeferred(caller()) d1.callback(0) self.successResultOf(d2) + + +class StopCancellationTests(TestCase): + """Tests for the `stop_cancellation` function.""" + + def test_succeed(self): + """Test that the new `Deferred` receives the result.""" + deferred: "Deferred[str]" = Deferred() + wrapper_deferred = stop_cancellation(deferred) + + # Success should propagate through. + deferred.callback("success") + self.assertTrue(wrapper_deferred.called) + self.assertEqual("success", self.successResultOf(wrapper_deferred)) + + def test_failure(self): + """Test that the new `Deferred` receives the `Failure`.""" + deferred: "Deferred[str]" = Deferred() + wrapper_deferred = stop_cancellation(deferred) + + # Failure should propagate through. + deferred.errback(ValueError("abc")) + self.assertTrue(wrapper_deferred.called) + self.failureResultOf(wrapper_deferred, ValueError) + self.assertIsNone(deferred.result, "`Failure` was not consumed") + + def test_cancellation(self): + """Test that cancellation of the new `Deferred` leaves the original running.""" + deferred: "Deferred[str]" = Deferred() + wrapper_deferred = stop_cancellation(deferred) + + # Cancel the new `Deferred`. + wrapper_deferred.cancel() + self.assertTrue(wrapper_deferred.called) + self.failureResultOf(wrapper_deferred, CancelledError) + self.assertFalse( + deferred.called, "Original `Deferred` was unexpectedly cancelled." + ) + + # Now make the inner `Deferred` fail. + # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed + # in logs. + deferred.errback(ValueError("abc")) + self.assertIsNone(deferred.result, "`Failure` was not consumed") From f26e390a40288be2801b3b9b3a99269b3f3ff81f Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 1 Mar 2022 13:55:18 +0000 Subject: [PATCH 74/84] Use Python 3.9 in Synapse dockerfiles by default (#12112) --- changelog.d/12112.docker | 1 + docker/Dockerfile | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12112.docker diff --git a/changelog.d/12112.docker b/changelog.d/12112.docker new file mode 100644 index 000000000..b9e630653 --- /dev/null +++ b/changelog.d/12112.docker @@ -0,0 +1 @@ +Use Python 3.9 in Docker images by default. \ No newline at end of file diff --git a/docker/Dockerfile b/docker/Dockerfile index e4c1c19b8..a8bb9b0e7 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -11,10 +11,10 @@ # There is an optional PYTHON_VERSION build argument which sets the # version of python to build against: for example: # -# DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile --build-arg PYTHON_VERSION=3.9 . +# DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile --build-arg PYTHON_VERSION=3.10 . # -ARG PYTHON_VERSION=3.8 +ARG PYTHON_VERSION=3.9 ### ### Stage 0: builder From 300ed0b8a6050b5187a2a524a82cf87baad3ca73 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 1 Mar 2022 15:00:03 +0000 Subject: [PATCH 75/84] Add module callbacks called for reacting to deactivation status change and profile update (#12062) --- changelog.d/12062.feature | 1 + docs/modules/third_party_rules_callbacks.md | 56 +++++ synapse/events/third_party_rules.py | 56 ++++- synapse/handlers/deactivate_account.py | 20 +- synapse/handlers/profile.py | 14 ++ synapse/module_api/__init__.py | 1 + tests/rest/client/test_third_party_rules.py | 219 +++++++++++++++++++- 7 files changed, 360 insertions(+), 7 deletions(-) create mode 100644 changelog.d/12062.feature diff --git a/changelog.d/12062.feature b/changelog.d/12062.feature new file mode 100644 index 000000000..46a606709 --- /dev/null +++ b/changelog.d/12062.feature @@ -0,0 +1 @@ +Add module callbacks to react to user deactivation status changes (i.e. deactivations and reactivations) and profile updates. diff --git a/docs/modules/third_party_rules_callbacks.md b/docs/modules/third_party_rules_callbacks.md index a3a17096a..09ac83810 100644 --- a/docs/modules/third_party_rules_callbacks.md +++ b/docs/modules/third_party_rules_callbacks.md @@ -148,6 +148,62 @@ deny an incoming event, see [`check_event_for_spam`](spam_checker_callbacks.md#c If multiple modules implement this callback, Synapse runs them all in order. +### `on_profile_update` + +_First introduced in Synapse v1.54.0_ + +```python +async def on_profile_update( + user_id: str, + new_profile: "synapse.module_api.ProfileInfo", + by_admin: bool, + deactivation: bool, +) -> None: +``` + +Called after updating a local user's profile. The update can be triggered either by the +user themselves or a server admin. The update can also be triggered by a user being +deactivated (in which case their display name is set to an empty string (`""`) and the +avatar URL is set to `None`). The module is passed the Matrix ID of the user whose profile +has been updated, their new profile, as well as a `by_admin` boolean that is `True` if the +update was triggered by a server admin (and `False` otherwise), and a `deactivated` +boolean that is `True` if the update is a result of the user being deactivated. + +Note that the `by_admin` boolean is also `True` if the profile change happens as a result +of the user logging in through Single Sign-On, or if a server admin updates their own +profile. + +Per-room profile changes do not trigger this callback to be called. Synapse administrators +wishing this callback to be called on every profile change are encouraged to disable +per-room profiles globally using the `allow_per_room_profiles` configuration setting in +Synapse's configuration file. +This callback is not called when registering a user, even when setting it through the +[`get_displayname_for_registration`](https://matrix-org.github.io/synapse/latest/modules/password_auth_provider_callbacks.html#get_displayname_for_registration) +module callback. + +If multiple modules implement this callback, Synapse runs them all in order. + +### `on_user_deactivation_status_changed` + +_First introduced in Synapse v1.54.0_ + +```python +async def on_user_deactivation_status_changed( + user_id: str, deactivated: bool, by_admin: bool +) -> None: +``` + +Called after deactivating a local user, or reactivating them through the admin API. The +deactivation can be triggered either by the user themselves or a server admin. The module +is passed the Matrix ID of the user whose status is changed, as well as a `deactivated` +boolean that is `True` if the user is being deactivated and `False` if they're being +reactivated, and a `by_admin` boolean that is `True` if the deactivation was triggered by +a server admin (and `False` otherwise). This latter `by_admin` boolean is always `True` +if the user is being reactivated, as this operation can only be performed through the +admin API. + +If multiple modules implement this callback, Synapse runs them all in order. + ## Example The example below is a module that implements the third-party rules callback diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 71ec100a7..dd3104faf 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Optional, Tupl from synapse.api.errors import ModuleFailedException, SynapseError from synapse.events import EventBase from synapse.events.snapshot import EventContext +from synapse.storage.roommember import ProfileInfo from synapse.types import Requester, StateMap from synapse.util.async_helpers import maybe_awaitable @@ -37,6 +38,8 @@ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[ [str, StateMap[EventBase], str], Awaitable[bool] ] ON_NEW_EVENT_CALLBACK = Callable[[EventBase, StateMap[EventBase]], Awaitable] +ON_PROFILE_UPDATE_CALLBACK = Callable[[str, ProfileInfo, bool, bool], Awaitable] +ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK = Callable[[str, bool, bool], Awaitable] def load_legacy_third_party_event_rules(hs: "HomeServer") -> None: @@ -154,6 +157,10 @@ class ThirdPartyEventRules: CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK ] = [] self._on_new_event_callbacks: List[ON_NEW_EVENT_CALLBACK] = [] + self._on_profile_update_callbacks: List[ON_PROFILE_UPDATE_CALLBACK] = [] + self._on_user_deactivation_status_changed_callbacks: List[ + ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK + ] = [] def register_third_party_rules_callbacks( self, @@ -166,6 +173,8 @@ class ThirdPartyEventRules: CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK ] = None, on_new_event: Optional[ON_NEW_EVENT_CALLBACK] = None, + on_profile_update: Optional[ON_PROFILE_UPDATE_CALLBACK] = None, + on_deactivation: Optional[ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK] = None, ) -> None: """Register callbacks from modules for each hook.""" if check_event_allowed is not None: @@ -187,6 +196,12 @@ class ThirdPartyEventRules: if on_new_event is not None: self._on_new_event_callbacks.append(on_new_event) + if on_profile_update is not None: + self._on_profile_update_callbacks.append(on_profile_update) + + if on_deactivation is not None: + self._on_user_deactivation_status_changed_callbacks.append(on_deactivation) + async def check_event_allowed( self, event: EventBase, context: EventContext ) -> Tuple[bool, Optional[dict]]: @@ -334,9 +349,6 @@ class ThirdPartyEventRules: Args: event_id: The ID of the event. - - Raises: - ModuleFailureError if a callback raised any exception. """ # Bail out early without hitting the store if we don't have any callbacks if len(self._on_new_event_callbacks) == 0: @@ -370,3 +382,41 @@ class ThirdPartyEventRules: state_events[key] = room_state_events[event_id] return state_events + + async def on_profile_update( + self, user_id: str, new_profile: ProfileInfo, by_admin: bool, deactivation: bool + ) -> None: + """Called after the global profile of a user has been updated. Does not include + per-room profile changes. + + Args: + user_id: The user whose profile was changed. + new_profile: The updated profile for the user. + by_admin: Whether the profile update was performed by a server admin. + deactivation: Whether this change was made while deactivating the user. + """ + for callback in self._on_profile_update_callbacks: + try: + await callback(user_id, new_profile, by_admin, deactivation) + except Exception as e: + logger.exception( + "Failed to run module API callback %s: %s", callback, e + ) + + async def on_user_deactivation_status_changed( + self, user_id: str, deactivated: bool, by_admin: bool + ) -> None: + """Called after a user has been deactivated or reactivated. + + Args: + user_id: The deactivated user. + deactivated: Whether the user is now deactivated. + by_admin: Whether the deactivation was performed by a server admin. + """ + for callback in self._on_user_deactivation_status_changed_callbacks: + try: + await callback(user_id, deactivated, by_admin) + except Exception as e: + logger.exception( + "Failed to run module API callback %s: %s", callback, e + ) diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index e4eae0305..76ae768e6 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -38,6 +38,7 @@ class DeactivateAccountHandler: self._profile_handler = hs.get_profile_handler() self.user_directory_handler = hs.get_user_directory_handler() self._server_name = hs.hostname + self._third_party_rules = hs.get_third_party_event_rules() # Flag that indicates whether the process to part users from rooms is running self._user_parter_running = False @@ -135,9 +136,13 @@ class DeactivateAccountHandler: if erase_data: user = UserID.from_string(user_id) # Remove avatar URL from this user - await self._profile_handler.set_avatar_url(user, requester, "", by_admin) + await self._profile_handler.set_avatar_url( + user, requester, "", by_admin, deactivation=True + ) # Remove displayname from this user - await self._profile_handler.set_displayname(user, requester, "", by_admin) + await self._profile_handler.set_displayname( + user, requester, "", by_admin, deactivation=True + ) logger.info("Marking %s as erased", user_id) await self.store.mark_user_erased(user_id) @@ -160,6 +165,13 @@ class DeactivateAccountHandler: # Remove account data (including ignored users and push rules). await self.store.purge_account_data_for_user(user_id) + # Let modules know the user has been deactivated. + await self._third_party_rules.on_user_deactivation_status_changed( + user_id, + True, + by_admin, + ) + return identity_server_supports_unbinding async def _reject_pending_invites_for_user(self, user_id: str) -> None: @@ -264,6 +276,10 @@ class DeactivateAccountHandler: # Mark the user as active. await self.store.set_user_deactivated_status(user_id, False) + await self._third_party_rules.on_user_deactivation_status_changed( + user_id, False, True + ) + # Add the user to the directory, if necessary. Note that # this must be done after the user is re-activated, because # deactivated users are excluded from the user directory. diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index dd27f0acc..6554c0d3c 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -71,6 +71,8 @@ class ProfileHandler: self.server_name = hs.config.server.server_name + self._third_party_rules = hs.get_third_party_event_rules() + if hs.config.worker.run_background_tasks: self.clock.looping_call( self._update_remote_profile_cache, self.PROFILE_UPDATE_MS @@ -171,6 +173,7 @@ class ProfileHandler: requester: Requester, new_displayname: str, by_admin: bool = False, + deactivation: bool = False, ) -> None: """Set the displayname of a user @@ -179,6 +182,7 @@ class ProfileHandler: requester: The user attempting to make this change. new_displayname: The displayname to give this user. by_admin: Whether this change was made by an administrator. + deactivation: Whether this change was made while deactivating the user. """ if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this homeserver") @@ -227,6 +231,10 @@ class ProfileHandler: target_user.to_string(), profile ) + await self._third_party_rules.on_profile_update( + target_user.to_string(), profile, by_admin, deactivation + ) + await self._update_join_states(requester, target_user) async def get_avatar_url(self, target_user: UserID) -> Optional[str]: @@ -261,6 +269,7 @@ class ProfileHandler: requester: Requester, new_avatar_url: str, by_admin: bool = False, + deactivation: bool = False, ) -> None: """Set a new avatar URL for a user. @@ -269,6 +278,7 @@ class ProfileHandler: requester: The user attempting to make this change. new_avatar_url: The avatar URL to give this user. by_admin: Whether this change was made by an administrator. + deactivation: Whether this change was made while deactivating the user. """ if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this homeserver") @@ -315,6 +325,10 @@ class ProfileHandler: target_user.to_string(), profile ) + await self._third_party_rules.on_profile_update( + target_user.to_string(), profile, by_admin, deactivation + ) + await self._update_join_states(requester, target_user) @cached() diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 902916d80..7e4693186 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -145,6 +145,7 @@ __all__ = [ "JsonDict", "EventBase", "StateMap", + "ProfileInfo", ] logger = logging.getLogger(__name__) diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index 9cca9edd3..bfc04785b 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -15,12 +15,12 @@ import threading from typing import TYPE_CHECKING, Dict, Optional, Tuple from unittest.mock import Mock -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventTypes, LoginType, Membership from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.rest import admin -from synapse.rest.client import login, room +from synapse.rest.client import account, login, profile, room from synapse.types import JsonDict, Requester, StateMap from synapse.util.frozenutils import unfreeze @@ -80,6 +80,8 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): admin.register_servlets, login.register_servlets, room.register_servlets, + profile.register_servlets, + account.register_servlets, ] def make_homeserver(self, reactor, clock): @@ -530,3 +532,216 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): }, tok=self.tok, ) + + def test_on_profile_update(self): + """Tests that the on_profile_update module callback is correctly called on + profile updates. + """ + displayname = "Foo" + avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo" + + # Register a mock callback. + m = Mock(return_value=make_awaitable(None)) + self.hs.get_third_party_event_rules()._on_profile_update_callbacks.append(m) + + # Change the display name. + channel = self.make_request( + "PUT", + "/_matrix/client/v3/profile/%s/displayname" % self.user_id, + {"displayname": displayname}, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the callback has been called once for our user. + m.assert_called_once() + args = m.call_args[0] + self.assertEqual(args[0], self.user_id) + + # Test that by_admin is False. + self.assertFalse(args[2]) + # Test that deactivation is False. + self.assertFalse(args[3]) + + # Check that we've got the right profile data. + profile_info = args[1] + self.assertEqual(profile_info.display_name, displayname) + self.assertIsNone(profile_info.avatar_url) + + # Change the avatar. + channel = self.make_request( + "PUT", + "/_matrix/client/v3/profile/%s/avatar_url" % self.user_id, + {"avatar_url": avatar_url}, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the callback has been called once for our user. + self.assertEqual(m.call_count, 2) + args = m.call_args[0] + self.assertEqual(args[0], self.user_id) + + # Test that by_admin is False. + self.assertFalse(args[2]) + # Test that deactivation is False. + self.assertFalse(args[3]) + + # Check that we've got the right profile data. + profile_info = args[1] + self.assertEqual(profile_info.display_name, displayname) + self.assertEqual(profile_info.avatar_url, avatar_url) + + def test_on_profile_update_admin(self): + """Tests that the on_profile_update module callback is correctly called on + profile updates triggered by a server admin. + """ + displayname = "Foo" + avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo" + + # Register a mock callback. + m = Mock(return_value=make_awaitable(None)) + self.hs.get_third_party_event_rules()._on_profile_update_callbacks.append(m) + + # Register an admin user. + self.register_user("admin", "password", admin=True) + admin_tok = self.login("admin", "password") + + # Change a user's profile. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % self.user_id, + {"displayname": displayname, "avatar_url": avatar_url}, + access_token=admin_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the callback has been called twice (since we update the display name + # and avatar separately). + self.assertEqual(m.call_count, 2) + + # Get the arguments for the last call and check it's about the right user. + args = m.call_args[0] + self.assertEqual(args[0], self.user_id) + + # Check that by_admin is True. + self.assertTrue(args[2]) + # Test that deactivation is False. + self.assertFalse(args[3]) + + # Check that we've got the right profile data. + profile_info = args[1] + self.assertEqual(profile_info.display_name, displayname) + self.assertEqual(profile_info.avatar_url, avatar_url) + + def test_on_user_deactivation_status_changed(self): + """Tests that the on_user_deactivation_status_changed module callback is called + correctly when processing a user's deactivation. + """ + # Register a mocked callback. + deactivation_mock = Mock(return_value=make_awaitable(None)) + third_party_rules = self.hs.get_third_party_event_rules() + third_party_rules._on_user_deactivation_status_changed_callbacks.append( + deactivation_mock, + ) + # Also register a mocked callback for profile updates, to check that the + # deactivation code calls it in a way that let modules know the user is being + # deactivated. + profile_mock = Mock(return_value=make_awaitable(None)) + self.hs.get_third_party_event_rules()._on_profile_update_callbacks.append( + profile_mock, + ) + + # Register a user that we'll deactivate. + user_id = self.register_user("altan", "password") + tok = self.login("altan", "password") + + # Deactivate that user. + channel = self.make_request( + "POST", + "/_matrix/client/v3/account/deactivate", + { + "auth": { + "type": LoginType.PASSWORD, + "password": "password", + "identifier": { + "type": "m.id.user", + "user": user_id, + }, + }, + "erase": True, + }, + access_token=tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the mock was called once. + deactivation_mock.assert_called_once() + args = deactivation_mock.call_args[0] + + # Check that the mock was called with the right user ID, and with a True + # deactivated flag and a False by_admin flag. + self.assertEqual(args[0], user_id) + self.assertTrue(args[1]) + self.assertFalse(args[2]) + + # Check that the profile update callback was called twice (once for the display + # name and once for the avatar URL), and that the "deactivation" boolean is true. + self.assertEqual(profile_mock.call_count, 2) + args = profile_mock.call_args[0] + self.assertTrue(args[3]) + + def test_on_user_deactivation_status_changed_admin(self): + """Tests that the on_user_deactivation_status_changed module callback is called + correctly when processing a user's deactivation triggered by a server admin as + well as a reactivation. + """ + # Register a mock callback. + m = Mock(return_value=make_awaitable(None)) + third_party_rules = self.hs.get_third_party_event_rules() + third_party_rules._on_user_deactivation_status_changed_callbacks.append(m) + + # Register an admin user. + self.register_user("admin", "password", admin=True) + admin_tok = self.login("admin", "password") + + # Register a user that we'll deactivate. + user_id = self.register_user("altan", "password") + + # Deactivate the user. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % user_id, + {"deactivated": True}, + access_token=admin_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the mock was called once. + m.assert_called_once() + args = m.call_args[0] + + # Check that the mock was called with the right user ID, and with True deactivated + # and by_admin flags. + self.assertEqual(args[0], user_id) + self.assertTrue(args[1]) + self.assertTrue(args[2]) + + # Reactivate the user. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % user_id, + {"deactivated": False, "password": "hackme"}, + access_token=admin_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the mock was called once. + self.assertEqual(m.call_count, 2) + args = m.call_args[0] + + # Check that the mock was called with the right user ID, and with a False + # deactivated flag and a True by_admin flag. + self.assertEqual(args[0], user_id) + self.assertFalse(args[1]) + self.assertTrue(args[2]) From 4d6b6c17c860a6ef258e513d841dbda6ea151cbd Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Tue, 1 Mar 2022 15:27:15 +0000 Subject: [PATCH 76/84] Fix rare error in `ReadWriteLock` when writers complete immediately (#12105) Signed-off-by: Sean Quah --- changelog.d/12105.bugfix | 1 + synapse/util/async_helpers.py | 5 ++++- tests/util/test_rwlock.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 35 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12105.bugfix diff --git a/changelog.d/12105.bugfix b/changelog.d/12105.bugfix new file mode 100644 index 000000000..f42e63e01 --- /dev/null +++ b/changelog.d/12105.bugfix @@ -0,0 +1 @@ +Fix an extremely rare, long-standing bug in `ReadWriteLock` that would cause an error when a newly unblocked writer completes instantly. diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 81320b897..60c03a66f 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -555,7 +555,10 @@ class ReadWriteLock: finally: with PreserveLoggingContext(): new_defer.callback(None) - if self.key_to_current_writer[key] == new_defer: + # `self.key_to_current_writer[key]` may be missing if there was another + # writer waiting for us and it completed entirely within the + # `new_defer.callback()` call above. + if self.key_to_current_writer.get(key) == new_defer: self.key_to_current_writer.pop(key) return _ctx_manager() diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py index a10071c70..0774625b8 100644 --- a/tests/util/test_rwlock.py +++ b/tests/util/test_rwlock.py @@ -13,6 +13,7 @@ # limitations under the License. from twisted.internet import defer +from twisted.internet.defer import Deferred from synapse.util.async_helpers import ReadWriteLock @@ -83,3 +84,32 @@ class ReadWriteLockTestCase(unittest.TestCase): self.assertTrue(d.called) with d.result: pass + + def test_lock_handoff_to_nonblocking_writer(self): + """Test a writer handing the lock to another writer that completes instantly.""" + rwlock = ReadWriteLock() + key = "key" + + unblock: "Deferred[None]" = Deferred() + + async def blocking_write(): + with await rwlock.write(key): + await unblock + + async def nonblocking_write(): + with await rwlock.write(key): + pass + + d1 = defer.ensureDeferred(blocking_write()) + d2 = defer.ensureDeferred(nonblocking_write()) + self.assertFalse(d1.called) + self.assertFalse(d2.called) + + # Unblock the first writer. The second writer will complete without blocking. + unblock.callback(None) + self.assertTrue(d1.called) + self.assertTrue(d2.called) + + # The `ReadWriteLock` should operate as normal. + d3 = defer.ensureDeferred(nonblocking_write()) + self.assertTrue(d3.called) From 313581e4e9bc2ec3d59ccff86e3a0c02661f71c4 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 1 Mar 2022 17:44:41 +0000 Subject: [PATCH 77/84] Use importlib.metadata to read requirements (#12088) * Pull runtime dep checks into their own module * Reimplement `check_requirements` using `importlib` I've tried to make this clearer. We start by working out which of Synapse's requirements we need to be installed here and now. I was surprised that there wasn't an easier way to see which packages were installed by a given extra. I've pulled out the error messages into functions that deal with "is this for an extra or not". And I've rearranged the loop over two different sets of requirements into one loop with a "must be instaled" flag. I hope you agree that this is clearer. * Test cases --- changelog.d/12088.misc | 1 + synapse/app/__init__.py | 6 +- synapse/app/homeserver.py | 2 +- synapse/config/cache.py | 2 +- synapse/config/metrics.py | 2 +- synapse/config/oidc.py | 2 +- synapse/config/redis.py | 2 +- synapse/config/repository.py | 2 +- synapse/config/saml2.py | 2 +- synapse/config/tracer.py | 2 +- synapse/python_dependencies.py | 107 +--------------------- synapse/util/check_dependencies.py | 127 ++++++++++++++++++++++++++ tests/util/test_check_dependencies.py | 95 +++++++++++++++++++ 13 files changed, 237 insertions(+), 115 deletions(-) create mode 100644 changelog.d/12088.misc create mode 100644 synapse/util/check_dependencies.py create mode 100644 tests/util/test_check_dependencies.py diff --git a/changelog.d/12088.misc b/changelog.d/12088.misc new file mode 100644 index 000000000..ce4213650 --- /dev/null +++ b/changelog.d/12088.misc @@ -0,0 +1 @@ +Inspect application dependencies using `importlib.metadata` or its backport. \ No newline at end of file diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py index ee51480a9..334c3d2c1 100644 --- a/synapse/app/__init__.py +++ b/synapse/app/__init__.py @@ -15,13 +15,13 @@ import logging import sys from typing import Container -from synapse import python_dependencies # noqa: E402 +from synapse.util import check_dependencies logger = logging.getLogger(__name__) try: - python_dependencies.check_requirements() -except python_dependencies.DependencyException as e: + check_dependencies.check_requirements() +except check_dependencies.DependencyException as e: sys.stderr.writelines( e.message # noqa: B306, DependencyException.message is a property ) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index b9931001c..a6789a840 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -59,7 +59,6 @@ from synapse.http.server import ( from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy -from synapse.python_dependencies import check_requirements from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.rest import ClientRestResource @@ -70,6 +69,7 @@ from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.rest.well_known import well_known_resource from synapse.server import HomeServer from synapse.storage import DataStore +from synapse.util.check_dependencies import check_requirements from synapse.util.httpresourcetree import create_resource_tree from synapse.util.module_loader import load_module diff --git a/synapse/config/cache.py b/synapse/config/cache.py index 387ac6d11..9a68da9c3 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -20,7 +20,7 @@ from typing import Callable, Dict, Optional import attr -from synapse.python_dependencies import DependencyException, check_requirements +from synapse.util.check_dependencies import DependencyException, check_requirements from ._base import Config, ConfigError diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index 1cc26e757..f62292ecf 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -15,7 +15,7 @@ import attr -from synapse.python_dependencies import DependencyException, check_requirements +from synapse.util.check_dependencies import DependencyException, check_requirements from ._base import Config, ConfigError diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py index e783b1131..f7e4f9ef2 100644 --- a/synapse/config/oidc.py +++ b/synapse/config/oidc.py @@ -20,11 +20,11 @@ import attr from synapse.config._util import validate_config from synapse.config.sso import SsoAttributeRequirement -from synapse.python_dependencies import DependencyException, check_requirements from synapse.types import JsonDict from synapse.util.module_loader import load_module from synapse.util.stringutils import parse_and_validate_mxc_uri +from ..util.check_dependencies import DependencyException, check_requirements from ._base import Config, ConfigError, read_file DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc.JinjaOidcMappingProvider" diff --git a/synapse/config/redis.py b/synapse/config/redis.py index 33104af73..bdb1aac3a 100644 --- a/synapse/config/redis.py +++ b/synapse/config/redis.py @@ -13,7 +13,7 @@ # limitations under the License. from synapse.config._base import Config -from synapse.python_dependencies import check_requirements +from synapse.util.check_dependencies import check_requirements class RedisConfig(Config): diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 1980351e7..0a0d901bf 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -20,8 +20,8 @@ from urllib.request import getproxies_environment # type: ignore import attr from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set -from synapse.python_dependencies import DependencyException, check_requirements from synapse.types import JsonDict +from synapse.util.check_dependencies import DependencyException, check_requirements from synapse.util.module_loader import load_module from ._base import Config, ConfigError diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py index ec9d9f65e..43c456d5c 100644 --- a/synapse/config/saml2.py +++ b/synapse/config/saml2.py @@ -17,8 +17,8 @@ import logging from typing import Any, List, Set from synapse.config.sso import SsoAttributeRequirement -from synapse.python_dependencies import DependencyException, check_requirements from synapse.types import JsonDict +from synapse.util.check_dependencies import DependencyException, check_requirements from synapse.util.module_loader import load_module, load_python_module from ._base import Config, ConfigError diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py index 21b9a8835..7aff618ea 100644 --- a/synapse/config/tracer.py +++ b/synapse/config/tracer.py @@ -14,7 +14,7 @@ from typing import Set -from synapse.python_dependencies import DependencyException, check_requirements +from synapse.util.check_dependencies import DependencyException, check_requirements from ._base import Config, ConfigError diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index f43fbb584..8f48a3393 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -17,14 +17,7 @@ import itertools import logging -from typing import List, Set - -from pkg_resources import ( - DistributionNotFound, - Requirement, - VersionConflict, - get_provider, -) +from typing import Set logger = logging.getLogger(__name__) @@ -90,6 +83,8 @@ REQUIREMENTS = [ # ijson 3.1.4 fixes a bug with "." in property names "ijson>=3.1.4", "matrix-common~=1.1.0", + # For runtime introspection of our dependencies + "packaging~=21.3", ] CONDITIONAL_REQUIREMENTS = { @@ -144,102 +139,6 @@ def list_requirements(): return list(set(REQUIREMENTS) | ALL_OPTIONAL_REQUIREMENTS) -class DependencyException(Exception): - @property - def message(self): - return "\n".join( - [ - "Missing Requirements: %s" % (", ".join(self.dependencies),), - "To install run:", - " pip install --upgrade --force %s" % (" ".join(self.dependencies),), - "", - ] - ) - - @property - def dependencies(self): - for i in self.args[0]: - yield '"' + i + '"' - - -def check_requirements(for_feature=None): - deps_needed = [] - errors = [] - - if for_feature: - reqs = CONDITIONAL_REQUIREMENTS[for_feature] - else: - reqs = REQUIREMENTS - - for dependency in reqs: - try: - _check_requirement(dependency) - except VersionConflict as e: - deps_needed.append(dependency) - errors.append( - "Needed %s, got %s==%s" - % ( - dependency, - e.dist.project_name, # type: ignore[attr-defined] # noqa - e.dist.version, # type: ignore[attr-defined] # noqa - ) - ) - except DistributionNotFound: - deps_needed.append(dependency) - if for_feature: - errors.append( - "Needed %s for the '%s' feature but it was not installed" - % (dependency, for_feature) - ) - else: - errors.append("Needed %s but it was not installed" % (dependency,)) - - if not for_feature: - # Check the optional dependencies are up to date. We allow them to not be - # installed. - OPTS: List[str] = sum(CONDITIONAL_REQUIREMENTS.values(), []) - - for dependency in OPTS: - try: - _check_requirement(dependency) - except VersionConflict as e: - deps_needed.append(dependency) - errors.append( - "Needed optional %s, got %s==%s" - % ( - dependency, - e.dist.project_name, # type: ignore[attr-defined] # noqa - e.dist.version, # type: ignore[attr-defined] # noqa - ) - ) - except DistributionNotFound: - # If it's not found, we don't care - pass - - if deps_needed: - for err in errors: - logging.error(err) - - raise DependencyException(deps_needed) - - -def _check_requirement(dependency_string): - """Parses a dependency string, and checks if the specified requirement is installed - - Raises: - VersionConflict if the requirement is installed, but with the the wrong version - DistributionNotFound if nothing is found to provide the requirement - """ - req = Requirement.parse(dependency_string) - - # first check if the markers specify that this requirement needs installing - if req.marker is not None and not req.marker.evaluate(): - # not required for this environment - return - - get_provider(req) - - if __name__ == "__main__": import sys diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py new file mode 100644 index 000000000..3a1f6b3c7 --- /dev/null +++ b/synapse/util/check_dependencies.py @@ -0,0 +1,127 @@ +import logging +from typing import Iterable, NamedTuple, Optional + +from packaging.requirements import Requirement + +DISTRIBUTION_NAME = "matrix-synapse" + +try: + from importlib import metadata +except ImportError: + import importlib_metadata as metadata # type: ignore[no-redef] + + +class DependencyException(Exception): + @property + def message(self) -> str: + return "\n".join( + [ + "Missing Requirements: %s" % (", ".join(self.dependencies),), + "To install run:", + " pip install --upgrade --force %s" % (" ".join(self.dependencies),), + "", + ] + ) + + @property + def dependencies(self) -> Iterable[str]: + for i in self.args[0]: + yield '"' + i + '"' + + +EXTRAS = set(metadata.metadata(DISTRIBUTION_NAME).get_all("Provides-Extra")) + + +class Dependency(NamedTuple): + requirement: Requirement + must_be_installed: bool + + +def _generic_dependencies() -> Iterable[Dependency]: + """Yield pairs (requirement, must_be_installed).""" + requirements = metadata.requires(DISTRIBUTION_NAME) + assert requirements is not None + for raw_requirement in requirements: + req = Requirement(raw_requirement) + # https://packaging.pypa.io/en/latest/markers.html#usage notes that + # > Evaluating an extra marker with no environment is an error + # so we pass in a dummy empty extra value here. + must_be_installed = req.marker is None or req.marker.evaluate({"extra": ""}) + yield Dependency(req, must_be_installed) + + +def _dependencies_for_extra(extra: str) -> Iterable[Dependency]: + """Yield additional dependencies needed for a given `extra`.""" + requirements = metadata.requires(DISTRIBUTION_NAME) + assert requirements is not None + for raw_requirement in requirements: + req = Requirement(raw_requirement) + # Exclude mandatory deps by only selecting deps needed with this extra. + if ( + req.marker is not None + and req.marker.evaluate({"extra": extra}) + and not req.marker.evaluate({"extra": ""}) + ): + yield Dependency(req, True) + + +def _not_installed(requirement: Requirement, extra: Optional[str] = None) -> str: + if extra: + return f"Need {requirement.name} for {extra}, but it is not installed" + else: + return f"Need {requirement.name}, but it is not installed" + + +def _incorrect_version( + requirement: Requirement, got: str, extra: Optional[str] = None +) -> str: + if extra: + return f"Need {requirement} for {extra}, but got {requirement.name}=={got}" + else: + return f"Need {requirement}, but got {requirement.name}=={got}" + + +def check_requirements(extra: Optional[str] = None) -> None: + """Check Synapse's dependencies are present and correctly versioned. + + If provided, `extra` must be the name of an pacakging extra (e.g. "saml2" in + `pip install matrix-synapse[saml2]`). + + If `extra` is None, this function checks that + - all mandatory dependencies are installed and correctly versioned, and + - each optional dependency that's installed is correctly versioned. + + If `extra` is not None, this function checks that + - the dependencies needed for that extra are installed and correctly versioned. + + :raises DependencyException: if a dependency is missing or incorrectly versioned. + :raises ValueError: if this extra does not exist. + """ + # First work out which dependencies are required, and which are optional. + if extra is None: + dependencies = _generic_dependencies() + elif extra in EXTRAS: + dependencies = _dependencies_for_extra(extra) + else: + raise ValueError(f"Synapse does not provide the feature '{extra}'") + + deps_unfulfilled = [] + errors = [] + + for (requirement, must_be_installed) in dependencies: + try: + dist: metadata.Distribution = metadata.distribution(requirement.name) + except metadata.PackageNotFoundError: + if must_be_installed: + deps_unfulfilled.append(requirement.name) + errors.append(_not_installed(requirement, extra)) + else: + if not requirement.specifier.contains(dist.version): + deps_unfulfilled.append(requirement.name) + errors.append(_incorrect_version(requirement, dist.version, extra)) + + if deps_unfulfilled: + for err in errors: + logging.error(err) + + raise DependencyException(deps_unfulfilled) diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py new file mode 100644 index 000000000..3c0725225 --- /dev/null +++ b/tests/util/test_check_dependencies.py @@ -0,0 +1,95 @@ +from contextlib import contextmanager +from typing import Generator, Optional +from unittest.mock import patch + +from synapse.util.check_dependencies import ( + DependencyException, + check_requirements, + metadata, +) + +from tests.unittest import TestCase + + +class DummyDistribution(metadata.Distribution): + def __init__(self, version: str): + self._version = version + + @property + def version(self): + return self._version + + def locate_file(self, path): + raise NotImplementedError() + + def read_text(self, filename): + raise NotImplementedError() + + +old = DummyDistribution("0.1.2") +new = DummyDistribution("1.2.3") + +# could probably use stdlib TestCase --- no need for twisted here + + +class TestDependencyChecker(TestCase): + @contextmanager + def mock_installed_package( + self, distribution: Optional[DummyDistribution] + ) -> Generator[None, None, None]: + """Pretend that looking up any distribution yields the given `distribution`.""" + + def mock_distribution(name: str): + if distribution is None: + raise metadata.PackageNotFoundError + else: + return distribution + + with patch( + "synapse.util.check_dependencies.metadata.distribution", + mock_distribution, + ): + yield + + def test_mandatory_dependency(self) -> None: + """Complain if a required package is missing or old.""" + with patch( + "synapse.util.check_dependencies.metadata.requires", + return_value=["dummypkg >= 1"], + ): + with self.mock_installed_package(None): + self.assertRaises(DependencyException, check_requirements) + with self.mock_installed_package(old): + self.assertRaises(DependencyException, check_requirements) + with self.mock_installed_package(new): + # should not raise + check_requirements() + + def test_generic_check_of_optional_dependency(self) -> None: + """Complain if an optional package is old.""" + with patch( + "synapse.util.check_dependencies.metadata.requires", + return_value=["dummypkg >= 1; extra == 'cool-extra'"], + ): + with self.mock_installed_package(None): + # should not raise + check_requirements() + with self.mock_installed_package(old): + self.assertRaises(DependencyException, check_requirements) + with self.mock_installed_package(new): + # should not raise + check_requirements() + + def test_check_for_extra_dependencies(self) -> None: + """Complain if a package required for an extra is missing or old.""" + with patch( + "synapse.util.check_dependencies.metadata.requires", + return_value=["dummypkg >= 1; extra == 'cool-extra'"], + ), patch("synapse.util.check_dependencies.EXTRAS", {"cool-extra"}): + with self.mock_installed_package(None): + self.assertRaises(DependencyException, check_requirements, "cool-extra") + with self.mock_installed_package(old): + self.assertRaises(DependencyException, check_requirements, "cool-extra") + with self.mock_installed_package(new): + # should not raise + check_requirements() From 5f62a094de10b4c4382908231128dace833a1195 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 1 Mar 2022 19:47:02 +0000 Subject: [PATCH 78/84] Detox, part 1 of N (#12119) * Don't use `tox` for `check-sampleconfig` * Don't use `tox` for check-newsfragment --- .github/workflows/tests.yml | 13 ++++++++++--- changelog.d/12119.misc | 1 + scripts-dev/check-newsfragment | 2 +- tox.ini | 10 ---------- 4 files changed, 12 insertions(+), 14 deletions(-) create mode 100644 changelog.d/12119.misc diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index bbf1033bd..e9e427732 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -10,12 +10,19 @@ concurrency: cancel-in-progress: true jobs: + check-sampleconfig: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + - run: pip install -e . + - run: scripts-dev/generate_sample_config --check + lint: runs-on: ubuntu-latest strategy: matrix: toxenv: - - "check-sampleconfig" - "check_codestyle" - "check_isort" - "mypy" @@ -43,7 +50,7 @@ jobs: ref: ${{ github.event.pull_request.head.sha }} fetch-depth: 0 - uses: actions/setup-python@v2 - - run: pip install tox + - run: "pip install 'towncrier>=18.6.0rc1'" - run: scripts-dev/check-newsfragment env: PULL_REQUEST_NUMBER: ${{ github.event.number }} @@ -51,7 +58,7 @@ jobs: # Dummy step to gate other tests on without repeating the whole list linting-done: if: ${{ !cancelled() }} # Run this even if prior jobs were skipped - needs: [lint, lint-crlf, lint-newsfile] + needs: [lint, lint-crlf, lint-newsfile, check-sampleconfig] runs-on: ubuntu-latest steps: - run: "true" diff --git a/changelog.d/12119.misc b/changelog.d/12119.misc new file mode 100644 index 000000000..f02d140f3 --- /dev/null +++ b/changelog.d/12119.misc @@ -0,0 +1 @@ +Move CI checks out of tox, to facilitate a move to using poetry. \ No newline at end of file diff --git a/scripts-dev/check-newsfragment b/scripts-dev/check-newsfragment index c764011d6..493558ad6 100755 --- a/scripts-dev/check-newsfragment +++ b/scripts-dev/check-newsfragment @@ -35,7 +35,7 @@ CONTRIBUTING_GUIDE_TEXT="!! Please see the contributing guide for help writing y https://github.com/matrix-org/synapse/blob/develop/CONTRIBUTING.md#changelog" # If check-newsfragment returns a non-zero exit code, print the contributing guide and exit -tox -qe check-newsfragment || (echo -e "$CONTRIBUTING_GUIDE_TEXT" >&2 && exit 1) +python -m towncrier.check --compare-with=origin/develop || (echo -e "$CONTRIBUTING_GUIDE_TEXT" >&2 && exit 1) echo echo "--------------------------" diff --git a/tox.ini b/tox.ini index 436ecf755..04b972e2c 100644 --- a/tox.ini +++ b/tox.ini @@ -168,16 +168,6 @@ commands = extras = lint commands = isort -c --df {[base]lint_targets} -[testenv:check-newsfragment] -skip_install = true -usedevelop = false -deps = towncrier>=18.6.0rc1 -commands = - python -m towncrier.check --compare-with=origin/develop - -[testenv:check-sampleconfig] -commands = {toxinidir}/scripts-dev/generate_sample_config --check - [testenv:combine] skip_install = true usedevelop = false From 8e56a1b73c9819ea4bddbe6a4734966e70b3b92c Mon Sep 17 00:00:00 2001 From: lukasdenk <63459921+lukasdenk@users.noreply.github.com> Date: Wed, 2 Mar 2022 11:35:34 +0100 Subject: [PATCH 79/84] Make get_room_version use cached get_room_version_id. (#11808) --- changelog.d/11808.misc | 1 + synapse/storage/databases/main/state.py | 27 ++++++++++++------------- tests/handlers/test_room_summary.py | 5 ++++- 3 files changed, 18 insertions(+), 15 deletions(-) create mode 100644 changelog.d/11808.misc diff --git a/changelog.d/11808.misc b/changelog.d/11808.misc new file mode 100644 index 000000000..cdc5fc75b --- /dev/null +++ b/changelog.d/11808.misc @@ -0,0 +1 @@ +Make method `get_room_version` use cached `get_room_version_id`. diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 2fb3e6519..417aef1db 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -42,6 +42,16 @@ logger = logging.getLogger(__name__) MAX_STATE_DELTA_HOPS = 100 +def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion: + v = KNOWN_ROOM_VERSIONS.get(room_version_id) + if not v: + raise UnsupportedRoomVersionError( + "Room %s uses a room version %s which is no longer supported" + % (room_id, room_version_id) + ) + return v + + # this inherits from EventsWorkerStore because it calls self.get_events class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """The parts of StateGroupStore that can be called from workers.""" @@ -62,11 +72,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): Typically this happens if support for the room's version has been removed from Synapse. """ - return await self.db_pool.runInteraction( - "get_room_version_txn", - self.get_room_version_txn, - room_id, - ) + room_version_id = await self.get_room_version_id(room_id) + return _retrieve_and_check_room_version(room_id, room_version_id) def get_room_version_txn( self, txn: LoggingTransaction, room_id: str @@ -82,15 +89,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): removed from Synapse. """ room_version_id = self.get_room_version_id_txn(txn, room_id) - v = KNOWN_ROOM_VERSIONS.get(room_version_id) - - if not v: - raise UnsupportedRoomVersionError( - "Room %s uses a room version %s which is no longer supported" - % (room_id, room_version_id) - ) - - return v + return _retrieve_and_check_room_version(room_id, room_version_id) @cached(max_entries=10000) async def get_room_version_id(self, room_id: str) -> str: diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py index b33ff94a3..cff07a897 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py @@ -658,7 +658,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): def test_unknown_room_version(self): """ - If an room with an unknown room version is encountered it should not cause + If a room with an unknown room version is encountered it should not cause the entire summary to skip. """ # Poke the database and update the room version to an unknown one. @@ -670,6 +670,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): desc="updated-room-version", ) ) + # Invalidate method so that it returns the currently updated version + # instead of the cached version. + self.hs.get_datastores().main.get_room_version_id.invalidate((self.room,)) # The result should have only the space, along with a link from space -> room. expected = [(self.space, [self.room])] From c7b2f1ccdc412c4f5f07f4fe630d2c2040caf93d Mon Sep 17 00:00:00 2001 From: reivilibre Date: Wed, 2 Mar 2022 10:37:04 +0000 Subject: [PATCH 80/84] Back out in-flight state caching changes. (#12126) --- changelog.d/10870.misc | 1 - changelog.d/11608.misc | 1 - changelog.d/11610.misc | 1 - changelog.d/12033.misc | 1 - changelog.d/12126.removal | 1 + synapse/storage/databases/state/store.py | 243 ++--------- tests/storage/databases/test_state_store.py | 454 -------------------- 7 files changed, 26 insertions(+), 676 deletions(-) delete mode 100644 changelog.d/10870.misc delete mode 100644 changelog.d/11608.misc delete mode 100644 changelog.d/11610.misc delete mode 100644 changelog.d/12033.misc create mode 100644 changelog.d/12126.removal delete mode 100644 tests/storage/databases/test_state_store.py diff --git a/changelog.d/10870.misc b/changelog.d/10870.misc deleted file mode 100644 index 3af049b96..000000000 --- a/changelog.d/10870.misc +++ /dev/null @@ -1 +0,0 @@ -Deduplicate in-flight requests in `_get_state_for_groups`. diff --git a/changelog.d/11608.misc b/changelog.d/11608.misc deleted file mode 100644 index 3af049b96..000000000 --- a/changelog.d/11608.misc +++ /dev/null @@ -1 +0,0 @@ -Deduplicate in-flight requests in `_get_state_for_groups`. diff --git a/changelog.d/11610.misc b/changelog.d/11610.misc deleted file mode 100644 index 3af049b96..000000000 --- a/changelog.d/11610.misc +++ /dev/null @@ -1 +0,0 @@ -Deduplicate in-flight requests in `_get_state_for_groups`. diff --git a/changelog.d/12033.misc b/changelog.d/12033.misc deleted file mode 100644 index 3af049b96..000000000 --- a/changelog.d/12033.misc +++ /dev/null @@ -1 +0,0 @@ -Deduplicate in-flight requests in `_get_state_for_groups`. diff --git a/changelog.d/12126.removal b/changelog.d/12126.removal new file mode 100644 index 000000000..8c8bf6ee7 --- /dev/null +++ b/changelog.d/12126.removal @@ -0,0 +1 @@ +Back out in-flight state caching changes. \ No newline at end of file diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index dadf3d1e3..7614d76ac 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -13,24 +13,11 @@ # limitations under the License. import logging -from typing import ( - TYPE_CHECKING, - Collection, - Dict, - Iterable, - Optional, - Sequence, - Set, - Tuple, -) +from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple import attr -from sortedcontainers import SortedDict - -from twisted.internet import defer from synapse.api.constants import EventTypes -from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -42,12 +29,6 @@ from synapse.storage.state import StateFilter from synapse.storage.types import Cursor from synapse.storage.util.sequence import build_sequence_generator from synapse.types import MutableStateMap, StateKey, StateMap -from synapse.util import unwrapFirstError -from synapse.util.async_helpers import ( - AbstractObservableDeferred, - ObservableDeferred, - yieldable_gather_results, -) from synapse.util.caches.descriptors import cached from synapse.util.caches.dictionary_cache import DictionaryCache @@ -56,8 +37,8 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) + MAX_STATE_DELTA_HOPS = 100 -MAX_INFLIGHT_REQUESTS_PER_GROUP = 5 @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -73,24 +54,6 @@ class _GetStateGroupDelta: return len(self.delta_ids) if self.delta_ids else 0 -def state_filter_rough_priority_comparator( - state_filter: StateFilter, -) -> Tuple[int, int]: - """ - Returns a comparable value that roughly indicates the relative size of this - state filter compared to others. - 'Larger' state filters should sort first when using ascending order, so - this is essentially the opposite of 'size'. - It should be treated as a rough guide only and should not be interpreted to - have any particular meaning. The representation may also change - - The current implementation returns a tuple of the form: - * -1 for include_others, 0 otherwise - * -(number of entries in state_filter.types) - """ - return -int(state_filter.include_others), -len(state_filter.types) - - class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): """A data store for fetching/storing state groups.""" @@ -143,12 +106,6 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): 500000, ) - # Current ongoing get_state_for_groups in-flight requests - # {group ID -> {StateFilter -> ObservableDeferred}} - self._state_group_inflight_requests: Dict[ - int, SortedDict[StateFilter, AbstractObservableDeferred[StateMap[str]]] - ] = {} - def get_max_state_group_txn(txn: Cursor) -> int: txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") return txn.fetchone()[0] # type: ignore @@ -200,7 +157,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): ) async def _get_state_groups_from_groups( - self, groups: Sequence[int], state_filter: StateFilter + self, groups: List[int], state_filter: StateFilter ) -> Dict[int, StateMap[str]]: """Returns the state groups for a given set of groups from the database, filtering on types of state events. @@ -271,170 +228,6 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return state_filter.filter_state(state_dict_ids), not missing_types - def _get_state_for_group_gather_inflight_requests( - self, group: int, state_filter_left_over: StateFilter - ) -> Tuple[Sequence[AbstractObservableDeferred[StateMap[str]]], StateFilter]: - """ - Attempts to gather in-flight requests and re-use them to retrieve state - for the given state group, filtered with the given state filter. - - If there are more than MAX_INFLIGHT_REQUESTS_PER_GROUP in-flight requests, - and there *still* isn't enough information to complete the request by solely - reusing others, a full state filter will be requested to ensure that subsequent - requests can reuse this request. - - Used as part of _get_state_for_group_using_inflight_cache. - - Returns: - Tuple of two values: - A sequence of ObservableDeferreds to observe - A StateFilter representing what else needs to be requested to fulfill the request - """ - - inflight_requests = self._state_group_inflight_requests.get(group) - if inflight_requests is None: - # no requests for this group, need to retrieve it all ourselves - return (), state_filter_left_over - - # The list of ongoing requests which will help narrow the current request. - reusable_requests = [] - - # Iterate over existing requests in roughly biggest-first order. - for request_state_filter in inflight_requests: - request_deferred = inflight_requests[request_state_filter] - new_state_filter_left_over = state_filter_left_over.approx_difference( - request_state_filter - ) - if new_state_filter_left_over == state_filter_left_over: - # Reusing this request would not gain us anything, so don't bother. - continue - - reusable_requests.append(request_deferred) - state_filter_left_over = new_state_filter_left_over - if state_filter_left_over == StateFilter.none(): - # we have managed to collect enough of the in-flight requests - # to cover our StateFilter and give us the state we need. - break - - if ( - state_filter_left_over != StateFilter.none() - and len(inflight_requests) >= MAX_INFLIGHT_REQUESTS_PER_GROUP - ): - # There are too many requests for this group. - # To prevent even more from building up, we request the whole - # state filter to guarantee that we can be reused by any subsequent - # requests for this state group. - return (), StateFilter.all() - - return reusable_requests, state_filter_left_over - - async def _get_state_for_group_fire_request( - self, group: int, state_filter: StateFilter - ) -> StateMap[str]: - """ - Fires off a request to get the state at a state group, - potentially filtering by type and/or state key. - - This request will be tracked in the in-flight request cache and automatically - removed when it is finished. - - Used as part of _get_state_for_group_using_inflight_cache. - - Args: - group: ID of the state group for which we want to get state - state_filter: the state filter used to fetch state from the database - """ - cache_sequence_nm = self._state_group_cache.sequence - cache_sequence_m = self._state_group_members_cache.sequence - - # Help the cache hit ratio by expanding the filter a bit - db_state_filter = state_filter.return_expanded() - - async def _the_request() -> StateMap[str]: - group_to_state_dict = await self._get_state_groups_from_groups( - (group,), state_filter=db_state_filter - ) - - # Now let's update the caches - self._insert_into_cache( - group_to_state_dict, - db_state_filter, - cache_seq_num_members=cache_sequence_m, - cache_seq_num_non_members=cache_sequence_nm, - ) - - # Remove ourselves from the in-flight cache - group_request_dict = self._state_group_inflight_requests[group] - del group_request_dict[db_state_filter] - if not group_request_dict: - # If there are no more requests in-flight for this group, - # clean up the cache by removing the empty dictionary - del self._state_group_inflight_requests[group] - - return group_to_state_dict[group] - - # We don't immediately await the result, so must use run_in_background - # But we DO await the result before the current log context (request) - # finishes, so don't need to run it as a background process. - request_deferred = run_in_background(_the_request) - observable_deferred = ObservableDeferred(request_deferred, consumeErrors=True) - - # Insert the ObservableDeferred into the cache - group_request_dict = self._state_group_inflight_requests.setdefault( - group, SortedDict(state_filter_rough_priority_comparator) - ) - group_request_dict[db_state_filter] = observable_deferred - - return await make_deferred_yieldable(observable_deferred.observe()) - - async def _get_state_for_group_using_inflight_cache( - self, group: int, state_filter: StateFilter - ) -> MutableStateMap[str]: - """ - Gets the state at a state group, potentially filtering by type and/or - state key. - - 1. Calls _get_state_for_group_gather_inflight_requests to gather any - ongoing requests which might overlap with the current request. - 2. Fires a new request, using _get_state_for_group_fire_request, - for any state which cannot be gathered from ongoing requests. - - Args: - group: ID of the state group for which we want to get state - state_filter: the state filter used to fetch state from the database - Returns: - state map - """ - - # first, figure out whether we can re-use any in-flight requests - # (and if so, what would be left over) - ( - reusable_requests, - state_filter_left_over, - ) = self._get_state_for_group_gather_inflight_requests(group, state_filter) - - if state_filter_left_over != StateFilter.none(): - # Fetch remaining state - remaining = await self._get_state_for_group_fire_request( - group, state_filter_left_over - ) - assembled_state: MutableStateMap[str] = dict(remaining) - else: - assembled_state = {} - - gathered = await make_deferred_yieldable( - defer.gatherResults( - (r.observe() for r in reusable_requests), consumeErrors=True - ) - ).addErrback(unwrapFirstError) - - # assemble our result. - for result_piece in gathered: - assembled_state.update(result_piece) - - # Filter out any state that may be more than what we asked for. - return state_filter.filter_state(assembled_state) - async def _get_state_for_groups( self, groups: Iterable[int], state_filter: Optional[StateFilter] = None ) -> Dict[int, MutableStateMap[str]]: @@ -476,17 +269,31 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): if not incomplete_groups: return state - async def get_from_cache(group: int, state_filter: StateFilter) -> None: - state[group] = await self._get_state_for_group_using_inflight_cache( - group, state_filter - ) + cache_sequence_nm = self._state_group_cache.sequence + cache_sequence_m = self._state_group_members_cache.sequence - await yieldable_gather_results( - get_from_cache, - incomplete_groups, - state_filter, + # Help the cache hit ratio by expanding the filter a bit + db_state_filter = state_filter.return_expanded() + + group_to_state_dict = await self._get_state_groups_from_groups( + list(incomplete_groups), state_filter=db_state_filter ) + # Now lets update the caches + self._insert_into_cache( + group_to_state_dict, + db_state_filter, + cache_seq_num_members=cache_sequence_m, + cache_seq_num_non_members=cache_sequence_nm, + ) + + # And finally update the result dict, by filtering out any extra + # stuff we pulled out of the database. + for group, group_state_dict in group_to_state_dict.items(): + # We just replace any existing entries, as we will have loaded + # everything we need from the database anyway. + state[group] = state_filter.filter_state(group_state_dict) + return state def _get_state_for_groups_using_cache( diff --git a/tests/storage/databases/test_state_store.py b/tests/storage/databases/test_state_store.py deleted file mode 100644 index 2b484c95a..000000000 --- a/tests/storage/databases/test_state_store.py +++ /dev/null @@ -1,454 +0,0 @@ -# Copyright 2022 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import typing -from typing import Dict, List, Sequence, Tuple -from unittest.mock import patch - -from parameterized import parameterized - -from twisted.internet.defer import Deferred, ensureDeferred -from twisted.test.proto_helpers import MemoryReactor - -from synapse.api.constants import EventTypes -from synapse.storage.databases.state.store import ( - MAX_INFLIGHT_REQUESTS_PER_GROUP, - state_filter_rough_priority_comparator, -) -from synapse.storage.state import StateFilter -from synapse.types import StateMap -from synapse.util import Clock - -from tests.unittest import HomeserverTestCase - -if typing.TYPE_CHECKING: - from synapse.server import HomeServer - -# StateFilter for ALL non-m.room.member state events -ALL_NON_MEMBERS_STATE_FILTER = StateFilter.freeze( - types={EventTypes.Member: set()}, - include_others=True, -) - -FAKE_STATE = { - (EventTypes.Member, "@alice:test"): "join", - (EventTypes.Member, "@bob:test"): "leave", - (EventTypes.Member, "@charlie:test"): "invite", - ("test.type", "a"): "AAA", - ("test.type", "b"): "BBB", - ("other.event.type", "state.key"): "123", -} - - -class StateGroupInflightCachingTestCase(HomeserverTestCase): - def prepare( - self, reactor: MemoryReactor, clock: Clock, homeserver: "HomeServer" - ) -> None: - self.state_storage = homeserver.get_storage().state - self.state_datastore = homeserver.get_datastores().state - # Patch out the `_get_state_groups_from_groups`. - # This is useful because it lets us pretend we have a slow database. - get_state_groups_patch = patch.object( - self.state_datastore, - "_get_state_groups_from_groups", - self._fake_get_state_groups_from_groups, - ) - get_state_groups_patch.start() - - self.addCleanup(get_state_groups_patch.stop) - self.get_state_group_calls: List[ - Tuple[Tuple[int, ...], StateFilter, Deferred[Dict[int, StateMap[str]]]] - ] = [] - - def _fake_get_state_groups_from_groups( - self, groups: Sequence[int], state_filter: StateFilter - ) -> "Deferred[Dict[int, StateMap[str]]]": - d: Deferred[Dict[int, StateMap[str]]] = Deferred() - self.get_state_group_calls.append((tuple(groups), state_filter, d)) - return d - - def _complete_request_fake( - self, - groups: Tuple[int, ...], - state_filter: StateFilter, - d: "Deferred[Dict[int, StateMap[str]]]", - ) -> None: - """ - Assemble a fake database response and complete the database request. - """ - - # Return a filtered copy of the fake state - d.callback({group: state_filter.filter_state(FAKE_STATE) for group in groups}) - - def test_duplicate_requests_deduplicated(self) -> None: - """ - Tests that duplicate requests for state are deduplicated. - - This test: - - requests some state (state group 42, 'all' state filter) - - requests it again, before the first request finishes - - checks to see that only one database query was made - - completes the database query - - checks that both requests see the same retrieved state - """ - req1 = ensureDeferred( - self.state_datastore._get_state_for_group_using_inflight_cache( - 42, StateFilter.all() - ) - ) - self.pump(by=0.1) - - # This should have gone to the database - self.assertEqual(len(self.get_state_group_calls), 1) - self.assertFalse(req1.called) - - req2 = ensureDeferred( - self.state_datastore._get_state_for_group_using_inflight_cache( - 42, StateFilter.all() - ) - ) - self.pump(by=0.1) - - # No more calls should have gone to the database - self.assertEqual(len(self.get_state_group_calls), 1) - self.assertFalse(req1.called) - self.assertFalse(req2.called) - - groups, sf, d = self.get_state_group_calls[0] - self.assertEqual(groups, (42,)) - self.assertEqual(sf, StateFilter.all()) - - # Now we can complete the request - self._complete_request_fake(groups, sf, d) - - self.assertEqual(self.get_success(req1), FAKE_STATE) - self.assertEqual(self.get_success(req2), FAKE_STATE) - - def test_smaller_request_deduplicated(self) -> None: - """ - Tests that duplicate requests for state are deduplicated. - - This test: - - requests some state (state group 42, 'all' state filter) - - requests a subset of that state, before the first request finishes - - checks to see that only one database query was made - - completes the database query - - checks that both requests see the correct retrieved state - """ - req1 = ensureDeferred( - self.state_datastore._get_state_for_group_using_inflight_cache( - 42, StateFilter.from_types((("test.type", None),)) - ) - ) - self.pump(by=0.1) - - # This should have gone to the database - self.assertEqual(len(self.get_state_group_calls), 1) - self.assertFalse(req1.called) - - req2 = ensureDeferred( - self.state_datastore._get_state_for_group_using_inflight_cache( - 42, StateFilter.from_types((("test.type", "b"),)) - ) - ) - self.pump(by=0.1) - - # No more calls should have gone to the database, because the second - # request was already in the in-flight cache! - self.assertEqual(len(self.get_state_group_calls), 1) - self.assertFalse(req1.called) - self.assertFalse(req2.called) - - groups, sf, d = self.get_state_group_calls[0] - self.assertEqual(groups, (42,)) - # The state filter is expanded internally for increased cache hit rate, - # so we the database sees a wider state filter than requested. - self.assertEqual(sf, ALL_NON_MEMBERS_STATE_FILTER) - - # Now we can complete the request - self._complete_request_fake(groups, sf, d) - - self.assertEqual( - self.get_success(req1), - {("test.type", "a"): "AAA", ("test.type", "b"): "BBB"}, - ) - self.assertEqual(self.get_success(req2), {("test.type", "b"): "BBB"}) - - def test_partially_overlapping_request_deduplicated(self) -> None: - """ - Tests that partially-overlapping requests are partially deduplicated. - - This test: - - requests a single type of wildcard state - (This is internally expanded to be all non-member state) - - requests the entire state in parallel - - checks to see that two database queries were made, but that the second - one is only for member state. - - completes the database queries - - checks that both requests have the correct result. - """ - - req1 = ensureDeferred( - self.state_datastore._get_state_for_group_using_inflight_cache( - 42, StateFilter.from_types((("test.type", None),)) - ) - ) - self.pump(by=0.1) - - # This should have gone to the database - self.assertEqual(len(self.get_state_group_calls), 1) - self.assertFalse(req1.called) - - req2 = ensureDeferred( - self.state_datastore._get_state_for_group_using_inflight_cache( - 42, StateFilter.all() - ) - ) - self.pump(by=0.1) - - # Because it only partially overlaps, this also went to the database - self.assertEqual(len(self.get_state_group_calls), 2) - self.assertFalse(req1.called) - self.assertFalse(req2.called) - - # First request: - groups, sf, d = self.get_state_group_calls[0] - self.assertEqual(groups, (42,)) - # The state filter is expanded internally for increased cache hit rate, - # so we the database sees a wider state filter than requested. - self.assertEqual(sf, ALL_NON_MEMBERS_STATE_FILTER) - self._complete_request_fake(groups, sf, d) - - # Second request: - groups, sf, d = self.get_state_group_calls[1] - self.assertEqual(groups, (42,)) - # The state filter is narrowed to only request membership state, because - # the remainder of the state is already being queried in the first request! - self.assertEqual( - sf, StateFilter.freeze({EventTypes.Member: None}, include_others=False) - ) - self._complete_request_fake(groups, sf, d) - - # Check the results are correct - self.assertEqual( - self.get_success(req1), - {("test.type", "a"): "AAA", ("test.type", "b"): "BBB"}, - ) - self.assertEqual(self.get_success(req2), FAKE_STATE) - - def test_in_flight_requests_stop_being_in_flight(self) -> None: - """ - Tests that in-flight request deduplication doesn't somehow 'hold on' - to completed requests: once they're done, they're taken out of the - in-flight cache. - """ - req1 = ensureDeferred( - self.state_datastore._get_state_for_group_using_inflight_cache( - 42, StateFilter.all() - ) - ) - self.pump(by=0.1) - - # This should have gone to the database - self.assertEqual(len(self.get_state_group_calls), 1) - self.assertFalse(req1.called) - - # Complete the request right away. - self._complete_request_fake(*self.get_state_group_calls[0]) - self.assertTrue(req1.called) - - # Send off another request - req2 = ensureDeferred( - self.state_datastore._get_state_for_group_using_inflight_cache( - 42, StateFilter.all() - ) - ) - self.pump(by=0.1) - - # It should have gone to the database again, because the previous request - # isn't in-flight and therefore isn't available for deduplication. - self.assertEqual(len(self.get_state_group_calls), 2) - self.assertFalse(req2.called) - - # Complete the request right away. - self._complete_request_fake(*self.get_state_group_calls[1]) - self.assertTrue(req2.called) - groups, sf, d = self.get_state_group_calls[0] - - self.assertEqual(self.get_success(req1), FAKE_STATE) - self.assertEqual(self.get_success(req2), FAKE_STATE) - - def test_inflight_requests_capped(self) -> None: - """ - Tests that the number of in-flight requests is capped to 5. - - - requests several pieces of state separately - (5 to hit the limit, 1 to 'shunt out', another that comes after the - group has been 'shunted out') - - checks to see that the torrent of requests is shunted out by - rewriting one of the filters as the 'all' state filter - - requests after that one do not cause any additional queries - """ - # 5 at the time of writing. - CAP_COUNT = MAX_INFLIGHT_REQUESTS_PER_GROUP - - reqs = [] - - # Request 7 different keys (1 to 7) of the `some.state` type. - for req_id in range(CAP_COUNT + 2): - reqs.append( - ensureDeferred( - self.state_datastore._get_state_for_group_using_inflight_cache( - 42, - StateFilter.freeze( - {"some.state": {str(req_id + 1)}}, include_others=False - ), - ) - ) - ) - self.pump(by=0.1) - - # There should only be 6 calls to the database, not 7. - self.assertEqual(len(self.get_state_group_calls), CAP_COUNT + 1) - - # Assert that the first 5 are exact requests for the individual pieces - # wanted - for req_id in range(CAP_COUNT): - groups, sf, d = self.get_state_group_calls[req_id] - self.assertEqual( - sf, - StateFilter.freeze( - {"some.state": {str(req_id + 1)}}, include_others=False - ), - ) - - # The 6th request should be the 'all' state filter - groups, sf, d = self.get_state_group_calls[CAP_COUNT] - self.assertEqual(sf, StateFilter.all()) - - # Complete the queries and check which requests complete as a result - for req_id in range(CAP_COUNT): - # This request should not have been completed yet - self.assertFalse(reqs[req_id].called) - - groups, sf, d = self.get_state_group_calls[req_id] - self._complete_request_fake(groups, sf, d) - - # This should have only completed this one request - self.assertTrue(reqs[req_id].called) - - # Now complete the final query; the last 2 requests should complete - # as a result - self.assertFalse(reqs[CAP_COUNT].called) - self.assertFalse(reqs[CAP_COUNT + 1].called) - groups, sf, d = self.get_state_group_calls[CAP_COUNT] - self._complete_request_fake(groups, sf, d) - self.assertTrue(reqs[CAP_COUNT].called) - self.assertTrue(reqs[CAP_COUNT + 1].called) - - @parameterized.expand([(False,), (True,)]) - def test_ordering_of_request_reuse(self, reverse: bool) -> None: - """ - Tests that 'larger' in-flight requests are ordered first. - - This is mostly a design decision in order to prevent a request from - hanging on to multiple queries when it would have been sufficient to - hang on to only one bigger query. - - The 'size' of a state filter is a rough heuristic. - - - requests two pieces of state, one 'larger' than the other, but each - spawning a query - - requests a third piece of state - - completes the larger of the first two queries - - checks that the third request gets completed (and doesn't needlessly - wait for the other query) - - Parameters: - reverse: whether to reverse the order of the initial requests, to ensure - that the effect doesn't depend on the order of request submission. - """ - - # We add in an extra state type to make sure that both requests spawn - # queries which are not optimised out. - state_filters = [ - StateFilter.freeze( - {"state.type": {"A"}, "other.state.type": {"a"}}, include_others=False - ), - StateFilter.freeze( - { - "state.type": None, - "other.state.type": {"b"}, - # The current rough size comparator uses the number of state types - # as an indicator of size. - # To influence it to make this state filter bigger than the previous one, - # we add another dummy state type. - "extra.state.type": {"c"}, - }, - include_others=False, - ), - ] - - if reverse: - # For fairness, we perform one test run with the list reversed. - state_filters.reverse() - smallest_state_filter_idx = 1 - biggest_state_filter_idx = 0 - else: - smallest_state_filter_idx = 0 - biggest_state_filter_idx = 1 - - # This assertion is for our own sanity more than anything else. - self.assertLess( - state_filter_rough_priority_comparator( - state_filters[biggest_state_filter_idx] - ), - state_filter_rough_priority_comparator( - state_filters[smallest_state_filter_idx] - ), - "Test invalid: bigger state filter is not actually bigger.", - ) - - # Spawn the initial two requests - for state_filter in state_filters: - ensureDeferred( - self.state_datastore._get_state_for_group_using_inflight_cache( - 42, - state_filter, - ) - ) - - # Spawn a third request - req = ensureDeferred( - self.state_datastore._get_state_for_group_using_inflight_cache( - 42, - StateFilter.freeze( - { - "state.type": {"A"}, - }, - include_others=False, - ), - ) - ) - self.pump(by=0.1) - - self.assertFalse(req.called) - - # Complete the largest request's query to make sure that the final request - # only waits for that one (and doesn't needlessly wait for both queries) - self._complete_request_fake( - *self.get_state_group_calls[biggest_state_filter_idx] - ) - - # That should have been sufficient to complete the third request - self.assertTrue(req.called) From a43a5ea5bfefcb25d90209958fb014a3b5e0ead0 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 2 Mar 2022 10:38:10 +0000 Subject: [PATCH 81/84] Remove misleading newsfile from #12126 which backs out an unreleased change. --- changelog.d/12126.removal | 1 - 1 file changed, 1 deletion(-) delete mode 100644 changelog.d/12126.removal diff --git a/changelog.d/12126.removal b/changelog.d/12126.removal deleted file mode 100644 index 8c8bf6ee7..000000000 --- a/changelog.d/12126.removal +++ /dev/null @@ -1 +0,0 @@ -Back out in-flight state caching changes. \ No newline at end of file From 879e4a7bd7a90cda4c8ea908aede53d8e038ca7c Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 2 Mar 2022 10:45:16 +0000 Subject: [PATCH 82/84] 1.54.0rc1 --- CHANGES.md | 96 +++++++++++++++++++++++++++++++++++++++ changelog.d/11599.doc | 1 - changelog.d/11617.feature | 1 - changelog.d/11808.misc | 1 - changelog.d/11835.feature | 1 - changelog.d/11865.removal | 1 - changelog.d/11900.misc | 1 - changelog.d/11972.misc | 1 - changelog.d/11974.misc | 1 - changelog.d/11984.misc | 1 - changelog.d/11985.feature | 1 - changelog.d/11991.misc | 1 - changelog.d/11992.bugfix | 1 - changelog.d/11994.misc | 1 - changelog.d/11996.misc | 1 - changelog.d/11997.docker | 1 - changelog.d/11999.bugfix | 1 - changelog.d/12000.feature | 1 - changelog.d/12001.feature | 1 - changelog.d/12003.doc | 1 - changelog.d/12004.doc | 1 - changelog.d/12005.misc | 1 - changelog.d/12008.removal | 1 - changelog.d/12009.feature | 1 - changelog.d/12011.misc | 1 - changelog.d/12012.misc | 1 - changelog.d/12013.misc | 1 - changelog.d/12015.misc | 1 - changelog.d/12016.misc | 1 - changelog.d/12018.removal | 1 - changelog.d/12019.misc | 1 - changelog.d/12020.feature | 1 - changelog.d/12021.feature | 1 - changelog.d/12022.feature | 1 - changelog.d/12024.bugfix | 1 - changelog.d/12025.misc | 1 - changelog.d/12030.misc | 1 - changelog.d/12031.misc | 1 - changelog.d/12034.misc | 1 - changelog.d/12037.bugfix | 1 - changelog.d/12039.misc | 1 - changelog.d/12041.misc | 1 - changelog.d/12051.misc | 1 - changelog.d/12052.misc | 1 - changelog.d/12056.bugfix | 1 - changelog.d/12058.feature | 1 - changelog.d/12059.misc | 1 - changelog.d/12060.misc | 1 - changelog.d/12062.feature | 1 - changelog.d/12063.misc | 1 - changelog.d/12066.misc | 1 - changelog.d/12067.feature | 1 - changelog.d/12068.misc | 1 - changelog.d/12069.misc | 1 - changelog.d/12070.misc | 1 - changelog.d/12072.misc | 1 - changelog.d/12073.removal | 1 - changelog.d/12077.bugfix | 1 - changelog.d/12084.misc | 1 - changelog.d/12088.misc | 1 - changelog.d/12089.bugfix | 1 - changelog.d/12092.misc | 1 - changelog.d/12094.misc | 1 - changelog.d/12098.bugfix | 1 - changelog.d/12099.misc | 1 - changelog.d/12100.bugfix | 1 - changelog.d/12105.bugfix | 1 - changelog.d/12106.misc | 1 - changelog.d/12109.misc | 1 - changelog.d/12111.misc | 1 - changelog.d/12112.docker | 1 - changelog.d/12119.misc | 1 - debian/changelog | 6 +++ synapse/__init__.py | 2 +- 74 files changed, 103 insertions(+), 72 deletions(-) delete mode 100644 changelog.d/11599.doc delete mode 100644 changelog.d/11617.feature delete mode 100644 changelog.d/11808.misc delete mode 100644 changelog.d/11835.feature delete mode 100644 changelog.d/11865.removal delete mode 100644 changelog.d/11900.misc delete mode 100644 changelog.d/11972.misc delete mode 100644 changelog.d/11974.misc delete mode 100644 changelog.d/11984.misc delete mode 100644 changelog.d/11985.feature delete mode 100644 changelog.d/11991.misc delete mode 100644 changelog.d/11992.bugfix delete mode 100644 changelog.d/11994.misc delete mode 100644 changelog.d/11996.misc delete mode 100644 changelog.d/11997.docker delete mode 100644 changelog.d/11999.bugfix delete mode 100644 changelog.d/12000.feature delete mode 100644 changelog.d/12001.feature delete mode 100644 changelog.d/12003.doc delete mode 100644 changelog.d/12004.doc delete mode 100644 changelog.d/12005.misc delete mode 100644 changelog.d/12008.removal delete mode 100644 changelog.d/12009.feature delete mode 100644 changelog.d/12011.misc delete mode 100644 changelog.d/12012.misc delete mode 100644 changelog.d/12013.misc delete mode 100644 changelog.d/12015.misc delete mode 100644 changelog.d/12016.misc delete mode 100644 changelog.d/12018.removal delete mode 100644 changelog.d/12019.misc delete mode 100644 changelog.d/12020.feature delete mode 100644 changelog.d/12021.feature delete mode 100644 changelog.d/12022.feature delete mode 100644 changelog.d/12024.bugfix delete mode 100644 changelog.d/12025.misc delete mode 100644 changelog.d/12030.misc delete mode 100644 changelog.d/12031.misc delete mode 100644 changelog.d/12034.misc delete mode 100644 changelog.d/12037.bugfix delete mode 100644 changelog.d/12039.misc delete mode 100644 changelog.d/12041.misc delete mode 100644 changelog.d/12051.misc delete mode 100644 changelog.d/12052.misc delete mode 100644 changelog.d/12056.bugfix delete mode 100644 changelog.d/12058.feature delete mode 100644 changelog.d/12059.misc delete mode 100644 changelog.d/12060.misc delete mode 100644 changelog.d/12062.feature delete mode 100644 changelog.d/12063.misc delete mode 100644 changelog.d/12066.misc delete mode 100644 changelog.d/12067.feature delete mode 100644 changelog.d/12068.misc delete mode 100644 changelog.d/12069.misc delete mode 100644 changelog.d/12070.misc delete mode 100644 changelog.d/12072.misc delete mode 100644 changelog.d/12073.removal delete mode 100644 changelog.d/12077.bugfix delete mode 100644 changelog.d/12084.misc delete mode 100644 changelog.d/12088.misc delete mode 100644 changelog.d/12089.bugfix delete mode 100644 changelog.d/12092.misc delete mode 100644 changelog.d/12094.misc delete mode 100644 changelog.d/12098.bugfix delete mode 100644 changelog.d/12099.misc delete mode 100644 changelog.d/12100.bugfix delete mode 100644 changelog.d/12105.bugfix delete mode 100644 changelog.d/12106.misc delete mode 100644 changelog.d/12109.misc delete mode 100644 changelog.d/12111.misc delete mode 100644 changelog.d/12112.docker delete mode 100644 changelog.d/12119.misc diff --git a/CHANGES.md b/CHANGES.md index 81333097a..4f0318970 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,99 @@ +Synapse 1.54.0rc1 (2022-03-02) +============================== + +Features +-------- + +- Add support for MSC3202: sending one-time key counts and fallback key usage states to Application Services. ([\#11617](https://github.com/matrix-org/synapse/issues/11617)) +- Make a `POST` to `/rooms//receipt/m.read/` only trigger a push notification if the count of unread messages is different to the one in the last successfully sent push. ([\#11835](https://github.com/matrix-org/synapse/issues/11835)) +- Fetch images when previewing Twitter URLs. Contributed by @AndrewRyanChama. ([\#11985](https://github.com/matrix-org/synapse/issues/11985)) +- Track cache invalidations in Prometheus metrics, as already happens for cache eviction based on size or time. ([\#12000](https://github.com/matrix-org/synapse/issues/12000)) +- Implement experimental support for [MSC3720](https://github.com/matrix-org/matrix-doc/pull/3720) (account status endpoints). ([\#12001](https://github.com/matrix-org/synapse/issues/12001), [\#12067](https://github.com/matrix-org/synapse/issues/12067)) +- Enable modules to set a custom display name when registering a user. ([\#12009](https://github.com/matrix-org/synapse/issues/12009)) +- Advertise Matrix 1.1 support on `/_matrix/client/versions`. ([\#12020](https://github.com/matrix-org/synapse/issues/12020)) +- Support only the stable identifier for [MSC3069](https://github.com/matrix-org/matrix-doc/pull/3069)'s `is_guest` on `/_matrix/client/v3/account/whoami`. ([\#12021](https://github.com/matrix-org/synapse/issues/12021)) +- Advertise Matrix 1.2 support on `/_matrix/client/versions`. ([\#12022](https://github.com/matrix-org/synapse/issues/12022)) +- Use room version 9 as the default room version (per [MSC3589](https://github.com/matrix-org/matrix-doc/pull/3589)). ([\#12058](https://github.com/matrix-org/synapse/issues/12058)) +- Add module callbacks to react to user deactivation status changes (i.e. deactivations and reactivations) and profile updates. ([\#12062](https://github.com/matrix-org/synapse/issues/12062)) + + +Bugfixes +-------- + +- Fix a bug introduced in Synapse v1.48.0 where an edit of the latest event in a thread would not be properly applied to the thread summary. ([\#11992](https://github.com/matrix-org/synapse/issues/11992)) +- Fix long standing bug where `get_rooms_for_user` was not correctly invalidated for remote users when the server left a room. ([\#11999](https://github.com/matrix-org/synapse/issues/11999)) +- Fix 500 error with Postgres when looking backwards with the [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) `/timestamp_to_event?dir=b` endpoint. ([\#12024](https://github.com/matrix-org/synapse/issues/12024)) +- Properly fix a long-standing bug where wrong data could be inserted in the `event_search` table when using sqlite. This could block running `synapse_port_db` with an "argument of type 'int' is not iterable" error. This bug was partially fixed by a change in Synapse 1.44.0. ([\#12037](https://github.com/matrix-org/synapse/issues/12037)) +- Fix slow performance of `/logout` in some cases where refresh tokens are in use. The slowness existed since the initial implementation of refresh tokens. ([\#12056](https://github.com/matrix-org/synapse/issues/12056)) +- Fix a long-standing bug where Synapse would make additional failing requests over federation for missing data. ([\#12077](https://github.com/matrix-org/synapse/issues/12077)) +- Fix occasional 'Unhandled error in Deferred' error message. ([\#12089](https://github.com/matrix-org/synapse/issues/12089)) +- Fix a bug introduced in Synapse 1.51.0rc1 where incoming federation transactions containing at least one EDU would be dropped if debug logging was enabled for `synapse.8631_debug`. ([\#12098](https://github.com/matrix-org/synapse/issues/12098)) +- Fix a long-standing bug which could cause push notifications to malfunction if `use_frozen_dicts` was set in the configuration. ([\#12100](https://github.com/matrix-org/synapse/issues/12100)) +- Fix an extremely rare, long-standing bug in `ReadWriteLock` that would cause an error when a newly unblocked writer completes instantly. ([\#12105](https://github.com/matrix-org/synapse/issues/12105)) + + +Updates to the Docker image +--------------------------- + +- The docker image no longer automatically creates a temporary volume at `/data`. This is not expected to affect normal usage. ([\#11997](https://github.com/matrix-org/synapse/issues/11997)) +- Use Python 3.9 in Docker images by default. ([\#12112](https://github.com/matrix-org/synapse/issues/12112)) + + +Improved Documentation +---------------------- + +- Document support for the `to_device`, `account_data`, `receipts`, and `presence` stream writers for workers. ([\#11599](https://github.com/matrix-org/synapse/issues/11599)) +- Explain the meaning of spam checker callbacks' return values. ([\#12003](https://github.com/matrix-org/synapse/issues/12003)) +- Clarify information about external Identity Provider IDs. ([\#12004](https://github.com/matrix-org/synapse/issues/12004)) + + +Deprecations and Removals +------------------------- + +- Deprecate using `synctl` with the config option `synctl_cache_factor` and print a warning if a user still uses this option. ([\#11865](https://github.com/matrix-org/synapse/issues/11865)) +- Remove support for the legacy structured logging configuration (please see the the [upgrade notes](https://matrix-org.github.io/synapse/develop/upgrade#legacy-structured-logging-configuration-removal) if you are using `structured: true` in the Synapse configuration). ([\#12008](https://github.com/matrix-org/synapse/issues/12008)) +- Drop support for [MSC3283](https://github.com/matrix-org/matrix-doc/pull/3283) unstable flags now that the stable flags are supported. ([\#12018](https://github.com/matrix-org/synapse/issues/12018)) +- Remove the unstable `/spaces` endpoint from [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). ([\#12073](https://github.com/matrix-org/synapse/issues/12073)) + + +Internal Changes +---------------- + +- Make method `get_room_version` use cached `get_room_version_id`. ([\#11808](https://github.com/matrix-org/synapse/issues/11808)) +- Remove unnecessary condition on knock->leave auth rule check. ([\#11900](https://github.com/matrix-org/synapse/issues/11900)) +- Add tests for device list changes between local users. ([\#11972](https://github.com/matrix-org/synapse/issues/11972)) +- Optimise calculating device_list changes in `/sync`. ([\#11974](https://github.com/matrix-org/synapse/issues/11974)) +- Add missing type hints to storage classes. ([\#11984](https://github.com/matrix-org/synapse/issues/11984)) +- Refactor the search code for improved readability. ([\#11991](https://github.com/matrix-org/synapse/issues/11991)) +- Move common deduplication code down into `_auth_and_persist_outliers`. ([\#11994](https://github.com/matrix-org/synapse/issues/11994)) +- Limit concurrent joins from applications services. ([\#11996](https://github.com/matrix-org/synapse/issues/11996)) +- Preparation for faster-room-join work: when parsing the `send_join` response, get the `m.room.create` event from `state`, not `auth_chain`. ([\#12005](https://github.com/matrix-org/synapse/issues/12005), [\#12039](https://github.com/matrix-org/synapse/issues/12039)) +- Preparation for faster-room-join work: parse msc3706 fields in send_join response. ([\#12011](https://github.com/matrix-org/synapse/issues/12011)) +- Preparation for faster-room-join work: persist information on which events and rooms have partial state to the database. ([\#12012](https://github.com/matrix-org/synapse/issues/12012)) +- Preparation for faster-room-join work: Support for calling `/federation/v1/state` on a remote server. ([\#12013](https://github.com/matrix-org/synapse/issues/12013)) +- Configure `tox` to use `venv` rather than `virtualenv`. ([\#12015](https://github.com/matrix-org/synapse/issues/12015)) +- Fix bug in `StateFilter.return_expanded()` and add some tests. ([\#12016](https://github.com/matrix-org/synapse/issues/12016)) +- Use Matrix v1.1 endpoints (`/_matrix/client/v3/auth/...`) in fallback auth HTML forms. ([\#12019](https://github.com/matrix-org/synapse/issues/12019)) +- Update the `olddeps` CI job to use an old version of `markupsafe`. ([\#12025](https://github.com/matrix-org/synapse/issues/12025)) +- Upgrade mypy to version 0.931. ([\#12030](https://github.com/matrix-org/synapse/issues/12030)) +- Remove legacy `HomeServer.get_datastore()`. ([\#12031](https://github.com/matrix-org/synapse/issues/12031), [\#12070](https://github.com/matrix-org/synapse/issues/12070)) +- Minor typing fixes. ([\#12034](https://github.com/matrix-org/synapse/issues/12034), [\#12069](https://github.com/matrix-org/synapse/issues/12069)) +- After joining a room, create a dedicated logcontext to process the queued events. ([\#12041](https://github.com/matrix-org/synapse/issues/12041)) +- Tidy up GitHub Actions config which builds distributions for PyPI. ([\#12051](https://github.com/matrix-org/synapse/issues/12051)) +- Move configuration out of `setup.cfg`. ([\#12052](https://github.com/matrix-org/synapse/issues/12052), [\#12059](https://github.com/matrix-org/synapse/issues/12059)) +- Fix error message when a worker process fails to talk to another worker process. ([\#12060](https://github.com/matrix-org/synapse/issues/12060)) +- Fix using the complement.sh script without specifying a dir or a branch. Contributed by Nico on behalf of Famedly. ([\#12063](https://github.com/matrix-org/synapse/issues/12063)) +- Add type hints to `tests/rest/client`. ([\#12066](https://github.com/matrix-org/synapse/issues/12066), [\#12072](https://github.com/matrix-org/synapse/issues/12072), [\#12084](https://github.com/matrix-org/synapse/issues/12084), [\#12094](https://github.com/matrix-org/synapse/issues/12094)) +- Add some logging to `/sync` to try and track down #11916. ([\#12068](https://github.com/matrix-org/synapse/issues/12068)) +- Inspect application dependencies using `importlib.metadata` or its backport. ([\#12088](https://github.com/matrix-org/synapse/issues/12088)) +- User `assertEqual` instead of the deprecated `assertEquals` in test code. ([\#12092](https://github.com/matrix-org/synapse/issues/12092)) +- Move experimental support for [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440) to /versions. ([\#12099](https://github.com/matrix-org/synapse/issues/12099)) +- Add `stop_cancellation` utility function to stop `Deferred`s from being cancelled. ([\#12106](https://github.com/matrix-org/synapse/issues/12106)) +- Improve exception handling for concurrent execution. ([\#12109](https://github.com/matrix-org/synapse/issues/12109)) +- Advertise support for Python 3.10 in packaging files. ([\#12111](https://github.com/matrix-org/synapse/issues/12111)) +- Move CI checks out of tox, to facilitate a move to using poetry. ([\#12119](https://github.com/matrix-org/synapse/issues/12119)) + + Synapse 1.53.0 (2022-02-22) =========================== diff --git a/changelog.d/11599.doc b/changelog.d/11599.doc deleted file mode 100644 index f07cfbef4..000000000 --- a/changelog.d/11599.doc +++ /dev/null @@ -1 +0,0 @@ -Document support for the `to_device`, `account_data`, `receipts`, and `presence` stream writers for workers. diff --git a/changelog.d/11617.feature b/changelog.d/11617.feature deleted file mode 100644 index cf03f00e7..000000000 --- a/changelog.d/11617.feature +++ /dev/null @@ -1 +0,0 @@ -Add support for MSC3202: sending one-time key counts and fallback key usage states to Application Services. \ No newline at end of file diff --git a/changelog.d/11808.misc b/changelog.d/11808.misc deleted file mode 100644 index cdc5fc75b..000000000 --- a/changelog.d/11808.misc +++ /dev/null @@ -1 +0,0 @@ -Make method `get_room_version` use cached `get_room_version_id`. diff --git a/changelog.d/11835.feature b/changelog.d/11835.feature deleted file mode 100644 index 7cee39b08..000000000 --- a/changelog.d/11835.feature +++ /dev/null @@ -1 +0,0 @@ -Make a `POST` to `/rooms//receipt/m.read/` only trigger a push notification if the count of unread messages is different to the one in the last successfully sent push. diff --git a/changelog.d/11865.removal b/changelog.d/11865.removal deleted file mode 100644 index 9fcabfc72..000000000 --- a/changelog.d/11865.removal +++ /dev/null @@ -1 +0,0 @@ -Deprecate using `synctl` with the config option `synctl_cache_factor` and print a warning if a user still uses this option. diff --git a/changelog.d/11900.misc b/changelog.d/11900.misc deleted file mode 100644 index edd2852fd..000000000 --- a/changelog.d/11900.misc +++ /dev/null @@ -1 +0,0 @@ -Remove unnecessary condition on knock->leave auth rule check. \ No newline at end of file diff --git a/changelog.d/11972.misc b/changelog.d/11972.misc deleted file mode 100644 index 29c38bfd8..000000000 --- a/changelog.d/11972.misc +++ /dev/null @@ -1 +0,0 @@ -Add tests for device list changes between local users. \ No newline at end of file diff --git a/changelog.d/11974.misc b/changelog.d/11974.misc deleted file mode 100644 index 1debad236..000000000 --- a/changelog.d/11974.misc +++ /dev/null @@ -1 +0,0 @@ -Optimise calculating device_list changes in `/sync`. diff --git a/changelog.d/11984.misc b/changelog.d/11984.misc deleted file mode 100644 index 8e405b922..000000000 --- a/changelog.d/11984.misc +++ /dev/null @@ -1 +0,0 @@ -Add missing type hints to storage classes. \ No newline at end of file diff --git a/changelog.d/11985.feature b/changelog.d/11985.feature deleted file mode 100644 index 120d888a4..000000000 --- a/changelog.d/11985.feature +++ /dev/null @@ -1 +0,0 @@ -Fetch images when previewing Twitter URLs. Contributed by @AndrewRyanChama. diff --git a/changelog.d/11991.misc b/changelog.d/11991.misc deleted file mode 100644 index 34a3b3a6b..000000000 --- a/changelog.d/11991.misc +++ /dev/null @@ -1 +0,0 @@ -Refactor the search code for improved readability. diff --git a/changelog.d/11992.bugfix b/changelog.d/11992.bugfix deleted file mode 100644 index f73c86bb2..000000000 --- a/changelog.d/11992.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug introduced in Synapse v1.48.0 where an edit of the latest event in a thread would not be properly applied to the thread summary. diff --git a/changelog.d/11994.misc b/changelog.d/11994.misc deleted file mode 100644 index d64297dd7..000000000 --- a/changelog.d/11994.misc +++ /dev/null @@ -1 +0,0 @@ -Move common deduplication code down into `_auth_and_persist_outliers`. diff --git a/changelog.d/11996.misc b/changelog.d/11996.misc deleted file mode 100644 index 6c675fd19..000000000 --- a/changelog.d/11996.misc +++ /dev/null @@ -1 +0,0 @@ -Limit concurrent joins from applications services. \ No newline at end of file diff --git a/changelog.d/11997.docker b/changelog.d/11997.docker deleted file mode 100644 index 1b3271457..000000000 --- a/changelog.d/11997.docker +++ /dev/null @@ -1 +0,0 @@ -The docker image no longer automatically creates a temporary volume at `/data`. This is not expected to affect normal usage. diff --git a/changelog.d/11999.bugfix b/changelog.d/11999.bugfix deleted file mode 100644 index fd8409590..000000000 --- a/changelog.d/11999.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix long standing bug where `get_rooms_for_user` was not correctly invalidated for remote users when the server left a room. diff --git a/changelog.d/12000.feature b/changelog.d/12000.feature deleted file mode 100644 index 246cc87f0..000000000 --- a/changelog.d/12000.feature +++ /dev/null @@ -1 +0,0 @@ -Track cache invalidations in Prometheus metrics, as already happens for cache eviction based on size or time. diff --git a/changelog.d/12001.feature b/changelog.d/12001.feature deleted file mode 100644 index dc1153c49..000000000 --- a/changelog.d/12001.feature +++ /dev/null @@ -1 +0,0 @@ -Implement experimental support for [MSC3720](https://github.com/matrix-org/matrix-doc/pull/3720) (account status endpoints). diff --git a/changelog.d/12003.doc b/changelog.d/12003.doc deleted file mode 100644 index 1ac816355..000000000 --- a/changelog.d/12003.doc +++ /dev/null @@ -1 +0,0 @@ -Explain the meaning of spam checker callbacks' return values. diff --git a/changelog.d/12004.doc b/changelog.d/12004.doc deleted file mode 100644 index 0b4baef21..000000000 --- a/changelog.d/12004.doc +++ /dev/null @@ -1 +0,0 @@ -Clarify information about external Identity Provider IDs. diff --git a/changelog.d/12005.misc b/changelog.d/12005.misc deleted file mode 100644 index 45e21dbe5..000000000 --- a/changelog.d/12005.misc +++ /dev/null @@ -1 +0,0 @@ -Preparation for faster-room-join work: when parsing the `send_join` response, get the `m.room.create` event from `state`, not `auth_chain`. diff --git a/changelog.d/12008.removal b/changelog.d/12008.removal deleted file mode 100644 index 57599d9ee..000000000 --- a/changelog.d/12008.removal +++ /dev/null @@ -1 +0,0 @@ -Remove support for the legacy structured logging configuration (please see the the [upgrade notes](https://matrix-org.github.io/synapse/develop/upgrade#legacy-structured-logging-configuration-removal) if you are using `structured: true` in the Synapse configuration). diff --git a/changelog.d/12009.feature b/changelog.d/12009.feature deleted file mode 100644 index c8a531481..000000000 --- a/changelog.d/12009.feature +++ /dev/null @@ -1 +0,0 @@ -Enable modules to set a custom display name when registering a user. diff --git a/changelog.d/12011.misc b/changelog.d/12011.misc deleted file mode 100644 index 258b0e389..000000000 --- a/changelog.d/12011.misc +++ /dev/null @@ -1 +0,0 @@ -Preparation for faster-room-join work: parse msc3706 fields in send_join response. diff --git a/changelog.d/12012.misc b/changelog.d/12012.misc deleted file mode 100644 index a473f41e7..000000000 --- a/changelog.d/12012.misc +++ /dev/null @@ -1 +0,0 @@ -Preparation for faster-room-join work: persist information on which events and rooms have partial state to the database. diff --git a/changelog.d/12013.misc b/changelog.d/12013.misc deleted file mode 100644 index c0fca8dcc..000000000 --- a/changelog.d/12013.misc +++ /dev/null @@ -1 +0,0 @@ -Preparation for faster-room-join work: Support for calling `/federation/v1/state` on a remote server. diff --git a/changelog.d/12015.misc b/changelog.d/12015.misc deleted file mode 100644 index 3aa32ab4c..000000000 --- a/changelog.d/12015.misc +++ /dev/null @@ -1 +0,0 @@ -Configure `tox` to use `venv` rather than `virtualenv`. diff --git a/changelog.d/12016.misc b/changelog.d/12016.misc deleted file mode 100644 index 8856ef46a..000000000 --- a/changelog.d/12016.misc +++ /dev/null @@ -1 +0,0 @@ -Fix bug in `StateFilter.return_expanded()` and add some tests. \ No newline at end of file diff --git a/changelog.d/12018.removal b/changelog.d/12018.removal deleted file mode 100644 index e940b6222..000000000 --- a/changelog.d/12018.removal +++ /dev/null @@ -1 +0,0 @@ -Drop support for [MSC3283](https://github.com/matrix-org/matrix-doc/pull/3283) unstable flags now that the stable flags are supported. diff --git a/changelog.d/12019.misc b/changelog.d/12019.misc deleted file mode 100644 index b2186320e..000000000 --- a/changelog.d/12019.misc +++ /dev/null @@ -1 +0,0 @@ -Use Matrix v1.1 endpoints (`/_matrix/client/v3/auth/...`) in fallback auth HTML forms. \ No newline at end of file diff --git a/changelog.d/12020.feature b/changelog.d/12020.feature deleted file mode 100644 index 1ac9d2060..000000000 --- a/changelog.d/12020.feature +++ /dev/null @@ -1 +0,0 @@ -Advertise Matrix 1.1 support on `/_matrix/client/versions`. \ No newline at end of file diff --git a/changelog.d/12021.feature b/changelog.d/12021.feature deleted file mode 100644 index 01378df8c..000000000 --- a/changelog.d/12021.feature +++ /dev/null @@ -1 +0,0 @@ -Support only the stable identifier for [MSC3069](https://github.com/matrix-org/matrix-doc/pull/3069)'s `is_guest` on `/_matrix/client/v3/account/whoami`. \ No newline at end of file diff --git a/changelog.d/12022.feature b/changelog.d/12022.feature deleted file mode 100644 index 188fb1257..000000000 --- a/changelog.d/12022.feature +++ /dev/null @@ -1 +0,0 @@ -Advertise Matrix 1.2 support on `/_matrix/client/versions`. \ No newline at end of file diff --git a/changelog.d/12024.bugfix b/changelog.d/12024.bugfix deleted file mode 100644 index 59bcdb93a..000000000 --- a/changelog.d/12024.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix 500 error with Postgres when looking backwards with the [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) `/timestamp_to_event?dir=b` endpoint. diff --git a/changelog.d/12025.misc b/changelog.d/12025.misc deleted file mode 100644 index d9475a771..000000000 --- a/changelog.d/12025.misc +++ /dev/null @@ -1 +0,0 @@ -Update the `olddeps` CI job to use an old version of `markupsafe`. diff --git a/changelog.d/12030.misc b/changelog.d/12030.misc deleted file mode 100644 index 607ee97ce..000000000 --- a/changelog.d/12030.misc +++ /dev/null @@ -1 +0,0 @@ -Upgrade mypy to version 0.931. diff --git a/changelog.d/12031.misc b/changelog.d/12031.misc deleted file mode 100644 index d4bedc6b9..000000000 --- a/changelog.d/12031.misc +++ /dev/null @@ -1 +0,0 @@ -Remove legacy `HomeServer.get_datastore()`. diff --git a/changelog.d/12034.misc b/changelog.d/12034.misc deleted file mode 100644 index 8374a6322..000000000 --- a/changelog.d/12034.misc +++ /dev/null @@ -1 +0,0 @@ -Minor typing fixes. diff --git a/changelog.d/12037.bugfix b/changelog.d/12037.bugfix deleted file mode 100644 index 9295cb4dc..000000000 --- a/changelog.d/12037.bugfix +++ /dev/null @@ -1 +0,0 @@ -Properly fix a long-standing bug where wrong data could be inserted in the `event_search` table when using sqlite. This could block running `synapse_port_db` with an "argument of type 'int' is not iterable" error. This bug was partially fixed by a change in Synapse 1.44.0. diff --git a/changelog.d/12039.misc b/changelog.d/12039.misc deleted file mode 100644 index 45e21dbe5..000000000 --- a/changelog.d/12039.misc +++ /dev/null @@ -1 +0,0 @@ -Preparation for faster-room-join work: when parsing the `send_join` response, get the `m.room.create` event from `state`, not `auth_chain`. diff --git a/changelog.d/12041.misc b/changelog.d/12041.misc deleted file mode 100644 index e56dc093d..000000000 --- a/changelog.d/12041.misc +++ /dev/null @@ -1 +0,0 @@ -After joining a room, create a dedicated logcontext to process the queued events. diff --git a/changelog.d/12051.misc b/changelog.d/12051.misc deleted file mode 100644 index 995919135..000000000 --- a/changelog.d/12051.misc +++ /dev/null @@ -1 +0,0 @@ -Tidy up GitHub Actions config which builds distributions for PyPI. \ No newline at end of file diff --git a/changelog.d/12052.misc b/changelog.d/12052.misc deleted file mode 100644 index 11755ae61..000000000 --- a/changelog.d/12052.misc +++ /dev/null @@ -1 +0,0 @@ -Move configuration out of `setup.cfg`. diff --git a/changelog.d/12056.bugfix b/changelog.d/12056.bugfix deleted file mode 100644 index 210e30c63..000000000 --- a/changelog.d/12056.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix slow performance of `/logout` in some cases where refresh tokens are in use. The slowness existed since the initial implementation of refresh tokens. \ No newline at end of file diff --git a/changelog.d/12058.feature b/changelog.d/12058.feature deleted file mode 100644 index 7b7169222..000000000 --- a/changelog.d/12058.feature +++ /dev/null @@ -1 +0,0 @@ -Use room version 9 as the default room version (per [MSC3589](https://github.com/matrix-org/matrix-doc/pull/3589)). diff --git a/changelog.d/12059.misc b/changelog.d/12059.misc deleted file mode 100644 index 9ba4759d9..000000000 --- a/changelog.d/12059.misc +++ /dev/null @@ -1 +0,0 @@ -Move configuration out of `setup.cfg`. \ No newline at end of file diff --git a/changelog.d/12060.misc b/changelog.d/12060.misc deleted file mode 100644 index d771e6a1b..000000000 --- a/changelog.d/12060.misc +++ /dev/null @@ -1 +0,0 @@ -Fix error message when a worker process fails to talk to another worker process. diff --git a/changelog.d/12062.feature b/changelog.d/12062.feature deleted file mode 100644 index 46a606709..000000000 --- a/changelog.d/12062.feature +++ /dev/null @@ -1 +0,0 @@ -Add module callbacks to react to user deactivation status changes (i.e. deactivations and reactivations) and profile updates. diff --git a/changelog.d/12063.misc b/changelog.d/12063.misc deleted file mode 100644 index e48c5dd08..000000000 --- a/changelog.d/12063.misc +++ /dev/null @@ -1 +0,0 @@ -Fix using the complement.sh script without specifying a dir or a branch. Contributed by Nico on behalf of Famedly. diff --git a/changelog.d/12066.misc b/changelog.d/12066.misc deleted file mode 100644 index 0360dbd61..000000000 --- a/changelog.d/12066.misc +++ /dev/null @@ -1 +0,0 @@ -Add type hints to `tests/rest/client`. diff --git a/changelog.d/12067.feature b/changelog.d/12067.feature deleted file mode 100644 index dc1153c49..000000000 --- a/changelog.d/12067.feature +++ /dev/null @@ -1 +0,0 @@ -Implement experimental support for [MSC3720](https://github.com/matrix-org/matrix-doc/pull/3720) (account status endpoints). diff --git a/changelog.d/12068.misc b/changelog.d/12068.misc deleted file mode 100644 index 72b211e4f..000000000 --- a/changelog.d/12068.misc +++ /dev/null @@ -1 +0,0 @@ -Add some logging to `/sync` to try and track down #11916. diff --git a/changelog.d/12069.misc b/changelog.d/12069.misc deleted file mode 100644 index 8374a6322..000000000 --- a/changelog.d/12069.misc +++ /dev/null @@ -1 +0,0 @@ -Minor typing fixes. diff --git a/changelog.d/12070.misc b/changelog.d/12070.misc deleted file mode 100644 index d4bedc6b9..000000000 --- a/changelog.d/12070.misc +++ /dev/null @@ -1 +0,0 @@ -Remove legacy `HomeServer.get_datastore()`. diff --git a/changelog.d/12072.misc b/changelog.d/12072.misc deleted file mode 100644 index 0360dbd61..000000000 --- a/changelog.d/12072.misc +++ /dev/null @@ -1 +0,0 @@ -Add type hints to `tests/rest/client`. diff --git a/changelog.d/12073.removal b/changelog.d/12073.removal deleted file mode 100644 index 1f3979271..000000000 --- a/changelog.d/12073.removal +++ /dev/null @@ -1 +0,0 @@ -Remove the unstable `/spaces` endpoint from [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/changelog.d/12077.bugfix b/changelog.d/12077.bugfix deleted file mode 100644 index 1bce82082..000000000 --- a/changelog.d/12077.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long-standing bug where Synapse would make additional failing requests over federation for missing data. diff --git a/changelog.d/12084.misc b/changelog.d/12084.misc deleted file mode 100644 index 0360dbd61..000000000 --- a/changelog.d/12084.misc +++ /dev/null @@ -1 +0,0 @@ -Add type hints to `tests/rest/client`. diff --git a/changelog.d/12088.misc b/changelog.d/12088.misc deleted file mode 100644 index ce4213650..000000000 --- a/changelog.d/12088.misc +++ /dev/null @@ -1 +0,0 @@ -Inspect application dependencies using `importlib.metadata` or its backport. \ No newline at end of file diff --git a/changelog.d/12089.bugfix b/changelog.d/12089.bugfix deleted file mode 100644 index 27172c482..000000000 --- a/changelog.d/12089.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix occasional 'Unhandled error in Deferred' error message. diff --git a/changelog.d/12092.misc b/changelog.d/12092.misc deleted file mode 100644 index 62653d6f8..000000000 --- a/changelog.d/12092.misc +++ /dev/null @@ -1 +0,0 @@ -User `assertEqual` instead of the deprecated `assertEquals` in test code. diff --git a/changelog.d/12094.misc b/changelog.d/12094.misc deleted file mode 100644 index 0360dbd61..000000000 --- a/changelog.d/12094.misc +++ /dev/null @@ -1 +0,0 @@ -Add type hints to `tests/rest/client`. diff --git a/changelog.d/12098.bugfix b/changelog.d/12098.bugfix deleted file mode 100644 index 6b696692e..000000000 --- a/changelog.d/12098.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug introduced in Synapse 1.51.0rc1 where incoming federation transactions containing at least one EDU would be dropped if debug logging was enabled for `synapse.8631_debug`. \ No newline at end of file diff --git a/changelog.d/12099.misc b/changelog.d/12099.misc deleted file mode 100644 index 0553825db..000000000 --- a/changelog.d/12099.misc +++ /dev/null @@ -1 +0,0 @@ -Move experimental support for [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440) to /versions. diff --git a/changelog.d/12100.bugfix b/changelog.d/12100.bugfix deleted file mode 100644 index 181095ad9..000000000 --- a/changelog.d/12100.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long-standing bug which could cause push notifications to malfunction if `use_frozen_dicts` was set in the configuration. diff --git a/changelog.d/12105.bugfix b/changelog.d/12105.bugfix deleted file mode 100644 index f42e63e01..000000000 --- a/changelog.d/12105.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix an extremely rare, long-standing bug in `ReadWriteLock` that would cause an error when a newly unblocked writer completes instantly. diff --git a/changelog.d/12106.misc b/changelog.d/12106.misc deleted file mode 100644 index d918e9e3b..000000000 --- a/changelog.d/12106.misc +++ /dev/null @@ -1 +0,0 @@ -Add `stop_cancellation` utility function to stop `Deferred`s from being cancelled. diff --git a/changelog.d/12109.misc b/changelog.d/12109.misc deleted file mode 100644 index 3295e49f4..000000000 --- a/changelog.d/12109.misc +++ /dev/null @@ -1 +0,0 @@ -Improve exception handling for concurrent execution. diff --git a/changelog.d/12111.misc b/changelog.d/12111.misc deleted file mode 100644 index be84789c9..000000000 --- a/changelog.d/12111.misc +++ /dev/null @@ -1 +0,0 @@ -Advertise support for Python 3.10 in packaging files. \ No newline at end of file diff --git a/changelog.d/12112.docker b/changelog.d/12112.docker deleted file mode 100644 index b9e630653..000000000 --- a/changelog.d/12112.docker +++ /dev/null @@ -1 +0,0 @@ -Use Python 3.9 in Docker images by default. \ No newline at end of file diff --git a/changelog.d/12119.misc b/changelog.d/12119.misc deleted file mode 100644 index f02d140f3..000000000 --- a/changelog.d/12119.misc +++ /dev/null @@ -1 +0,0 @@ -Move CI checks out of tox, to facilitate a move to using poetry. \ No newline at end of file diff --git a/debian/changelog b/debian/changelog index 574930c08..df3db85b8 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.54.0~rc1) stable; urgency=medium + + * New synapse release 1.54.0~rc1. + + -- Synapse Packaging team Wed, 02 Mar 2022 10:43:22 +0000 + matrix-synapse-py3 (1.53.0) stable; urgency=medium * New synapse release 1.53.0. diff --git a/synapse/__init__.py b/synapse/__init__.py index 903f2e815..b21e1ed0f 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -47,7 +47,7 @@ try: except ImportError: pass -__version__ = "1.53.0" +__version__ = "1.54.0rc1" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when From d800108bb4e1272235aa6f5f80b2732cee9aa5bf Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 2 Mar 2022 10:54:52 +0000 Subject: [PATCH 83/84] Reword changelog --- CHANGES.md | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 4f0318970..5485e8d47 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,10 +1,14 @@ Synapse 1.54.0rc1 (2022-03-02) ============================== +Please note that this will be the last release of Synapse that is compatible with Mjolnir 1.3.1 and earlier. +Administrators of servers which have the Mjolnir module installed are advised to upgrade Mjolnir to version 1.3.2 or later. + + Features -------- -- Add support for MSC3202: sending one-time key counts and fallback key usage states to Application Services. ([\#11617](https://github.com/matrix-org/synapse/issues/11617)) +- Add support for [MSC3202](https://github.com/matrix-org/matrix-doc/pull/3202): sending one-time key counts and fallback key usage states to Application Services. ([\#11617](https://github.com/matrix-org/synapse/issues/11617)) - Make a `POST` to `/rooms//receipt/m.read/` only trigger a push notification if the count of unread messages is different to the one in the last successfully sent push. ([\#11835](https://github.com/matrix-org/synapse/issues/11835)) - Fetch images when previewing Twitter URLs. Contributed by @AndrewRyanChama. ([\#11985](https://github.com/matrix-org/synapse/issues/11985)) - Track cache invalidations in Prometheus metrics, as already happens for cache eviction based on size or time. ([\#12000](https://github.com/matrix-org/synapse/issues/12000)) @@ -22,8 +26,8 @@ Bugfixes - Fix a bug introduced in Synapse v1.48.0 where an edit of the latest event in a thread would not be properly applied to the thread summary. ([\#11992](https://github.com/matrix-org/synapse/issues/11992)) - Fix long standing bug where `get_rooms_for_user` was not correctly invalidated for remote users when the server left a room. ([\#11999](https://github.com/matrix-org/synapse/issues/11999)) -- Fix 500 error with Postgres when looking backwards with the [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) `/timestamp_to_event?dir=b` endpoint. ([\#12024](https://github.com/matrix-org/synapse/issues/12024)) -- Properly fix a long-standing bug where wrong data could be inserted in the `event_search` table when using sqlite. This could block running `synapse_port_db` with an "argument of type 'int' is not iterable" error. This bug was partially fixed by a change in Synapse 1.44.0. ([\#12037](https://github.com/matrix-org/synapse/issues/12037)) +- Fix a 500 error with Postgres when looking backwards with the [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) `/timestamp_to_event?dir=b` endpoint. ([\#12024](https://github.com/matrix-org/synapse/issues/12024)) +- Properly fix a long-standing bug where wrong data could be inserted into the `event_search` table when using SQLite. This could block running `synapse_port_db` with an "argument of type 'int' is not iterable" error. This bug was partially fixed by a change in Synapse 1.44.0. ([\#12037](https://github.com/matrix-org/synapse/issues/12037)) - Fix slow performance of `/logout` in some cases where refresh tokens are in use. The slowness existed since the initial implementation of refresh tokens. ([\#12056](https://github.com/matrix-org/synapse/issues/12056)) - Fix a long-standing bug where Synapse would make additional failing requests over federation for missing data. ([\#12077](https://github.com/matrix-org/synapse/issues/12077)) - Fix occasional 'Unhandled error in Deferred' error message. ([\#12089](https://github.com/matrix-org/synapse/issues/12089)) @@ -35,7 +39,7 @@ Bugfixes Updates to the Docker image --------------------------- -- The docker image no longer automatically creates a temporary volume at `/data`. This is not expected to affect normal usage. ([\#11997](https://github.com/matrix-org/synapse/issues/11997)) +- The Docker image no longer automatically creates a temporary volume at `/data`. This is not expected to affect normal usage. ([\#11997](https://github.com/matrix-org/synapse/issues/11997)) - Use Python 3.9 in Docker images by default. ([\#12112](https://github.com/matrix-org/synapse/issues/12112)) @@ -59,35 +63,35 @@ Deprecations and Removals Internal Changes ---------------- -- Make method `get_room_version` use cached `get_room_version_id`. ([\#11808](https://github.com/matrix-org/synapse/issues/11808)) -- Remove unnecessary condition on knock->leave auth rule check. ([\#11900](https://github.com/matrix-org/synapse/issues/11900)) +- Make the `get_room_version` method use `get_room_version_id` to benefit from caching. ([\#11808](https://github.com/matrix-org/synapse/issues/11808)) +- Remove unnecessary condition on knock -> leave auth rule check. ([\#11900](https://github.com/matrix-org/synapse/issues/11900)) - Add tests for device list changes between local users. ([\#11972](https://github.com/matrix-org/synapse/issues/11972)) -- Optimise calculating device_list changes in `/sync`. ([\#11974](https://github.com/matrix-org/synapse/issues/11974)) +- Optimise calculating `device_list` changes in `/sync`. ([\#11974](https://github.com/matrix-org/synapse/issues/11974)) - Add missing type hints to storage classes. ([\#11984](https://github.com/matrix-org/synapse/issues/11984)) - Refactor the search code for improved readability. ([\#11991](https://github.com/matrix-org/synapse/issues/11991)) - Move common deduplication code down into `_auth_and_persist_outliers`. ([\#11994](https://github.com/matrix-org/synapse/issues/11994)) - Limit concurrent joins from applications services. ([\#11996](https://github.com/matrix-org/synapse/issues/11996)) - Preparation for faster-room-join work: when parsing the `send_join` response, get the `m.room.create` event from `state`, not `auth_chain`. ([\#12005](https://github.com/matrix-org/synapse/issues/12005), [\#12039](https://github.com/matrix-org/synapse/issues/12039)) -- Preparation for faster-room-join work: parse msc3706 fields in send_join response. ([\#12011](https://github.com/matrix-org/synapse/issues/12011)) +- Preparation for faster-room-join work: parse MSC3706 fields in send_join response. ([\#12011](https://github.com/matrix-org/synapse/issues/12011)) - Preparation for faster-room-join work: persist information on which events and rooms have partial state to the database. ([\#12012](https://github.com/matrix-org/synapse/issues/12012)) - Preparation for faster-room-join work: Support for calling `/federation/v1/state` on a remote server. ([\#12013](https://github.com/matrix-org/synapse/issues/12013)) - Configure `tox` to use `venv` rather than `virtualenv`. ([\#12015](https://github.com/matrix-org/synapse/issues/12015)) - Fix bug in `StateFilter.return_expanded()` and add some tests. ([\#12016](https://github.com/matrix-org/synapse/issues/12016)) - Use Matrix v1.1 endpoints (`/_matrix/client/v3/auth/...`) in fallback auth HTML forms. ([\#12019](https://github.com/matrix-org/synapse/issues/12019)) - Update the `olddeps` CI job to use an old version of `markupsafe`. ([\#12025](https://github.com/matrix-org/synapse/issues/12025)) -- Upgrade mypy to version 0.931. ([\#12030](https://github.com/matrix-org/synapse/issues/12030)) +- Upgrade Mypy to version 0.931. ([\#12030](https://github.com/matrix-org/synapse/issues/12030)) - Remove legacy `HomeServer.get_datastore()`. ([\#12031](https://github.com/matrix-org/synapse/issues/12031), [\#12070](https://github.com/matrix-org/synapse/issues/12070)) - Minor typing fixes. ([\#12034](https://github.com/matrix-org/synapse/issues/12034), [\#12069](https://github.com/matrix-org/synapse/issues/12069)) - After joining a room, create a dedicated logcontext to process the queued events. ([\#12041](https://github.com/matrix-org/synapse/issues/12041)) - Tidy up GitHub Actions config which builds distributions for PyPI. ([\#12051](https://github.com/matrix-org/synapse/issues/12051)) - Move configuration out of `setup.cfg`. ([\#12052](https://github.com/matrix-org/synapse/issues/12052), [\#12059](https://github.com/matrix-org/synapse/issues/12059)) - Fix error message when a worker process fails to talk to another worker process. ([\#12060](https://github.com/matrix-org/synapse/issues/12060)) -- Fix using the complement.sh script without specifying a dir or a branch. Contributed by Nico on behalf of Famedly. ([\#12063](https://github.com/matrix-org/synapse/issues/12063)) +- Fix using the `complement.sh` script without specifying a directory or a branch. Contributed by Nico on behalf of Famedly. ([\#12063](https://github.com/matrix-org/synapse/issues/12063)) - Add type hints to `tests/rest/client`. ([\#12066](https://github.com/matrix-org/synapse/issues/12066), [\#12072](https://github.com/matrix-org/synapse/issues/12072), [\#12084](https://github.com/matrix-org/synapse/issues/12084), [\#12094](https://github.com/matrix-org/synapse/issues/12094)) - Add some logging to `/sync` to try and track down #11916. ([\#12068](https://github.com/matrix-org/synapse/issues/12068)) - Inspect application dependencies using `importlib.metadata` or its backport. ([\#12088](https://github.com/matrix-org/synapse/issues/12088)) -- User `assertEqual` instead of the deprecated `assertEquals` in test code. ([\#12092](https://github.com/matrix-org/synapse/issues/12092)) -- Move experimental support for [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440) to /versions. ([\#12099](https://github.com/matrix-org/synapse/issues/12099)) +- Use `assertEqual` instead of the deprecated `assertEquals` in test code. ([\#12092](https://github.com/matrix-org/synapse/issues/12092)) +- Move experimental support for [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440) to `/versions`. ([\#12099](https://github.com/matrix-org/synapse/issues/12099)) - Add `stop_cancellation` utility function to stop `Deferred`s from being cancelled. ([\#12106](https://github.com/matrix-org/synapse/issues/12106)) - Improve exception handling for concurrent execution. ([\#12109](https://github.com/matrix-org/synapse/issues/12109)) - Advertise support for Python 3.10 in packaging files. ([\#12111](https://github.com/matrix-org/synapse/issues/12111)) From 010457011c83d4db96d84c43687ee5e291f7685f Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 2 Mar 2022 11:17:36 +0000 Subject: [PATCH 84/84] Apply suggestions to changelog --- CHANGES.md | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 5485e8d47..a81d0a4b1 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -9,14 +9,12 @@ Features -------- - Add support for [MSC3202](https://github.com/matrix-org/matrix-doc/pull/3202): sending one-time key counts and fallback key usage states to Application Services. ([\#11617](https://github.com/matrix-org/synapse/issues/11617)) -- Make a `POST` to `/rooms//receipt/m.read/` only trigger a push notification if the count of unread messages is different to the one in the last successfully sent push. ([\#11835](https://github.com/matrix-org/synapse/issues/11835)) -- Fetch images when previewing Twitter URLs. Contributed by @AndrewRyanChama. ([\#11985](https://github.com/matrix-org/synapse/issues/11985)) +- Improve the preview that is produced when generating URL previews for some web pages. Contributed by @AndrewRyanChama. ([\#11985](https://github.com/matrix-org/synapse/issues/11985)) - Track cache invalidations in Prometheus metrics, as already happens for cache eviction based on size or time. ([\#12000](https://github.com/matrix-org/synapse/issues/12000)) - Implement experimental support for [MSC3720](https://github.com/matrix-org/matrix-doc/pull/3720) (account status endpoints). ([\#12001](https://github.com/matrix-org/synapse/issues/12001), [\#12067](https://github.com/matrix-org/synapse/issues/12067)) - Enable modules to set a custom display name when registering a user. ([\#12009](https://github.com/matrix-org/synapse/issues/12009)) -- Advertise Matrix 1.1 support on `/_matrix/client/versions`. ([\#12020](https://github.com/matrix-org/synapse/issues/12020)) +- Advertise Matrix 1.1 and 1.2 support on `/_matrix/client/versions`. ([\#12020](https://github.com/matrix-org/synapse/issues/12020), ([\#12022](https://github.com/matrix-org/synapse/issues/12022)) - Support only the stable identifier for [MSC3069](https://github.com/matrix-org/matrix-doc/pull/3069)'s `is_guest` on `/_matrix/client/v3/account/whoami`. ([\#12021](https://github.com/matrix-org/synapse/issues/12021)) -- Advertise Matrix 1.2 support on `/_matrix/client/versions`. ([\#12022](https://github.com/matrix-org/synapse/issues/12022)) - Use room version 9 as the default room version (per [MSC3589](https://github.com/matrix-org/matrix-doc/pull/3589)). ([\#12058](https://github.com/matrix-org/synapse/issues/12058)) - Add module callbacks to react to user deactivation status changes (i.e. deactivations and reactivations) and profile updates. ([\#12062](https://github.com/matrix-org/synapse/issues/12062)) @@ -24,16 +22,17 @@ Features Bugfixes -------- -- Fix a bug introduced in Synapse v1.48.0 where an edit of the latest event in a thread would not be properly applied to the thread summary. ([\#11992](https://github.com/matrix-org/synapse/issues/11992)) -- Fix long standing bug where `get_rooms_for_user` was not correctly invalidated for remote users when the server left a room. ([\#11999](https://github.com/matrix-org/synapse/issues/11999)) +- Fix a bug introduced in Synapse 1.48.0 where an edit of the latest event in a thread would not be properly applied to the thread summary. ([\#11992](https://github.com/matrix-org/synapse/issues/11992)) +- Fix long-standing bug where `get_rooms_for_user` was not correctly invalidated for remote users when the server left a room. ([\#11999](https://github.com/matrix-org/synapse/issues/11999)) - Fix a 500 error with Postgres when looking backwards with the [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) `/timestamp_to_event?dir=b` endpoint. ([\#12024](https://github.com/matrix-org/synapse/issues/12024)) -- Properly fix a long-standing bug where wrong data could be inserted into the `event_search` table when using SQLite. This could block running `synapse_port_db` with an "argument of type 'int' is not iterable" error. This bug was partially fixed by a change in Synapse 1.44.0. ([\#12037](https://github.com/matrix-org/synapse/issues/12037)) -- Fix slow performance of `/logout` in some cases where refresh tokens are in use. The slowness existed since the initial implementation of refresh tokens. ([\#12056](https://github.com/matrix-org/synapse/issues/12056)) +- Properly fix a long-standing bug where wrong data could be inserted into the `event_search` table when using SQLite. This could block running `synapse_port_db` with an `argument of type 'int' is not iterable` error. This bug was partially fixed by a change in Synapse 1.44.0. ([\#12037](https://github.com/matrix-org/synapse/issues/12037)) +- Fix slow performance of `/logout` in some cases where refresh tokens are in use. The slowness existed since the initial implementation of refresh tokens in version 1.38.0. ([\#12056](https://github.com/matrix-org/synapse/issues/12056)) - Fix a long-standing bug where Synapse would make additional failing requests over federation for missing data. ([\#12077](https://github.com/matrix-org/synapse/issues/12077)) -- Fix occasional 'Unhandled error in Deferred' error message. ([\#12089](https://github.com/matrix-org/synapse/issues/12089)) -- Fix a bug introduced in Synapse 1.51.0rc1 where incoming federation transactions containing at least one EDU would be dropped if debug logging was enabled for `synapse.8631_debug`. ([\#12098](https://github.com/matrix-org/synapse/issues/12098)) +- Fix occasional `Unhandled error in Deferred` error message. ([\#12089](https://github.com/matrix-org/synapse/issues/12089)) +- Fix a bug introduced in Synapse 1.51.0 where incoming federation transactions containing at least one EDU would be dropped if debug logging was enabled for `synapse.8631_debug`. ([\#12098](https://github.com/matrix-org/synapse/issues/12098)) - Fix a long-standing bug which could cause push notifications to malfunction if `use_frozen_dicts` was set in the configuration. ([\#12100](https://github.com/matrix-org/synapse/issues/12100)) - Fix an extremely rare, long-standing bug in `ReadWriteLock` that would cause an error when a newly unblocked writer completes instantly. ([\#12105](https://github.com/matrix-org/synapse/issues/12105)) +- Make a `POST` to `/rooms//receipt/m.read/` only trigger a push notification if the count of unread messages is different to the one in the last successfully sent push. This reduces server load and load on the receiving device. ([\#11835](https://github.com/matrix-org/synapse/issues/11835)) Updates to the Docker image