Properly typecheck tests.api (#14983)

This commit is contained in:
David Robertson 2023-02-03 20:03:23 +00:00 committed by GitHub
parent b2d97bac09
commit 6e6edea6c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 141 additions and 111 deletions

1
changelog.d/14983.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints.

View File

@ -32,7 +32,6 @@ exclude = (?x)
|synapse/storage/databases/main/cache.py |synapse/storage/databases/main/cache.py
|synapse/storage/schema/ |synapse/storage/schema/
|tests/api/test_auth.py
|tests/appservice/test_scheduler.py |tests/appservice/test_scheduler.py
|tests/federation/test_federation_catch_up.py |tests/federation/test_federation_catch_up.py
|tests/federation/test_federation_sender.py |tests/federation/test_federation_sender.py
@ -73,6 +72,9 @@ disallow_untyped_defs = False
[mypy-tests.*] [mypy-tests.*]
disallow_untyped_defs = False disallow_untyped_defs = False
[mypy-tests.api.*]
disallow_untyped_defs = True
[mypy-tests.app.*] [mypy-tests.app.*]
disallow_untyped_defs = True disallow_untyped_defs = True

View File

@ -252,9 +252,9 @@ class FilterCollection:
return self._room_timeline_filter.unread_thread_notifications return self._room_timeline_filter.unread_thread_notifications
async def filter_presence( async def filter_presence(
self, events: Iterable[UserPresenceState] self, presence_states: Iterable[UserPresenceState]
) -> List[UserPresenceState]: ) -> List[UserPresenceState]:
return await self._presence_filter.filter(events) return await self._presence_filter.filter(presence_states)
async def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]: async def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
return await self._account_data.filter(events) return await self._account_data.filter(events)

View File

