Fix-up some type hints in the relations tests. (#11076)

This commit is contained in:
Patrick Cloke 2021-10-14 09:19:35 -04:00 committed by GitHub
parent 50d8601581
commit 1609ccf8fe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 64 additions and 51 deletions

View file

@ -13,15 +13,15 @@
# limitations under the License.
import itertools
import json
import urllib
from typing import Optional
import urllib.parse
from typing import Dict, List, Optional, Tuple
from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin
from synapse.rest.client import login, register, relations, room
from tests import unittest
from tests.server import FakeChannel
class RelationsTestCase(unittest.HomeserverTestCase):
@ -34,16 +34,16 @@ class RelationsTestCase(unittest.HomeserverTestCase):
]
hijack_auth = False
def make_homeserver(self, reactor, clock):
def default_config(self) -> dict:
# We need to enable msc1849 support for aggregations
config = self.default_config()
config = super().default_config()
config["experimental_msc1849_support_enabled"] = True
# We enable frozen dicts as relations/edits change event contents, so we
# want to test that we don't modify the events in the caches.
config["use_frozen_dicts"] = True
return self.setup_test_homeserver(config=config)
return config
def prepare(self, reactor, clock, hs):
self.user_id, self.user_token = self._create_user("alice")
@ -146,8 +146,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
expected_event_ids.append(channel.json_body["event_id"])
prev_token = None
found_event_ids = []
prev_token: Optional[str] = None
found_event_ids: List[str] = []
for _ in range(20):
from_token = ""
if prev_token:
@ -203,8 +203,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
idx += 1
idx %= len(access_tokens)
prev_token = None
found_groups = {}
prev_token: Optional[str] = None
found_groups: Dict[str, int] = {}
for _ in range(20):
from_token = ""
if prev_token:
@ -270,8 +270,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
self.assertEquals(200, channel.code, channel.json_body)
prev_token = None
found_event_ids = []
prev_token: Optional[str] = None
found_event_ids: List[str] = []
encoded_key = urllib.parse.quote_plus("👍".encode())
for _ in range(20):
from_token = ""
@ -677,24 +677,23 @@ class RelationsTestCase(unittest.HomeserverTestCase):
def _send_relation(
self,
relation_type,
event_type,
key=None,
relation_type: str,
event_type: str,
key: Optional[str] = None,
content: Optional[dict] = None,
access_token=None,
parent_id=None,
):
access_token: Optional[str] = None,
parent_id: Optional[str] = None,
) -> FakeChannel:
"""Helper function to send a relation pointing at `self.parent_id`
Args:
relation_type (str): One of `RelationTypes`
event_type (str): The type of the event to create
parent_id (str): The event_id this relation relates to. If None, then self.parent_id
key (str|None): The aggregation key used for m.annotation relation
type.
content(dict|None): The content of the created event.
access_token (str|None): The access token used to send the relation,
defaults to `self.user_token`
relation_type: One of `RelationTypes`
event_type: The type of the event to create
key: The aggregation key used for m.annotation relation type.
content: The content of the created event.
access_token: The access token used to send the relation, defaults
to `self.user_token`
parent_id: The event_id this relation relates to. If None, then self.parent_id
Returns:
FakeChannel
@ -712,12 +711,12 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"POST",
"/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
% (self.room, original_id, relation_type, event_type, query),
json.dumps(content or {}).encode("utf-8"),
content or {},
access_token=access_token,
)
return channel
def _create_user(self, localpart):
def _create_user(self, localpart: str) -> Tuple[str, str]:
user_id = self.register_user(localpart, "abc123")
access_token = self.login(localpart, "abc123")