mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Add type hints for event streams. (#10856)
This commit is contained in:
parent
b25a494779
commit
4054dfa409
1
changelog.d/10856.misc
Normal file
1
changelog.d/10856.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add missing type hints to handlers.
|
@ -13,7 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import random
|
import random
|
||||||
from typing import TYPE_CHECKING, Any, List, Tuple
|
from typing import TYPE_CHECKING, Collection, List, Optional, Tuple
|
||||||
|
|
||||||
from synapse.replication.http.account_data import (
|
from synapse.replication.http.account_data import (
|
||||||
ReplicationAddTagRestServlet,
|
ReplicationAddTagRestServlet,
|
||||||
@ -21,6 +21,7 @@ from synapse.replication.http.account_data import (
|
|||||||
ReplicationRoomAccountDataRestServlet,
|
ReplicationRoomAccountDataRestServlet,
|
||||||
ReplicationUserAccountDataRestServlet,
|
ReplicationUserAccountDataRestServlet,
|
||||||
)
|
)
|
||||||
|
from synapse.streams import EventSource
|
||||||
from synapse.types import JsonDict, UserID
|
from synapse.types import JsonDict, UserID
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -163,7 +164,7 @@ class AccountDataHandler:
|
|||||||
return response["max_stream_id"]
|
return response["max_stream_id"]
|
||||||
|
|
||||||
|
|
||||||
class AccountDataEventSource:
|
class AccountDataEventSource(EventSource[int, JsonDict]):
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
@ -171,7 +172,13 @@ class AccountDataEventSource:
|
|||||||
return self.store.get_max_account_data_stream_id()
|
return self.store.get_max_account_data_stream_id()
|
||||||
|
|
||||||
async def get_new_events(
|
async def get_new_events(
|
||||||
self, user: UserID, from_key: int, **kwargs: Any
|
self,
|
||||||
|
user: UserID,
|
||||||
|
from_key: int,
|
||||||
|
limit: Optional[int],
|
||||||
|
room_ids: Collection[str],
|
||||||
|
is_guest: bool,
|
||||||
|
explicit_room_id: Optional[str] = None,
|
||||||
) -> Tuple[List[JsonDict], int]:
|
) -> Tuple[List[JsonDict], int]:
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
last_stream_id = from_key
|
last_stream_id = from_key
|
||||||
|
@ -254,7 +254,7 @@ class ApplicationServicesHandler:
|
|||||||
async def _handle_typing(
|
async def _handle_typing(
|
||||||
self, service: ApplicationService, new_token: int
|
self, service: ApplicationService, new_token: int
|
||||||
) -> List[JsonDict]:
|
) -> List[JsonDict]:
|
||||||
typing_source = self.event_sources.sources["typing"]
|
typing_source = self.event_sources.sources.typing
|
||||||
# Get the typing events from just before current
|
# Get the typing events from just before current
|
||||||
typing, _ = await typing_source.get_new_events_as(
|
typing, _ = await typing_source.get_new_events_as(
|
||||||
service=service,
|
service=service,
|
||||||
@ -269,7 +269,7 @@ class ApplicationServicesHandler:
|
|||||||
from_key = await self.store.get_type_stream_id_for_appservice(
|
from_key = await self.store.get_type_stream_id_for_appservice(
|
||||||
service, "read_receipt"
|
service, "read_receipt"
|
||||||
)
|
)
|
||||||
receipts_source = self.event_sources.sources["receipt"]
|
receipts_source = self.event_sources.sources.receipt
|
||||||
receipts, _ = await receipts_source.get_new_events_as(
|
receipts, _ = await receipts_source.get_new_events_as(
|
||||||
service=service, from_key=from_key
|
service=service, from_key=from_key
|
||||||
)
|
)
|
||||||
@ -279,7 +279,7 @@ class ApplicationServicesHandler:
|
|||||||
self, service: ApplicationService, users: Collection[Union[str, UserID]]
|
self, service: ApplicationService, users: Collection[Union[str, UserID]]
|
||||||
) -> List[JsonDict]:
|
) -> List[JsonDict]:
|
||||||
events: List[JsonDict] = []
|
events: List[JsonDict] = []
|
||||||
presence_source = self.event_sources.sources["presence"]
|
presence_source = self.event_sources.sources.presence
|
||||||
from_key = await self.store.get_type_stream_id_for_appservice(
|
from_key = await self.store.get_type_stream_id_for_appservice(
|
||||||
service, "presence"
|
service, "presence"
|
||||||
)
|
)
|
||||||
|
@ -125,7 +125,7 @@ class InitialSyncHandler(BaseHandler):
|
|||||||
|
|
||||||
now_token = self.hs.get_event_sources().get_current_token()
|
now_token = self.hs.get_event_sources().get_current_token()
|
||||||
|
|
||||||
presence_stream = self.hs.get_event_sources().sources["presence"]
|
presence_stream = self.hs.get_event_sources().sources.presence
|
||||||
presence, _ = await presence_stream.get_new_events(
|
presence, _ = await presence_stream.get_new_events(
|
||||||
user, from_key=None, include_offline=False
|
user, from_key=None, include_offline=False
|
||||||
)
|
)
|
||||||
|
@ -65,6 +65,7 @@ from synapse.replication.http.streams import ReplicationGetStreamUpdates
|
|||||||
from synapse.replication.tcp.commands import ClearUserSyncsCommand
|
from synapse.replication.tcp.commands import ClearUserSyncsCommand
|
||||||
from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream
|
from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream
|
||||||
from synapse.storage.databases.main import DataStore
|
from synapse.storage.databases.main import DataStore
|
||||||
|
from synapse.streams import EventSource
|
||||||
from synapse.types import JsonDict, UserID, get_domain_from_id
|
from synapse.types import JsonDict, UserID, get_domain_from_id
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.caches.descriptors import _CacheContext, cached
|
from synapse.util.caches.descriptors import _CacheContext, cached
|
||||||
@ -1500,7 +1501,7 @@ def format_user_presence_state(
|
|||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
class PresenceEventSource:
|
class PresenceEventSource(EventSource[int, UserPresenceState]):
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
# We can't call get_presence_handler here because there's a cycle:
|
# We can't call get_presence_handler here because there's a cycle:
|
||||||
#
|
#
|
||||||
@ -1519,10 +1520,11 @@ class PresenceEventSource:
|
|||||||
self,
|
self,
|
||||||
user: UserID,
|
user: UserID,
|
||||||
from_key: Optional[int],
|
from_key: Optional[int],
|
||||||
|
limit: Optional[int] = None,
|
||||||
room_ids: Optional[List[str]] = None,
|
room_ids: Optional[List[str]] = None,
|
||||||
include_offline: bool = True,
|
is_guest: bool = False,
|
||||||
explicit_room_id: Optional[str] = None,
|
explicit_room_id: Optional[str] = None,
|
||||||
**kwargs: Any,
|
include_offline: bool = True,
|
||||||
) -> Tuple[List[UserPresenceState], int]:
|
) -> Tuple[List[UserPresenceState], int]:
|
||||||
# The process for getting presence events are:
|
# The process for getting presence events are:
|
||||||
# 1. Get the rooms the user is in.
|
# 1. Get the rooms the user is in.
|
||||||
|
@ -12,11 +12,12 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
from synapse.api.constants import ReadReceiptEventFields
|
from synapse.api.constants import ReadReceiptEventFields
|
||||||
from synapse.appservice import ApplicationService
|
from synapse.appservice import ApplicationService
|
||||||
from synapse.handlers._base import BaseHandler
|
from synapse.handlers._base import BaseHandler
|
||||||
|
from synapse.streams import EventSource
|
||||||
from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id
|
from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -162,7 +163,7 @@ class ReceiptsHandler(BaseHandler):
|
|||||||
await self.federation_sender.send_read_receipt(receipt)
|
await self.federation_sender.send_read_receipt(receipt)
|
||||||
|
|
||||||
|
|
||||||
class ReceiptEventSource:
|
class ReceiptEventSource(EventSource[int, JsonDict]):
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.config = hs.config
|
self.config = hs.config
|
||||||
@ -216,7 +217,13 @@ class ReceiptEventSource:
|
|||||||
return visible_events
|
return visible_events
|
||||||
|
|
||||||
async def get_new_events(
|
async def get_new_events(
|
||||||
self, from_key: int, room_ids: List[str], user: UserID, **kwargs: Any
|
self,
|
||||||
|
user: UserID,
|
||||||
|
from_key: int,
|
||||||
|
limit: Optional[int],
|
||||||
|
room_ids: Iterable[str],
|
||||||
|
is_guest: bool,
|
||||||
|
explicit_room_id: Optional[str] = None,
|
||||||
) -> Tuple[List[JsonDict], int]:
|
) -> Tuple[List[JsonDict], int]:
|
||||||
from_key = int(from_key)
|
from_key = int(from_key)
|
||||||
to_key = self.get_current_key()
|
to_key = self.get_current_key()
|
||||||
|
@ -20,7 +20,16 @@ import math
|
|||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Awaitable,
|
||||||
|
Collection,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
)
|
||||||
|
|
||||||
from synapse.api.constants import (
|
from synapse.api.constants import (
|
||||||
EventContentFields,
|
EventContentFields,
|
||||||
@ -47,6 +56,7 @@ from synapse.events import EventBase
|
|||||||
from synapse.events.utils import copy_power_levels_contents
|
from synapse.events.utils import copy_power_levels_contents
|
||||||
from synapse.rest.admin._base import assert_user_is_admin
|
from synapse.rest.admin._base import assert_user_is_admin
|
||||||
from synapse.storage.state import StateFilter
|
from synapse.storage.state import StateFilter
|
||||||
|
from synapse.streams import EventSource
|
||||||
from synapse.types import (
|
from synapse.types import (
|
||||||
JsonDict,
|
JsonDict,
|
||||||
MutableStateMap,
|
MutableStateMap,
|
||||||
@ -1173,7 +1183,7 @@ class RoomContextHandler:
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
class RoomEventSource:
|
class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
@ -1181,8 +1191,8 @@ class RoomEventSource:
|
|||||||
self,
|
self,
|
||||||
user: UserID,
|
user: UserID,
|
||||||
from_key: RoomStreamToken,
|
from_key: RoomStreamToken,
|
||||||
limit: int,
|
limit: Optional[int],
|
||||||
room_ids: List[str],
|
room_ids: Collection[str],
|
||||||
is_guest: bool,
|
is_guest: bool,
|
||||||
explicit_room_id: Optional[str] = None,
|
explicit_room_id: Optional[str] = None,
|
||||||
) -> Tuple[List[EventBase], RoomStreamToken]:
|
) -> Tuple[List[EventBase], RoomStreamToken]:
|
||||||
|
@ -443,7 +443,7 @@ class SyncHandler:
|
|||||||
|
|
||||||
room_ids = sync_result_builder.joined_room_ids
|
room_ids = sync_result_builder.joined_room_ids
|
||||||
|
|
||||||
typing_source = self.event_sources.sources["typing"]
|
typing_source = self.event_sources.sources.typing
|
||||||
typing, typing_key = await typing_source.get_new_events(
|
typing, typing_key = await typing_source.get_new_events(
|
||||||
user=sync_config.user,
|
user=sync_config.user,
|
||||||
from_key=typing_key,
|
from_key=typing_key,
|
||||||
@ -465,7 +465,7 @@ class SyncHandler:
|
|||||||
|
|
||||||
receipt_key = since_token.receipt_key if since_token else 0
|
receipt_key = since_token.receipt_key if since_token else 0
|
||||||
|
|
||||||
receipt_source = self.event_sources.sources["receipt"]
|
receipt_source = self.event_sources.sources.receipt
|
||||||
receipts, receipt_key = await receipt_source.get_new_events(
|
receipts, receipt_key = await receipt_source.get_new_events(
|
||||||
user=sync_config.user,
|
user=sync_config.user,
|
||||||
from_key=receipt_key,
|
from_key=receipt_key,
|
||||||
@ -1415,7 +1415,7 @@ class SyncHandler:
|
|||||||
sync_config = sync_result_builder.sync_config
|
sync_config = sync_result_builder.sync_config
|
||||||
user = sync_result_builder.sync_config.user
|
user = sync_result_builder.sync_config.user
|
||||||
|
|
||||||
presence_source = self.event_sources.sources["presence"]
|
presence_source = self.event_sources.sources.presence
|
||||||
|
|
||||||
since_token = sync_result_builder.since_token
|
since_token = sync_result_builder.since_token
|
||||||
presence_key = None
|
presence_key = None
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
|
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
|
||||||
|
|
||||||
from synapse.api.errors import AuthError, ShadowBanError, SynapseError
|
from synapse.api.errors import AuthError, ShadowBanError, SynapseError
|
||||||
from synapse.appservice import ApplicationService
|
from synapse.appservice import ApplicationService
|
||||||
@ -23,6 +23,7 @@ from synapse.metrics.background_process_metrics import (
|
|||||||
wrap_as_background_process,
|
wrap_as_background_process,
|
||||||
)
|
)
|
||||||
from synapse.replication.tcp.streams import TypingStream
|
from synapse.replication.tcp.streams import TypingStream
|
||||||
|
from synapse.streams import EventSource
|
||||||
from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
|
from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
@ -439,7 +440,7 @@ class TypingWriterHandler(FollowerTypingHandler):
|
|||||||
raise Exception("Typing writer instance got typing info over replication")
|
raise Exception("Typing writer instance got typing info over replication")
|
||||||
|
|
||||||
|
|
||||||
class TypingNotificationEventSource:
|
class TypingNotificationEventSource(EventSource[int, JsonDict]):
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
@ -485,7 +486,13 @@ class TypingNotificationEventSource:
|
|||||||
return (events, handler._latest_room_serial)
|
return (events, handler._latest_room_serial)
|
||||||
|
|
||||||
async def get_new_events(
|
async def get_new_events(
|
||||||
self, from_key: int, room_ids: Iterable[str], **kwargs: Any
|
self,
|
||||||
|
user: UserID,
|
||||||
|
from_key: int,
|
||||||
|
limit: Optional[int],
|
||||||
|
room_ids: Iterable[str],
|
||||||
|
is_guest: bool,
|
||||||
|
explicit_room_id: Optional[str] = None,
|
||||||
) -> Tuple[List[JsonDict], int]:
|
) -> Tuple[List[JsonDict], int]:
|
||||||
with Measure(self.clock, "typing.get_new_events"):
|
with Measure(self.clock, "typing.get_new_events"):
|
||||||
from_key = int(from_key)
|
from_key = int(from_key)
|
||||||
|
@ -91,7 +91,7 @@ class ModuleApi:
|
|||||||
self._auth = hs.get_auth()
|
self._auth = hs.get_auth()
|
||||||
self._auth_handler = auth_handler
|
self._auth_handler = auth_handler
|
||||||
self._server_name = hs.hostname
|
self._server_name = hs.hostname
|
||||||
self._presence_stream = hs.get_event_sources().sources["presence"]
|
self._presence_stream = hs.get_event_sources().sources.presence
|
||||||
self._state = hs.get_state_handler()
|
self._state = hs.get_state_handler()
|
||||||
self._clock: Clock = hs.get_clock()
|
self._clock: Clock = hs.get_clock()
|
||||||
self._send_email_handler = hs.get_send_email_handler()
|
self._send_email_handler = hs.get_send_email_handler()
|
||||||
|
@ -584,7 +584,7 @@ class Notifier:
|
|||||||
events: List[EventBase] = []
|
events: List[EventBase] = []
|
||||||
end_token = from_token
|
end_token = from_token
|
||||||
|
|
||||||
for name, source in self.event_sources.sources.items():
|
for name, source in self.event_sources.sources.get_sources():
|
||||||
keyname = "%s_key" % name
|
keyname = "%s_key" % name
|
||||||
before_id = getattr(before_token, keyname)
|
before_id = getattr(before_token, keyname)
|
||||||
after_id = getattr(after_token, keyname)
|
after_id = getattr(after_token, keyname)
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
@ -153,12 +153,12 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||||||
}
|
}
|
||||||
|
|
||||||
async def get_linearized_receipts_for_rooms(
|
async def get_linearized_receipts_for_rooms(
|
||||||
self, room_ids: List[str], to_key: int, from_key: Optional[int] = None
|
self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None
|
||||||
) -> List[dict]:
|
) -> List[dict]:
|
||||||
"""Get receipts for multiple rooms for sending to clients.
|
"""Get receipts for multiple rooms for sending to clients.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_id: List of room_ids.
|
room_id: The room IDs to fetch receipts of.
|
||||||
to_key: Max stream id to fetch receipts up to.
|
to_key: Max stream id to fetch receipts up to.
|
||||||
from_key: Min stream id to fetch receipts from. None fetches
|
from_key: Min stream id to fetch receipts from. None fetches
|
||||||
from the start.
|
from the start.
|
||||||
|
@ -11,3 +11,25 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Collection, Generic, List, Optional, Tuple, TypeVar
|
||||||
|
|
||||||
|
from synapse.types import UserID
|
||||||
|
|
||||||
|
# The key, this is either a stream token or int.
|
||||||
|
K = TypeVar("K")
|
||||||
|
# The return type.
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
|
class EventSource(Generic[K, R]):
|
||||||
|
async def get_new_events(
|
||||||
|
self,
|
||||||
|
user: UserID,
|
||||||
|
from_key: K,
|
||||||
|
limit: Optional[int],
|
||||||
|
room_ids: Collection[str],
|
||||||
|
is_guest: bool,
|
||||||
|
explicit_room_id: Optional[str] = None,
|
||||||
|
) -> Tuple[List[R], K]:
|
||||||
|
...
|
||||||
|
@ -12,29 +12,40 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import TYPE_CHECKING, Iterator, Tuple
|
||||||
|
|
||||||
|
import attr
|
||||||
|
|
||||||
from synapse.handlers.account_data import AccountDataEventSource
|
from synapse.handlers.account_data import AccountDataEventSource
|
||||||
from synapse.handlers.presence import PresenceEventSource
|
from synapse.handlers.presence import PresenceEventSource
|
||||||
from synapse.handlers.receipts import ReceiptEventSource
|
from synapse.handlers.receipts import ReceiptEventSource
|
||||||
from synapse.handlers.room import RoomEventSource
|
from synapse.handlers.room import RoomEventSource
|
||||||
from synapse.handlers.typing import TypingNotificationEventSource
|
from synapse.handlers.typing import TypingNotificationEventSource
|
||||||
|
from synapse.streams import EventSource
|
||||||
from synapse.types import StreamToken
|
from synapse.types import StreamToken
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(frozen=True, slots=True, auto_attribs=True)
|
||||||
|
class _EventSourcesInner:
|
||||||
|
room: RoomEventSource
|
||||||
|
presence: PresenceEventSource
|
||||||
|
typing: TypingNotificationEventSource
|
||||||
|
receipt: ReceiptEventSource
|
||||||
|
account_data: AccountDataEventSource
|
||||||
|
|
||||||
|
def get_sources(self) -> Iterator[Tuple[str, EventSource]]:
|
||||||
|
for attribute in _EventSourcesInner.__attrs_attrs__: # type: ignore[attr-defined]
|
||||||
|
yield attribute.name, getattr(self, attribute.name)
|
||||||
|
|
||||||
|
|
||||||
class EventSources:
|
class EventSources:
|
||||||
SOURCE_TYPES = {
|
def __init__(self, hs: "HomeServer"):
|
||||||
"room": RoomEventSource,
|
self.sources = _EventSourcesInner(
|
||||||
"presence": PresenceEventSource,
|
*(attribute.type(hs) for attribute in _EventSourcesInner.__attrs_attrs__) # type: ignore[attr-defined]
|
||||||
"typing": TypingNotificationEventSource,
|
)
|
||||||
"receipt": ReceiptEventSource,
|
|
||||||
"account_data": AccountDataEventSource,
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, hs):
|
|
||||||
self.sources: Dict[str, Any] = {
|
|
||||||
name: cls(hs) for name, cls in EventSources.SOURCE_TYPES.items()
|
|
||||||
}
|
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
def get_current_token(self) -> StreamToken:
|
def get_current_token(self) -> StreamToken:
|
||||||
@ -44,11 +55,11 @@ class EventSources:
|
|||||||
groups_key = self.store.get_group_stream_token()
|
groups_key = self.store.get_group_stream_token()
|
||||||
|
|
||||||
token = StreamToken(
|
token = StreamToken(
|
||||||
room_key=self.sources["room"].get_current_key(),
|
room_key=self.sources.room.get_current_key(),
|
||||||
presence_key=self.sources["presence"].get_current_key(),
|
presence_key=self.sources.presence.get_current_key(),
|
||||||
typing_key=self.sources["typing"].get_current_key(),
|
typing_key=self.sources.typing.get_current_key(),
|
||||||
receipt_key=self.sources["receipt"].get_current_key(),
|
receipt_key=self.sources.receipt.get_current_key(),
|
||||||
account_data_key=self.sources["account_data"].get_current_key(),
|
account_data_key=self.sources.account_data.get_current_key(),
|
||||||
push_rules_key=push_rules_key,
|
push_rules_key=push_rules_key,
|
||||||
to_device_key=to_device_key,
|
to_device_key=to_device_key,
|
||||||
device_list_key=device_list_key,
|
device_list_key=device_list_key,
|
||||||
@ -67,7 +78,7 @@ class EventSources:
|
|||||||
The current token for pagination.
|
The current token for pagination.
|
||||||
"""
|
"""
|
||||||
token = StreamToken(
|
token = StreamToken(
|
||||||
room_key=self.sources["room"].get_current_key(),
|
room_key=self.sources.room.get_current_key(),
|
||||||
presence_key=0,
|
presence_key=0,
|
||||||
typing_key=0,
|
typing_key=0,
|
||||||
receipt_key=0,
|
receipt_key=0,
|
||||||
|
@ -23,7 +23,7 @@ from tests import unittest
|
|||||||
|
|
||||||
class ReceiptsTestCase(unittest.HomeserverTestCase):
|
class ReceiptsTestCase(unittest.HomeserverTestCase):
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor, clock, hs):
|
||||||
self.event_source = hs.get_event_sources().sources["receipt"]
|
self.event_source = hs.get_event_sources().sources.receipt
|
||||||
|
|
||||||
# In the first param of _test_filters_hidden we use "hidden" instead of
|
# In the first param of _test_filters_hidden we use "hidden" instead of
|
||||||
# ReadReceiptEventFields.MSC2285_HIDDEN. We do this because we're mocking
|
# ReadReceiptEventFields.MSC2285_HIDDEN. We do this because we're mocking
|
||||||
|
@ -89,7 +89,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.handler = hs.get_typing_handler()
|
self.handler = hs.get_typing_handler()
|
||||||
|
|
||||||
self.event_source = hs.get_event_sources().sources["typing"]
|
self.event_source = hs.get_event_sources().sources.typing
|
||||||
|
|
||||||
self.datastore = hs.get_datastore()
|
self.datastore = hs.get_datastore()
|
||||||
self.datastore.get_destination_retry_timings = Mock(
|
self.datastore.get_destination_retry_timings = Mock(
|
||||||
@ -171,7 +171,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||||
events = self.get_success(
|
events = self.get_success(
|
||||||
self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
|
self.event_source.get_new_events(
|
||||||
|
user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
events[0],
|
events[0],
|
||||||
@ -239,7 +241,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||||
events = self.get_success(
|
events = self.get_success(
|
||||||
self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
|
self.event_source.get_new_events(
|
||||||
|
user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
events[0],
|
events[0],
|
||||||
@ -276,7 +280,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 0)
|
self.assertEquals(self.event_source.get_current_key(), 0)
|
||||||
events = self.get_success(
|
events = self.get_success(
|
||||||
self.event_source.get_new_events(room_ids=[OTHER_ROOM_ID], from_key=0)
|
self.event_source.get_new_events(
|
||||||
|
user=U_APPLE,
|
||||||
|
from_key=0,
|
||||||
|
limit=None,
|
||||||
|
room_ids=[OTHER_ROOM_ID],
|
||||||
|
is_guest=False,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.assertEquals(events[0], [])
|
self.assertEquals(events[0], [])
|
||||||
self.assertEquals(events[1], 0)
|
self.assertEquals(events[1], 0)
|
||||||
@ -324,7 +334,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||||
events = self.get_success(
|
events = self.get_success(
|
||||||
self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
|
self.event_source.get_new_events(
|
||||||
|
user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
events[0],
|
events[0],
|
||||||
@ -350,7 +362,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||||
events = self.get_success(
|
events = self.get_success(
|
||||||
self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
|
self.event_source.get_new_events(
|
||||||
|
user=U_APPLE,
|
||||||
|
from_key=0,
|
||||||
|
limit=None,
|
||||||
|
room_ids=[ROOM_ID],
|
||||||
|
is_guest=False,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
events[0],
|
events[0],
|
||||||
@ -369,7 +387,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 2)
|
self.assertEquals(self.event_source.get_current_key(), 2)
|
||||||
events = self.get_success(
|
events = self.get_success(
|
||||||
self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1)
|
self.event_source.get_new_events(
|
||||||
|
user=U_APPLE,
|
||||||
|
from_key=1,
|
||||||
|
limit=None,
|
||||||
|
room_ids=[ROOM_ID],
|
||||||
|
is_guest=False,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
events[0],
|
events[0],
|
||||||
@ -392,7 +416,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 3)
|
self.assertEquals(self.event_source.get_current_key(), 3)
|
||||||
events = self.get_success(
|
events = self.get_success(
|
||||||
self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
|
self.event_source.get_new_events(
|
||||||
|
user=U_APPLE,
|
||||||
|
from_key=0,
|
||||||
|
limit=None,
|
||||||
|
room_ids=[ROOM_ID],
|
||||||
|
is_guest=False,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
events[0],
|
events[0],
|
||||||
|
@ -193,7 +193,7 @@ class RoomTestCase(_ShadowBannedBase):
|
|||||||
self.assertEquals(200, channel.code)
|
self.assertEquals(200, channel.code)
|
||||||
|
|
||||||
# There should be no typing events.
|
# There should be no typing events.
|
||||||
event_source = self.hs.get_event_sources().sources["typing"]
|
event_source = self.hs.get_event_sources().sources.typing
|
||||||
self.assertEquals(event_source.get_current_key(), 0)
|
self.assertEquals(event_source.get_current_key(), 0)
|
||||||
|
|
||||||
# The other user can join and send typing events.
|
# The other user can join and send typing events.
|
||||||
@ -210,7 +210,13 @@ class RoomTestCase(_ShadowBannedBase):
|
|||||||
# These appear in the room.
|
# These appear in the room.
|
||||||
self.assertEquals(event_source.get_current_key(), 1)
|
self.assertEquals(event_source.get_current_key(), 1)
|
||||||
events = self.get_success(
|
events = self.get_success(
|
||||||
event_source.get_new_events(from_key=0, room_ids=[room_id])
|
event_source.get_new_events(
|
||||||
|
user=UserID.from_string(self.other_user_id),
|
||||||
|
from_key=0,
|
||||||
|
limit=None,
|
||||||
|
room_ids=[room_id],
|
||||||
|
is_guest=False,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
events[0],
|
events[0],
|
||||||
|
@ -41,7 +41,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
|
|||||||
federation_client=Mock(),
|
federation_client=Mock(),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.event_source = hs.get_event_sources().sources["typing"]
|
self.event_source = hs.get_event_sources().sources.typing
|
||||||
|
|
||||||
hs.get_federation_handler = Mock()
|
hs.get_federation_handler = Mock()
|
||||||
|
|
||||||
@ -76,7 +76,13 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||||
events = self.get_success(
|
events = self.get_success(
|
||||||
self.event_source.get_new_events(from_key=0, room_ids=[self.room_id])
|
self.event_source.get_new_events(
|
||||||
|
user=UserID.from_string(self.user_id),
|
||||||
|
from_key=0,
|
||||||
|
limit=None,
|
||||||
|
room_ids=[self.room_id],
|
||||||
|
is_guest=False,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
events[0],
|
events[0],
|
||||||
|
Loading…
Reference in New Issue
Block a user