Add type hints for event streams. (#10856)

This commit is contained in:
Patrick Cloke 2021-09-21 13:34:26 -04:00 committed by GitHub
parent b25a494779
commit 4054dfa409
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 169 additions and 60 deletions

1
changelog.d/10856.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints to handlers.

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import random import random
from typing import TYPE_CHECKING, Any, List, Tuple from typing import TYPE_CHECKING, Collection, List, Optional, Tuple
from synapse.replication.http.account_data import ( from synapse.replication.http.account_data import (
ReplicationAddTagRestServlet, ReplicationAddTagRestServlet,
@ -21,6 +21,7 @@ from synapse.replication.http.account_data import (
ReplicationRoomAccountDataRestServlet, ReplicationRoomAccountDataRestServlet,
ReplicationUserAccountDataRestServlet, ReplicationUserAccountDataRestServlet,
) )
from synapse.streams import EventSource
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, UserID
if TYPE_CHECKING: if TYPE_CHECKING:
@ -163,7 +164,7 @@ class AccountDataHandler:
return response["max_stream_id"] return response["max_stream_id"]
class AccountDataEventSource: class AccountDataEventSource(EventSource[int, JsonDict]):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -171,7 +172,13 @@ class AccountDataEventSource:
return self.store.get_max_account_data_stream_id() return self.store.get_max_account_data_stream_id()
async def get_new_events( async def get_new_events(
self, user: UserID, from_key: int, **kwargs: Any self,
user: UserID,
from_key: int,
limit: Optional[int],
room_ids: Collection[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[JsonDict], int]: ) -> Tuple[List[JsonDict], int]:
user_id = user.to_string() user_id = user.to_string()
last_stream_id = from_key last_stream_id = from_key

View File

@ -254,7 +254,7 @@ class ApplicationServicesHandler:
async def _handle_typing( async def _handle_typing(
self, service: ApplicationService, new_token: int self, service: ApplicationService, new_token: int
) -> List[JsonDict]: ) -> List[JsonDict]:
typing_source = self.event_sources.sources["typing"] typing_source = self.event_sources.sources.typing
# Get the typing events from just before current # Get the typing events from just before current
typing, _ = await typing_source.get_new_events_as( typing, _ = await typing_source.get_new_events_as(
service=service, service=service,
@ -269,7 +269,7 @@ class ApplicationServicesHandler:
from_key = await self.store.get_type_stream_id_for_appservice( from_key = await self.store.get_type_stream_id_for_appservice(
service, "read_receipt" service, "read_receipt"
) )
receipts_source = self.event_sources.sources["receipt"] receipts_source = self.event_sources.sources.receipt
receipts, _ = await receipts_source.get_new_events_as( receipts, _ = await receipts_source.get_new_events_as(
service=service, from_key=from_key service=service, from_key=from_key
) )
@ -279,7 +279,7 @@ class ApplicationServicesHandler:
self, service: ApplicationService, users: Collection[Union[str, UserID]] self, service: ApplicationService, users: Collection[Union[str, UserID]]
) -> List[JsonDict]: ) -> List[JsonDict]:
events: List[JsonDict] = [] events: List[JsonDict] = []
presence_source = self.event_sources.sources["presence"] presence_source = self.event_sources.sources.presence
from_key = await self.store.get_type_stream_id_for_appservice( from_key = await self.store.get_type_stream_id_for_appservice(
service, "presence" service, "presence"
) )

View File

@ -125,7 +125,7 @@ class InitialSyncHandler(BaseHandler):
now_token = self.hs.get_event_sources().get_current_token() now_token = self.hs.get_event_sources().get_current_token()
presence_stream = self.hs.get_event_sources().sources["presence"] presence_stream = self.hs.get_event_sources().sources.presence
presence, _ = await presence_stream.get_new_events( presence, _ = await presence_stream.get_new_events(
user, from_key=None, include_offline=False user, from_key=None, include_offline=False
) )

View File

@ -65,6 +65,7 @@ from synapse.replication.http.streams import ReplicationGetStreamUpdates
from synapse.replication.tcp.commands import ClearUserSyncsCommand from synapse.replication.tcp.commands import ClearUserSyncsCommand
from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
from synapse.streams import EventSource
from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.caches.descriptors import _CacheContext, cached
@ -1500,7 +1501,7 @@ def format_user_presence_state(
return content return content
class PresenceEventSource: class PresenceEventSource(EventSource[int, UserPresenceState]):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
# We can't call get_presence_handler here because there's a cycle: # We can't call get_presence_handler here because there's a cycle:
# #
@ -1519,10 +1520,11 @@ class PresenceEventSource:
self, self,
user: UserID, user: UserID,
from_key: Optional[int], from_key: Optional[int],
limit: Optional[int] = None,
room_ids: Optional[List[str]] = None, room_ids: Optional[List[str]] = None,
include_offline: bool = True, is_guest: bool = False,
explicit_room_id: Optional[str] = None, explicit_room_id: Optional[str] = None,
**kwargs: Any, include_offline: bool = True,
) -> Tuple[List[UserPresenceState], int]: ) -> Tuple[List[UserPresenceState], int]:
# The process for getting presence events are: # The process for getting presence events are:
# 1. Get the rooms the user is in. # 1. Get the rooms the user is in.

View File

@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Any, List, Optional, Tuple from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from synapse.api.constants import ReadReceiptEventFields from synapse.api.constants import ReadReceiptEventFields
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.handlers._base import BaseHandler from synapse.handlers._base import BaseHandler
from synapse.streams import EventSource
from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id
if TYPE_CHECKING: if TYPE_CHECKING:
@ -162,7 +163,7 @@ class ReceiptsHandler(BaseHandler):
await self.federation_sender.send_read_receipt(receipt) await self.federation_sender.send_read_receipt(receipt)
class ReceiptEventSource: class ReceiptEventSource(EventSource[int, JsonDict]):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.config = hs.config self.config = hs.config
@ -216,7 +217,13 @@ class ReceiptEventSource:
return visible_events return visible_events
async def get_new_events( async def get_new_events(
self, from_key: int, room_ids: List[str], user: UserID, **kwargs: Any self,
user: UserID,
from_key: int,
limit: Optional[int],
room_ids: Iterable[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[JsonDict], int]: ) -> Tuple[List[JsonDict], int]:
from_key = int(from_key) from_key = int(from_key)
to_key = self.get_current_key() to_key = self.get_current_key()

View File

@ -20,7 +20,16 @@ import math
import random import random
import string import string
from collections import OrderedDict from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Collection,
Dict,
List,
Optional,
Tuple,
)
from synapse.api.constants import ( from synapse.api.constants import (
EventContentFields, EventContentFields,
@ -47,6 +56,7 @@ from synapse.events import EventBase
from synapse.events.utils import copy_power_levels_contents from synapse.events.utils import copy_power_levels_contents
from synapse.rest.admin._base import assert_user_is_admin from synapse.rest.admin._base import assert_user_is_admin
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.streams import EventSource
from synapse.types import ( from synapse.types import (
JsonDict, JsonDict,
MutableStateMap, MutableStateMap,
@ -1173,7 +1183,7 @@ class RoomContextHandler:
return results return results
class RoomEventSource: class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -1181,8 +1191,8 @@ class RoomEventSource:
self, self,
user: UserID, user: UserID,
from_key: RoomStreamToken, from_key: RoomStreamToken,
limit: int, limit: Optional[int],
room_ids: List[str], room_ids: Collection[str],
is_guest: bool, is_guest: bool,
explicit_room_id: Optional[str] = None, explicit_room_id: Optional[str] = None,
) -> Tuple[List[EventBase], RoomStreamToken]: ) -> Tuple[List[EventBase], RoomStreamToken]:

View File

@ -443,7 +443,7 @@ class SyncHandler:
room_ids = sync_result_builder.joined_room_ids room_ids = sync_result_builder.joined_room_ids
typing_source = self.event_sources.sources["typing"] typing_source = self.event_sources.sources.typing
typing, typing_key = await typing_source.get_new_events( typing, typing_key = await typing_source.get_new_events(
user=sync_config.user, user=sync_config.user,
from_key=typing_key, from_key=typing_key,
@ -465,7 +465,7 @@ class SyncHandler:
receipt_key = since_token.receipt_key if since_token else 0 receipt_key = since_token.receipt_key if since_token else 0
receipt_source = self.event_sources.sources["receipt"] receipt_source = self.event_sources.sources.receipt
receipts, receipt_key = await receipt_source.get_new_events( receipts, receipt_key = await receipt_source.get_new_events(
user=sync_config.user, user=sync_config.user,
from_key=receipt_key, from_key=receipt_key,
@ -1415,7 +1415,7 @@ class SyncHandler:
sync_config = sync_result_builder.sync_config sync_config = sync_result_builder.sync_config
user = sync_result_builder.sync_config.user user = sync_result_builder.sync_config.user
presence_source = self.event_sources.sources["presence"] presence_source = self.event_sources.sources.presence
since_token = sync_result_builder.since_token since_token = sync_result_builder.since_token
presence_key = None presence_key = None

View File

@ -14,7 +14,7 @@
import logging import logging
import random import random
from collections import namedtuple from collections import namedtuple
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api.errors import AuthError, ShadowBanError, SynapseError from synapse.api.errors import AuthError, ShadowBanError, SynapseError
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
@ -23,6 +23,7 @@ from synapse.metrics.background_process_metrics import (
wrap_as_background_process, wrap_as_background_process,
) )
from synapse.replication.tcp.streams import TypingStream from synapse.replication.tcp.streams import TypingStream
from synapse.streams import EventSource
from synapse.types import JsonDict, Requester, UserID, get_domain_from_id from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -439,7 +440,7 @@ class TypingWriterHandler(FollowerTypingHandler):
raise Exception("Typing writer instance got typing info over replication") raise Exception("Typing writer instance got typing info over replication")
class TypingNotificationEventSource: class TypingNotificationEventSource(EventSource[int, JsonDict]):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -485,7 +486,13 @@ class TypingNotificationEventSource:
return (events, handler._latest_room_serial) return (events, handler._latest_room_serial)
async def get_new_events( async def get_new_events(
self, from_key: int, room_ids: Iterable[str], **kwargs: Any self,
user: UserID,
from_key: int,
limit: Optional[int],
room_ids: Iterable[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[JsonDict], int]: ) -> Tuple[List[JsonDict], int]:
with Measure(self.clock, "typing.get_new_events"): with Measure(self.clock, "typing.get_new_events"):
from_key = int(from_key) from_key = int(from_key)

View File

@ -91,7 +91,7 @@ class ModuleApi:
self._auth = hs.get_auth() self._auth = hs.get_auth()
self._auth_handler = auth_handler self._auth_handler = auth_handler
self._server_name = hs.hostname self._server_name = hs.hostname
self._presence_stream = hs.get_event_sources().sources["presence"] self._presence_stream = hs.get_event_sources().sources.presence
self._state = hs.get_state_handler() self._state = hs.get_state_handler()
self._clock: Clock = hs.get_clock() self._clock: Clock = hs.get_clock()
self._send_email_handler = hs.get_send_email_handler() self._send_email_handler = hs.get_send_email_handler()

View File

@ -584,7 +584,7 @@ class Notifier:
events: List[EventBase] = [] events: List[EventBase] = []
end_token = from_token end_token = from_token
for name, source in self.event_sources.sources.items(): for name, source in self.event_sources.sources.get_sources():
keyname = "%s_key" % name keyname = "%s_key" % name
before_id = getattr(before_token, keyname) before_id = getattr(before_token, keyname)
after_id = getattr(after_token, keyname) after_id = getattr(after_token, keyname)

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, Iterable, List, Optional, Tuple
from twisted.internet import defer from twisted.internet import defer
@ -153,12 +153,12 @@ class ReceiptsWorkerStore(SQLBaseStore):
} }
async def get_linearized_receipts_for_rooms( async def get_linearized_receipts_for_rooms(
self, room_ids: List[str], to_key: int, from_key: Optional[int] = None self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None
) -> List[dict]: ) -> List[dict]:
"""Get receipts for multiple rooms for sending to clients. """Get receipts for multiple rooms for sending to clients.
Args: Args:
room_id: List of room_ids. room_id: The room IDs to fetch receipts of.
to_key: Max stream id to fetch receipts up to. to_key: Max stream id to fetch receipts up to.
from_key: Min stream id to fetch receipts from. None fetches from_key: Min stream id to fetch receipts from. None fetches
from the start. from the start.

View File

@ -11,3 +11,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Collection, Generic, List, Optional, Tuple, TypeVar
from synapse.types import UserID
# The key, this is either a stream token or int.
K = TypeVar("K")
# The return type.
R = TypeVar("R")
class EventSource(Generic[K, R]):
async def get_new_events(
self,
user: UserID,
from_key: K,
limit: Optional[int],
room_ids: Collection[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[R], K]:
...

View File

@ -12,29 +12,40 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict from typing import TYPE_CHECKING, Iterator, Tuple
import attr
from synapse.handlers.account_data import AccountDataEventSource from synapse.handlers.account_data import AccountDataEventSource
from synapse.handlers.presence import PresenceEventSource from synapse.handlers.presence import PresenceEventSource
from synapse.handlers.receipts import ReceiptEventSource from synapse.handlers.receipts import ReceiptEventSource
from synapse.handlers.room import RoomEventSource from synapse.handlers.room import RoomEventSource
from synapse.handlers.typing import TypingNotificationEventSource from synapse.handlers.typing import TypingNotificationEventSource
from synapse.streams import EventSource
from synapse.types import StreamToken from synapse.types import StreamToken
if TYPE_CHECKING:
from synapse.server import HomeServer
@attr.s(frozen=True, slots=True, auto_attribs=True)
class _EventSourcesInner:
room: RoomEventSource
presence: PresenceEventSource
typing: TypingNotificationEventSource
receipt: ReceiptEventSource
account_data: AccountDataEventSource
def get_sources(self) -> Iterator[Tuple[str, EventSource]]:
for attribute in _EventSourcesInner.__attrs_attrs__: # type: ignore[attr-defined]
yield attribute.name, getattr(self, attribute.name)
class EventSources: class EventSources:
SOURCE_TYPES = { def __init__(self, hs: "HomeServer"):
"room": RoomEventSource, self.sources = _EventSourcesInner(
"presence": PresenceEventSource, *(attribute.type(hs) for attribute in _EventSourcesInner.__attrs_attrs__) # type: ignore[attr-defined]
"typing": TypingNotificationEventSource, )
"receipt": ReceiptEventSource,
"account_data": AccountDataEventSource,
}
def __init__(self, hs):
self.sources: Dict[str, Any] = {
name: cls(hs) for name, cls in EventSources.SOURCE_TYPES.items()
}
self.store = hs.get_datastore() self.store = hs.get_datastore()
def get_current_token(self) -> StreamToken: def get_current_token(self) -> StreamToken:
@ -44,11 +55,11 @@ class EventSources:
groups_key = self.store.get_group_stream_token() groups_key = self.store.get_group_stream_token()
token = StreamToken( token = StreamToken(
room_key=self.sources["room"].get_current_key(), room_key=self.sources.room.get_current_key(),
presence_key=self.sources["presence"].get_current_key(), presence_key=self.sources.presence.get_current_key(),
typing_key=self.sources["typing"].get_current_key(), typing_key=self.sources.typing.get_current_key(),
receipt_key=self.sources["receipt"].get_current_key(), receipt_key=self.sources.receipt.get_current_key(),
account_data_key=self.sources["account_data"].get_current_key(), account_data_key=self.sources.account_data.get_current_key(),
push_rules_key=push_rules_key, push_rules_key=push_rules_key,
to_device_key=to_device_key, to_device_key=to_device_key,
device_list_key=device_list_key, device_list_key=device_list_key,
@ -67,7 +78,7 @@ class EventSources:
The current token for pagination. The current token for pagination.
""" """
token = StreamToken( token = StreamToken(
room_key=self.sources["room"].get_current_key(), room_key=self.sources.room.get_current_key(),
presence_key=0, presence_key=0,
typing_key=0, typing_key=0,
receipt_key=0, receipt_key=0,

View File

@ -23,7 +23,7 @@ from tests import unittest
class ReceiptsTestCase(unittest.HomeserverTestCase): class ReceiptsTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
self.event_source = hs.get_event_sources().sources["receipt"] self.event_source = hs.get_event_sources().sources.receipt
# In the first param of _test_filters_hidden we use "hidden" instead of # In the first param of _test_filters_hidden we use "hidden" instead of
# ReadReceiptEventFields.MSC2285_HIDDEN. We do this because we're mocking # ReadReceiptEventFields.MSC2285_HIDDEN. We do this because we're mocking

View File

@ -89,7 +89,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.handler = hs.get_typing_handler() self.handler = hs.get_typing_handler()
self.event_source = hs.get_event_sources().sources["typing"] self.event_source = hs.get_event_sources().sources.typing
self.datastore = hs.get_datastore() self.datastore = hs.get_datastore()
self.datastore.get_destination_retry_timings = Mock( self.datastore.get_destination_retry_timings = Mock(
@ -171,7 +171,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = self.get_success( events = self.get_success(
self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) self.event_source.get_new_events(
user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False
)
) )
self.assertEquals( self.assertEquals(
events[0], events[0],
@ -239,7 +241,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = self.get_success( events = self.get_success(
self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) self.event_source.get_new_events(
user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False
)
) )
self.assertEquals( self.assertEquals(
events[0], events[0],
@ -276,7 +280,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 0) self.assertEquals(self.event_source.get_current_key(), 0)
events = self.get_success( events = self.get_success(
self.event_source.get_new_events(room_ids=[OTHER_ROOM_ID], from_key=0) self.event_source.get_new_events(
user=U_APPLE,
from_key=0,
limit=None,
room_ids=[OTHER_ROOM_ID],
is_guest=False,
)
) )
self.assertEquals(events[0], []) self.assertEquals(events[0], [])
self.assertEquals(events[1], 0) self.assertEquals(events[1], 0)
@ -324,7 +334,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = self.get_success( events = self.get_success(
self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) self.event_source.get_new_events(
user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False
)
) )
self.assertEquals( self.assertEquals(
events[0], events[0],
@ -350,7 +362,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = self.get_success( events = self.get_success(
self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) self.event_source.get_new_events(
user=U_APPLE,
from_key=0,
limit=None,
room_ids=[ROOM_ID],
is_guest=False,
)
) )
self.assertEquals( self.assertEquals(
events[0], events[0],
@ -369,7 +387,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 2) self.assertEquals(self.event_source.get_current_key(), 2)
events = self.get_success( events = self.get_success(
self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1) self.event_source.get_new_events(
user=U_APPLE,
from_key=1,
limit=None,
room_ids=[ROOM_ID],
is_guest=False,
)
) )
self.assertEquals( self.assertEquals(
events[0], events[0],
@ -392,7 +416,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 3) self.assertEquals(self.event_source.get_current_key(), 3)
events = self.get_success( events = self.get_success(
self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) self.event_source.get_new_events(
user=U_APPLE,
from_key=0,
limit=None,
room_ids=[ROOM_ID],
is_guest=False,
)
) )
self.assertEquals( self.assertEquals(
events[0], events[0],

View File

@ -193,7 +193,7 @@ class RoomTestCase(_ShadowBannedBase):
self.assertEquals(200, channel.code) self.assertEquals(200, channel.code)
# There should be no typing events. # There should be no typing events.
event_source = self.hs.get_event_sources().sources["typing"] event_source = self.hs.get_event_sources().sources.typing
self.assertEquals(event_source.get_current_key(), 0) self.assertEquals(event_source.get_current_key(), 0)
# The other user can join and send typing events. # The other user can join and send typing events.
@ -210,7 +210,13 @@ class RoomTestCase(_ShadowBannedBase):
# These appear in the room. # These appear in the room.
self.assertEquals(event_source.get_current_key(), 1) self.assertEquals(event_source.get_current_key(), 1)
events = self.get_success( events = self.get_success(
event_source.get_new_events(from_key=0, room_ids=[room_id]) event_source.get_new_events(
user=UserID.from_string(self.other_user_id),
from_key=0,
limit=None,
room_ids=[room_id],
is_guest=False,
)
) )
self.assertEquals( self.assertEquals(
events[0], events[0],

View File

@ -41,7 +41,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
federation_client=Mock(), federation_client=Mock(),
) )
self.event_source = hs.get_event_sources().sources["typing"] self.event_source = hs.get_event_sources().sources.typing
hs.get_federation_handler = Mock() hs.get_federation_handler = Mock()
@ -76,7 +76,13 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = self.get_success( events = self.get_success(
self.event_source.get_new_events(from_key=0, room_ids=[self.room_id]) self.event_source.get_new_events(
user=UserID.from_string(self.user_id),
from_key=0,
limit=None,
room_ids=[self.room_id],
is_guest=False,
)
) )
self.assertEquals( self.assertEquals(
events[0], events[0],