@ -31,7 +31,7 @@ from synapse.api.errors import (
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.databases.main.registration import TokenLookupResult from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import Requester from synapse.types import Requester, UserID
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
@ -41,10 +41,12 @@ from tests.utils import mock_getRawHeaders
class AuthTestCase(unittest.HomeserverTestCase): class AuthTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = Mock() self.store = Mock()
hs.datastores.main = self.store # type-ignore: datastores is None until hs.setup() is called---but it'll
# have been called by the HomeserverTestCase machinery.
hs.datastores.main = self.store # type: ignore[union-attr]
hs.get_auth_handler().store = self.store hs.get_auth_handler().store = self.store
self.auth = Auth(hs) self.auth = Auth(hs)
@ -61,7 +63,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.insert_client_ip = simple_async_mock(None) self.store.insert_client_ip = simple_async_mock(None)
self.store.is_support_user = simple_async_mock(False) self.store.is_support_user = simple_async_mock(False)
def test_get_user_by_req_user_valid_token(self): def test_get_user_by_req_user_valid_token(self) -> None:
user_info = TokenLookupResult( user_info = TokenLookupResult(
user_id=self.test_user, token_id=5, device_id="device" user_id=self.test_user, token_id=5, device_id="device"
) )
@ -74,7 +76,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
requester = self.get_success(self.auth.get_user_by_req(request)) requester = self.get_success(self.auth.get_user_by_req(request))
self.assertEqual(requester.user.to_string(), self.test_user) self.assertEqual(requester.user.to_string(), self.test_user)
def test_get_user_by_req_user_bad_token(self): def test_get_user_by_req_user_bad_token(self) -> None:
self.store.get_user_by_access_token = simple_async_mock(None) self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={}) request = Mock(args={})
@ -86,7 +88,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(f.code, 401) self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_user_missing_token(self): def test_get_user_by_req_user_missing_token(self) -> None:
user_info = TokenLookupResult(user_id=self.test_user, token_id=5) user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
self.store.get_user_by_access_token = simple_async_mock(user_info) self.store.get_user_by_access_token = simple_async_mock(user_info)
@ -98,7 +100,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(f.code, 401) self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_MISSING_TOKEN") self.assertEqual(f.errcode, "M_MISSING_TOKEN")
def test_get_user_by_req_appservice_valid_token(self): def test_get_user_by_req_appservice_valid_token(self) -> None:
app_service = Mock( app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
) )
@ -112,7 +114,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
requester = self.get_success(self.auth.get_user_by_req(request)) requester = self.get_success(self.auth.get_user_by_req(request))
self.assertEqual(requester.user.to_string(), self.test_user) self.assertEqual(requester.user.to_string(), self.test_user)
def test_get_user_by_req_appservice_valid_token_good_ip(self): def test_get_user_by_req_appservice_valid_token_good_ip(self) -> None:
from netaddr import IPSet from netaddr import IPSet
app_service = Mock( app_service = Mock(
@ -131,7 +133,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
requester = self.get_success(self.auth.get_user_by_req(request)) requester = self.get_success(self.auth.get_user_by_req(request))
self.assertEqual(requester.user.to_string(), self.test_user) self.assertEqual(requester.user.to_string(), self.test_user)
def test_get_user_by_req_appservice_valid_token_bad_ip(self): def test_get_user_by_req_appservice_valid_token_bad_ip(self) -> None:
from netaddr import IPSet from netaddr import IPSet
app_service = Mock( app_service = Mock(
@ -153,7 +155,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(f.code, 401) self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_appservice_bad_token(self): def test_get_user_by_req_appservice_bad_token(self) -> None:
self.store.get_app_service_by_token = Mock(return_value=None) self.store.get_app_service_by_token = Mock(return_value=None)
self.store.get_user_by_access_token = simple_async_mock(None) self.store.get_user_by_access_token = simple_async_mock(None)
@ -166,7 +168,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(f.code, 401) self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_appservice_missing_token(self): def test_get_user_by_req_appservice_missing_token(self) -> None:
app_service = Mock(token="foobar", url="a_url", sender=self.test_user) app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = simple_async_mock(None) self.store.get_user_by_access_token = simple_async_mock(None)
@ -179,7 +181,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(f.code, 401) self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_MISSING_TOKEN") self.assertEqual(f.errcode, "M_MISSING_TOKEN")
def test_get_user_by_req_appservice_valid_token_valid_user_id(self): def test_get_user_by_req_appservice_valid_token_valid_user_id(self) -> None:
masquerading_user_id = b"@doppelganger:matrix.org" masquerading_user_id = b"@doppelganger:matrix.org"
app_service = Mock( app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
@ -200,7 +202,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
requester.user.to_string(), masquerading_user_id.decode("utf8") requester.user.to_string(), masquerading_user_id.decode("utf8")
) )
def test_get_user_by_req_appservice_valid_token_bad_user_id(self): def test_get_user_by_req_appservice_valid_token_bad_user_id(self) -> None:
masquerading_user_id = b"@doppelganger:matrix.org" masquerading_user_id = b"@doppelganger:matrix.org"
app_service = Mock( app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
@ -217,7 +219,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.get_failure(self.auth.get_user_by_req(request), AuthError) self.get_failure(self.auth.get_user_by_req(request), AuthError)
@override_config({"experimental_features": {"msc3202_device_masquerading": True}}) @override_config({"experimental_features": {"msc3202_device_masquerading": True}})
def test_get_user_by_req_appservice_valid_token_valid_device_id(self): def test_get_user_by_req_appservice_valid_token_valid_device_id(self) -> None:
""" """
Tests that when an application service passes the device_id URL parameter Tests that when an application service passes the device_id URL parameter
with the ID of a valid device for the user in question, with the ID of a valid device for the user in question,
@ -249,7 +251,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(requester.device_id, masquerading_device_id.decode("utf8")) self.assertEqual(requester.device_id, masquerading_device_id.decode("utf8"))
@override_config({"experimental_features": {"msc3202_device_masquerading": True}}) @override_config({"experimental_features": {"msc3202_device_masquerading": True}})
def test_get_user_by_req_appservice_valid_token_invalid_device_id(self): def test_get_user_by_req_appservice_valid_token_invalid_device_id(self) -> None:
""" """
Tests that when an application service passes the device_id URL parameter Tests that when an application service passes the device_id URL parameter
with an ID that is not a valid device ID for the user in question, with an ID that is not a valid device ID for the user in question,
@ -279,7 +281,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(failure.value.code, 400) self.assertEqual(failure.value.code, 400)
self.assertEqual(failure.value.errcode, Codes.EXCLUSIVE) self.assertEqual(failure.value.errcode, Codes.EXCLUSIVE)
def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self): def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self) -> None:
self.store.get_user_by_access_token = simple_async_mock( self.store.get_user_by_access_token = simple_async_mock(
TokenLookupResult( TokenLookupResult(
user_id="@baldrick:matrix.org", user_id="@baldrick:matrix.org",
@ -298,7 +300,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.get_success(self.auth.get_user_by_req(request)) self.get_success(self.auth.get_user_by_req(request))
self.store.insert_client_ip.assert_called_once() self.store.insert_client_ip.assert_called_once()
def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self): def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self) -> None:
self.auth._track_puppeted_user_ips = True self.auth._track_puppeted_user_ips = True
self.store.get_user_by_access_token = simple_async_mock( self.store.get_user_by_access_token = simple_async_mock(
TokenLookupResult( TokenLookupResult(
@ -318,7 +320,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.get_success(self.auth.get_user_by_req(request)) self.get_success(self.auth.get_user_by_req(request))
self.assertEqual(self.store.insert_client_ip.call_count, 2) self.assertEqual(self.store.insert_client_ip.call_count, 2)
def test_get_user_from_macaroon(self): def test_get_user_from_macaroon(self) -> None:
self.store.get_user_by_access_token = simple_async_mock(None) self.store.get_user_by_access_token = simple_async_mock(None)
user_id = "@baldrick:matrix.org" user_id = "@baldrick:matrix.org"
@ -336,7 +338,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.auth.get_user_by_access_token(serialized), InvalidClientTokenError self.auth.get_user_by_access_token(serialized), InvalidClientTokenError
) )
def test_get_guest_user_from_macaroon(self): def test_get_guest_user_from_macaroon(self) -> None:
self.store.get_user_by_id = simple_async_mock({"is_guest": True}) self.store.get_user_by_id = simple_async_mock({"is_guest": True})
self.store.get_user_by_access_token = simple_async_mock(None) self.store.get_user_by_access_token = simple_async_mock(None)
@ -357,7 +359,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertTrue(user_info.is_guest) self.assertTrue(user_info.is_guest)
self.store.get_user_by_id.assert_called_with(user_id) self.store.get_user_by_id.assert_called_with(user_id)
def test_blocking_mau(self): def test_blocking_mau(self) -> None:
self.auth_blocking._limit_usage_by_mau = False self.auth_blocking._limit_usage_by_mau = False
self.auth_blocking._max_mau_value = 50 self.auth_blocking._max_mau_value = 50
lots_of_users = 100 lots_of_users = 100
@ -381,7 +383,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.get_monthly_active_count = simple_async_mock(small_number_of_users) self.store.get_monthly_active_count = simple_async_mock(small_number_of_users)
self.get_success(self.auth_blocking.check_auth_blocking()) self.get_success(self.auth_blocking.check_auth_blocking())
def test_blocking_mau__depending_on_user_type(self): def test_blocking_mau__depending_on_user_type(self) -> None:
self.auth_blocking._max_mau_value = 50 self.auth_blocking._max_mau_value = 50
self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
@ -400,7 +402,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
# Real users not allowed # Real users not allowed
self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError) self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError)
def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(self): def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(
self,
) -> None:
self.auth_blocking._max_mau_value = 50 self.auth_blocking._max_mau_value = 50
self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.auth_blocking._track_appservice_user_ips = False self.auth_blocking._track_appservice_user_ips = False
@ -418,7 +422,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
sender="@appservice:sender", sender="@appservice:sender",
) )
requester = Requester( requester = Requester(
user="@appservice:server", user=UserID.from_string("@appservice:server"),
access_token_id=None, access_token_id=None,
device_id="FOOBAR", device_id="FOOBAR",
is_guest=False, is_guest=False,
@ -428,7 +432,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
) )
self.get_success(self.auth_blocking.check_auth_blocking(requester=requester)) self.get_success(self.auth_blocking.check_auth_blocking(requester=requester))
def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(self): def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(
self,
) -> None:
self.auth_blocking._max_mau_value = 50 self.auth_blocking._max_mau_value = 50
self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.auth_blocking._track_appservice_user_ips = True self.auth_blocking._track_appservice_user_ips = True
@ -446,7 +452,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
sender="@appservice:sender", sender="@appservice:sender",
) )
requester = Requester( requester = Requester(
user="@appservice:server", user=UserID.from_string("@appservice:server"),
access_token_id=None, access_token_id=None,
device_id="FOOBAR", device_id="FOOBAR",
is_guest=False, is_guest=False,
@ -459,7 +465,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
ResourceLimitError, ResourceLimitError,
) )
def test_reserved_threepid(self): def test_reserved_threepid(self) -> None:
self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.auth_blocking._max_mau_value = 1 self.auth_blocking._max_mau_value = 1
self.store.get_monthly_active_count = simple_async_mock(2) self.store.get_monthly_active_count = simple_async_mock(2)
@ -476,7 +482,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.get_success(self.auth_blocking.check_auth_blocking(threepid=threepid)) self.get_success(self.auth_blocking.check_auth_blocking(threepid=threepid))
def test_hs_disabled(self): def test_hs_disabled(self) -> None:
self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled = True
self.auth_blocking._hs_disabled_message = "Reason for being disabled" self.auth_blocking._hs_disabled_message = "Reason for being disabled"
e = self.get_failure( e = self.get_failure(
@ -486,7 +492,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEqual(e.value.code, 403) self.assertEqual(e.value.code, 403)
def test_hs_disabled_no_server_notices_user(self): def test_hs_disabled_no_server_notices_user(self) -> None:
"""Check that 'hs_disabled_message' works correctly when there is no """Check that 'hs_disabled_message' works correctly when there is no
server_notices user. server_notices user.
""" """
@ -503,7 +509,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEqual(e.value.code, 403) self.assertEqual(e.value.code, 403)
def test_server_notices_mxid_special_cased(self): def test_server_notices_mxid_special_cased(self) -> None:
self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled = True
user = "@user:server" user = "@user:server"
self.auth_blocking._server_notices_mxid = user self.auth_blocking._server_notices_mxid = user

View File

@ -14,40 +14,36 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List
from unittest.mock import patch from unittest.mock import patch
import jsonschema import jsonschema
from frozendict import frozendict from frozendict import frozendict
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EduTypes, EventContentFields from synapse.api.constants import EduTypes, EventContentFields
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.events import make_event_from_dict from synapse.api.presence import UserPresenceState
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.events.test_utils import MockEvent
user_localpart = "test_user" user_localpart = "test_user"
def MockEvent(**kwargs):
if "event_id" not in kwargs:
kwargs["event_id"] = "fake_event_id"
if "type" not in kwargs:
kwargs["type"] = "fake_type"
if "content" not in kwargs:
kwargs["content"] = {}
return make_event_from_dict(kwargs)
class FilteringTestCase(unittest.HomeserverTestCase): class FilteringTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.filtering = hs.get_filtering() self.filtering = hs.get_filtering()
self.datastore = hs.get_datastores().main self.datastore = hs.get_datastores().main
def test_errors_on_invalid_filters(self): def test_errors_on_invalid_filters(self) -> None:
# See USER_FILTER_SCHEMA for the filter schema. # See USER_FILTER_SCHEMA for the filter schema.
invalid_filters = [ invalid_filters: List[JsonDict] = [
# `account_data` must be a dictionary # `account_data` must be a dictionary
{"account_data": "Hello World"}, {"account_data": "Hello World"},
# `event_fields` entries must not contain backslashes # `event_fields` entries must not contain backslashes
@ -63,10 +59,10 @@ class FilteringTestCase(unittest.HomeserverTestCase):
with self.assertRaises(SynapseError): with self.assertRaises(SynapseError):
self.filtering.check_valid_filter(filter) self.filtering.check_valid_filter(filter)
def test_ignores_unknown_filter_fields(self): def test_ignores_unknown_filter_fields(self) -> None:
# For forward compatibility, we must ignore unknown filter fields. # For forward compatibility, we must ignore unknown filter fields.
# See USER_FILTER_SCHEMA for the filter schema. # See USER_FILTER_SCHEMA for the filter schema.
filters = [ filters: List[JsonDict] = [
{"org.matrix.msc9999.future_option": True}, {"org.matrix.msc9999.future_option": True},
{"presence": {"org.matrix.msc9999.future_option": True}}, {"presence": {"org.matrix.msc9999.future_option": True}},
{"room": {"org.matrix.msc9999.future_option": True}}, {"room": {"org.matrix.msc9999.future_option": True}},
@ -76,8 +72,8 @@ class FilteringTestCase(unittest.HomeserverTestCase):
self.filtering.check_valid_filter(filter) self.filtering.check_valid_filter(filter)
# Must not raise. # Must not raise.
def test_valid_filters(self): def test_valid_filters(self) -> None:
valid_filters = [ valid_filters: List[JsonDict] = [
{ {
"room": { "room": {
"timeline": {"limit": 20}, "timeline": {"limit": 20},
@ -132,22 +128,22 @@ class FilteringTestCase(unittest.HomeserverTestCase):
except jsonschema.ValidationError as e: except jsonschema.ValidationError as e:
self.fail(e) self.fail(e)
def test_limits_are_applied(self): def test_limits_are_applied(self) -> None:
# TODO # TODO
pass pass
def test_definition_types_works_with_literals(self): def test_definition_types_works_with_literals(self) -> None:
definition = {"types": ["m.room.message", "org.matrix.foo.bar"]} definition = {"types": ["m.room.message", "org.matrix.foo.bar"]}
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar") event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
self.assertTrue(Filter(self.hs, definition)._check(event)) self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_types_works_with_wildcards(self): def test_definition_types_works_with_wildcards(self) -> None:
definition = {"types": ["m.*", "org.matrix.foo.bar"]} definition = {"types": ["m.*", "org.matrix.foo.bar"]}
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar") event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
self.assertTrue(Filter(self.hs, definition)._check(event)) self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_types_works_with_unknowns(self): def test_definition_types_works_with_unknowns(self) -> None:
definition = {"types": ["m.room.message", "org.matrix.foo.bar"]} definition = {"types": ["m.room.message", "org.matrix.foo.bar"]}
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar",
@ -156,24 +152,24 @@ class FilteringTestCase(unittest.HomeserverTestCase):
) )
self.assertFalse(Filter(self.hs, definition)._check(event)) self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_types_works_with_literals(self): def test_definition_not_types_works_with_literals(self) -> None:
definition = {"not_types": ["m.room.message", "org.matrix.foo.bar"]} definition = {"not_types": ["m.room.message", "org.matrix.foo.bar"]}
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar") event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
self.assertFalse(Filter(self.hs, definition)._check(event)) self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_types_works_with_wildcards(self): def test_definition_not_types_works_with_wildcards(self) -> None:
definition = {"not_types": ["m.room.message", "org.matrix.*"]} definition = {"not_types": ["m.room.message", "org.matrix.*"]}
event = MockEvent( event = MockEvent(
sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar" sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar"
) )
self.assertFalse(Filter(self.hs, definition)._check(event)) self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_types_works_with_unknowns(self): def test_definition_not_types_works_with_unknowns(self) -> None:
definition = {"not_types": ["m.*", "org.*"]} definition = {"not_types": ["m.*", "org.*"]}
event = MockEvent(sender="@foo:bar", type="com.nom.nom.nom", room_id="!foo:bar") event = MockEvent(sender="@foo:bar", type="com.nom.nom.nom", room_id="!foo:bar")
self.assertTrue(Filter(self.hs, definition)._check(event)) self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_not_types_takes_priority_over_types(self): def test_definition_not_types_takes_priority_over_types(self) -> None:
definition = { definition = {
"not_types": ["m.*", "org.*"], "not_types": ["m.*", "org.*"],
"types": ["m.room.message", "m.room.topic"], "types": ["m.room.message", "m.room.topic"],
@ -181,35 +177,35 @@ class FilteringTestCase(unittest.HomeserverTestCase):
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar") event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
self.assertFalse(Filter(self.hs, definition)._check(event)) self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_senders_works_with_literals(self): def test_definition_senders_works_with_literals(self) -> None:
definition = {"senders": ["@flibble:wibble"]} definition = {"senders": ["@flibble:wibble"]}
event = MockEvent( event = MockEvent(
sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar" sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar"
) )
self.assertTrue(Filter(self.hs, definition)._check(event)) self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_senders_works_with_unknowns(self): def test_definition_senders_works_with_unknowns(self) -> None:
definition = {"senders": ["@flibble:wibble"]} definition = {"senders": ["@flibble:wibble"]}
event = MockEvent( event = MockEvent(
sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar" sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar"
) )
self.assertFalse(Filter(self.hs, definition)._check(event)) self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_senders_works_with_literals(self): def test_definition_not_senders_works_with_literals(self) -> None:
definition = {"not_senders": ["@flibble:wibble"]} definition = {"not_senders": ["@flibble:wibble"]}
event = MockEvent( event = MockEvent(
sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar" sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar"
) )
self.assertFalse(Filter(self.hs, definition)._check(event)) self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_senders_works_with_unknowns(self): def test_definition_not_senders_works_with_unknowns(self) -> None:
definition = {"not_senders": ["@flibble:wibble"]} definition = {"not_senders": ["@flibble:wibble"]}
event = MockEvent( event = MockEvent(
sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar" sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar"
) )
self.assertTrue(Filter(self.hs, definition)._check(event)) self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_not_senders_takes_priority_over_senders(self): def test_definition_not_senders_takes_priority_over_senders(self) -> None:
definition = { definition = {
"not_senders": ["@misspiggy:muppets"], "not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets", "@misspiggy:muppets"], "senders": ["@kermit:muppets", "@misspiggy:muppets"],
@ -219,14 +215,14 @@ class FilteringTestCase(unittest.HomeserverTestCase):
) )
self.assertFalse(Filter(self.hs, definition)._check(event)) self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_rooms_works_with_literals(self): def test_definition_rooms_works_with_literals(self) -> None:
definition = {"rooms": ["!secretbase:unknown"]} definition = {"rooms": ["!secretbase:unknown"]}
event = MockEvent( event = MockEvent(
sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown" sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown"
) )
self.assertTrue(Filter(self.hs, definition)._check(event)) self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_rooms_works_with_unknowns(self): def test_definition_rooms_works_with_unknowns(self) -> None:
definition = {"rooms": ["!secretbase:unknown"]} definition = {"rooms": ["!secretbase:unknown"]}
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar",
@ -235,7 +231,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
) )
self.assertFalse(Filter(self.hs, definition)._check(event)) self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_rooms_works_with_literals(self): def test_definition_not_rooms_works_with_literals(self) -> None:
definition = {"not_rooms": ["!anothersecretbase:unknown"]} definition = {"not_rooms": ["!anothersecretbase:unknown"]}
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar",
@ -244,7 +240,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
) )
self.assertFalse(Filter(self.hs, definition)._check(event)) self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_rooms_works_with_unknowns(self): def test_definition_not_rooms_works_with_unknowns(self) -> None:
definition = {"not_rooms": ["!secretbase:unknown"]} definition = {"not_rooms": ["!secretbase:unknown"]}
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar",
@ -253,7 +249,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
) )
self.assertTrue(Filter(self.hs, definition)._check(event)) self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_not_rooms_takes_priority_over_rooms(self): def test_definition_not_rooms_takes_priority_over_rooms(self) -> None:
definition = { definition = {
"not_rooms": ["!secretbase:unknown"], "not_rooms": ["!secretbase:unknown"],
"rooms": ["!secretbase:unknown"], "rooms": ["!secretbase:unknown"],
@ -263,7 +259,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
) )
self.assertFalse(Filter(self.hs, definition)._check(event)) self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_combined_event(self): def test_definition_combined_event(self) -> None:
definition = { definition = {
"not_senders": ["@misspiggy:muppets"], "not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets"], "senders": ["@kermit:muppets"],
@ -279,7 +275,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
) )
self.assertTrue(Filter(self.hs, definition)._check(event)) self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_combined_event_bad_sender(self): def test_definition_combined_event_bad_sender(self) -> None:
definition = { definition = {
"not_senders": ["@misspiggy:muppets"], "not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets"], "senders": ["@kermit:muppets"],
@ -295,7 +291,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
) )
self.assertFalse(Filter(self.hs, definition)._check(event)) self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_combined_event_bad_room(self): def test_definition_combined_event_bad_room(self) -> None:
definition = { definition = {
"not_senders": ["@misspiggy:muppets"], "not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets"], "senders": ["@kermit:muppets"],
@ -311,7 +307,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
) )
self.assertFalse(Filter(self.hs, definition)._check(event)) self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_combined_event_bad_type(self): def test_definition_combined_event_bad_type(self) -> None:
definition = { definition = {
"not_senders": ["@misspiggy:muppets"], "not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets"], "senders": ["@kermit:muppets"],
@ -327,7 +323,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
) )
self.assertFalse(Filter(self.hs, definition)._check(event)) self.assertFalse(Filter(self.hs, definition)._check(event))
def test_filter_labels(self): def test_filter_labels(self) -> None:
definition = {"org.matrix.labels": ["#fun"]} definition = {"org.matrix.labels": ["#fun"]}
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar",
@ -356,7 +352,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
) )
self.assertTrue(Filter(self.hs, definition)._check(event)) self.assertTrue(Filter(self.hs, definition)._check(event))
def test_filter_not_labels(self): def test_filter_not_labels(self) -> None:
definition = {"org.matrix.not_labels": ["#fun"]} definition = {"org.matrix.not_labels": ["#fun"]}
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar",
@ -377,7 +373,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
self.assertTrue(Filter(self.hs, definition)._check(event)) self.assertTrue(Filter(self.hs, definition)._check(event))
@unittest.override_config({"experimental_features": {"msc3874_enabled": True}}) @unittest.override_config({"experimental_features": {"msc3874_enabled": True}})
def test_filter_rel_type(self): def test_filter_rel_type(self) -> None:
definition = {"org.matrix.msc3874.rel_types": ["m.thread"]} definition = {"org.matrix.msc3874.rel_types": ["m.thread"]}
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar",
@ -407,7 +403,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
self.assertTrue(Filter(self.hs, definition)._check(event)) self.assertTrue(Filter(self.hs, definition)._check(event))
@unittest.override_config({"experimental_features": {"msc3874_enabled": True}}) @unittest.override_config({"experimental_features": {"msc3874_enabled": True}})
def test_filter_not_rel_type(self): def test_filter_not_rel_type(self) -> None:
definition = {"org.matrix.msc3874.not_rel_types": ["m.thread"]} definition = {"org.matrix.msc3874.not_rel_types": ["m.thread"]}
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar",
@ -436,15 +432,25 @@ class FilteringTestCase(unittest.HomeserverTestCase):
self.assertTrue(Filter(self.hs, definition)._check(event)) self.assertTrue(Filter(self.hs, definition)._check(event))
def test_filter_presence_match(self): def test_filter_presence_match(self) -> None:
user_filter_json = {"presence": {"types": ["m.*"]}} """Check that filter_presence return events which matches the filter."""
user_filter_json = {"presence": {"senders": ["@foo:bar"]}}
filter_id = self.get_success( filter_id = self.get_success(
self.datastore.add_user_filter( self.datastore.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json user_localpart=user_localpart, user_filter=user_filter_json
) )
) )
event = MockEvent(sender="@foo:bar", type="m.profile") presence_states = [
events = [event] UserPresenceState(
user_id="@foo:bar",
state="unavailable",
last_active_ts=0,
last_federation_update_ts=0,
last_user_sync_ts=0,
status_msg=None,
currently_active=False,
),
]
user_filter = self.get_success( user_filter = self.get_success(
self.filtering.get_user_filter( self.filtering.get_user_filter(
@ -452,23 +458,29 @@ class FilteringTestCase(unittest.HomeserverTestCase):
) )
) )
results = self.get_success(user_filter.filter_presence(events=events)) results = self.get_success(user_filter.filter_presence(presence_states))
self.assertEqual(events, results) self.assertEqual(presence_states, results)
def test_filter_presence_no_match(self): def test_filter_presence_no_match(self) -> None:
user_filter_json = {"presence": {"types": ["m.*"]}} """Check that filter_presence does not return events rejected by the filter."""
user_filter_json = {"presence": {"not_senders": ["@foo:bar"]}}
filter_id = self.get_success( filter_id = self.get_success(
self.datastore.add_user_filter( self.datastore.add_user_filter(
user_localpart=user_localpart + "2", user_filter=user_filter_json user_localpart=user_localpart + "2", user_filter=user_filter_json
) )
) )
event = MockEvent( presence_states = [
event_id="$asdasd:localhost", UserPresenceState(
sender="@foo:bar", user_id="@foo:bar",
type="custom.avatar.3d.crazy", state="unavailable",
) last_active_ts=0,
events = [event] last_federation_update_ts=0,
last_user_sync_ts=0,
status_msg=None,
currently_active=False,
),
]
user_filter = self.get_success( user_filter = self.get_success(
self.filtering.get_user_filter( self.filtering.get_user_filter(
@ -476,10 +488,10 @@ class FilteringTestCase(unittest.HomeserverTestCase):
) )
) )
results = self.get_success(user_filter.filter_presence(events=events)) results = self.get_success(user_filter.filter_presence(presence_states))
self.assertEqual([], results) self.assertEqual([], results)
def test_filter_room_state_match(self): def test_filter_room_state_match(self) -> None:
user_filter_json = {"room": {"state": {"types": ["m.*"]}}} user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
filter_id = self.get_success( filter_id = self.get_success(
self.datastore.add_user_filter( self.datastore.add_user_filter(
@ -498,7 +510,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
results = self.get_success(user_filter.filter_room_state(events=events)) results = self.get_success(user_filter.filter_room_state(events=events))
self.assertEqual(events, results) self.assertEqual(events, results)
def test_filter_room_state_no_match(self): def test_filter_room_state_no_match(self) -> None:
user_filter_json = {"room": {"state": {"types": ["m.*"]}}} user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
filter_id = self.get_success( filter_id = self.get_success(
self.datastore.add_user_filter( self.datastore.add_user_filter(
@ -519,7 +531,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
results = self.get_success(user_filter.filter_room_state(events)) results = self.get_success(user_filter.filter_room_state(events))
self.assertEqual([], results) self.assertEqual([], results)
def test_filter_rooms(self): def test_filter_rooms(self) -> None:
definition = { definition = {
"rooms": ["!allowed:example.com", "!excluded:example.com"], "rooms": ["!allowed:example.com", "!excluded:example.com"],
"not_rooms": ["!excluded:example.com"], "not_rooms": ["!excluded:example.com"],
@ -535,7 +547,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
self.assertEqual(filtered_room_ids, ["!allowed:example.com"]) self.assertEqual(filtered_room_ids, ["!allowed:example.com"])
def test_filter_relations(self): def test_filter_relations(self) -> None:
events = [ events = [
# An event without a relation. # An event without a relation.
MockEvent( MockEvent(
@ -551,9 +563,8 @@ class FilteringTestCase(unittest.HomeserverTestCase):
type="org.matrix.custom.event", type="org.matrix.custom.event",
room_id="!foo:bar", room_id="!foo:bar",
), ),
# Non-EventBase objects get passed through.
{},
] ]
jsondicts: List[JsonDict] = [{}]
# For the following tests we patch the datastore method (intead of injecting # For the following tests we patch the datastore method (intead of injecting
# events). This is a bit cheeky, but tests the logic of _check_event_relations. # events). This is a bit cheeky, but tests the logic of _check_event_relations.
@ -561,7 +572,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
# Filter for a particular sender. # Filter for a particular sender.
definition = {"related_by_senders": ["@foo:bar"]} definition = {"related_by_senders": ["@foo:bar"]}
async def events_have_relations(*args, **kwargs): async def events_have_relations(*args: object, **kwargs: object) -> List[str]:
return ["$with_relation"] return ["$with_relation"]
with patch.object( with patch.object(
@ -572,9 +583,17 @@ class FilteringTestCase(unittest.HomeserverTestCase):
Filter(self.hs, definition)._check_event_relations(events) Filter(self.hs, definition)._check_event_relations(events)
) )
) )
self.assertEqual(filtered_events, events[1:]) # Non-EventBase objects get passed through.
filtered_jsondicts = list(
self.get_success(
Filter(self.hs, definition)._check_event_relations(jsondicts)
)
)
def test_add_filter(self): self.assertEqual(filtered_events, events[1:])
self.assertEqual(filtered_jsondicts, [{}])
def test_add_filter(self) -> None:
user_filter_json = {"room": {"state": {"types": ["m.*"]}}} user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
filter_id = self.get_success( filter_id = self.get_success(
@ -595,7 +614,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
), ),
) )
def test_get_filter(self): def test_get_filter(self) -> None:
user_filter_json = {"room": {"state": {"types": ["m.*"]}}} user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
filter_id = self.get_success( filter_id = self.get_success(

View File

@ -6,7 +6,7 @@ from tests import unittest
class TestRatelimiter(unittest.HomeserverTestCase): class TestRatelimiter(unittest.HomeserverTestCase):
def test_allowed_via_can_do_action(self): def test_allowed_via_can_do_action(self) -> None:
limiter = Ratelimiter( limiter = Ratelimiter(
store=self.hs.get_datastores().main, store=self.hs.get_datastores().main,
clock=self.clock, clock=self.clock,
@ -31,7 +31,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
self.assertTrue(allowed) self.assertTrue(allowed)
self.assertEqual(20.0, time_allowed) self.assertEqual(20.0, time_allowed)
def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): def test_allowed_appservice_ratelimited_via_can_requester_do_action(self) -> None:
appservice = ApplicationService( appservice = ApplicationService(
token="fake_token", token="fake_token",
id="foo", id="foo",
@ -64,7 +64,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
self.assertTrue(allowed) self.assertTrue(allowed)
self.assertEqual(20.0, time_allowed) self.assertEqual(20.0, time_allowed)
def test_allowed_appservice_via_can_requester_do_action(self): def test_allowed_appservice_via_can_requester_do_action(self) -> None:
appservice = ApplicationService( appservice = ApplicationService(
token="fake_token", token="fake_token",
id="foo", id="foo",
@ -97,7 +97,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
self.assertTrue(allowed) self.assertTrue(allowed)
self.assertEqual(-1, time_allowed) self.assertEqual(-1, time_allowed)
def test_allowed_via_ratelimit(self): def test_allowed_via_ratelimit(self) -> None:
limiter = Ratelimiter( limiter = Ratelimiter(
store=self.hs.get_datastores().main, store=self.hs.get_datastores().main,
clock=self.clock, clock=self.clock,
@ -120,7 +120,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter.ratelimit(None, key="test_id", _time_now_s=10) limiter.ratelimit(None, key="test_id", _time_now_s=10)
) )
def test_allowed_via_can_do_action_and_overriding_parameters(self): def test_allowed_via_can_do_action_and_overriding_parameters(self) -> None:
"""Test that we can override options of can_do_action that would otherwise fail """Test that we can override options of can_do_action that would otherwise fail
an action an action
""" """
@ -169,7 +169,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
self.assertTrue(allowed) self.assertTrue(allowed)
self.assertEqual(1.0, time_allowed) self.assertEqual(1.0, time_allowed)
def test_allowed_via_ratelimit_and_overriding_parameters(self): def test_allowed_via_ratelimit_and_overriding_parameters(self) -> None:
"""Test that we can override options of the ratelimit method that would otherwise """Test that we can override options of the ratelimit method that would otherwise
fail an action fail an action
""" """
@ -204,7 +204,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter.ratelimit(None, key=("test_id",), _time_now_s=1, burst_count=10) limiter.ratelimit(None, key=("test_id",), _time_now_s=1, burst_count=10)
) )
def test_pruning(self): def test_pruning(self) -> None:
limiter = Ratelimiter( limiter = Ratelimiter(
store=self.hs.get_datastores().main, store=self.hs.get_datastores().main,
clock=self.clock, clock=self.clock,
@ -223,7 +223,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
self.assertNotIn("test_id_1", limiter.actions) self.assertNotIn("test_id_1", limiter.actions)
def test_db_user_override(self): def test_db_user_override(self) -> None:
"""Test that users that have ratelimiting disabled in the DB aren't """Test that users that have ratelimiting disabled in the DB aren't
ratelimited. ratelimited.
""" """
@ -250,7 +250,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
for _ in range(20): for _ in range(20):
self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0)) self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0))
def test_multiple_actions(self): def test_multiple_actions(self) -> None:
limiter = Ratelimiter( limiter = Ratelimiter(
store=self.hs.get_datastores().main, store=self.hs.get_datastores().main,
clock=self.clock, clock=self.clock,

View File

@ -35,6 +35,8 @@ def MockEvent(**kwargs: Any) -> EventBase:
kwargs["event_id"] = "fake_event_id" kwargs["event_id"] = "fake_event_id"
if "type" not in kwargs: if "type" not in kwargs:
kwargs["type"] = "fake_type" kwargs["type"] = "fake_type"
if "content" not in kwargs:
kwargs["content"] = {}
return make_event_from_dict(kwargs) return make_event_from_dict(kwargs)