Add typing info to Notifier (#8058)

This commit is contained in:
Erik Johnston 2020-08-11 19:40:02 +01:00 committed by GitHub
parent a0f574f3c2
commit a1e9bb9eae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 91 additions and 52 deletions

1
changelog.d/8058.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to `Notifier`.

View File

@ -57,13 +57,10 @@ class EventStreamHandler(BaseHandler):
timeout=0, timeout=0,
as_client_event=True, as_client_event=True,
affect_presence=True, affect_presence=True,
only_keys=None,
room_id=None, room_id=None,
is_guest=False, is_guest=False,
): ):
"""Fetches the events stream for a given user. """Fetches the events stream for a given user.
If `only_keys` is not None, events from keys will be sent down.
""" """
if room_id: if room_id:
@ -93,7 +90,6 @@ class EventStreamHandler(BaseHandler):
auth_user, auth_user,
pagin_config, pagin_config,
timeout, timeout,
only_keys=only_keys,
is_guest=is_guest, is_guest=is_guest,
explicit_room_id=room_id, explicit_room_id=room_id,
) )

View File

@ -15,7 +15,17 @@
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import Callable, Iterable, List, TypeVar from typing import (
Awaitable,
Callable,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
TypeVar,
)
from prometheus_client import Counter from prometheus_client import Counter
@ -24,12 +34,14 @@ from twisted.internet import defer
import synapse.server import synapse.server
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.events import EventBase
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
from synapse.logging.context import PreserveLoggingContext from synapse.logging.context import PreserveLoggingContext
from synapse.logging.utils import log_function from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import StreamToken from synapse.streams.config import PaginationConfig
from synapse.types import Collection, StreamToken, UserID
from synapse.util.async_helpers import ObservableDeferred, timeout_deferred from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
@ -77,7 +89,13 @@ class _NotifierUserStream(object):
so that it can remove itself from the indexes in the Notifier class. so that it can remove itself from the indexes in the Notifier class.
""" """
def __init__(self, user_id, rooms, current_token, time_now_ms): def __init__(
self,
user_id: str,
rooms: Collection[str],
current_token: StreamToken,
time_now_ms: int,
):
self.user_id = user_id self.user_id = user_id
self.rooms = set(rooms) self.rooms = set(rooms)
self.current_token = current_token self.current_token = current_token
@ -93,13 +111,13 @@ class _NotifierUserStream(object):
with PreserveLoggingContext(): with PreserveLoggingContext():
self.notify_deferred = ObservableDeferred(defer.Deferred()) self.notify_deferred = ObservableDeferred(defer.Deferred())
def notify(self, stream_key, stream_id, time_now_ms): def notify(self, stream_key: str, stream_id: int, time_now_ms: int):
"""Notify any listeners for this user of a new event from an """Notify any listeners for this user of a new event from an
event source. event source.
Args: Args:
stream_key(str): The stream the event came from. stream_key: The stream the event came from.
stream_id(str): The new id for the stream the event came from. stream_id: The new id for the stream the event came from.
time_now_ms(int): The current time in milliseconds. time_now_ms: The current time in milliseconds.
""" """
self.current_token = self.current_token.copy_and_advance(stream_key, stream_id) self.current_token = self.current_token.copy_and_advance(stream_key, stream_id)
self.last_notified_token = self.current_token self.last_notified_token = self.current_token
@ -112,7 +130,7 @@ class _NotifierUserStream(object):
self.notify_deferred = ObservableDeferred(defer.Deferred()) self.notify_deferred = ObservableDeferred(defer.Deferred())
noify_deferred.callback(self.current_token) noify_deferred.callback(self.current_token)
def remove(self, notifier): def remove(self, notifier: "Notifier"):
""" Remove this listener from all the indexes in the Notifier """ Remove this listener from all the indexes in the Notifier
it knows about. it knows about.
""" """
@ -123,10 +141,10 @@ class _NotifierUserStream(object):
notifier.user_to_user_stream.pop(self.user_id) notifier.user_to_user_stream.pop(self.user_id)
def count_listeners(self): def count_listeners(self) -> int:
return len(self.notify_deferred.observers()) return len(self.notify_deferred.observers())
def new_listener(self, token): def new_listener(self, token: StreamToken) -> _NotificationListener:
"""Returns a deferred that is resolved when there is a new token """Returns a deferred that is resolved when there is a new token
greater than the given token. greater than the given token.
@ -159,14 +177,16 @@ class Notifier(object):
UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000 UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000
def __init__(self, hs: "synapse.server.HomeServer"): def __init__(self, hs: "synapse.server.HomeServer"):
self.user_to_user_stream = {} self.user_to_user_stream = {} # type: Dict[str, _NotifierUserStream]
self.room_to_user_streams = {} self.room_to_user_streams = {} # type: Dict[str, Set[_NotifierUserStream]]
self.hs = hs self.hs = hs
self.storage = hs.get_storage() self.storage = hs.get_storage()
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.pending_new_room_events = [] self.pending_new_room_events = (
[]
) # type: List[Tuple[int, EventBase, Collection[str]]]
# Called when there are new things to stream over replication # Called when there are new things to stream over replication
self.replication_callbacks = [] # type: List[Callable[[], None]] self.replication_callbacks = [] # type: List[Callable[[], None]]
@ -178,10 +198,9 @@ class Notifier(object):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.appservice_handler = hs.get_application_service_handler() self.appservice_handler = hs.get_application_service_handler()
self.federation_sender = None
if hs.should_send_federation(): if hs.should_send_federation():
self.federation_sender = hs.get_federation_sender() self.federation_sender = hs.get_federation_sender()
else:
self.federation_sender = None
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
@ -193,12 +212,12 @@ class Notifier(object):
# when rendering the metrics page, which is likely once per minute at # when rendering the metrics page, which is likely once per minute at
# most when scraping it. # most when scraping it.
def count_listeners(): def count_listeners():
all_user_streams = set() all_user_streams = set() # type: Set[_NotifierUserStream]
for x in list(self.room_to_user_streams.values()): for streams in list(self.room_to_user_streams.values()):
all_user_streams |= x all_user_streams |= streams
for x in list(self.user_to_user_stream.values()): for stream in list(self.user_to_user_stream.values()):
all_user_streams.add(x) all_user_streams.add(stream)
return sum(stream.count_listeners() for stream in all_user_streams) return sum(stream.count_listeners() for stream in all_user_streams)
@ -223,7 +242,11 @@ class Notifier(object):
self.replication_callbacks.append(cb) self.replication_callbacks.append(cb)
def on_new_room_event( def on_new_room_event(
self, event, room_stream_id, max_room_stream_id, extra_users=[] self,
event: EventBase,
room_stream_id: int,
max_room_stream_id: int,
extra_users: Collection[str] = [],
): ):
""" Used by handlers to inform the notifier something has happened """ Used by handlers to inform the notifier something has happened
in the room, room event wise. in the room, room event wise.
@ -241,11 +264,11 @@ class Notifier(object):
self.notify_replication() self.notify_replication()
def _notify_pending_new_room_events(self, max_room_stream_id): def _notify_pending_new_room_events(self, max_room_stream_id: int):
"""Notify for the room events that were queued waiting for a previous """Notify for the room events that were queued waiting for a previous
event to be persisted. event to be persisted.
Args: Args:
max_room_stream_id(int): The highest stream_id below which all max_room_stream_id: The highest stream_id below which all
events have been persisted. events have been persisted.
""" """
pending = self.pending_new_room_events pending = self.pending_new_room_events
@ -258,7 +281,9 @@ class Notifier(object):
else: else:
self._on_new_room_event(event, room_stream_id, extra_users) self._on_new_room_event(event, room_stream_id, extra_users)
def _on_new_room_event(self, event, room_stream_id, extra_users=[]): def _on_new_room_event(
self, event: EventBase, room_stream_id: int, extra_users: Collection[str] = []
):
"""Notify any user streams that are interested in this room event""" """Notify any user streams that are interested in this room event"""
# poke any interested application service. # poke any interested application service.
run_as_background_process( run_as_background_process(
@ -275,13 +300,19 @@ class Notifier(object):
"room_key", room_stream_id, users=extra_users, rooms=[event.room_id] "room_key", room_stream_id, users=extra_users, rooms=[event.room_id]
) )
async def _notify_app_services(self, room_stream_id): async def _notify_app_services(self, room_stream_id: int):
try: try:
await self.appservice_handler.notify_interested_services(room_stream_id) await self.appservice_handler.notify_interested_services(room_stream_id)
except Exception: except Exception:
logger.exception("Error notifying application services of event") logger.exception("Error notifying application services of event")
def on_new_event(self, stream_key, new_token, users=[], rooms=[]): def on_new_event(
self,
stream_key: str,
new_token: int,
users: Collection[str] = [],
rooms: Collection[str] = [],
):
""" Used to inform listeners that something has happened event wise. """ Used to inform listeners that something has happened event wise.
Will wake up all listeners for the given users and rooms. Will wake up all listeners for the given users and rooms.
@ -307,14 +338,19 @@ class Notifier(object):
self.notify_replication() self.notify_replication()
def on_new_replication_data(self): def on_new_replication_data(self) -> None:
"""Used to inform replication listeners that something has happend """Used to inform replication listeners that something has happend
without waking up any of the normal user event streams""" without waking up any of the normal user event streams"""
self.notify_replication() self.notify_replication()
async def wait_for_events( async def wait_for_events(
self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START self,
): user_id: str,
timeout: int,
callback: Callable[[StreamToken, StreamToken], Awaitable[T]],
room_ids=None,
from_token=StreamToken.START,
) -> T:
"""Wait until the callback returns a non empty response or the """Wait until the callback returns a non empty response or the
timeout fires. timeout fires.
""" """
@ -377,19 +413,16 @@ class Notifier(object):
async def get_events_for( async def get_events_for(
self, self,
user, user: UserID,
pagination_config, pagination_config: PaginationConfig,
timeout, timeout: int,
only_keys=None, is_guest: bool = False,
is_guest=False, explicit_room_id: str = None,
explicit_room_id=None, ) -> EventStreamResult:
):
""" For the given user and rooms, return any new events for them. If """ For the given user and rooms, return any new events for them. If
there are no new events wait for up to `timeout` milliseconds for any there are no new events wait for up to `timeout` milliseconds for any
new events to happen before returning. new events to happen before returning.
If `only_keys` is not None, events from keys will be sent down.
If explicit_room_id is not set, the user's joined rooms will be polled If explicit_room_id is not set, the user's joined rooms will be polled
for events. for events.
If explicit_room_id is set, that room will be polled for events only if If explicit_room_id is set, that room will be polled for events only if
@ -404,11 +437,13 @@ class Notifier(object):
room_ids, is_joined = await self._get_room_ids(user, explicit_room_id) room_ids, is_joined = await self._get_room_ids(user, explicit_room_id)
is_peeking = not is_joined is_peeking = not is_joined
async def check_for_updates(before_token, after_token): async def check_for_updates(
before_token: StreamToken, after_token: StreamToken
) -> EventStreamResult:
if not after_token.is_after(before_token): if not after_token.is_after(before_token):
return EventStreamResult([], (from_token, from_token)) return EventStreamResult([], (from_token, from_token))
events = [] events = [] # type: List[EventBase]
end_token = from_token end_token = from_token
for name, source in self.event_sources.sources.items(): for name, source in self.event_sources.sources.items():
@ -417,8 +452,6 @@ class Notifier(object):
after_id = getattr(after_token, keyname) after_id = getattr(after_token, keyname)
if before_id == after_id: if before_id == after_id:
continue continue
if only_keys and name not in only_keys:
continue
new_events, new_key = await source.get_new_events( new_events, new_key = await source.get_new_events(
user=user, user=user,
@ -476,7 +509,9 @@ class Notifier(object):
return result return result
async def _get_room_ids(self, user, explicit_room_id): async def _get_room_ids(
self, user: UserID, explicit_room_id: Optional[str]
) -> Tuple[Collection[str], bool]:
joined_room_ids = await self.store.get_rooms_for_user(user.to_string()) joined_room_ids = await self.store.get_rooms_for_user(user.to_string())
if explicit_room_id: if explicit_room_id:
if explicit_room_id in joined_room_ids: if explicit_room_id in joined_room_ids:
@ -486,7 +521,7 @@ class Notifier(object):
raise AuthError(403, "Non-joined access not allowed") raise AuthError(403, "Non-joined access not allowed")
return joined_room_ids, True return joined_room_ids, True
async def _is_world_readable(self, room_id): async def _is_world_readable(self, room_id: str) -> bool:
state = await self.state_handler.get_current_state( state = await self.state_handler.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, "" room_id, EventTypes.RoomHistoryVisibility, ""
) )
@ -496,7 +531,7 @@ class Notifier(object):
return False return False
@log_function @log_function
def remove_expired_streams(self): def remove_expired_streams(self) -> None:
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
expired_streams = [] expired_streams = []
expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS
@ -510,21 +545,21 @@ class Notifier(object):
expired_stream.remove(self) expired_stream.remove(self)
@log_function @log_function
def _register_with_keys(self, user_stream): def _register_with_keys(self, user_stream: _NotifierUserStream):
self.user_to_user_stream[user_stream.user_id] = user_stream self.user_to_user_stream[user_stream.user_id] = user_stream
for room in user_stream.rooms: for room in user_stream.rooms:
s = self.room_to_user_streams.setdefault(room, set()) s = self.room_to_user_streams.setdefault(room, set())
s.add(user_stream) s.add(user_stream)
def _user_joined_room(self, user_id, room_id): def _user_joined_room(self, user_id: str, room_id: str):
new_user_stream = self.user_to_user_stream.get(user_id) new_user_stream = self.user_to_user_stream.get(user_id)
if new_user_stream is not None: if new_user_stream is not None:
room_streams = self.room_to_user_streams.setdefault(room_id, set()) room_streams = self.room_to_user_streams.setdefault(room_id, set())
room_streams.add(new_user_stream) room_streams.add(new_user_stream)
new_user_stream.rooms.add(room_id) new_user_stream.rooms.add(room_id)
def notify_replication(self): def notify_replication(self) -> None:
"""Notify the any replication listeners that there's a new event""" """Notify the any replication listeners that there's a new event"""
for cb in self.replication_callbacks: for cb in self.replication_callbacks:
cb() cb()

View File

@ -31,8 +31,10 @@ import synapse.server_notices.server_notices_sender
import synapse.state import synapse.state
import synapse.storage import synapse.storage
from synapse.events.builder import EventBuilderFactory from synapse.events.builder import EventBuilderFactory
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.typing import FollowerTypingHandler from synapse.handlers.typing import FollowerTypingHandler
from synapse.replication.tcp.streams import Stream from synapse.replication.tcp.streams import Stream
from synapse.streams.events import EventSources
class HomeServer(object): class HomeServer(object):
@property @property
@ -153,3 +155,7 @@ class HomeServer(object):
pass pass
def get_typing_handler(self) -> FollowerTypingHandler: def get_typing_handler(self) -> FollowerTypingHandler:
pass pass
def get_event_sources(self) -> EventSources:
pass
def get_application_service_handler(self):
return ApplicationServicesHandler(self)

View File

@ -198,6 +198,7 @@ commands = mypy \
synapse/logging/ \ synapse/logging/ \
synapse/metrics \ synapse/metrics \
synapse/module_api \ synapse/module_api \
synapse/notifier.py \
synapse/push/pusherpool.py \ synapse/push/pusherpool.py \
synapse/push/push_rule_evaluator.py \ synapse/push/push_rule_evaluator.py \
synapse/replication \ synapse/replication \