mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Add type hints to tests/rest/client
(#12108)
* Add type hints to `tests/rest/client` * newsfile * fix imports * add `test_account.py` * Remove one type hint in `test_report_event.py` * change `on_create_room` to `async` * update new functions in `test_third_party_rules.py` * Add `test_filter.py` * add `test_rooms.py` * change to `assertEquals` to `assertEqual` * lint
This commit is contained in:
parent
b4461e7d8a
commit
2ffaf30803
1
changelog.d/12108.misc
Normal file
1
changelog.d/12108.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add type hints to `tests/rest/client`.
|
6
mypy.ini
6
mypy.ini
@ -78,13 +78,7 @@ exclude = (?x)
|
|||||||
|tests/push/test_http.py
|
|tests/push/test_http.py
|
||||||
|tests/push/test_presentable_names.py
|
|tests/push/test_presentable_names.py
|
||||||
|tests/push/test_push_rule_evaluator.py
|
|tests/push/test_push_rule_evaluator.py
|
||||||
|tests/rest/client/test_account.py
|
|
||||||
|tests/rest/client/test_filter.py
|
|
||||||
|tests/rest/client/test_report_event.py
|
|
||||||
|tests/rest/client/test_rooms.py
|
|
||||||
|tests/rest/client/test_third_party_rules.py
|
|
||||||
|tests/rest/client/test_transactions.py
|
|tests/rest/client/test_transactions.py
|
||||||
|tests/rest/client/test_typing.py
|
|
||||||
|tests/rest/key/v2/test_remote_key_resource.py
|
|tests/rest/key/v2/test_remote_key_resource.py
|
||||||
|tests/rest/media/v1/test_base.py
|
|tests/rest/media/v1/test_base.py
|
||||||
|tests/rest/media/v1/test_media_storage.py
|
|tests/rest/media/v1/test_media_storage.py
|
||||||
|
@ -15,11 +15,12 @@ import json
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from email.parser import Parser
|
from email.parser import Parser
|
||||||
from typing import Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Union
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
|
|
||||||
|
from twisted.internet.interfaces import IReactorTCP
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
@ -30,6 +31,7 @@ from synapse.rest import admin
|
|||||||
from synapse.rest.client import account, login, register, room
|
from synapse.rest.client import account, login, register, room
|
||||||
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
|
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
from synapse.types import JsonDict
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
@ -46,7 +48,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
|||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
config = self.default_config()
|
config = self.default_config()
|
||||||
|
|
||||||
# Email config.
|
# Email config.
|
||||||
@ -67,20 +69,27 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
|||||||
hs = self.setup_test_homeserver(config=config)
|
hs = self.setup_test_homeserver(config=config)
|
||||||
|
|
||||||
async def sendmail(
|
async def sendmail(
|
||||||
reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs
|
reactor: IReactorTCP,
|
||||||
):
|
smtphost: str,
|
||||||
self.email_attempts.append(msg)
|
smtpport: int,
|
||||||
|
from_addr: str,
|
||||||
|
to_addr: str,
|
||||||
|
msg_bytes: bytes,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
self.email_attempts.append(msg_bytes)
|
||||||
|
|
||||||
self.email_attempts = []
|
self.email_attempts: List[bytes] = []
|
||||||
hs.get_send_email_handler()._sendmail = sendmail
|
hs.get_send_email_handler()._sendmail = sendmail
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.submit_token_resource = PasswordResetSubmitTokenResource(hs)
|
self.submit_token_resource = PasswordResetSubmitTokenResource(hs)
|
||||||
|
|
||||||
def test_basic_password_reset(self):
|
def test_basic_password_reset(self) -> None:
|
||||||
"""Test basic password reset flow"""
|
"""Test basic password reset flow"""
|
||||||
old_password = "monkey"
|
old_password = "monkey"
|
||||||
new_password = "kangeroo"
|
new_password = "kangeroo"
|
||||||
@ -118,7 +127,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
|||||||
self.attempt_wrong_password_login("kermit", old_password)
|
self.attempt_wrong_password_login("kermit", old_password)
|
||||||
|
|
||||||
@override_config({"rc_3pid_validation": {"burst_count": 3}})
|
@override_config({"rc_3pid_validation": {"burst_count": 3}})
|
||||||
def test_ratelimit_by_email(self):
|
def test_ratelimit_by_email(self) -> None:
|
||||||
"""Test that we ratelimit /requestToken for the same email."""
|
"""Test that we ratelimit /requestToken for the same email."""
|
||||||
old_password = "monkey"
|
old_password = "monkey"
|
||||||
new_password = "kangeroo"
|
new_password = "kangeroo"
|
||||||
@ -139,7 +148,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def reset(ip):
|
def reset(ip: str) -> None:
|
||||||
client_secret = "foobar"
|
client_secret = "foobar"
|
||||||
session_id = self._request_token(email, client_secret, ip)
|
session_id = self._request_token(email, client_secret, ip)
|
||||||
|
|
||||||
@ -166,7 +175,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.assertEqual(cm.exception.code, 429)
|
self.assertEqual(cm.exception.code, 429)
|
||||||
|
|
||||||
def test_basic_password_reset_canonicalise_email(self):
|
def test_basic_password_reset_canonicalise_email(self) -> None:
|
||||||
"""Test basic password reset flow
|
"""Test basic password reset flow
|
||||||
Request password reset with different spelling
|
Request password reset with different spelling
|
||||||
"""
|
"""
|
||||||
@ -206,7 +215,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
|||||||
# Assert we can't log in with the old password
|
# Assert we can't log in with the old password
|
||||||
self.attempt_wrong_password_login("kermit", old_password)
|
self.attempt_wrong_password_login("kermit", old_password)
|
||||||
|
|
||||||
def test_cant_reset_password_without_clicking_link(self):
|
def test_cant_reset_password_without_clicking_link(self) -> None:
|
||||||
"""Test that we do actually need to click the link in the email"""
|
"""Test that we do actually need to click the link in the email"""
|
||||||
old_password = "monkey"
|
old_password = "monkey"
|
||||||
new_password = "kangeroo"
|
new_password = "kangeroo"
|
||||||
@ -241,7 +250,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
|||||||
# Assert we can't log in with the new password
|
# Assert we can't log in with the new password
|
||||||
self.attempt_wrong_password_login("kermit", new_password)
|
self.attempt_wrong_password_login("kermit", new_password)
|
||||||
|
|
||||||
def test_no_valid_token(self):
|
def test_no_valid_token(self) -> None:
|
||||||
"""Test that we do actually need to request a token and can't just
|
"""Test that we do actually need to request a token and can't just
|
||||||
make a session up.
|
make a session up.
|
||||||
"""
|
"""
|
||||||
@ -277,7 +286,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
|||||||
self.attempt_wrong_password_login("kermit", new_password)
|
self.attempt_wrong_password_login("kermit", new_password)
|
||||||
|
|
||||||
@unittest.override_config({"request_token_inhibit_3pid_errors": True})
|
@unittest.override_config({"request_token_inhibit_3pid_errors": True})
|
||||||
def test_password_reset_bad_email_inhibit_error(self):
|
def test_password_reset_bad_email_inhibit_error(self) -> None:
|
||||||
"""Test that triggering a password reset with an email address that isn't bound
|
"""Test that triggering a password reset with an email address that isn't bound
|
||||||
to an account doesn't leak the lack of binding for that address if configured
|
to an account doesn't leak the lack of binding for that address if configured
|
||||||
that way.
|
that way.
|
||||||
@ -292,7 +301,12 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.assertIsNotNone(session_id)
|
self.assertIsNotNone(session_id)
|
||||||
|
|
||||||
def _request_token(self, email, client_secret, ip="127.0.0.1"):
|
def _request_token(
|
||||||
|
self,
|
||||||
|
email: str,
|
||||||
|
client_secret: str,
|
||||||
|
ip: str = "127.0.0.1",
|
||||||
|
) -> str:
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"POST",
|
"POST",
|
||||||
b"account/password/email/requestToken",
|
b"account/password/email/requestToken",
|
||||||
@ -309,7 +323,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
return channel.json_body["sid"]
|
return channel.json_body["sid"]
|
||||||
|
|
||||||
def _validate_token(self, link):
|
def _validate_token(self, link: str) -> None:
|
||||||
# Remove the host
|
# Remove the host
|
||||||
path = link.replace("https://example.com", "")
|
path = link.replace("https://example.com", "")
|
||||||
|
|
||||||
@ -339,7 +353,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(200, channel.code, channel.result)
|
self.assertEqual(200, channel.code, channel.result)
|
||||||
|
|
||||||
def _get_link_from_email(self):
|
def _get_link_from_email(self) -> str:
|
||||||
assert self.email_attempts, "No emails have been sent"
|
assert self.email_attempts, "No emails have been sent"
|
||||||
|
|
||||||
raw_msg = self.email_attempts[-1].decode("UTF-8")
|
raw_msg = self.email_attempts[-1].decode("UTF-8")
|
||||||
@ -354,14 +368,19 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
|||||||
if not text:
|
if not text:
|
||||||
self.fail("Could not find text portion of email to parse")
|
self.fail("Could not find text portion of email to parse")
|
||||||
|
|
||||||
|
assert text is not None
|
||||||
match = re.search(r"https://example.com\S+", text)
|
match = re.search(r"https://example.com\S+", text)
|
||||||
assert match, "Could not find link in email"
|
assert match, "Could not find link in email"
|
||||||
|
|
||||||
return match.group(0)
|
return match.group(0)
|
||||||
|
|
||||||
def _reset_password(
|
def _reset_password(
|
||||||
self, new_password, session_id, client_secret, expected_code=200
|
self,
|
||||||
):
|
new_password: str,
|
||||||
|
session_id: str,
|
||||||
|
client_secret: str,
|
||||||
|
expected_code: int = 200,
|
||||||
|
) -> None:
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"POST",
|
"POST",
|
||||||
b"account/password",
|
b"account/password",
|
||||||
@ -388,11 +407,11 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
|
|||||||
room.register_servlets,
|
room.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
self.hs = self.setup_test_homeserver()
|
self.hs = self.setup_test_homeserver()
|
||||||
return self.hs
|
return self.hs
|
||||||
|
|
||||||
def test_deactivate_account(self):
|
def test_deactivate_account(self) -> None:
|
||||||
user_id = self.register_user("kermit", "test")
|
user_id = self.register_user("kermit", "test")
|
||||||
tok = self.login("kermit", "test")
|
tok = self.login("kermit", "test")
|
||||||
|
|
||||||
@ -407,7 +426,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
|
|||||||
channel = self.make_request("GET", "account/whoami", access_token=tok)
|
channel = self.make_request("GET", "account/whoami", access_token=tok)
|
||||||
self.assertEqual(channel.code, 401)
|
self.assertEqual(channel.code, 401)
|
||||||
|
|
||||||
def test_pending_invites(self):
|
def test_pending_invites(self) -> None:
|
||||||
"""Tests that deactivating a user rejects every pending invite for them."""
|
"""Tests that deactivating a user rejects every pending invite for them."""
|
||||||
store = self.hs.get_datastores().main
|
store = self.hs.get_datastores().main
|
||||||
|
|
||||||
@ -448,7 +467,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
|
|||||||
self.assertEqual(len(memberships), 1, memberships)
|
self.assertEqual(len(memberships), 1, memberships)
|
||||||
self.assertEqual(memberships[0].room_id, room_id, memberships)
|
self.assertEqual(memberships[0].room_id, room_id, memberships)
|
||||||
|
|
||||||
def deactivate(self, user_id, tok):
|
def deactivate(self, user_id: str, tok: str) -> None:
|
||||||
request_data = json.dumps(
|
request_data = json.dumps(
|
||||||
{
|
{
|
||||||
"auth": {
|
"auth": {
|
||||||
@ -474,12 +493,12 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
|
|||||||
register.register_servlets,
|
register.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def default_config(self):
|
def default_config(self) -> Dict[str, Any]:
|
||||||
config = super().default_config()
|
config = super().default_config()
|
||||||
config["allow_guest_access"] = True
|
config["allow_guest_access"] = True
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def test_GET_whoami(self):
|
def test_GET_whoami(self) -> None:
|
||||||
device_id = "wouldgohere"
|
device_id = "wouldgohere"
|
||||||
user_id = self.register_user("kermit", "test")
|
user_id = self.register_user("kermit", "test")
|
||||||
tok = self.login("kermit", "test", device_id=device_id)
|
tok = self.login("kermit", "test", device_id=device_id)
|
||||||
@ -496,7 +515,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_GET_whoami_guests(self):
|
def test_GET_whoami_guests(self) -> None:
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
b"POST", b"/_matrix/client/r0/register?kind=guest", b"{}"
|
b"POST", b"/_matrix/client/r0/register?kind=guest", b"{}"
|
||||||
)
|
)
|
||||||
@ -516,7 +535,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_GET_whoami_appservices(self):
|
def test_GET_whoami_appservices(self) -> None:
|
||||||
user_id = "@as:test"
|
user_id = "@as:test"
|
||||||
as_token = "i_am_an_app_service"
|
as_token = "i_am_an_app_service"
|
||||||
|
|
||||||
@ -541,7 +560,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
self.assertFalse(hasattr(whoami, "device_id"))
|
self.assertFalse(hasattr(whoami, "device_id"))
|
||||||
|
|
||||||
def _whoami(self, tok):
|
def _whoami(self, tok: str) -> JsonDict:
|
||||||
channel = self.make_request("GET", "account/whoami", {}, access_token=tok)
|
channel = self.make_request("GET", "account/whoami", {}, access_token=tok)
|
||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
return channel.json_body
|
return channel.json_body
|
||||||
@ -555,7 +574,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
|||||||
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||||
]
|
]
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
config = self.default_config()
|
config = self.default_config()
|
||||||
|
|
||||||
# Email config.
|
# Email config.
|
||||||
@ -576,16 +595,23 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
|||||||
self.hs = self.setup_test_homeserver(config=config)
|
self.hs = self.setup_test_homeserver(config=config)
|
||||||
|
|
||||||
async def sendmail(
|
async def sendmail(
|
||||||
reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs
|
reactor: IReactorTCP,
|
||||||
):
|
smtphost: str,
|
||||||
self.email_attempts.append(msg)
|
smtpport: int,
|
||||||
|
from_addr: str,
|
||||||
|
to_addr: str,
|
||||||
|
msg_bytes: bytes,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
self.email_attempts.append(msg_bytes)
|
||||||
|
|
||||||
self.email_attempts = []
|
self.email_attempts: List[bytes] = []
|
||||||
self.hs.get_send_email_handler()._sendmail = sendmail
|
self.hs.get_send_email_handler()._sendmail = sendmail
|
||||||
|
|
||||||
return self.hs
|
return self.hs
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
|
|
||||||
self.user_id = self.register_user("kermit", "test")
|
self.user_id = self.register_user("kermit", "test")
|
||||||
@ -593,83 +619,73 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
|||||||
self.email = "test@example.com"
|
self.email = "test@example.com"
|
||||||
self.url_3pid = b"account/3pid"
|
self.url_3pid = b"account/3pid"
|
||||||
|
|
||||||
def test_add_valid_email(self):
|
def test_add_valid_email(self) -> None:
|
||||||
self.get_success(self._add_email(self.email, self.email))
|
self._add_email(self.email, self.email)
|
||||||
|
|
||||||
def test_add_valid_email_second_time(self):
|
def test_add_valid_email_second_time(self) -> None:
|
||||||
self.get_success(self._add_email(self.email, self.email))
|
self._add_email(self.email, self.email)
|
||||||
self.get_success(
|
|
||||||
self._request_token_invalid_email(
|
self._request_token_invalid_email(
|
||||||
self.email,
|
self.email,
|
||||||
expected_errcode=Codes.THREEPID_IN_USE,
|
expected_errcode=Codes.THREEPID_IN_USE,
|
||||||
expected_error="Email is already in use",
|
expected_error="Email is already in use",
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
def test_add_valid_email_second_time_canonicalise(self):
|
def test_add_valid_email_second_time_canonicalise(self) -> None:
|
||||||
self.get_success(self._add_email(self.email, self.email))
|
self._add_email(self.email, self.email)
|
||||||
self.get_success(
|
|
||||||
self._request_token_invalid_email(
|
self._request_token_invalid_email(
|
||||||
"TEST@EXAMPLE.COM",
|
"TEST@EXAMPLE.COM",
|
||||||
expected_errcode=Codes.THREEPID_IN_USE,
|
expected_errcode=Codes.THREEPID_IN_USE,
|
||||||
expected_error="Email is already in use",
|
expected_error="Email is already in use",
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
def test_add_email_no_at(self):
|
def test_add_email_no_at(self) -> None:
|
||||||
self.get_success(
|
|
||||||
self._request_token_invalid_email(
|
self._request_token_invalid_email(
|
||||||
"address-without-at.bar",
|
"address-without-at.bar",
|
||||||
expected_errcode=Codes.UNKNOWN,
|
expected_errcode=Codes.UNKNOWN,
|
||||||
expected_error="Unable to parse email address",
|
expected_error="Unable to parse email address",
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
def test_add_email_two_at(self):
|
def test_add_email_two_at(self) -> None:
|
||||||
self.get_success(
|
|
||||||
self._request_token_invalid_email(
|
self._request_token_invalid_email(
|
||||||
"foo@foo@test.bar",
|
"foo@foo@test.bar",
|
||||||
expected_errcode=Codes.UNKNOWN,
|
expected_errcode=Codes.UNKNOWN,
|
||||||
expected_error="Unable to parse email address",
|
expected_error="Unable to parse email address",
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
def test_add_email_bad_format(self):
|
def test_add_email_bad_format(self) -> None:
|
||||||
self.get_success(
|
|
||||||
self._request_token_invalid_email(
|
self._request_token_invalid_email(
|
||||||
"user@bad.example.net@good.example.com",
|
"user@bad.example.net@good.example.com",
|
||||||
expected_errcode=Codes.UNKNOWN,
|
expected_errcode=Codes.UNKNOWN,
|
||||||
expected_error="Unable to parse email address",
|
expected_error="Unable to parse email address",
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
def test_add_email_domain_to_lower(self):
|
def test_add_email_domain_to_lower(self) -> None:
|
||||||
self.get_success(self._add_email("foo@TEST.BAR", "foo@test.bar"))
|
self._add_email("foo@TEST.BAR", "foo@test.bar")
|
||||||
|
|
||||||
def test_add_email_domain_with_umlaut(self):
|
def test_add_email_domain_with_umlaut(self) -> None:
|
||||||
self.get_success(self._add_email("foo@Öumlaut.com", "foo@öumlaut.com"))
|
self._add_email("foo@Öumlaut.com", "foo@öumlaut.com")
|
||||||
|
|
||||||
def test_add_email_address_casefold(self):
|
def test_add_email_address_casefold(self) -> None:
|
||||||
self.get_success(self._add_email("Strauß@Example.com", "strauss@example.com"))
|
self._add_email("Strauß@Example.com", "strauss@example.com")
|
||||||
|
|
||||||
def test_address_trim(self):
|
def test_address_trim(self) -> None:
|
||||||
self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar"))
|
self._add_email(" foo@test.bar ", "foo@test.bar")
|
||||||
|
|
||||||
@override_config({"rc_3pid_validation": {"burst_count": 3}})
|
@override_config({"rc_3pid_validation": {"burst_count": 3}})
|
||||||
def test_ratelimit_by_ip(self):
|
def test_ratelimit_by_ip(self) -> None:
|
||||||
"""Tests that adding emails is ratelimited by IP"""
|
"""Tests that adding emails is ratelimited by IP"""
|
||||||
|
|
||||||
# We expect to be able to set three emails before getting ratelimited.
|
# We expect to be able to set three emails before getting ratelimited.
|
||||||
self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar"))
|
self._add_email("foo1@test.bar", "foo1@test.bar")
|
||||||
self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar"))
|
self._add_email("foo2@test.bar", "foo2@test.bar")
|
||||||
self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar"))
|
self._add_email("foo3@test.bar", "foo3@test.bar")
|
||||||
|
|
||||||
with self.assertRaises(HttpResponseException) as cm:
|
with self.assertRaises(HttpResponseException) as cm:
|
||||||
self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar"))
|
self._add_email("foo4@test.bar", "foo4@test.bar")
|
||||||
|
|
||||||
self.assertEqual(cm.exception.code, 429)
|
self.assertEqual(cm.exception.code, 429)
|
||||||
|
|
||||||
def test_add_email_if_disabled(self):
|
def test_add_email_if_disabled(self) -> None:
|
||||||
"""Test adding email to profile when doing so is disallowed"""
|
"""Test adding email to profile when doing so is disallowed"""
|
||||||
self.hs.config.registration.enable_3pid_changes = False
|
self.hs.config.registration.enable_3pid_changes = False
|
||||||
|
|
||||||
@ -695,7 +711,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
|||||||
},
|
},
|
||||||
access_token=self.user_id_tok,
|
access_token=self.user_id_tok,
|
||||||
)
|
)
|
||||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
||||||
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
|
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
|
||||||
|
|
||||||
# Get user
|
# Get user
|
||||||
@ -705,10 +721,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
|||||||
access_token=self.user_id_tok,
|
access_token=self.user_id_tok,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||||
self.assertFalse(channel.json_body["threepids"])
|
self.assertFalse(channel.json_body["threepids"])
|
||||||
|
|
||||||
def test_delete_email(self):
|
def test_delete_email(self) -> None:
|
||||||
"""Test deleting an email from profile"""
|
"""Test deleting an email from profile"""
|
||||||
# Add a threepid
|
# Add a threepid
|
||||||
self.get_success(
|
self.get_success(
|
||||||
@ -727,7 +743,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
|||||||
{"medium": "email", "address": self.email},
|
{"medium": "email", "address": self.email},
|
||||||
access_token=self.user_id_tok,
|
access_token=self.user_id_tok,
|
||||||
)
|
)
|
||||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||||
|
|
||||||
# Get user
|
# Get user
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
@ -736,10 +752,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
|||||||
access_token=self.user_id_tok,
|
access_token=self.user_id_tok,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||||
self.assertFalse(channel.json_body["threepids"])
|
self.assertFalse(channel.json_body["threepids"])
|
||||||
|
|
||||||
def test_delete_email_if_disabled(self):
|
def test_delete_email_if_disabled(self) -> None:
|
||||||
"""Test deleting an email from profile when disallowed"""
|
"""Test deleting an email from profile when disallowed"""
|
||||||
self.hs.config.registration.enable_3pid_changes = False
|
self.hs.config.registration.enable_3pid_changes = False
|
||||||
|
|
||||||
@ -761,7 +777,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
|||||||
access_token=self.user_id_tok,
|
access_token=self.user_id_tok,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
||||||
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
|
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
|
||||||
|
|
||||||
# Get user
|
# Get user
|
||||||
@ -771,11 +787,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
|||||||
access_token=self.user_id_tok,
|
access_token=self.user_id_tok,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||||
self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
|
self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
|
||||||
|
|
||||||
def test_cant_add_email_without_clicking_link(self):
|
def test_cant_add_email_without_clicking_link(self) -> None:
|
||||||
"""Test that we do actually need to click the link in the email"""
|
"""Test that we do actually need to click the link in the email"""
|
||||||
client_secret = "foobar"
|
client_secret = "foobar"
|
||||||
session_id = self._request_token(self.email, client_secret)
|
session_id = self._request_token(self.email, client_secret)
|
||||||
@ -797,7 +813,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
|||||||
},
|
},
|
||||||
access_token=self.user_id_tok,
|
access_token=self.user_id_tok,
|
||||||
)
|
)
|
||||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
||||||
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
|
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
|
||||||
|
|
||||||
# Get user
|
# Get user
|
||||||
@ -807,10 +823,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
|||||||
access_token=self.user_id_tok,
|
access_token=self.user_id_tok,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||||
self.assertFalse(channel.json_body["threepids"])
|
self.assertFalse(channel.json_body["threepids"])
|
||||||
|
|
||||||
def test_no_valid_token(self):
|
def test_no_valid_token(self) -> None:
|
||||||
"""Test that we do actually need to request a token and can't just
|
"""Test that we do actually need to request a token and can't just
|
||||||
make a session up.
|
make a session up.
|
||||||
"""
|
"""
|
||||||
@ -832,7 +848,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
|||||||
},
|
},
|
||||||
access_token=self.user_id_tok,
|
access_token=self.user_id_tok,
|
||||||
)
|
)
|
||||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
||||||
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
|
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
|
||||||
|
|
||||||
# Get user
|
# Get user
|
||||||
@ -842,11 +858,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
|||||||
access_token=self.user_id_tok,
|
access_token=self.user_id_tok,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||||
self.assertFalse(channel.json_body["threepids"])
|
self.assertFalse(channel.json_body["threepids"])
|
||||||
|
|
||||||
@override_config({"next_link_domain_whitelist": None})
|
@override_config({"next_link_domain_whitelist": None})
|
||||||
def test_next_link(self):
|
def test_next_link(self) -> None:
|
||||||
"""Tests a valid next_link parameter value with no whitelist (good case)"""
|
"""Tests a valid next_link parameter value with no whitelist (good case)"""
|
||||||
self._request_token(
|
self._request_token(
|
||||||
"something@example.com",
|
"something@example.com",
|
||||||
@ -856,7 +872,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"next_link_domain_whitelist": None})
|
@override_config({"next_link_domain_whitelist": None})
|
||||||
def test_next_link_exotic_protocol(self):
|
def test_next_link_exotic_protocol(self) -> None:
|
||||||
"""Tests using a esoteric protocol as a next_link parameter value.
|
"""Tests using a esoteric protocol as a next_link parameter value.
|
||||||
Someone may be hosting a client on IPFS etc.
|
Someone may be hosting a client on IPFS etc.
|
||||||
"""
|
"""
|
||||||
@ -868,7 +884,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"next_link_domain_whitelist": None})
|
@override_config({"next_link_domain_whitelist": None})
|
||||||
def test_next_link_file_uri(self):
|
def test_next_link_file_uri(self) -> None:
|
||||||
"""Tests next_link parameters cannot be file URI"""
|
"""Tests next_link parameters cannot be file URI"""
|
||||||
# Attempt to use a next_link value that points to the local disk
|
# Attempt to use a next_link value that points to the local disk
|
||||||
self._request_token(
|
self._request_token(
|
||||||
@ -879,7 +895,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"next_link_domain_whitelist": ["example.com", "example.org"]})
|
@override_config({"next_link_domain_whitelist": ["example.com", "example.org"]})
|
||||||
def test_next_link_domain_whitelist(self):
|
def test_next_link_domain_whitelist(self) -> None:
|
||||||
"""Tests next_link parameters must fit the whitelist if provided"""
|
"""Tests next_link parameters must fit the whitelist if provided"""
|
||||||
|
|
||||||
# Ensure not providing a next_link parameter still works
|
# Ensure not providing a next_link parameter still works
|
||||||
@ -912,7 +928,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"next_link_domain_whitelist": []})
|
@override_config({"next_link_domain_whitelist": []})
|
||||||
def test_empty_next_link_domain_whitelist(self):
|
def test_empty_next_link_domain_whitelist(self) -> None:
|
||||||
"""Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially
|
"""Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially
|
||||||
disallowed
|
disallowed
|
||||||
"""
|
"""
|
||||||
@ -962,28 +978,28 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
def _request_token_invalid_email(
|
def _request_token_invalid_email(
|
||||||
self,
|
self,
|
||||||
email,
|
email: str,
|
||||||
expected_errcode,
|
expected_errcode: str,
|
||||||
expected_error,
|
expected_error: str,
|
||||||
client_secret="foobar",
|
client_secret: str = "foobar",
|
||||||
):
|
) -> None:
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"POST",
|
"POST",
|
||||||
b"account/3pid/email/requestToken",
|
b"account/3pid/email/requestToken",
|
||||||
{"client_secret": client_secret, "email": email, "send_attempt": 1},
|
{"client_secret": client_secret, "email": email, "send_attempt": 1},
|
||||||
)
|
)
|
||||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
||||||
self.assertEqual(expected_errcode, channel.json_body["errcode"])
|
self.assertEqual(expected_errcode, channel.json_body["errcode"])
|
||||||
self.assertEqual(expected_error, channel.json_body["error"])
|
self.assertEqual(expected_error, channel.json_body["error"])
|
||||||
|
|
||||||
def _validate_token(self, link):
|
def _validate_token(self, link: str) -> None:
|
||||||
# Remove the host
|
# Remove the host
|
||||||
path = link.replace("https://example.com", "")
|
path = link.replace("https://example.com", "")
|
||||||
|
|
||||||
channel = self.make_request("GET", path, shorthand=False)
|
channel = self.make_request("GET", path, shorthand=False)
|
||||||
self.assertEqual(200, channel.code, channel.result)
|
self.assertEqual(200, channel.code, channel.result)
|
||||||
|
|
||||||
def _get_link_from_email(self):
|
def _get_link_from_email(self) -> str:
|
||||||
assert self.email_attempts, "No emails have been sent"
|
assert self.email_attempts, "No emails have been sent"
|
||||||
|
|
||||||
raw_msg = self.email_attempts[-1].decode("UTF-8")
|
raw_msg = self.email_attempts[-1].decode("UTF-8")
|
||||||
@ -998,12 +1014,13 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
|||||||
if not text:
|
if not text:
|
||||||
self.fail("Could not find text portion of email to parse")
|
self.fail("Could not find text portion of email to parse")
|
||||||
|
|
||||||
|
assert text is not None
|
||||||
match = re.search(r"https://example.com\S+", text)
|
match = re.search(r"https://example.com\S+", text)
|
||||||
assert match, "Could not find link in email"
|
assert match, "Could not find link in email"
|
||||||
|
|
||||||
return match.group(0)
|
return match.group(0)
|
||||||
|
|
||||||
def _add_email(self, request_email, expected_email):
|
def _add_email(self, request_email: str, expected_email: str) -> None:
|
||||||
"""Test adding an email to profile"""
|
"""Test adding an email to profile"""
|
||||||
previous_email_attempts = len(self.email_attempts)
|
previous_email_attempts = len(self.email_attempts)
|
||||||
|
|
||||||
@ -1030,7 +1047,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
|||||||
access_token=self.user_id_tok,
|
access_token=self.user_id_tok,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||||
|
|
||||||
# Get user
|
# Get user
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
@ -1039,7 +1056,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
|||||||
access_token=self.user_id_tok,
|
access_token=self.user_id_tok,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||||
|
|
||||||
threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
|
threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
|
||||||
@ -1055,18 +1072,18 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
url = "/_matrix/client/unstable/org.matrix.msc3720/account_status"
|
url = "/_matrix/client/unstable/org.matrix.msc3720/account_status"
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
config = self.default_config()
|
config = self.default_config()
|
||||||
config["experimental_features"] = {"msc3720_enabled": True}
|
config["experimental_features"] = {"msc3720_enabled": True}
|
||||||
|
|
||||||
return self.setup_test_homeserver(config=config)
|
return self.setup_test_homeserver(config=config)
|
||||||
|
|
||||||
def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.requester = self.register_user("requester", "password")
|
self.requester = self.register_user("requester", "password")
|
||||||
self.requester_tok = self.login("requester", "password")
|
self.requester_tok = self.login("requester", "password")
|
||||||
self.server_name = homeserver.config.server.server_name
|
self.server_name = hs.config.server.server_name
|
||||||
|
|
||||||
def test_missing_mxid(self):
|
def test_missing_mxid(self) -> None:
|
||||||
"""Tests that not providing any MXID raises an error."""
|
"""Tests that not providing any MXID raises an error."""
|
||||||
self._test_status(
|
self._test_status(
|
||||||
users=None,
|
users=None,
|
||||||
@ -1074,7 +1091,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
|
|||||||
expected_errcode=Codes.MISSING_PARAM,
|
expected_errcode=Codes.MISSING_PARAM,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_invalid_mxid(self):
|
def test_invalid_mxid(self) -> None:
|
||||||
"""Tests that providing an invalid MXID raises an error."""
|
"""Tests that providing an invalid MXID raises an error."""
|
||||||
self._test_status(
|
self._test_status(
|
||||||
users=["bad:test"],
|
users=["bad:test"],
|
||||||
@ -1082,7 +1099,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
|
|||||||
expected_errcode=Codes.INVALID_PARAM,
|
expected_errcode=Codes.INVALID_PARAM,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_local_user_not_exists(self):
|
def test_local_user_not_exists(self) -> None:
|
||||||
"""Tests that the account status endpoints correctly reports that a user doesn't
|
"""Tests that the account status endpoints correctly reports that a user doesn't
|
||||||
exist.
|
exist.
|
||||||
"""
|
"""
|
||||||
@ -1098,7 +1115,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
|
|||||||
expected_failures=[],
|
expected_failures=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_local_user_exists(self):
|
def test_local_user_exists(self) -> None:
|
||||||
"""Tests that the account status endpoint correctly reports that a user doesn't
|
"""Tests that the account status endpoint correctly reports that a user doesn't
|
||||||
exist.
|
exist.
|
||||||
"""
|
"""
|
||||||
@ -1115,7 +1132,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
|
|||||||
expected_failures=[],
|
expected_failures=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_local_user_deactivated(self):
|
def test_local_user_deactivated(self) -> None:
|
||||||
"""Tests that the account status endpoint correctly reports a deactivated user."""
|
"""Tests that the account status endpoint correctly reports a deactivated user."""
|
||||||
user = self.register_user("someuser", "password")
|
user = self.register_user("someuser", "password")
|
||||||
self.get_success(
|
self.get_success(
|
||||||
@ -1135,7 +1152,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
|
|||||||
expected_failures=[],
|
expected_failures=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_mixed_local_and_remote_users(self):
|
def test_mixed_local_and_remote_users(self) -> None:
|
||||||
"""Tests that if some users are remote the account status endpoint correctly
|
"""Tests that if some users are remote the account status endpoint correctly
|
||||||
merges the remote responses with the local result.
|
merges the remote responses with the local result.
|
||||||
"""
|
"""
|
||||||
@ -1150,7 +1167,13 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
|
|||||||
"@bad:badremote",
|
"@bad:badremote",
|
||||||
]
|
]
|
||||||
|
|
||||||
async def post_json(destination, path, data, *a, **kwa):
|
async def post_json(
|
||||||
|
destination: str,
|
||||||
|
path: str,
|
||||||
|
data: Optional[JsonDict] = None,
|
||||||
|
*a: Any,
|
||||||
|
**kwa: Any,
|
||||||
|
) -> Union[JsonDict, list]:
|
||||||
if destination == "remote":
|
if destination == "remote":
|
||||||
return {
|
return {
|
||||||
"account_statuses": {
|
"account_statuses": {
|
||||||
@ -1160,9 +1183,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if destination == "otherremote":
|
elif destination == "badremote":
|
||||||
return {}
|
|
||||||
if destination == "badremote":
|
|
||||||
# badremote tries to overwrite the status of a user that doesn't belong
|
# badremote tries to overwrite the status of a user that doesn't belong
|
||||||
# to it (i.e. users[1]) with false data, which Synapse is expected to
|
# to it (i.e. users[1]) with false data, which Synapse is expected to
|
||||||
# ignore.
|
# ignore.
|
||||||
@ -1176,6 +1197,9 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
# if destination == "otherremote"
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
# Register a mock that will return the expected result depending on the remote.
|
# Register a mock that will return the expected result depending on the remote.
|
||||||
self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json)
|
self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json)
|
||||||
@ -1205,7 +1229,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
|
|||||||
expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None,
|
expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None,
|
||||||
expected_failures: Optional[List[str]] = None,
|
expected_failures: Optional[List[str]] = None,
|
||||||
expected_errcode: Optional[str] = None,
|
expected_errcode: Optional[str] = None,
|
||||||
):
|
) -> None:
|
||||||
"""Send a request to the account status endpoint and check that the response
|
"""Send a request to the account status endpoint and check that the response
|
||||||
matches with what's expected.
|
matches with what's expected.
|
||||||
|
|
||||||
|
@ -12,10 +12,12 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.errors import Codes
|
from synapse.api.errors import Codes
|
||||||
from synapse.rest.client import filter
|
from synapse.rest.client import filter
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
@ -30,11 +32,11 @@ class FilterTestCase(unittest.HomeserverTestCase):
|
|||||||
EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}'
|
EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}'
|
||||||
servlets = [filter.register_servlets]
|
servlets = [filter.register_servlets]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.filtering = hs.get_filtering()
|
self.filtering = hs.get_filtering()
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
|
|
||||||
def test_add_filter(self):
|
def test_add_filter(self) -> None:
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"POST",
|
"POST",
|
||||||
"/_matrix/client/r0/user/%s/filter" % (self.user_id),
|
"/_matrix/client/r0/user/%s/filter" % (self.user_id),
|
||||||
@ -43,11 +45,13 @@ class FilterTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.assertEqual(channel.result["code"], b"200")
|
self.assertEqual(channel.result["code"], b"200")
|
||||||
self.assertEqual(channel.json_body, {"filter_id": "0"})
|
self.assertEqual(channel.json_body, {"filter_id": "0"})
|
||||||
filter = self.store.get_user_filter(user_localpart="apple", filter_id=0)
|
filter = self.get_success(
|
||||||
|
self.store.get_user_filter(user_localpart="apple", filter_id=0)
|
||||||
|
)
|
||||||
self.pump()
|
self.pump()
|
||||||
self.assertEqual(filter.result, self.EXAMPLE_FILTER)
|
self.assertEqual(filter, self.EXAMPLE_FILTER)
|
||||||
|
|
||||||
def test_add_filter_for_other_user(self):
|
def test_add_filter_for_other_user(self) -> None:
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"POST",
|
"POST",
|
||||||
"/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"),
|
"/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"),
|
||||||
@ -57,7 +61,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
|
|||||||
self.assertEqual(channel.result["code"], b"403")
|
self.assertEqual(channel.result["code"], b"403")
|
||||||
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
|
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
|
||||||
|
|
||||||
def test_add_filter_non_local_user(self):
|
def test_add_filter_non_local_user(self) -> None:
|
||||||
_is_mine = self.hs.is_mine
|
_is_mine = self.hs.is_mine
|
||||||
self.hs.is_mine = lambda target_user: False
|
self.hs.is_mine = lambda target_user: False
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
@ -70,14 +74,13 @@ class FilterTestCase(unittest.HomeserverTestCase):
|
|||||||
self.assertEqual(channel.result["code"], b"403")
|
self.assertEqual(channel.result["code"], b"403")
|
||||||
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
|
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
|
||||||
|
|
||||||
def test_get_filter(self):
|
def test_get_filter(self) -> None:
|
||||||
filter_id = defer.ensureDeferred(
|
filter_id = self.get_success(
|
||||||
self.filtering.add_user_filter(
|
self.filtering.add_user_filter(
|
||||||
user_localpart="apple", user_filter=self.EXAMPLE_FILTER
|
user_localpart="apple", user_filter=self.EXAMPLE_FILTER
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.reactor.advance(1)
|
self.reactor.advance(1)
|
||||||
filter_id = filter_id.result
|
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id)
|
"GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id)
|
||||||
)
|
)
|
||||||
@ -85,7 +88,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
|
|||||||
self.assertEqual(channel.result["code"], b"200")
|
self.assertEqual(channel.result["code"], b"200")
|
||||||
self.assertEqual(channel.json_body, self.EXAMPLE_FILTER)
|
self.assertEqual(channel.json_body, self.EXAMPLE_FILTER)
|
||||||
|
|
||||||
def test_get_filter_non_existant(self):
|
def test_get_filter_non_existant(self) -> None:
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id)
|
"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id)
|
||||||
)
|
)
|
||||||
@ -95,7 +98,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
# Currently invalid params do not have an appropriate errcode
|
# Currently invalid params do not have an appropriate errcode
|
||||||
# in errors.py
|
# in errors.py
|
||||||
def test_get_filter_invalid_id(self):
|
def test_get_filter_invalid_id(self) -> None:
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id)
|
"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id)
|
||||||
)
|
)
|
||||||
@ -103,7 +106,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
|
|||||||
self.assertEqual(channel.result["code"], b"400")
|
self.assertEqual(channel.result["code"], b"400")
|
||||||
|
|
||||||
# No ID also returns an invalid_id error
|
# No ID also returns an invalid_id error
|
||||||
def test_get_filter_no_id(self):
|
def test_get_filter_no_id(self) -> None:
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id)
|
"GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id)
|
||||||
)
|
)
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
@ -45,7 +45,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
|
|||||||
]
|
]
|
||||||
hijack_auth = False
|
hijack_auth = False
|
||||||
|
|
||||||
def default_config(self) -> dict:
|
def default_config(self) -> Dict[str, Any]:
|
||||||
# We need to enable msc1849 support for aggregations
|
# We need to enable msc1849 support for aggregations
|
||||||
config = super().default_config()
|
config = super().default_config()
|
||||||
|
|
||||||
|
@ -14,8 +14,13 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
from synapse.rest.client import login, report_event, room
|
from synapse.rest.client import login, report_event, room
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
@ -28,7 +33,7 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
|
|||||||
report_event.register_servlets,
|
report_event.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.admin_user = self.register_user("admin", "pass", admin=True)
|
self.admin_user = self.register_user("admin", "pass", admin=True)
|
||||||
self.admin_user_tok = self.login("admin", "pass")
|
self.admin_user_tok = self.login("admin", "pass")
|
||||||
self.other_user = self.register_user("user", "pass")
|
self.other_user = self.register_user("user", "pass")
|
||||||
@ -42,35 +47,35 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
|
|||||||
self.event_id = resp["event_id"]
|
self.event_id = resp["event_id"]
|
||||||
self.report_path = f"rooms/{self.room_id}/report/{self.event_id}"
|
self.report_path = f"rooms/{self.room_id}/report/{self.event_id}"
|
||||||
|
|
||||||
def test_reason_str_and_score_int(self):
|
def test_reason_str_and_score_int(self) -> None:
|
||||||
data = {"reason": "this makes me sad", "score": -100}
|
data = {"reason": "this makes me sad", "score": -100}
|
||||||
self._assert_status(200, data)
|
self._assert_status(200, data)
|
||||||
|
|
||||||
def test_no_reason(self):
|
def test_no_reason(self) -> None:
|
||||||
data = {"score": 0}
|
data = {"score": 0}
|
||||||
self._assert_status(200, data)
|
self._assert_status(200, data)
|
||||||
|
|
||||||
def test_no_score(self):
|
def test_no_score(self) -> None:
|
||||||
data = {"reason": "this makes me sad"}
|
data = {"reason": "this makes me sad"}
|
||||||
self._assert_status(200, data)
|
self._assert_status(200, data)
|
||||||
|
|
||||||
def test_no_reason_and_no_score(self):
|
def test_no_reason_and_no_score(self) -> None:
|
||||||
data = {}
|
data: JsonDict = {}
|
||||||
self._assert_status(200, data)
|
self._assert_status(200, data)
|
||||||
|
|
||||||
def test_reason_int_and_score_str(self):
|
def test_reason_int_and_score_str(self) -> None:
|
||||||
data = {"reason": 10, "score": "string"}
|
data = {"reason": 10, "score": "string"}
|
||||||
self._assert_status(400, data)
|
self._assert_status(400, data)
|
||||||
|
|
||||||
def test_reason_zero_and_score_blank(self):
|
def test_reason_zero_and_score_blank(self) -> None:
|
||||||
data = {"reason": 0, "score": ""}
|
data = {"reason": 0, "score": ""}
|
||||||
self._assert_status(400, data)
|
self._assert_status(400, data)
|
||||||
|
|
||||||
def test_reason_and_score_null(self):
|
def test_reason_and_score_null(self) -> None:
|
||||||
data = {"reason": None, "score": None}
|
data = {"reason": None, "score": None}
|
||||||
self._assert_status(400, data)
|
self._assert_status(400, data)
|
||||||
|
|
||||||
def _assert_status(self, response_status, data):
|
def _assert_status(self, response_status: int, data: JsonDict) -> None:
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"POST",
|
"POST",
|
||||||
self.report_path,
|
self.report_path,
|
||||||
|
@ -18,11 +18,12 @@
|
|||||||
"""Tests REST events for /rooms paths."""
|
"""Tests REST events for /rooms paths."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Iterable, List
|
from typing import Any, Dict, Iterable, List, Optional
|
||||||
from unittest.mock import Mock, call
|
from unittest.mock import Mock, call
|
||||||
from urllib import parse as urlparse
|
from urllib import parse as urlparse
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
from synapse.api.constants import (
|
from synapse.api.constants import (
|
||||||
@ -35,7 +36,9 @@ from synapse.api.errors import Codes, HttpResponseException
|
|||||||
from synapse.handlers.pagination import PurgeStatus
|
from synapse.handlers.pagination import PurgeStatus
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.rest.client import account, directory, login, profile, room, sync
|
from synapse.rest.client import account, directory, login, profile, room, sync
|
||||||
|
from synapse.server import HomeServer
|
||||||
from synapse.types import JsonDict, RoomAlias, UserID, create_requester
|
from synapse.types import JsonDict, RoomAlias, UserID, create_requester
|
||||||
|
from synapse.util import Clock
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
@ -45,11 +48,11 @@ PATH_PREFIX = b"/_matrix/client/api/v1"
|
|||||||
|
|
||||||
|
|
||||||
class RoomBase(unittest.HomeserverTestCase):
|
class RoomBase(unittest.HomeserverTestCase):
|
||||||
rmcreator_id = None
|
rmcreator_id: Optional[str] = None
|
||||||
|
|
||||||
servlets = [room.register_servlets, room.register_deprecated_servlets]
|
servlets = [room.register_servlets, room.register_deprecated_servlets]
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
|
|
||||||
self.hs = self.setup_test_homeserver(
|
self.hs = self.setup_test_homeserver(
|
||||||
"red",
|
"red",
|
||||||
@ -57,15 +60,15 @@ class RoomBase(unittest.HomeserverTestCase):
|
|||||||
federation_client=Mock(),
|
federation_client=Mock(),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.hs.get_federation_handler = Mock()
|
self.hs.get_federation_handler = Mock() # type: ignore[assignment]
|
||||||
self.hs.get_federation_handler.return_value.maybe_backfill = Mock(
|
self.hs.get_federation_handler.return_value.maybe_backfill = Mock(
|
||||||
return_value=make_awaitable(None)
|
return_value=make_awaitable(None)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _insert_client_ip(*args, **kwargs):
|
async def _insert_client_ip(*args: Any, **kwargs: Any) -> None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
self.hs.get_datastores().main.insert_client_ip = _insert_client_ip
|
self.hs.get_datastores().main.insert_client_ip = _insert_client_ip # type: ignore[assignment]
|
||||||
|
|
||||||
return self.hs
|
return self.hs
|
||||||
|
|
||||||
@ -76,7 +79,7 @@ class RoomPermissionsTestCase(RoomBase):
|
|||||||
user_id = "@sid1:red"
|
user_id = "@sid1:red"
|
||||||
rmcreator_id = "@notme:red"
|
rmcreator_id = "@notme:red"
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
|
|
||||||
self.helper.auth_user_id = self.rmcreator_id
|
self.helper.auth_user_id = self.rmcreator_id
|
||||||
# create some rooms under the name rmcreator_id
|
# create some rooms under the name rmcreator_id
|
||||||
@ -108,12 +111,12 @@ class RoomPermissionsTestCase(RoomBase):
|
|||||||
# auth as user_id now
|
# auth as user_id now
|
||||||
self.helper.auth_user_id = self.user_id
|
self.helper.auth_user_id = self.user_id
|
||||||
|
|
||||||
def test_can_do_action(self):
|
def test_can_do_action(self) -> None:
|
||||||
msg_content = b'{"msgtype":"m.text","body":"hello"}'
|
msg_content = b'{"msgtype":"m.text","body":"hello"}'
|
||||||
|
|
||||||
seq = iter(range(100))
|
seq = iter(range(100))
|
||||||
|
|
||||||
def send_msg_path():
|
def send_msg_path() -> str:
|
||||||
return "/rooms/%s/send/m.room.message/mid%s" % (
|
return "/rooms/%s/send/m.room.message/mid%s" % (
|
||||||
self.created_rmid,
|
self.created_rmid,
|
||||||
str(next(seq)),
|
str(next(seq)),
|
||||||
@ -148,7 +151,7 @@ class RoomPermissionsTestCase(RoomBase):
|
|||||||
channel = self.make_request("PUT", send_msg_path(), msg_content)
|
channel = self.make_request("PUT", send_msg_path(), msg_content)
|
||||||
self.assertEqual(403, channel.code, msg=channel.result["body"])
|
self.assertEqual(403, channel.code, msg=channel.result["body"])
|
||||||
|
|
||||||
def test_topic_perms(self):
|
def test_topic_perms(self) -> None:
|
||||||
topic_content = b'{"topic":"My Topic Name"}'
|
topic_content = b'{"topic":"My Topic Name"}'
|
||||||
topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid
|
topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid
|
||||||
|
|
||||||
@ -214,14 +217,14 @@ class RoomPermissionsTestCase(RoomBase):
|
|||||||
self.assertEqual(403, channel.code, msg=channel.result["body"])
|
self.assertEqual(403, channel.code, msg=channel.result["body"])
|
||||||
|
|
||||||
def _test_get_membership(
|
def _test_get_membership(
|
||||||
self, room=None, members: Iterable = frozenset(), expect_code=None
|
self, room: str, members: Iterable = frozenset(), expect_code: int = 200
|
||||||
):
|
) -> None:
|
||||||
for member in members:
|
for member in members:
|
||||||
path = "/rooms/%s/state/m.room.member/%s" % (room, member)
|
path = "/rooms/%s/state/m.room.member/%s" % (room, member)
|
||||||
channel = self.make_request("GET", path)
|
channel = self.make_request("GET", path)
|
||||||
self.assertEqual(expect_code, channel.code)
|
self.assertEqual(expect_code, channel.code)
|
||||||
|
|
||||||
def test_membership_basic_room_perms(self):
|
def test_membership_basic_room_perms(self) -> None:
|
||||||
# === room does not exist ===
|
# === room does not exist ===
|
||||||
room = self.uncreated_rmid
|
room = self.uncreated_rmid
|
||||||
# get membership of self, get membership of other, uncreated room
|
# get membership of self, get membership of other, uncreated room
|
||||||
@ -241,7 +244,7 @@ class RoomPermissionsTestCase(RoomBase):
|
|||||||
self.helper.join(room=room, user=usr, expect_code=404)
|
self.helper.join(room=room, user=usr, expect_code=404)
|
||||||
self.helper.leave(room=room, user=usr, expect_code=404)
|
self.helper.leave(room=room, user=usr, expect_code=404)
|
||||||
|
|
||||||
def test_membership_private_room_perms(self):
|
def test_membership_private_room_perms(self) -> None:
|
||||||
room = self.created_rmid
|
room = self.created_rmid
|
||||||
# get membership of self, get membership of other, private room + invite
|
# get membership of self, get membership of other, private room + invite
|
||||||
# expect all 403s
|
# expect all 403s
|
||||||
@ -264,7 +267,7 @@ class RoomPermissionsTestCase(RoomBase):
|
|||||||
members=[self.user_id, self.rmcreator_id], room=room, expect_code=200
|
members=[self.user_id, self.rmcreator_id], room=room, expect_code=200
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_membership_public_room_perms(self):
|
def test_membership_public_room_perms(self) -> None:
|
||||||
room = self.created_public_rmid
|
room = self.created_public_rmid
|
||||||
# get membership of self, get membership of other, public room + invite
|
# get membership of self, get membership of other, public room + invite
|
||||||
# expect 403
|
# expect 403
|
||||||
@ -287,7 +290,7 @@ class RoomPermissionsTestCase(RoomBase):
|
|||||||
members=[self.user_id, self.rmcreator_id], room=room, expect_code=200
|
members=[self.user_id, self.rmcreator_id], room=room, expect_code=200
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_invited_permissions(self):
|
def test_invited_permissions(self) -> None:
|
||||||
room = self.created_rmid
|
room = self.created_rmid
|
||||||
self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
|
self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
|
||||||
|
|
||||||
@ -310,7 +313,7 @@ class RoomPermissionsTestCase(RoomBase):
|
|||||||
expect_code=403,
|
expect_code=403,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_joined_permissions(self):
|
def test_joined_permissions(self) -> None:
|
||||||
room = self.created_rmid
|
room = self.created_rmid
|
||||||
self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
|
self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
|
||||||
self.helper.join(room=room, user=self.user_id)
|
self.helper.join(room=room, user=self.user_id)
|
||||||
@ -348,7 +351,7 @@ class RoomPermissionsTestCase(RoomBase):
|
|||||||
# set left of self, expect 200
|
# set left of self, expect 200
|
||||||
self.helper.leave(room=room, user=self.user_id)
|
self.helper.leave(room=room, user=self.user_id)
|
||||||
|
|
||||||
def test_leave_permissions(self):
|
def test_leave_permissions(self) -> None:
|
||||||
room = self.created_rmid
|
room = self.created_rmid
|
||||||
self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
|
self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
|
||||||
self.helper.join(room=room, user=self.user_id)
|
self.helper.join(room=room, user=self.user_id)
|
||||||
@ -383,7 +386,7 @@ class RoomPermissionsTestCase(RoomBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# tests the "from banned" line from the table in https://spec.matrix.org/unstable/client-server-api/#mroommember
|
# tests the "from banned" line from the table in https://spec.matrix.org/unstable/client-server-api/#mroommember
|
||||||
def test_member_event_from_ban(self):
|
def test_member_event_from_ban(self) -> None:
|
||||||
room = self.created_rmid
|
room = self.created_rmid
|
||||||
self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
|
self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
|
||||||
self.helper.join(room=room, user=self.user_id)
|
self.helper.join(room=room, user=self.user_id)
|
||||||
@ -475,21 +478,21 @@ class RoomsMemberListTestCase(RoomBase):
|
|||||||
|
|
||||||
user_id = "@sid1:red"
|
user_id = "@sid1:red"
|
||||||
|
|
||||||
def test_get_member_list(self):
|
def test_get_member_list(self) -> None:
|
||||||
room_id = self.helper.create_room_as(self.user_id)
|
room_id = self.helper.create_room_as(self.user_id)
|
||||||
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
|
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
|
||||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||||
|
|
||||||
def test_get_member_list_no_room(self):
|
def test_get_member_list_no_room(self) -> None:
|
||||||
channel = self.make_request("GET", "/rooms/roomdoesnotexist/members")
|
channel = self.make_request("GET", "/rooms/roomdoesnotexist/members")
|
||||||
self.assertEqual(403, channel.code, msg=channel.result["body"])
|
self.assertEqual(403, channel.code, msg=channel.result["body"])
|
||||||
|
|
||||||
def test_get_member_list_no_permission(self):
|
def test_get_member_list_no_permission(self) -> None:
|
||||||
room_id = self.helper.create_room_as("@some_other_guy:red")
|
room_id = self.helper.create_room_as("@some_other_guy:red")
|
||||||
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
|
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
|
||||||
self.assertEqual(403, channel.code, msg=channel.result["body"])
|
self.assertEqual(403, channel.code, msg=channel.result["body"])
|
||||||
|
|
||||||
def test_get_member_list_no_permission_with_at_token(self):
|
def test_get_member_list_no_permission_with_at_token(self) -> None:
|
||||||
"""
|
"""
|
||||||
Tests that a stranger to the room cannot get the member list
|
Tests that a stranger to the room cannot get the member list
|
||||||
(in the case that they use an at token).
|
(in the case that they use an at token).
|
||||||
@ -509,7 +512,7 @@ class RoomsMemberListTestCase(RoomBase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(403, channel.code, msg=channel.result["body"])
|
self.assertEqual(403, channel.code, msg=channel.result["body"])
|
||||||
|
|
||||||
def test_get_member_list_no_permission_former_member(self):
|
def test_get_member_list_no_permission_former_member(self) -> None:
|
||||||
"""
|
"""
|
||||||
Tests that a former member of the room can not get the member list.
|
Tests that a former member of the room can not get the member list.
|
||||||
"""
|
"""
|
||||||
@ -529,7 +532,7 @@ class RoomsMemberListTestCase(RoomBase):
|
|||||||
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
|
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
|
||||||
self.assertEqual(403, channel.code, msg=channel.result["body"])
|
self.assertEqual(403, channel.code, msg=channel.result["body"])
|
||||||
|
|
||||||
def test_get_member_list_no_permission_former_member_with_at_token(self):
|
def test_get_member_list_no_permission_former_member_with_at_token(self) -> None:
|
||||||
"""
|
"""
|
||||||
Tests that a former member of the room can not get the member list
|
Tests that a former member of the room can not get the member list
|
||||||
(in the case that they use an at token).
|
(in the case that they use an at token).
|
||||||
@ -569,7 +572,7 @@ class RoomsMemberListTestCase(RoomBase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(403, channel.code, msg=channel.result["body"])
|
self.assertEqual(403, channel.code, msg=channel.result["body"])
|
||||||
|
|
||||||
def test_get_member_list_mixed_memberships(self):
|
def test_get_member_list_mixed_memberships(self) -> None:
|
||||||
room_creator = "@some_other_guy:red"
|
room_creator = "@some_other_guy:red"
|
||||||
room_id = self.helper.create_room_as(room_creator)
|
room_id = self.helper.create_room_as(room_creator)
|
||||||
room_path = "/rooms/%s/members" % room_id
|
room_path = "/rooms/%s/members" % room_id
|
||||||
@ -594,26 +597,26 @@ class RoomsCreateTestCase(RoomBase):
|
|||||||
|
|
||||||
user_id = "@sid1:red"
|
user_id = "@sid1:red"
|
||||||
|
|
||||||
def test_post_room_no_keys(self):
|
def test_post_room_no_keys(self) -> None:
|
||||||
# POST with no config keys, expect new room id
|
# POST with no config keys, expect new room id
|
||||||
channel = self.make_request("POST", "/createRoom", "{}")
|
channel = self.make_request("POST", "/createRoom", "{}")
|
||||||
|
|
||||||
self.assertEqual(200, channel.code, channel.result)
|
self.assertEqual(200, channel.code, channel.result)
|
||||||
self.assertTrue("room_id" in channel.json_body)
|
self.assertTrue("room_id" in channel.json_body)
|
||||||
|
|
||||||
def test_post_room_visibility_key(self):
|
def test_post_room_visibility_key(self) -> None:
|
||||||
# POST with visibility config key, expect new room id
|
# POST with visibility config key, expect new room id
|
||||||
channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}')
|
channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}')
|
||||||
self.assertEqual(200, channel.code)
|
self.assertEqual(200, channel.code)
|
||||||
self.assertTrue("room_id" in channel.json_body)
|
self.assertTrue("room_id" in channel.json_body)
|
||||||
|
|
||||||
def test_post_room_custom_key(self):
|
def test_post_room_custom_key(self) -> None:
|
||||||
# POST with custom config keys, expect new room id
|
# POST with custom config keys, expect new room id
|
||||||
channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}')
|
channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}')
|
||||||
self.assertEqual(200, channel.code)
|
self.assertEqual(200, channel.code)
|
||||||
self.assertTrue("room_id" in channel.json_body)
|
self.assertTrue("room_id" in channel.json_body)
|
||||||
|
|
||||||
def test_post_room_known_and_unknown_keys(self):
|
def test_post_room_known_and_unknown_keys(self) -> None:
|
||||||
# POST with custom + known config keys, expect new room id
|
# POST with custom + known config keys, expect new room id
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"POST", "/createRoom", b'{"visibility":"private","custom":"things"}'
|
"POST", "/createRoom", b'{"visibility":"private","custom":"things"}'
|
||||||
@ -621,7 +624,7 @@ class RoomsCreateTestCase(RoomBase):
|
|||||||
self.assertEqual(200, channel.code)
|
self.assertEqual(200, channel.code)
|
||||||
self.assertTrue("room_id" in channel.json_body)
|
self.assertTrue("room_id" in channel.json_body)
|
||||||
|
|
||||||
def test_post_room_invalid_content(self):
|
def test_post_room_invalid_content(self) -> None:
|
||||||
# POST with invalid content / paths, expect 400
|
# POST with invalid content / paths, expect 400
|
||||||
channel = self.make_request("POST", "/createRoom", b'{"visibili')
|
channel = self.make_request("POST", "/createRoom", b'{"visibili')
|
||||||
self.assertEqual(400, channel.code)
|
self.assertEqual(400, channel.code)
|
||||||
@ -629,7 +632,7 @@ class RoomsCreateTestCase(RoomBase):
|
|||||||
channel = self.make_request("POST", "/createRoom", b'["hello"]')
|
channel = self.make_request("POST", "/createRoom", b'["hello"]')
|
||||||
self.assertEqual(400, channel.code)
|
self.assertEqual(400, channel.code)
|
||||||
|
|
||||||
def test_post_room_invitees_invalid_mxid(self):
|
def test_post_room_invitees_invalid_mxid(self) -> None:
|
||||||
# POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088
|
# POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088
|
||||||
# Note the trailing space in the MXID here!
|
# Note the trailing space in the MXID here!
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
@ -638,7 +641,7 @@ class RoomsCreateTestCase(RoomBase):
|
|||||||
self.assertEqual(400, channel.code)
|
self.assertEqual(400, channel.code)
|
||||||
|
|
||||||
@unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}})
|
@unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}})
|
||||||
def test_post_room_invitees_ratelimit(self):
|
def test_post_room_invitees_ratelimit(self) -> None:
|
||||||
"""Test that invites sent when creating a room are ratelimited by a RateLimiter,
|
"""Test that invites sent when creating a room are ratelimited by a RateLimiter,
|
||||||
which ratelimits them correctly, including by not limiting when the requester is
|
which ratelimits them correctly, including by not limiting when the requester is
|
||||||
exempt from ratelimiting.
|
exempt from ratelimiting.
|
||||||
@ -674,7 +677,7 @@ class RoomsCreateTestCase(RoomBase):
|
|||||||
channel = self.make_request("POST", "/createRoom", content)
|
channel = self.make_request("POST", "/createRoom", content)
|
||||||
self.assertEqual(200, channel.code)
|
self.assertEqual(200, channel.code)
|
||||||
|
|
||||||
def test_spam_checker_may_join_room(self):
|
def test_spam_checker_may_join_room(self) -> None:
|
||||||
"""Tests that the user_may_join_room spam checker callback is correctly bypassed
|
"""Tests that the user_may_join_room spam checker callback is correctly bypassed
|
||||||
when creating a new room.
|
when creating a new room.
|
||||||
"""
|
"""
|
||||||
@ -704,12 +707,12 @@ class RoomTopicTestCase(RoomBase):
|
|||||||
|
|
||||||
user_id = "@sid1:red"
|
user_id = "@sid1:red"
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
# create the room
|
# create the room
|
||||||
self.room_id = self.helper.create_room_as(self.user_id)
|
self.room_id = self.helper.create_room_as(self.user_id)
|
||||||
self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,)
|
self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,)
|
||||||
|
|
||||||
def test_invalid_puts(self):
|
def test_invalid_puts(self) -> None:
|
||||||
# missing keys or invalid json
|
# missing keys or invalid json
|
||||||
channel = self.make_request("PUT", self.path, "{}")
|
channel = self.make_request("PUT", self.path, "{}")
|
||||||
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
||||||
@ -736,7 +739,7 @@ class RoomTopicTestCase(RoomBase):
|
|||||||
channel = self.make_request("PUT", self.path, content)
|
channel = self.make_request("PUT", self.path, content)
|
||||||
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
||||||
|
|
||||||
def test_rooms_topic(self):
|
def test_rooms_topic(self) -> None:
|
||||||
# nothing should be there
|
# nothing should be there
|
||||||
channel = self.make_request("GET", self.path)
|
channel = self.make_request("GET", self.path)
|
||||||
self.assertEqual(404, channel.code, msg=channel.result["body"])
|
self.assertEqual(404, channel.code, msg=channel.result["body"])
|
||||||
@ -751,7 +754,7 @@ class RoomTopicTestCase(RoomBase):
|
|||||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||||
self.assert_dict(json.loads(content), channel.json_body)
|
self.assert_dict(json.loads(content), channel.json_body)
|
||||||
|
|
||||||
def test_rooms_topic_with_extra_keys(self):
|
def test_rooms_topic_with_extra_keys(self) -> None:
|
||||||
# valid put with extra keys
|
# valid put with extra keys
|
||||||
content = '{"topic":"Seasons","subtopic":"Summer"}'
|
content = '{"topic":"Seasons","subtopic":"Summer"}'
|
||||||
channel = self.make_request("PUT", self.path, content)
|
channel = self.make_request("PUT", self.path, content)
|
||||||
@ -768,10 +771,10 @@ class RoomMemberStateTestCase(RoomBase):
|
|||||||
|
|
||||||
user_id = "@sid1:red"
|
user_id = "@sid1:red"
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.room_id = self.helper.create_room_as(self.user_id)
|
self.room_id = self.helper.create_room_as(self.user_id)
|
||||||
|
|
||||||
def test_invalid_puts(self):
|
def test_invalid_puts(self) -> None:
|
||||||
path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id)
|
path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id)
|
||||||
# missing keys or invalid json
|
# missing keys or invalid json
|
||||||
channel = self.make_request("PUT", path, "{}")
|
channel = self.make_request("PUT", path, "{}")
|
||||||
@ -801,7 +804,7 @@ class RoomMemberStateTestCase(RoomBase):
|
|||||||
channel = self.make_request("PUT", path, content.encode("ascii"))
|
channel = self.make_request("PUT", path, content.encode("ascii"))
|
||||||
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
||||||
|
|
||||||
def test_rooms_members_self(self):
|
def test_rooms_members_self(self) -> None:
|
||||||
path = "/rooms/%s/state/m.room.member/%s" % (
|
path = "/rooms/%s/state/m.room.member/%s" % (
|
||||||
urlparse.quote(self.room_id),
|
urlparse.quote(self.room_id),
|
||||||
self.user_id,
|
self.user_id,
|
||||||
@ -812,13 +815,13 @@ class RoomMemberStateTestCase(RoomBase):
|
|||||||
channel = self.make_request("PUT", path, content.encode("ascii"))
|
channel = self.make_request("PUT", path, content.encode("ascii"))
|
||||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||||
|
|
||||||
channel = self.make_request("GET", path, None)
|
channel = self.make_request("GET", path, content=b"")
|
||||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||||
|
|
||||||
expected_response = {"membership": Membership.JOIN}
|
expected_response = {"membership": Membership.JOIN}
|
||||||
self.assertEqual(expected_response, channel.json_body)
|
self.assertEqual(expected_response, channel.json_body)
|
||||||
|
|
||||||
def test_rooms_members_other(self):
|
def test_rooms_members_other(self) -> None:
|
||||||
self.other_id = "@zzsid1:red"
|
self.other_id = "@zzsid1:red"
|
||||||
path = "/rooms/%s/state/m.room.member/%s" % (
|
path = "/rooms/%s/state/m.room.member/%s" % (
|
||||||
urlparse.quote(self.room_id),
|
urlparse.quote(self.room_id),
|
||||||
@ -830,11 +833,11 @@ class RoomMemberStateTestCase(RoomBase):
|
|||||||
channel = self.make_request("PUT", path, content)
|
channel = self.make_request("PUT", path, content)
|
||||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||||
|
|
||||||
channel = self.make_request("GET", path, None)
|
channel = self.make_request("GET", path, content=b"")
|
||||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||||
self.assertEqual(json.loads(content), channel.json_body)
|
self.assertEqual(json.loads(content), channel.json_body)
|
||||||
|
|
||||||
def test_rooms_members_other_custom_keys(self):
|
def test_rooms_members_other_custom_keys(self) -> None:
|
||||||
self.other_id = "@zzsid1:red"
|
self.other_id = "@zzsid1:red"
|
||||||
path = "/rooms/%s/state/m.room.member/%s" % (
|
path = "/rooms/%s/state/m.room.member/%s" % (
|
||||||
urlparse.quote(self.room_id),
|
urlparse.quote(self.room_id),
|
||||||
@ -849,7 +852,7 @@ class RoomMemberStateTestCase(RoomBase):
|
|||||||
channel = self.make_request("PUT", path, content)
|
channel = self.make_request("PUT", path, content)
|
||||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||||
|
|
||||||
channel = self.make_request("GET", path, None)
|
channel = self.make_request("GET", path, content=b"")
|
||||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||||
self.assertEqual(json.loads(content), channel.json_body)
|
self.assertEqual(json.loads(content), channel.json_body)
|
||||||
|
|
||||||
@ -866,7 +869,7 @@ class RoomInviteRatelimitTestCase(RoomBase):
|
|||||||
@unittest.override_config(
|
@unittest.override_config(
|
||||||
{"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}}
|
{"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}}
|
||||||
)
|
)
|
||||||
def test_invites_by_rooms_ratelimit(self):
|
def test_invites_by_rooms_ratelimit(self) -> None:
|
||||||
"""Tests that invites in a room are actually rate-limited."""
|
"""Tests that invites in a room are actually rate-limited."""
|
||||||
room_id = self.helper.create_room_as(self.user_id)
|
room_id = self.helper.create_room_as(self.user_id)
|
||||||
|
|
||||||
@ -878,7 +881,7 @@ class RoomInviteRatelimitTestCase(RoomBase):
|
|||||||
@unittest.override_config(
|
@unittest.override_config(
|
||||||
{"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
|
{"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
|
||||||
)
|
)
|
||||||
def test_invites_by_users_ratelimit(self):
|
def test_invites_by_users_ratelimit(self) -> None:
|
||||||
"""Tests that invites to a specific user are actually rate-limited."""
|
"""Tests that invites to a specific user are actually rate-limited."""
|
||||||
|
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
@ -897,7 +900,7 @@ class RoomJoinTestCase(RoomBase):
|
|||||||
room.register_servlets,
|
room.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, homeserver):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.user1 = self.register_user("thomas", "hackme")
|
self.user1 = self.register_user("thomas", "hackme")
|
||||||
self.tok1 = self.login("thomas", "hackme")
|
self.tok1 = self.login("thomas", "hackme")
|
||||||
|
|
||||||
@ -908,7 +911,7 @@ class RoomJoinTestCase(RoomBase):
|
|||||||
self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
|
self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
|
||||||
self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
|
self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
|
||||||
|
|
||||||
def test_spam_checker_may_join_room(self):
|
def test_spam_checker_may_join_room(self) -> None:
|
||||||
"""Tests that the user_may_join_room spam checker callback is correctly called
|
"""Tests that the user_may_join_room spam checker callback is correctly called
|
||||||
and blocks room joins when needed.
|
and blocks room joins when needed.
|
||||||
"""
|
"""
|
||||||
@ -975,8 +978,8 @@ class RoomJoinRatelimitTestCase(RoomBase):
|
|||||||
room.register_servlets,
|
room.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, homeserver):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
super().prepare(reactor, clock, homeserver)
|
super().prepare(reactor, clock, hs)
|
||||||
# profile changes expect that the user is actually registered
|
# profile changes expect that the user is actually registered
|
||||||
user = UserID.from_string(self.user_id)
|
user = UserID.from_string(self.user_id)
|
||||||
self.get_success(self.register_user(user.localpart, "supersecretpassword"))
|
self.get_success(self.register_user(user.localpart, "supersecretpassword"))
|
||||||
@ -984,7 +987,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
|
|||||||
@unittest.override_config(
|
@unittest.override_config(
|
||||||
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
|
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
|
||||||
)
|
)
|
||||||
def test_join_local_ratelimit(self):
|
def test_join_local_ratelimit(self) -> None:
|
||||||
"""Tests that local joins are actually rate-limited."""
|
"""Tests that local joins are actually rate-limited."""
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
self.helper.create_room_as(self.user_id)
|
self.helper.create_room_as(self.user_id)
|
||||||
@ -994,7 +997,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
|
|||||||
@unittest.override_config(
|
@unittest.override_config(
|
||||||
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
|
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
|
||||||
)
|
)
|
||||||
def test_join_local_ratelimit_profile_change(self):
|
def test_join_local_ratelimit_profile_change(self) -> None:
|
||||||
"""Tests that sending a profile update into all of the user's joined rooms isn't
|
"""Tests that sending a profile update into all of the user's joined rooms isn't
|
||||||
rate-limited by the rate-limiter on joins."""
|
rate-limited by the rate-limiter on joins."""
|
||||||
|
|
||||||
@ -1031,7 +1034,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
|
|||||||
@unittest.override_config(
|
@unittest.override_config(
|
||||||
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
|
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
|
||||||
)
|
)
|
||||||
def test_join_local_ratelimit_idempotent(self):
|
def test_join_local_ratelimit_idempotent(self) -> None:
|
||||||
"""Tests that the room join endpoints remain idempotent despite rate-limiting
|
"""Tests that the room join endpoints remain idempotent despite rate-limiting
|
||||||
on room joins."""
|
on room joins."""
|
||||||
room_id = self.helper.create_room_as(self.user_id)
|
room_id = self.helper.create_room_as(self.user_id)
|
||||||
@ -1056,7 +1059,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
|
|||||||
"autocreate_auto_join_rooms": True,
|
"autocreate_auto_join_rooms": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
def test_autojoin_rooms(self):
|
def test_autojoin_rooms(self) -> None:
|
||||||
user_id = self.register_user("testuser", "password")
|
user_id = self.register_user("testuser", "password")
|
||||||
|
|
||||||
# Check that the new user successfully joined the four rooms
|
# Check that the new user successfully joined the four rooms
|
||||||
@ -1071,10 +1074,10 @@ class RoomMessagesTestCase(RoomBase):
|
|||||||
|
|
||||||
user_id = "@sid1:red"
|
user_id = "@sid1:red"
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.room_id = self.helper.create_room_as(self.user_id)
|
self.room_id = self.helper.create_room_as(self.user_id)
|
||||||
|
|
||||||
def test_invalid_puts(self):
|
def test_invalid_puts(self) -> None:
|
||||||
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
|
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
|
||||||
# missing keys or invalid json
|
# missing keys or invalid json
|
||||||
channel = self.make_request("PUT", path, b"{}")
|
channel = self.make_request("PUT", path, b"{}")
|
||||||
@ -1095,7 +1098,7 @@ class RoomMessagesTestCase(RoomBase):
|
|||||||
channel = self.make_request("PUT", path, b"")
|
channel = self.make_request("PUT", path, b"")
|
||||||
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
||||||
|
|
||||||
def test_rooms_messages_sent(self):
|
def test_rooms_messages_sent(self) -> None:
|
||||||
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
|
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
|
||||||
|
|
||||||
content = b'{"body":"test","msgtype":{"type":"a"}}'
|
content = b'{"body":"test","msgtype":{"type":"a"}}'
|
||||||
@ -1119,11 +1122,11 @@ class RoomInitialSyncTestCase(RoomBase):
|
|||||||
|
|
||||||
user_id = "@sid1:red"
|
user_id = "@sid1:red"
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
# create the room
|
# create the room
|
||||||
self.room_id = self.helper.create_room_as(self.user_id)
|
self.room_id = self.helper.create_room_as(self.user_id)
|
||||||
|
|
||||||
def test_initial_sync(self):
|
def test_initial_sync(self) -> None:
|
||||||
channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id)
|
channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id)
|
||||||
self.assertEqual(200, channel.code)
|
self.assertEqual(200, channel.code)
|
||||||
|
|
||||||
@ -1131,7 +1134,7 @@ class RoomInitialSyncTestCase(RoomBase):
|
|||||||
self.assertEqual("join", channel.json_body["membership"])
|
self.assertEqual("join", channel.json_body["membership"])
|
||||||
|
|
||||||
# Room state is easier to assert on if we unpack it into a dict
|
# Room state is easier to assert on if we unpack it into a dict
|
||||||
state = {}
|
state: JsonDict = {}
|
||||||
for event in channel.json_body["state"]:
|
for event in channel.json_body["state"]:
|
||||||
if "state_key" not in event:
|
if "state_key" not in event:
|
||||||
continue
|
continue
|
||||||
@ -1160,10 +1163,10 @@ class RoomMessageListTestCase(RoomBase):
|
|||||||
|
|
||||||
user_id = "@sid1:red"
|
user_id = "@sid1:red"
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.room_id = self.helper.create_room_as(self.user_id)
|
self.room_id = self.helper.create_room_as(self.user_id)
|
||||||
|
|
||||||
def test_topo_token_is_accepted(self):
|
def test_topo_token_is_accepted(self) -> None:
|
||||||
token = "t1-0_0_0_0_0_0_0_0_0"
|
token = "t1-0_0_0_0_0_0_0_0_0"
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
|
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
|
||||||
@ -1174,7 +1177,7 @@ class RoomMessageListTestCase(RoomBase):
|
|||||||
self.assertTrue("chunk" in channel.json_body)
|
self.assertTrue("chunk" in channel.json_body)
|
||||||
self.assertTrue("end" in channel.json_body)
|
self.assertTrue("end" in channel.json_body)
|
||||||
|
|
||||||
def test_stream_token_is_accepted_for_fwd_pagianation(self):
|
def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None:
|
||||||
token = "s0_0_0_0_0_0_0_0_0"
|
token = "s0_0_0_0_0_0_0_0_0"
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
|
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
|
||||||
@ -1185,7 +1188,7 @@ class RoomMessageListTestCase(RoomBase):
|
|||||||
self.assertTrue("chunk" in channel.json_body)
|
self.assertTrue("chunk" in channel.json_body)
|
||||||
self.assertTrue("end" in channel.json_body)
|
self.assertTrue("end" in channel.json_body)
|
||||||
|
|
||||||
def test_room_messages_purge(self):
|
def test_room_messages_purge(self) -> None:
|
||||||
store = self.hs.get_datastores().main
|
store = self.hs.get_datastores().main
|
||||||
pagination_handler = self.hs.get_pagination_handler()
|
pagination_handler = self.hs.get_pagination_handler()
|
||||||
|
|
||||||
@ -1278,10 +1281,10 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
|
|||||||
user_id = True
|
user_id = True
|
||||||
hijack_auth = False
|
hijack_auth = False
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
|
|
||||||
# Register the user who does the searching
|
# Register the user who does the searching
|
||||||
self.user_id = self.register_user("user", "pass")
|
self.user_id2 = self.register_user("user", "pass")
|
||||||
self.access_token = self.login("user", "pass")
|
self.access_token = self.login("user", "pass")
|
||||||
|
|
||||||
# Register the user who sends the message
|
# Register the user who sends the message
|
||||||
@ -1289,12 +1292,12 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
|
|||||||
self.other_access_token = self.login("otheruser", "pass")
|
self.other_access_token = self.login("otheruser", "pass")
|
||||||
|
|
||||||
# Create a room
|
# Create a room
|
||||||
self.room = self.helper.create_room_as(self.user_id, tok=self.access_token)
|
self.room = self.helper.create_room_as(self.user_id2, tok=self.access_token)
|
||||||
|
|
||||||
# Invite the other person
|
# Invite the other person
|
||||||
self.helper.invite(
|
self.helper.invite(
|
||||||
room=self.room,
|
room=self.room,
|
||||||
src=self.user_id,
|
src=self.user_id2,
|
||||||
tok=self.access_token,
|
tok=self.access_token,
|
||||||
targ=self.other_user_id,
|
targ=self.other_user_id,
|
||||||
)
|
)
|
||||||
@ -1304,7 +1307,7 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
|
|||||||
room=self.room, user=self.other_user_id, tok=self.other_access_token
|
room=self.room, user=self.other_user_id, tok=self.other_access_token
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_finds_message(self):
|
def test_finds_message(self) -> None:
|
||||||
"""
|
"""
|
||||||
The search functionality will search for content in messages if asked to
|
The search functionality will search for content in messages if asked to
|
||||||
do so.
|
do so.
|
||||||
@ -1333,7 +1336,7 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
|
|||||||
# No context was requested, so we should get none.
|
# No context was requested, so we should get none.
|
||||||
self.assertEqual(results["results"][0]["context"], {})
|
self.assertEqual(results["results"][0]["context"], {})
|
||||||
|
|
||||||
def test_include_context(self):
|
def test_include_context(self) -> None:
|
||||||
"""
|
"""
|
||||||
When event_context includes include_profile, profile information will be
|
When event_context includes include_profile, profile information will be
|
||||||
included in the search response.
|
included in the search response.
|
||||||
@ -1379,7 +1382,7 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
|
|||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
|
|
||||||
self.url = b"/_matrix/client/r0/publicRooms"
|
self.url = b"/_matrix/client/r0/publicRooms"
|
||||||
|
|
||||||
@ -1389,11 +1392,11 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
return self.hs
|
return self.hs
|
||||||
|
|
||||||
def test_restricted_no_auth(self):
|
def test_restricted_no_auth(self) -> None:
|
||||||
channel = self.make_request("GET", self.url)
|
channel = self.make_request("GET", self.url)
|
||||||
self.assertEqual(channel.code, 401, channel.result)
|
self.assertEqual(channel.code, 401, channel.result)
|
||||||
|
|
||||||
def test_restricted_auth(self):
|
def test_restricted_auth(self) -> None:
|
||||||
self.register_user("user", "pass")
|
self.register_user("user", "pass")
|
||||||
tok = self.login("user", "pass")
|
tok = self.login("user", "pass")
|
||||||
|
|
||||||
@ -1412,19 +1415,19 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
|
|||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
return self.setup_test_homeserver(federation_client=Mock())
|
return self.setup_test_homeserver(federation_client=Mock())
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.register_user("user", "pass")
|
self.register_user("user", "pass")
|
||||||
self.token = self.login("user", "pass")
|
self.token = self.login("user", "pass")
|
||||||
|
|
||||||
self.federation_client = hs.get_federation_client()
|
self.federation_client = hs.get_federation_client()
|
||||||
|
|
||||||
def test_simple(self):
|
def test_simple(self) -> None:
|
||||||
"Simple test for searching rooms over federation"
|
"Simple test for searching rooms over federation"
|
||||||
self.federation_client.get_public_rooms.side_effect = (
|
self.federation_client.get_public_rooms.side_effect = lambda *a, **k: defer.succeed( # type: ignore[attr-defined]
|
||||||
lambda *a, **k: defer.succeed({})
|
{}
|
||||||
)
|
)
|
||||||
|
|
||||||
search_filter = {"generic_search_term": "foobar"}
|
search_filter = {"generic_search_term": "foobar"}
|
||||||
@ -1437,7 +1440,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
|
|
||||||
self.federation_client.get_public_rooms.assert_called_once_with(
|
self.federation_client.get_public_rooms.assert_called_once_with( # type: ignore[attr-defined]
|
||||||
"testserv",
|
"testserv",
|
||||||
limit=100,
|
limit=100,
|
||||||
since_token=None,
|
since_token=None,
|
||||||
@ -1446,12 +1449,12 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
|
|||||||
third_party_instance_id=None,
|
third_party_instance_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_fallback(self):
|
def test_fallback(self) -> None:
|
||||||
"Test that searching public rooms over federation falls back if it gets a 404"
|
"Test that searching public rooms over federation falls back if it gets a 404"
|
||||||
|
|
||||||
# The `get_public_rooms` should be called again if the first call fails
|
# The `get_public_rooms` should be called again if the first call fails
|
||||||
# with a 404, when using search filters.
|
# with a 404, when using search filters.
|
||||||
self.federation_client.get_public_rooms.side_effect = (
|
self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined]
|
||||||
HttpResponseException(404, "Not Found", b""),
|
HttpResponseException(404, "Not Found", b""),
|
||||||
defer.succeed({}),
|
defer.succeed({}),
|
||||||
)
|
)
|
||||||
@ -1466,7 +1469,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
|
|
||||||
self.federation_client.get_public_rooms.assert_has_calls(
|
self.federation_client.get_public_rooms.assert_has_calls( # type: ignore[attr-defined]
|
||||||
[
|
[
|
||||||
call(
|
call(
|
||||||
"testserv",
|
"testserv",
|
||||||
@ -1497,14 +1500,14 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
|
|||||||
profile.register_servlets,
|
profile.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
config = self.default_config()
|
config = self.default_config()
|
||||||
config["allow_per_room_profiles"] = False
|
config["allow_per_room_profiles"] = False
|
||||||
self.hs = self.setup_test_homeserver(config=config)
|
self.hs = self.setup_test_homeserver(config=config)
|
||||||
|
|
||||||
return self.hs
|
return self.hs
|
||||||
|
|
||||||
def prepare(self, reactor, clock, homeserver):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.user_id = self.register_user("test", "test")
|
self.user_id = self.register_user("test", "test")
|
||||||
self.tok = self.login("test", "test")
|
self.tok = self.login("test", "test")
|
||||||
|
|
||||||
@ -1522,7 +1525,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
|
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
|
||||||
|
|
||||||
def test_per_room_profile_forbidden(self):
|
def test_per_room_profile_forbidden(self) -> None:
|
||||||
data = {"membership": "join", "displayname": "other test user"}
|
data = {"membership": "join", "displayname": "other test user"}
|
||||||
request_data = json.dumps(data)
|
request_data = json.dumps(data)
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
@ -1557,7 +1560,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
|
|||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, homeserver):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.creator = self.register_user("creator", "test")
|
self.creator = self.register_user("creator", "test")
|
||||||
self.creator_tok = self.login("creator", "test")
|
self.creator_tok = self.login("creator", "test")
|
||||||
|
|
||||||
@ -1566,7 +1569,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok)
|
self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok)
|
||||||
|
|
||||||
def test_join_reason(self):
|
def test_join_reason(self) -> None:
|
||||||
reason = "hello"
|
reason = "hello"
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"POST",
|
"POST",
|
||||||
@ -1578,7 +1581,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self._check_for_reason(reason)
|
self._check_for_reason(reason)
|
||||||
|
|
||||||
def test_leave_reason(self):
|
def test_leave_reason(self) -> None:
|
||||||
self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
|
self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
|
||||||
|
|
||||||
reason = "hello"
|
reason = "hello"
|
||||||
@ -1592,7 +1595,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self._check_for_reason(reason)
|
self._check_for_reason(reason)
|
||||||
|
|
||||||
def test_kick_reason(self):
|
def test_kick_reason(self) -> None:
|
||||||
self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
|
self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
|
||||||
|
|
||||||
reason = "hello"
|
reason = "hello"
|
||||||
@ -1606,7 +1609,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self._check_for_reason(reason)
|
self._check_for_reason(reason)
|
||||||
|
|
||||||
def test_ban_reason(self):
|
def test_ban_reason(self) -> None:
|
||||||
self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
|
self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
|
||||||
|
|
||||||
reason = "hello"
|
reason = "hello"
|
||||||
@ -1620,7 +1623,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self._check_for_reason(reason)
|
self._check_for_reason(reason)
|
||||||
|
|
||||||
def test_unban_reason(self):
|
def test_unban_reason(self) -> None:
|
||||||
reason = "hello"
|
reason = "hello"
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"POST",
|
"POST",
|
||||||
@ -1632,7 +1635,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self._check_for_reason(reason)
|
self._check_for_reason(reason)
|
||||||
|
|
||||||
def test_invite_reason(self):
|
def test_invite_reason(self) -> None:
|
||||||
reason = "hello"
|
reason = "hello"
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"POST",
|
"POST",
|
||||||
@ -1644,7 +1647,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self._check_for_reason(reason)
|
self._check_for_reason(reason)
|
||||||
|
|
||||||
def test_reject_invite_reason(self):
|
def test_reject_invite_reason(self) -> None:
|
||||||
self.helper.invite(
|
self.helper.invite(
|
||||||
self.room_id,
|
self.room_id,
|
||||||
src=self.creator,
|
src=self.creator,
|
||||||
@ -1663,7 +1666,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self._check_for_reason(reason)
|
self._check_for_reason(reason)
|
||||||
|
|
||||||
def _check_for_reason(self, reason):
|
def _check_for_reason(self, reason: str) -> None:
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET",
|
"GET",
|
||||||
"/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format(
|
"/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format(
|
||||||
@ -1704,12 +1707,12 @@ class LabelsTestCase(unittest.HomeserverTestCase):
|
|||||||
"org.matrix.not_labels": ["#notfun"],
|
"org.matrix.not_labels": ["#notfun"],
|
||||||
}
|
}
|
||||||
|
|
||||||
def prepare(self, reactor, clock, homeserver):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.user_id = self.register_user("test", "test")
|
self.user_id = self.register_user("test", "test")
|
||||||
self.tok = self.login("test", "test")
|
self.tok = self.login("test", "test")
|
||||||
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
|
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
|
||||||
|
|
||||||
def test_context_filter_labels(self):
|
def test_context_filter_labels(self) -> None:
|
||||||
"""Test that we can filter by a label on a /context request."""
|
"""Test that we can filter by a label on a /context request."""
|
||||||
event_id = self._send_labelled_messages_in_room()
|
event_id = self._send_labelled_messages_in_room()
|
||||||
|
|
||||||
@ -1739,7 +1742,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
|
|||||||
events_after[0]["content"]["body"], "with right label", events_after[0]
|
events_after[0]["content"]["body"], "with right label", events_after[0]
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_context_filter_not_labels(self):
|
def test_context_filter_not_labels(self) -> None:
|
||||||
"""Test that we can filter by the absence of a label on a /context request."""
|
"""Test that we can filter by the absence of a label on a /context request."""
|
||||||
event_id = self._send_labelled_messages_in_room()
|
event_id = self._send_labelled_messages_in_room()
|
||||||
|
|
||||||
@ -1772,7 +1775,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
|
|||||||
events_after[1]["content"]["body"], "with two wrong labels", events_after[1]
|
events_after[1]["content"]["body"], "with two wrong labels", events_after[1]
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_context_filter_labels_not_labels(self):
|
def test_context_filter_labels_not_labels(self) -> None:
|
||||||
"""Test that we can filter by both a label and the absence of another label on a
|
"""Test that we can filter by both a label and the absence of another label on a
|
||||||
/context request.
|
/context request.
|
||||||
"""
|
"""
|
||||||
@ -1801,7 +1804,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
|
|||||||
events_after[0]["content"]["body"], "with wrong label", events_after[0]
|
events_after[0]["content"]["body"], "with wrong label", events_after[0]
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_messages_filter_labels(self):
|
def test_messages_filter_labels(self) -> None:
|
||||||
"""Test that we can filter by a label on a /messages request."""
|
"""Test that we can filter by a label on a /messages request."""
|
||||||
self._send_labelled_messages_in_room()
|
self._send_labelled_messages_in_room()
|
||||||
|
|
||||||
@ -1818,7 +1821,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
|
|||||||
self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
|
self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
|
||||||
self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
|
self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
|
||||||
|
|
||||||
def test_messages_filter_not_labels(self):
|
def test_messages_filter_not_labels(self) -> None:
|
||||||
"""Test that we can filter by the absence of a label on a /messages request."""
|
"""Test that we can filter by the absence of a label on a /messages request."""
|
||||||
self._send_labelled_messages_in_room()
|
self._send_labelled_messages_in_room()
|
||||||
|
|
||||||
@ -1839,7 +1842,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
|
|||||||
events[3]["content"]["body"], "with two wrong labels", events[3]
|
events[3]["content"]["body"], "with two wrong labels", events[3]
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_messages_filter_labels_not_labels(self):
|
def test_messages_filter_labels_not_labels(self) -> None:
|
||||||
"""Test that we can filter by both a label and the absence of another label on a
|
"""Test that we can filter by both a label and the absence of another label on a
|
||||||
/messages request.
|
/messages request.
|
||||||
"""
|
"""
|
||||||
@ -1862,7 +1865,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
|
|||||||
self.assertEqual(len(events), 1, [event["content"] for event in events])
|
self.assertEqual(len(events), 1, [event["content"] for event in events])
|
||||||
self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
|
self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
|
||||||
|
|
||||||
def test_search_filter_labels(self):
|
def test_search_filter_labels(self) -> None:
|
||||||
"""Test that we can filter by a label on a /search request."""
|
"""Test that we can filter by a label on a /search request."""
|
||||||
request_data = json.dumps(
|
request_data = json.dumps(
|
||||||
{
|
{
|
||||||
@ -1899,7 +1902,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
|
|||||||
results[1]["result"]["content"]["body"],
|
results[1]["result"]["content"]["body"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_search_filter_not_labels(self):
|
def test_search_filter_not_labels(self) -> None:
|
||||||
"""Test that we can filter by the absence of a label on a /search request."""
|
"""Test that we can filter by the absence of a label on a /search request."""
|
||||||
request_data = json.dumps(
|
request_data = json.dumps(
|
||||||
{
|
{
|
||||||
@ -1946,7 +1949,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
|
|||||||
results[3]["result"]["content"]["body"],
|
results[3]["result"]["content"]["body"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_search_filter_labels_not_labels(self):
|
def test_search_filter_labels_not_labels(self) -> None:
|
||||||
"""Test that we can filter by both a label and the absence of another label on a
|
"""Test that we can filter by both a label and the absence of another label on a
|
||||||
/search request.
|
/search request.
|
||||||
"""
|
"""
|
||||||
@ -1980,7 +1983,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
|
|||||||
results[0]["result"]["content"]["body"],
|
results[0]["result"]["content"]["body"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def _send_labelled_messages_in_room(self):
|
def _send_labelled_messages_in_room(self) -> str:
|
||||||
"""Sends several messages to a room with different labels (or without any) to test
|
"""Sends several messages to a room with different labels (or without any) to test
|
||||||
filtering by label.
|
filtering by label.
|
||||||
Returns:
|
Returns:
|
||||||
@ -2056,12 +2059,12 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
|||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def default_config(self):
|
def default_config(self) -> Dict[str, Any]:
|
||||||
config = super().default_config()
|
config = super().default_config()
|
||||||
config["experimental_features"] = {"msc3440_enabled": True}
|
config["experimental_features"] = {"msc3440_enabled": True}
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def prepare(self, reactor, clock, homeserver):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.user_id = self.register_user("test", "test")
|
self.user_id = self.register_user("test", "test")
|
||||||
self.tok = self.login("test", "test")
|
self.tok = self.login("test", "test")
|
||||||
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
|
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
|
||||||
@ -2136,7 +2139,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
return channel.json_body["chunk"]
|
return channel.json_body["chunk"]
|
||||||
|
|
||||||
def test_filter_relation_senders(self):
|
def test_filter_relation_senders(self) -> None:
|
||||||
# Messages which second user reacted to.
|
# Messages which second user reacted to.
|
||||||
filter = {"io.element.relation_senders": [self.second_user_id]}
|
filter = {"io.element.relation_senders": [self.second_user_id]}
|
||||||
chunk = self._filter_messages(filter)
|
chunk = self._filter_messages(filter)
|
||||||
@ -2159,7 +2162,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
|||||||
[c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
|
[c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_filter_relation_type(self):
|
def test_filter_relation_type(self) -> None:
|
||||||
# Messages which have annotations.
|
# Messages which have annotations.
|
||||||
filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]}
|
filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]}
|
||||||
chunk = self._filter_messages(filter)
|
chunk = self._filter_messages(filter)
|
||||||
@ -2185,7 +2188,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
|||||||
[c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
|
[c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_filter_relation_senders_and_type(self):
|
def test_filter_relation_senders_and_type(self) -> None:
|
||||||
# Messages which second user reacted to.
|
# Messages which second user reacted to.
|
||||||
filter = {
|
filter = {
|
||||||
"io.element.relation_senders": [self.second_user_id],
|
"io.element.relation_senders": [self.second_user_id],
|
||||||
@ -2205,7 +2208,7 @@ class ContextTestCase(unittest.HomeserverTestCase):
|
|||||||
account.register_servlets,
|
account.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, homeserver):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.user_id = self.register_user("user", "password")
|
self.user_id = self.register_user("user", "password")
|
||||||
self.tok = self.login("user", "password")
|
self.tok = self.login("user", "password")
|
||||||
self.room_id = self.helper.create_room_as(
|
self.room_id = self.helper.create_room_as(
|
||||||
@ -2218,7 +2221,7 @@ class ContextTestCase(unittest.HomeserverTestCase):
|
|||||||
self.helper.invite(self.room_id, self.user_id, self.other_user_id, tok=self.tok)
|
self.helper.invite(self.room_id, self.user_id, self.other_user_id, tok=self.tok)
|
||||||
self.helper.join(self.room_id, self.other_user_id, tok=self.other_tok)
|
self.helper.join(self.room_id, self.other_user_id, tok=self.other_tok)
|
||||||
|
|
||||||
def test_erased_sender(self):
|
def test_erased_sender(self) -> None:
|
||||||
"""Test that an erasure request results in the requester's events being hidden
|
"""Test that an erasure request results in the requester's events being hidden
|
||||||
from any new member of the room.
|
from any new member of the room.
|
||||||
"""
|
"""
|
||||||
@ -2332,7 +2335,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
|
|||||||
room.register_servlets,
|
room.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, homeserver):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.room_owner = self.register_user("room_owner", "test")
|
self.room_owner = self.register_user("room_owner", "test")
|
||||||
self.room_owner_tok = self.login("room_owner", "test")
|
self.room_owner_tok = self.login("room_owner", "test")
|
||||||
|
|
||||||
@ -2340,17 +2343,17 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
|
|||||||
self.room_owner, tok=self.room_owner_tok
|
self.room_owner, tok=self.room_owner_tok
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_no_aliases(self):
|
def test_no_aliases(self) -> None:
|
||||||
res = self._get_aliases(self.room_owner_tok)
|
res = self._get_aliases(self.room_owner_tok)
|
||||||
self.assertEqual(res["aliases"], [])
|
self.assertEqual(res["aliases"], [])
|
||||||
|
|
||||||
def test_not_in_room(self):
|
def test_not_in_room(self) -> None:
|
||||||
self.register_user("user", "test")
|
self.register_user("user", "test")
|
||||||
user_tok = self.login("user", "test")
|
user_tok = self.login("user", "test")
|
||||||
res = self._get_aliases(user_tok, expected_code=403)
|
res = self._get_aliases(user_tok, expected_code=403)
|
||||||
self.assertEqual(res["errcode"], "M_FORBIDDEN")
|
self.assertEqual(res["errcode"], "M_FORBIDDEN")
|
||||||
|
|
||||||
def test_admin_user(self):
|
def test_admin_user(self) -> None:
|
||||||
alias1 = self._random_alias()
|
alias1 = self._random_alias()
|
||||||
self._set_alias_via_directory(alias1)
|
self._set_alias_via_directory(alias1)
|
||||||
|
|
||||||
@ -2360,7 +2363,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
|
|||||||
res = self._get_aliases(user_tok)
|
res = self._get_aliases(user_tok)
|
||||||
self.assertEqual(res["aliases"], [alias1])
|
self.assertEqual(res["aliases"], [alias1])
|
||||||
|
|
||||||
def test_with_aliases(self):
|
def test_with_aliases(self) -> None:
|
||||||
alias1 = self._random_alias()
|
alias1 = self._random_alias()
|
||||||
alias2 = self._random_alias()
|
alias2 = self._random_alias()
|
||||||
|
|
||||||
@ -2370,7 +2373,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
|
|||||||
res = self._get_aliases(self.room_owner_tok)
|
res = self._get_aliases(self.room_owner_tok)
|
||||||
self.assertEqual(set(res["aliases"]), {alias1, alias2})
|
self.assertEqual(set(res["aliases"]), {alias1, alias2})
|
||||||
|
|
||||||
def test_peekable_room(self):
|
def test_peekable_room(self) -> None:
|
||||||
alias1 = self._random_alias()
|
alias1 = self._random_alias()
|
||||||
self._set_alias_via_directory(alias1)
|
self._set_alias_via_directory(alias1)
|
||||||
|
|
||||||
@ -2404,7 +2407,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
|
|||||||
def _random_alias(self) -> str:
|
def _random_alias(self) -> str:
|
||||||
return RoomAlias(random_string(5), self.hs.hostname).to_string()
|
return RoomAlias(random_string(5), self.hs.hostname).to_string()
|
||||||
|
|
||||||
def _set_alias_via_directory(self, alias: str, expected_code: int = 200):
|
def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None:
|
||||||
url = "/_matrix/client/r0/directory/room/" + alias
|
url = "/_matrix/client/r0/directory/room/" + alias
|
||||||
data = {"room_id": self.room_id}
|
data = {"room_id": self.room_id}
|
||||||
request_data = json.dumps(data)
|
request_data = json.dumps(data)
|
||||||
@ -2423,7 +2426,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
|
|||||||
room.register_servlets,
|
room.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, homeserver):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.room_owner = self.register_user("room_owner", "test")
|
self.room_owner = self.register_user("room_owner", "test")
|
||||||
self.room_owner_tok = self.login("room_owner", "test")
|
self.room_owner_tok = self.login("room_owner", "test")
|
||||||
|
|
||||||
@ -2434,7 +2437,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
|
|||||||
self.alias = "#alias:test"
|
self.alias = "#alias:test"
|
||||||
self._set_alias_via_directory(self.alias)
|
self._set_alias_via_directory(self.alias)
|
||||||
|
|
||||||
def _set_alias_via_directory(self, alias: str, expected_code: int = 200):
|
def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None:
|
||||||
url = "/_matrix/client/r0/directory/room/" + alias
|
url = "/_matrix/client/r0/directory/room/" + alias
|
||||||
data = {"room_id": self.room_id}
|
data = {"room_id": self.room_id}
|
||||||
request_data = json.dumps(data)
|
request_data = json.dumps(data)
|
||||||
@ -2456,7 +2459,9 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
|
|||||||
self.assertIsInstance(res, dict)
|
self.assertIsInstance(res, dict)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict:
|
def _set_canonical_alias(
|
||||||
|
self, content: JsonDict, expected_code: int = 200
|
||||||
|
) -> JsonDict:
|
||||||
"""Calls the endpoint under test. returns the json response object."""
|
"""Calls the endpoint under test. returns the json response object."""
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"PUT",
|
"PUT",
|
||||||
@ -2469,7 +2474,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
|
|||||||
self.assertIsInstance(res, dict)
|
self.assertIsInstance(res, dict)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def test_canonical_alias(self):
|
def test_canonical_alias(self) -> None:
|
||||||
"""Test a basic alias message."""
|
"""Test a basic alias message."""
|
||||||
# There is no canonical alias to start with.
|
# There is no canonical alias to start with.
|
||||||
self._get_canonical_alias(expected_code=404)
|
self._get_canonical_alias(expected_code=404)
|
||||||
@ -2488,7 +2493,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
|
|||||||
res = self._get_canonical_alias()
|
res = self._get_canonical_alias()
|
||||||
self.assertEqual(res, {})
|
self.assertEqual(res, {})
|
||||||
|
|
||||||
def test_alt_aliases(self):
|
def test_alt_aliases(self) -> None:
|
||||||
"""Test a canonical alias message with alt_aliases."""
|
"""Test a canonical alias message with alt_aliases."""
|
||||||
# Create an alias.
|
# Create an alias.
|
||||||
self._set_canonical_alias({"alt_aliases": [self.alias]})
|
self._set_canonical_alias({"alt_aliases": [self.alias]})
|
||||||
@ -2504,7 +2509,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
|
|||||||
res = self._get_canonical_alias()
|
res = self._get_canonical_alias()
|
||||||
self.assertEqual(res, {})
|
self.assertEqual(res, {})
|
||||||
|
|
||||||
def test_alias_alt_aliases(self):
|
def test_alias_alt_aliases(self) -> None:
|
||||||
"""Test a canonical alias message with an alias and alt_aliases."""
|
"""Test a canonical alias message with an alias and alt_aliases."""
|
||||||
# Create an alias.
|
# Create an alias.
|
||||||
self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
|
self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
|
||||||
@ -2520,7 +2525,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
|
|||||||
res = self._get_canonical_alias()
|
res = self._get_canonical_alias()
|
||||||
self.assertEqual(res, {})
|
self.assertEqual(res, {})
|
||||||
|
|
||||||
def test_partial_modify(self):
|
def test_partial_modify(self) -> None:
|
||||||
"""Test removing only the alt_aliases."""
|
"""Test removing only the alt_aliases."""
|
||||||
# Create an alias.
|
# Create an alias.
|
||||||
self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
|
self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
|
||||||
@ -2536,7 +2541,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
|
|||||||
res = self._get_canonical_alias()
|
res = self._get_canonical_alias()
|
||||||
self.assertEqual(res, {"alias": self.alias})
|
self.assertEqual(res, {"alias": self.alias})
|
||||||
|
|
||||||
def test_add_alias(self):
|
def test_add_alias(self) -> None:
|
||||||
"""Test removing only the alt_aliases."""
|
"""Test removing only the alt_aliases."""
|
||||||
# Create an additional alias.
|
# Create an additional alias.
|
||||||
second_alias = "#second:test"
|
second_alias = "#second:test"
|
||||||
@ -2556,7 +2561,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
|
|||||||
res, {"alias": self.alias, "alt_aliases": [self.alias, second_alias]}
|
res, {"alias": self.alias, "alt_aliases": [self.alias, second_alias]}
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_bad_data(self):
|
def test_bad_data(self) -> None:
|
||||||
"""Invalid data for alt_aliases should cause errors."""
|
"""Invalid data for alt_aliases should cause errors."""
|
||||||
self._set_canonical_alias({"alt_aliases": "@bad:test"}, expected_code=400)
|
self._set_canonical_alias({"alt_aliases": "@bad:test"}, expected_code=400)
|
||||||
self._set_canonical_alias({"alt_aliases": None}, expected_code=400)
|
self._set_canonical_alias({"alt_aliases": None}, expected_code=400)
|
||||||
@ -2566,7 +2571,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
|
|||||||
self._set_canonical_alias({"alt_aliases": True}, expected_code=400)
|
self._set_canonical_alias({"alt_aliases": True}, expected_code=400)
|
||||||
self._set_canonical_alias({"alt_aliases": {}}, expected_code=400)
|
self._set_canonical_alias({"alt_aliases": {}}, expected_code=400)
|
||||||
|
|
||||||
def test_bad_alias(self):
|
def test_bad_alias(self) -> None:
|
||||||
"""An alias which does not point to the room raises a SynapseError."""
|
"""An alias which does not point to the room raises a SynapseError."""
|
||||||
self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400)
|
self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400)
|
||||||
self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400)
|
self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400)
|
||||||
@ -2580,13 +2585,13 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
|
|||||||
room.register_servlets,
|
room.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, homeserver):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.user_id = self.register_user("thomas", "hackme")
|
self.user_id = self.register_user("thomas", "hackme")
|
||||||
self.tok = self.login("thomas", "hackme")
|
self.tok = self.login("thomas", "hackme")
|
||||||
|
|
||||||
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
|
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
|
||||||
|
|
||||||
def test_threepid_invite_spamcheck(self):
|
def test_threepid_invite_spamcheck(self) -> None:
|
||||||
# Mock a few functions to prevent the test from failing due to failing to talk to
|
# Mock a few functions to prevent the test from failing due to failing to talk to
|
||||||
# a remote IS. We keep the mock for _mock_make_and_store_3pid_invite around so we
|
# a remote IS. We keep the mock for _mock_make_and_store_3pid_invite around so we
|
||||||
# can check its call_count later on during the test.
|
# can check its call_count later on during the test.
|
||||||
|
@ -12,16 +12,22 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import threading
|
import threading
|
||||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, LoginType, Membership
|
from synapse.api.constants import EventTypes, LoginType, Membership
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
|
from synapse.api.room_versions import RoomVersion
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
|
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.rest.client import account, login, profile, room
|
from synapse.rest.client import account, login, profile, room
|
||||||
|
from synapse.server import HomeServer
|
||||||
from synapse.types import JsonDict, Requester, StateMap
|
from synapse.types import JsonDict, Requester, StateMap
|
||||||
|
from synapse.util import Clock
|
||||||
from synapse.util.frozenutils import unfreeze
|
from synapse.util.frozenutils import unfreeze
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
@ -34,7 +40,7 @@ thread_local = threading.local()
|
|||||||
|
|
||||||
|
|
||||||
class LegacyThirdPartyRulesTestModule:
|
class LegacyThirdPartyRulesTestModule:
|
||||||
def __init__(self, config: Dict, module_api: "ModuleApi"):
|
def __init__(self, config: Dict, module_api: "ModuleApi") -> None:
|
||||||
# keep a record of the "current" rules module, so that the test can patch
|
# keep a record of the "current" rules module, so that the test can patch
|
||||||
# it if desired.
|
# it if desired.
|
||||||
thread_local.rules_module = self
|
thread_local.rules_module = self
|
||||||
@ -42,32 +48,36 @@ class LegacyThirdPartyRulesTestModule:
|
|||||||
|
|
||||||
async def on_create_room(
|
async def on_create_room(
|
||||||
self, requester: Requester, config: dict, is_requester_admin: bool
|
self, requester: Requester, config: dict, is_requester_admin: bool
|
||||||
):
|
) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
|
async def check_event_allowed(
|
||||||
|
self, event: EventBase, state: StateMap[EventBase]
|
||||||
|
) -> Union[bool, dict]:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_config(config):
|
def parse_config(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule):
|
class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule):
|
||||||
def __init__(self, config: Dict, module_api: "ModuleApi"):
|
def __init__(self, config: Dict, module_api: "ModuleApi") -> None:
|
||||||
super().__init__(config, module_api)
|
super().__init__(config, module_api)
|
||||||
|
|
||||||
def on_create_room(
|
async def on_create_room(
|
||||||
self, requester: Requester, config: dict, is_requester_admin: bool
|
self, requester: Requester, config: dict, is_requester_admin: bool
|
||||||
):
|
) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class LegacyChangeEvents(LegacyThirdPartyRulesTestModule):
|
class LegacyChangeEvents(LegacyThirdPartyRulesTestModule):
|
||||||
def __init__(self, config: Dict, module_api: "ModuleApi"):
|
def __init__(self, config: Dict, module_api: "ModuleApi") -> None:
|
||||||
super().__init__(config, module_api)
|
super().__init__(config, module_api)
|
||||||
|
|
||||||
async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
|
async def check_event_allowed(
|
||||||
|
self, event: EventBase, state: StateMap[EventBase]
|
||||||
|
) -> JsonDict:
|
||||||
d = event.get_dict()
|
d = event.get_dict()
|
||||||
content = unfreeze(event.content)
|
content = unfreeze(event.content)
|
||||||
content["foo"] = "bar"
|
content["foo"] = "bar"
|
||||||
@ -84,7 +94,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
|||||||
account.register_servlets,
|
account.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
hs = self.setup_test_homeserver()
|
hs = self.setup_test_homeserver()
|
||||||
|
|
||||||
load_legacy_third_party_event_rules(hs)
|
load_legacy_third_party_event_rules(hs)
|
||||||
@ -94,22 +104,30 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
|||||||
# Note that these checks are not relevant to this test case.
|
# Note that these checks are not relevant to this test case.
|
||||||
|
|
||||||
# Have this homeserver auto-approve all event signature checking.
|
# Have this homeserver auto-approve all event signature checking.
|
||||||
async def approve_all_signature_checking(_, pdu):
|
async def approve_all_signature_checking(
|
||||||
|
_: RoomVersion, pdu: EventBase
|
||||||
|
) -> EventBase:
|
||||||
return pdu
|
return pdu
|
||||||
|
|
||||||
hs.get_federation_server()._check_sigs_and_hash = approve_all_signature_checking
|
hs.get_federation_server()._check_sigs_and_hash = approve_all_signature_checking # type: ignore[assignment]
|
||||||
|
|
||||||
# Have this homeserver skip event auth checks. This is necessary due to
|
# Have this homeserver skip event auth checks. This is necessary due to
|
||||||
# event auth checks ensuring that events were signed by the sender's homeserver.
|
# event auth checks ensuring that events were signed by the sender's homeserver.
|
||||||
async def _check_event_auth(origin, event, context, *args, **kwargs):
|
async def _check_event_auth(
|
||||||
|
origin: str,
|
||||||
|
event: EventBase,
|
||||||
|
context: EventContext,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> EventContext:
|
||||||
return context
|
return context
|
||||||
|
|
||||||
hs.get_federation_event_handler()._check_event_auth = _check_event_auth
|
hs.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment]
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def prepare(self, reactor, clock, homeserver):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
super().prepare(reactor, clock, homeserver)
|
super().prepare(reactor, clock, hs)
|
||||||
# Create some users and a room to play with during the tests
|
# Create some users and a room to play with during the tests
|
||||||
self.user_id = self.register_user("kermit", "monkey")
|
self.user_id = self.register_user("kermit", "monkey")
|
||||||
self.invitee = self.register_user("invitee", "hackme")
|
self.invitee = self.register_user("invitee", "hackme")
|
||||||
@ -121,13 +139,15 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_third_party_rules(self):
|
def test_third_party_rules(self) -> None:
|
||||||
"""Tests that a forbidden event is forbidden from being sent, but an allowed one
|
"""Tests that a forbidden event is forbidden from being sent, but an allowed one
|
||||||
can be sent.
|
can be sent.
|
||||||
"""
|
"""
|
||||||
# patch the rules module with a Mock which will return False for some event
|
# patch the rules module with a Mock which will return False for some event
|
||||||
# types
|
# types
|
||||||
async def check(ev, state):
|
async def check(
|
||||||
|
ev: EventBase, state: StateMap[EventBase]
|
||||||
|
) -> Tuple[bool, Optional[JsonDict]]:
|
||||||
return ev.type != "foo.bar.forbidden", None
|
return ev.type != "foo.bar.forbidden", None
|
||||||
|
|
||||||
callback = Mock(spec=[], side_effect=check)
|
callback = Mock(spec=[], side_effect=check)
|
||||||
@ -161,7 +181,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(channel.result["code"], b"403", channel.result)
|
self.assertEqual(channel.result["code"], b"403", channel.result)
|
||||||
|
|
||||||
def test_third_party_rules_workaround_synapse_errors_pass_through(self):
|
def test_third_party_rules_workaround_synapse_errors_pass_through(self) -> None:
|
||||||
"""
|
"""
|
||||||
Tests that the workaround introduced by https://github.com/matrix-org/synapse/pull/11042
|
Tests that the workaround introduced by https://github.com/matrix-org/synapse/pull/11042
|
||||||
is functional: that SynapseErrors are passed through from check_event_allowed
|
is functional: that SynapseErrors are passed through from check_event_allowed
|
||||||
@ -172,7 +192,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
class NastyHackException(SynapseError):
|
class NastyHackException(SynapseError):
|
||||||
def error_dict(self):
|
def error_dict(self) -> JsonDict:
|
||||||
"""
|
"""
|
||||||
This overrides SynapseError's `error_dict` to nastily inject
|
This overrides SynapseError's `error_dict` to nastily inject
|
||||||
JSON into the error response.
|
JSON into the error response.
|
||||||
@ -182,7 +202,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
# add a callback that will raise our hacky exception
|
# add a callback that will raise our hacky exception
|
||||||
async def check(ev, state) -> Tuple[bool, Optional[JsonDict]]:
|
async def check(
|
||||||
|
ev: EventBase, state: StateMap[EventBase]
|
||||||
|
) -> Tuple[bool, Optional[JsonDict]]:
|
||||||
raise NastyHackException(429, "message")
|
raise NastyHackException(429, "message")
|
||||||
|
|
||||||
self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
|
self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
|
||||||
@ -202,11 +224,13 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
|||||||
{"errcode": "M_UNKNOWN", "error": "message", "nasty": "very"},
|
{"errcode": "M_UNKNOWN", "error": "message", "nasty": "very"},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_cannot_modify_event(self):
|
def test_cannot_modify_event(self) -> None:
|
||||||
"""cannot accidentally modify an event before it is persisted"""
|
"""cannot accidentally modify an event before it is persisted"""
|
||||||
|
|
||||||
# first patch the event checker so that it will try to modify the event
|
# first patch the event checker so that it will try to modify the event
|
||||||
async def check(ev: EventBase, state):
|
async def check(
|
||||||
|
ev: EventBase, state: StateMap[EventBase]
|
||||||
|
) -> Tuple[bool, Optional[JsonDict]]:
|
||||||
ev.content = {"x": "y"}
|
ev.content = {"x": "y"}
|
||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
@ -223,10 +247,12 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
|||||||
# 500 Internal Server Error
|
# 500 Internal Server Error
|
||||||
self.assertEqual(channel.code, 500, channel.result)
|
self.assertEqual(channel.code, 500, channel.result)
|
||||||
|
|
||||||
def test_modify_event(self):
|
def test_modify_event(self) -> None:
|
||||||
"""The module can return a modified version of the event"""
|
"""The module can return a modified version of the event"""
|
||||||
# first patch the event checker so that it will modify the event
|
# first patch the event checker so that it will modify the event
|
||||||
async def check(ev: EventBase, state):
|
async def check(
|
||||||
|
ev: EventBase, state: StateMap[EventBase]
|
||||||
|
) -> Tuple[bool, Optional[JsonDict]]:
|
||||||
d = ev.get_dict()
|
d = ev.get_dict()
|
||||||
d["content"] = {"x": "y"}
|
d["content"] = {"x": "y"}
|
||||||
return True, d
|
return True, d
|
||||||
@ -253,10 +279,12 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
|||||||
ev = channel.json_body
|
ev = channel.json_body
|
||||||
self.assertEqual(ev["content"]["x"], "y")
|
self.assertEqual(ev["content"]["x"], "y")
|
||||||
|
|
||||||
def test_message_edit(self):
|
def test_message_edit(self) -> None:
|
||||||
"""Ensure that the module doesn't cause issues with edited messages."""
|
"""Ensure that the module doesn't cause issues with edited messages."""
|
||||||
# first patch the event checker so that it will modify the event
|
# first patch the event checker so that it will modify the event
|
||||||
async def check(ev: EventBase, state):
|
async def check(
|
||||||
|
ev: EventBase, state: StateMap[EventBase]
|
||||||
|
) -> Tuple[bool, Optional[JsonDict]]:
|
||||||
d = ev.get_dict()
|
d = ev.get_dict()
|
||||||
d["content"] = {
|
d["content"] = {
|
||||||
"msgtype": "m.text",
|
"msgtype": "m.text",
|
||||||
@ -315,7 +343,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
|||||||
ev = channel.json_body
|
ev = channel.json_body
|
||||||
self.assertEqual(ev["content"]["body"], "EDITED BODY")
|
self.assertEqual(ev["content"]["body"], "EDITED BODY")
|
||||||
|
|
||||||
def test_send_event(self):
|
def test_send_event(self) -> None:
|
||||||
"""Tests that a module can send an event into a room via the module api"""
|
"""Tests that a module can send an event into a room via the module api"""
|
||||||
content = {
|
content = {
|
||||||
"msgtype": "m.text",
|
"msgtype": "m.text",
|
||||||
@ -344,7 +372,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_legacy_check_event_allowed(self):
|
def test_legacy_check_event_allowed(self) -> None:
|
||||||
"""Tests that the wrapper for legacy check_event_allowed callbacks works
|
"""Tests that the wrapper for legacy check_event_allowed callbacks works
|
||||||
correctly.
|
correctly.
|
||||||
"""
|
"""
|
||||||
@ -379,13 +407,13 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_legacy_on_create_room(self):
|
def test_legacy_on_create_room(self) -> None:
|
||||||
"""Tests that the wrapper for legacy on_create_room callbacks works
|
"""Tests that the wrapper for legacy on_create_room callbacks works
|
||||||
correctly.
|
correctly.
|
||||||
"""
|
"""
|
||||||
self.helper.create_room_as(self.user_id, tok=self.tok, expect_code=403)
|
self.helper.create_room_as(self.user_id, tok=self.tok, expect_code=403)
|
||||||
|
|
||||||
def test_sent_event_end_up_in_room_state(self):
|
def test_sent_event_end_up_in_room_state(self) -> None:
|
||||||
"""Tests that a state event sent by a module while processing another state event
|
"""Tests that a state event sent by a module while processing another state event
|
||||||
doesn't get dropped from the state of the room. This is to guard against a bug
|
doesn't get dropped from the state of the room. This is to guard against a bug
|
||||||
where Synapse has been observed doing so, see https://github.com/matrix-org/synapse/issues/10830
|
where Synapse has been observed doing so, see https://github.com/matrix-org/synapse/issues/10830
|
||||||
@ -400,7 +428,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
|||||||
api = self.hs.get_module_api()
|
api = self.hs.get_module_api()
|
||||||
|
|
||||||
# Define a callback that sends a custom event on power levels update.
|
# Define a callback that sends a custom event on power levels update.
|
||||||
async def test_fn(event: EventBase, state_events):
|
async def test_fn(
|
||||||
|
event: EventBase, state_events: StateMap[EventBase]
|
||||||
|
) -> Tuple[bool, Optional[JsonDict]]:
|
||||||
if event.is_state and event.type == EventTypes.PowerLevels:
|
if event.is_state and event.type == EventTypes.PowerLevels:
|
||||||
await api.create_and_send_event_into_room(
|
await api.create_and_send_event_into_room(
|
||||||
{
|
{
|
||||||
@ -436,7 +466,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
|||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
self.assertEqual(channel.json_body["i"], i)
|
self.assertEqual(channel.json_body["i"], i)
|
||||||
|
|
||||||
def test_on_new_event(self):
|
def test_on_new_event(self) -> None:
|
||||||
"""Test that the on_new_event callback is called on new events"""
|
"""Test that the on_new_event callback is called on new events"""
|
||||||
on_new_event = Mock(make_awaitable(None))
|
on_new_event = Mock(make_awaitable(None))
|
||||||
self.hs.get_third_party_event_rules()._on_new_event_callbacks.append(
|
self.hs.get_third_party_event_rules()._on_new_event_callbacks.append(
|
||||||
@ -501,7 +531,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
|||||||
|
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
|
|
||||||
def _update_power_levels(self, event_default: int = 0):
|
def _update_power_levels(self, event_default: int = 0) -> None:
|
||||||
"""Updates the room's power levels.
|
"""Updates the room's power levels.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -533,7 +563,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
|||||||
tok=self.tok,
|
tok=self.tok,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_on_profile_update(self):
|
def test_on_profile_update(self) -> None:
|
||||||
"""Tests that the on_profile_update module callback is correctly called on
|
"""Tests that the on_profile_update module callback is correctly called on
|
||||||
profile updates.
|
profile updates.
|
||||||
"""
|
"""
|
||||||
@ -592,7 +622,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
|||||||
self.assertEqual(profile_info.display_name, displayname)
|
self.assertEqual(profile_info.display_name, displayname)
|
||||||
self.assertEqual(profile_info.avatar_url, avatar_url)
|
self.assertEqual(profile_info.avatar_url, avatar_url)
|
||||||
|
|
||||||
def test_on_profile_update_admin(self):
|
def test_on_profile_update_admin(self) -> None:
|
||||||
"""Tests that the on_profile_update module callback is correctly called on
|
"""Tests that the on_profile_update module callback is correctly called on
|
||||||
profile updates triggered by a server admin.
|
profile updates triggered by a server admin.
|
||||||
"""
|
"""
|
||||||
@ -634,7 +664,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
|||||||
self.assertEqual(profile_info.display_name, displayname)
|
self.assertEqual(profile_info.display_name, displayname)
|
||||||
self.assertEqual(profile_info.avatar_url, avatar_url)
|
self.assertEqual(profile_info.avatar_url, avatar_url)
|
||||||
|
|
||||||
def test_on_user_deactivation_status_changed(self):
|
def test_on_user_deactivation_status_changed(self) -> None:
|
||||||
"""Tests that the on_user_deactivation_status_changed module callback is called
|
"""Tests that the on_user_deactivation_status_changed module callback is called
|
||||||
correctly when processing a user's deactivation.
|
correctly when processing a user's deactivation.
|
||||||
"""
|
"""
|
||||||
@ -691,7 +721,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
|||||||
args = profile_mock.call_args[0]
|
args = profile_mock.call_args[0]
|
||||||
self.assertTrue(args[3])
|
self.assertTrue(args[3])
|
||||||
|
|
||||||
def test_on_user_deactivation_status_changed_admin(self):
|
def test_on_user_deactivation_status_changed_admin(self) -> None:
|
||||||
"""Tests that the on_user_deactivation_status_changed module callback is called
|
"""Tests that the on_user_deactivation_status_changed module callback is called
|
||||||
correctly when processing a user's deactivation triggered by a server admin as
|
correctly when processing a user's deactivation triggered by a server admin as
|
||||||
well as a reactivation.
|
well as a reactivation.
|
||||||
|
@ -14,11 +14,16 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Tests REST events for /rooms paths."""
|
"""Tests REST events for /rooms paths."""
|
||||||
|
from typing import Any
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.rest.client import room
|
from synapse.rest.client import room
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.storage.databases.main.registration import TokenLookupResult
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
@ -33,7 +38,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
|
|||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
servlets = [room.register_servlets]
|
servlets = [room.register_servlets]
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
|
|
||||||
hs = self.setup_test_homeserver(
|
hs = self.setup_test_homeserver(
|
||||||
"red",
|
"red",
|
||||||
@ -43,30 +48,34 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.event_source = hs.get_event_sources().sources.typing
|
self.event_source = hs.get_event_sources().sources.typing
|
||||||
|
|
||||||
hs.get_federation_handler = Mock()
|
hs.get_federation_handler = Mock() # type: ignore[assignment]
|
||||||
|
|
||||||
async def get_user_by_access_token(token=None, allow_guest=False):
|
async def get_user_by_access_token(
|
||||||
return {
|
token: str,
|
||||||
"user": UserID.from_string(self.auth_user_id),
|
rights: str = "access",
|
||||||
"token_id": 1,
|
allow_expired: bool = False,
|
||||||
"is_guest": False,
|
) -> TokenLookupResult:
|
||||||
}
|
return TokenLookupResult(
|
||||||
|
user_id=self.user_id,
|
||||||
|
is_guest=False,
|
||||||
|
token_id=1,
|
||||||
|
)
|
||||||
|
|
||||||
hs.get_auth().get_user_by_access_token = get_user_by_access_token
|
hs.get_auth().get_user_by_access_token = get_user_by_access_token # type: ignore[assignment]
|
||||||
|
|
||||||
async def _insert_client_ip(*args, **kwargs):
|
async def _insert_client_ip(*args: Any, **kwargs: Any) -> None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
hs.get_datastores().main.insert_client_ip = _insert_client_ip
|
hs.get_datastores().main.insert_client_ip = _insert_client_ip # type: ignore[assignment]
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.room_id = self.helper.create_room_as(self.user_id)
|
self.room_id = self.helper.create_room_as(self.user_id)
|
||||||
# Need another user to make notifications actually work
|
# Need another user to make notifications actually work
|
||||||
self.helper.join(self.room_id, user="@jim:red")
|
self.helper.join(self.room_id, user="@jim:red")
|
||||||
|
|
||||||
def test_set_typing(self):
|
def test_set_typing(self) -> None:
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"PUT",
|
"PUT",
|
||||||
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
|
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
|
||||||
@ -95,7 +104,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_set_not_typing(self):
|
def test_set_not_typing(self) -> None:
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"PUT",
|
"PUT",
|
||||||
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
|
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
|
||||||
@ -103,7 +112,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(200, channel.code)
|
self.assertEqual(200, channel.code)
|
||||||
|
|
||||||
def test_typing_timeout(self):
|
def test_typing_timeout(self) -> None:
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"PUT",
|
"PUT",
|
||||||
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
|
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
|
||||||
|
Loading…
Reference in New Issue
Block a user