Add type hints for the federation sender. (#9681)

Includes an abstract base class which both the FederationSender
and the FederationRemoteSendQueue must implement.
This commit is contained in:
Patrick Cloke 2021-03-29 11:43:20 -04:00 committed by GitHub
parent 4bbd535450
commit da75d2ea1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 177 additions and 59 deletions

1
changelog.d/9681.misc Normal file
View File

@ -0,0 +1 @@
Add additional type hints to the Homeserver object.

View File

@ -787,13 +787,6 @@ class FederationSenderHandler:
self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer") self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
def on_start(self):
# There may be some events that are persisted but haven't been sent,
# so send them now.
self.federation_sender.notify_new_events(
self.store.get_room_max_stream_ordering()
)
def wake_destination(self, server: str): def wake_destination(self, server: str):
self.federation_sender.wake_destination(server) self.federation_sender.wake_destination(server)

View File

@ -31,25 +31,39 @@ Events are replicated via a separate events stream.
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import Dict, List, Tuple, Type from typing import (
TYPE_CHECKING,
Dict,
Hashable,
Iterable,
List,
Optional,
Sized,
Tuple,
Type,
)
from sortedcontainers import SortedDict from sortedcontainers import SortedDict
from twisted.internet import defer
from synapse.api.presence import UserPresenceState from synapse.api.presence import UserPresenceState
from synapse.federation.sender import AbstractFederationSender, FederationSender
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
from synapse.replication.tcp.streams.federation import FederationStream
from synapse.types import JsonDict, ReadReceipt, RoomStreamToken
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from .units import Edu from .units import Edu
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class FederationRemoteSendQueue: class FederationRemoteSendQueue(AbstractFederationSender):
"""A drop in replacement for FederationSender""" """A drop in replacement for FederationSender"""
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname self.server_name = hs.hostname
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
@ -58,7 +72,7 @@ class FederationRemoteSendQueue:
# We may have multiple federation sender instances, so we need to track # We may have multiple federation sender instances, so we need to track
# their positions separately. # their positions separately.
self._sender_instances = hs.config.worker.federation_shard_config.instances self._sender_instances = hs.config.worker.federation_shard_config.instances
self._sender_positions = {} self._sender_positions = {} # type: Dict[str, int]
# Pending presence map user_id -> UserPresenceState # Pending presence map user_id -> UserPresenceState
self.presence_map = {} # type: Dict[str, UserPresenceState] self.presence_map = {} # type: Dict[str, UserPresenceState]
@ -71,7 +85,7 @@ class FederationRemoteSendQueue:
# Stream position -> (user_id, destinations) # Stream position -> (user_id, destinations)
self.presence_destinations = ( self.presence_destinations = (
SortedDict() SortedDict()
) # type: SortedDict[int, Tuple[str, List[str]]] ) # type: SortedDict[int, Tuple[str, Iterable[str]]]
# (destination, key) -> EDU # (destination, key) -> EDU
self.keyed_edu = {} # type: Dict[Tuple[str, tuple], Edu] self.keyed_edu = {} # type: Dict[Tuple[str, tuple], Edu]
@ -94,7 +108,7 @@ class FederationRemoteSendQueue:
# we make a new function, so we need to make a new function so the inner # we make a new function, so we need to make a new function so the inner
# lambda binds to the queue rather than to the name of the queue which # lambda binds to the queue rather than to the name of the queue which
# changes. ARGH. # changes. ARGH.
def register(name, queue): def register(name: str, queue: Sized) -> None:
LaterGauge( LaterGauge(
"synapse_federation_send_queue_%s_size" % (queue_name,), "synapse_federation_send_queue_%s_size" % (queue_name,),
"", "",
@ -115,13 +129,13 @@ class FederationRemoteSendQueue:
self.clock.looping_call(self._clear_queue, 30 * 1000) self.clock.looping_call(self._clear_queue, 30 * 1000)
def _next_pos(self): def _next_pos(self) -> int:
pos = self.pos pos = self.pos
self.pos += 1 self.pos += 1
self.pos_time[self.clock.time_msec()] = pos self.pos_time[self.clock.time_msec()] = pos
return pos return pos
def _clear_queue(self): def _clear_queue(self) -> None:
"""Clear the queues for anything older than N minutes""" """Clear the queues for anything older than N minutes"""
FIVE_MINUTES_AGO = 5 * 60 * 1000 FIVE_MINUTES_AGO = 5 * 60 * 1000
@ -138,7 +152,7 @@ class FederationRemoteSendQueue:
self._clear_queue_before_pos(position_to_delete) self._clear_queue_before_pos(position_to_delete)
def _clear_queue_before_pos(self, position_to_delete): def _clear_queue_before_pos(self, position_to_delete: int) -> None:
"""Clear all the queues from before a given position""" """Clear all the queues from before a given position"""
with Measure(self.clock, "send_queue._clear"): with Measure(self.clock, "send_queue._clear"):
# Delete things out of presence maps # Delete things out of presence maps
@ -188,13 +202,18 @@ class FederationRemoteSendQueue:
for key in keys[:i]: for key in keys[:i]:
del self.edus[key] del self.edus[key]
def notify_new_events(self, max_token): def notify_new_events(self, max_token: RoomStreamToken) -> None:
"""As per FederationSender""" """As per FederationSender"""
# We don't need to replicate this as it gets sent down a different # This should never get called.
# stream. raise NotImplementedError()
pass
def build_and_send_edu(self, destination, edu_type, content, key=None): def build_and_send_edu(
self,
destination: str,
edu_type: str,
content: JsonDict,
key: Optional[Hashable] = None,
) -> None:
"""As per FederationSender""" """As per FederationSender"""
if destination == self.server_name: if destination == self.server_name:
logger.info("Not sending EDU to ourselves") logger.info("Not sending EDU to ourselves")
@ -218,38 +237,39 @@ class FederationRemoteSendQueue:
self.notifier.on_new_replication_data() self.notifier.on_new_replication_data()
def send_read_receipt(self, receipt): async def send_read_receipt(self, receipt: ReadReceipt) -> None:
"""As per FederationSender """As per FederationSender
Args: Args:
receipt (synapse.types.ReadReceipt): receipt:
""" """
# nothing to do here: the replication listener will handle it. # nothing to do here: the replication listener will handle it.
return defer.succeed(None)
def send_presence(self, states): def send_presence(self, states: List[UserPresenceState]) -> None:
"""As per FederationSender """As per FederationSender
Args: Args:
states (list(UserPresenceState)) states
""" """
pos = self._next_pos() pos = self._next_pos()
# We only want to send presence for our own users, so lets always just # We only want to send presence for our own users, so lets always just
# filter here just in case. # filter here just in case.
local_states = list(filter(lambda s: self.is_mine_id(s.user_id), states)) local_states = [s for s in states if self.is_mine_id(s.user_id)]
self.presence_map.update({state.user_id: state for state in local_states}) self.presence_map.update({state.user_id: state for state in local_states})
self.presence_changed[pos] = [state.user_id for state in local_states] self.presence_changed[pos] = [state.user_id for state in local_states]
self.notifier.on_new_replication_data() self.notifier.on_new_replication_data()
def send_presence_to_destinations(self, states, destinations): def send_presence_to_destinations(
self, states: Iterable[UserPresenceState], destinations: Iterable[str]
) -> None:
"""As per FederationSender """As per FederationSender
Args: Args:
states (list[UserPresenceState]) states
destinations (list[str]) destinations
""" """
for state in states: for state in states:
pos = self._next_pos() pos = self._next_pos()
@ -258,15 +278,18 @@ class FederationRemoteSendQueue:
self.notifier.on_new_replication_data() self.notifier.on_new_replication_data()
def send_device_messages(self, destination): def send_device_messages(self, destination: str) -> None:
"""As per FederationSender""" """As per FederationSender"""
# We don't need to replicate this as it gets sent down a different # We don't need to replicate this as it gets sent down a different
# stream. # stream.
def get_current_token(self): def wake_destination(self, server: str) -> None:
pass
def get_current_token(self) -> int:
return self.pos - 1 return self.pos - 1
def federation_ack(self, instance_name, token): def federation_ack(self, instance_name: str, token: int) -> None:
if self._sender_instances: if self._sender_instances:
# If we have configured multiple federation sender instances we need # If we have configured multiple federation sender instances we need
# to track their positions separately, and only clear the queue up # to track their positions separately, and only clear the queue up
@ -504,13 +527,16 @@ ParsedFederationStreamData = namedtuple(
) )
def process_rows_for_federation(transaction_queue, rows): def process_rows_for_federation(
transaction_queue: FederationSender,
rows: List[FederationStream.FederationStreamRow],
) -> None:
"""Parse a list of rows from the federation stream and put them in the """Parse a list of rows from the federation stream and put them in the
transaction queue ready for sending to the relevant homeservers. transaction queue ready for sending to the relevant homeservers.
Args: Args:
transaction_queue (FederationSender) transaction_queue
rows (list(synapse.replication.tcp.streams.federation.FederationStream.FederationStreamRow)) rows
""" """
# The federation stream contains a bunch of different types of # The federation stream contains a bunch of different types of

View File

@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc
import logging import logging
from typing import Dict, Hashable, Iterable, List, Optional, Set, Tuple from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Set, Tuple
from prometheus_client import Counter from prometheus_client import Counter
from twisted.internet import defer from twisted.internet import defer
import synapse
import synapse.metrics import synapse.metrics
from synapse.api.presence import UserPresenceState from synapse.api.presence import UserPresenceState
from synapse.events import EventBase from synapse.events import EventBase
@ -40,9 +40,12 @@ from synapse.metrics import (
events_processed_counter, events_processed_counter,
) )
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import ReadReceipt, RoomStreamToken from synapse.types import JsonDict, ReadReceipt, RoomStreamToken
from synapse.util.metrics import Measure, measure_func from synapse.util.metrics import Measure, measure_func
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
sent_pdus_destination_dist_count = Counter( sent_pdus_destination_dist_count = Counter(
@ -65,8 +68,91 @@ CATCH_UP_STARTUP_DELAY_SEC = 15
CATCH_UP_STARTUP_INTERVAL_SEC = 5 CATCH_UP_STARTUP_INTERVAL_SEC = 5
class FederationSender: class AbstractFederationSender(metaclass=abc.ABCMeta):
def __init__(self, hs: "synapse.server.HomeServer"): @abc.abstractmethod
def notify_new_events(self, max_token: RoomStreamToken) -> None:
"""This gets called when we have some new events we might want to
send out to other servers.
"""
raise NotImplementedError()
@abc.abstractmethod
async def send_read_receipt(self, receipt: ReadReceipt) -> None:
"""Send a RR to any other servers in the room
Args:
receipt: receipt to be sent
"""
raise NotImplementedError()
@abc.abstractmethod
def send_presence(self, states: List[UserPresenceState]) -> None:
"""Send the new presence states to the appropriate destinations.
This actually queues up the presence states ready for sending and
triggers a background task to process them and send out the transactions.
"""
raise NotImplementedError()
@abc.abstractmethod
def send_presence_to_destinations(
self, states: Iterable[UserPresenceState], destinations: Iterable[str]
) -> None:
"""Send the given presence states to the given destinations.
Args:
destinations:
"""
raise NotImplementedError()
@abc.abstractmethod
def build_and_send_edu(
self,
destination: str,
edu_type: str,
content: JsonDict,
key: Optional[Hashable] = None,
) -> None:
"""Construct an Edu object, and queue it for sending
Args:
destination: name of server to send to
edu_type: type of EDU to send
content: content of EDU
key: clobbering key for this edu
"""
raise NotImplementedError()
@abc.abstractmethod
def send_device_messages(self, destination: str) -> None:
raise NotImplementedError()
@abc.abstractmethod
def wake_destination(self, destination: str) -> None:
"""Called when we want to retry sending transactions to a remote.
This is mainly useful if the remote server has been down and we think it
might have come back.
"""
raise NotImplementedError()
@abc.abstractmethod
def get_current_token(self) -> int:
raise NotImplementedError()
@abc.abstractmethod
def federation_ack(self, instance_name: str, token: int) -> None:
raise NotImplementedError()
@abc.abstractmethod
async def get_replication_rows(
self, instance_name: str, from_token: int, to_token: int, target_row_count: int
) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
raise NotImplementedError()
class FederationSender(AbstractFederationSender):
def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.server_name = hs.hostname self.server_name = hs.hostname
@ -432,7 +518,7 @@ class FederationSender:
queue.flush_read_receipts_for_room(room_id) queue.flush_read_receipts_for_room(room_id)
@preserve_fn # the caller should not yield on this @preserve_fn # the caller should not yield on this
async def send_presence(self, states: List[UserPresenceState]): async def send_presence(self, states: List[UserPresenceState]) -> None:
"""Send the new presence states to the appropriate destinations. """Send the new presence states to the appropriate destinations.
This actually queues up the presence states ready for sending and This actually queues up the presence states ready for sending and
@ -494,7 +580,7 @@ class FederationSender:
self._get_per_destination_queue(destination).send_presence(states) self._get_per_destination_queue(destination).send_presence(states)
@measure_func("txnqueue._process_presence") @measure_func("txnqueue._process_presence")
async def _process_presence_inner(self, states: List[UserPresenceState]): async def _process_presence_inner(self, states: List[UserPresenceState]) -> None:
"""Given a list of states populate self.pending_presence_by_dest and """Given a list of states populate self.pending_presence_by_dest and
poke to send a new transaction to each destination poke to send a new transaction to each destination
""" """
@ -516,9 +602,9 @@ class FederationSender:
self, self,
destination: str, destination: str,
edu_type: str, edu_type: str,
content: dict, content: JsonDict,
key: Optional[Hashable] = None, key: Optional[Hashable] = None,
): ) -> None:
"""Construct an Edu object, and queue it for sending """Construct an Edu object, and queue it for sending
Args: Args:
@ -545,7 +631,7 @@ class FederationSender:
self.send_edu(edu, key) self.send_edu(edu, key)
def send_edu(self, edu: Edu, key: Optional[Hashable]): def send_edu(self, edu: Edu, key: Optional[Hashable]) -> None:
"""Queue an EDU for sending """Queue an EDU for sending
Args: Args:
@ -563,7 +649,7 @@ class FederationSender:
else: else:
queue.send_edu(edu) queue.send_edu(edu)
def send_device_messages(self, destination: str): def send_device_messages(self, destination: str) -> None:
if destination == self.server_name: if destination == self.server_name:
logger.warning("Not sending device update to ourselves") logger.warning("Not sending device update to ourselves")
return return
@ -575,7 +661,7 @@ class FederationSender:
self._get_per_destination_queue(destination).attempt_new_transaction() self._get_per_destination_queue(destination).attempt_new_transaction()
def wake_destination(self, destination: str): def wake_destination(self, destination: str) -> None:
"""Called when we want to retry sending transactions to a remote. """Called when we want to retry sending transactions to a remote.
This is mainly useful if the remote server has been down and we think it This is mainly useful if the remote server has been down and we think it
@ -599,6 +685,10 @@ class FederationSender:
# to a worker. # to a worker.
return 0 return 0
def federation_ack(self, instance_name: str, token: int) -> None:
# It is not expected that this gets called on FederationSender.
raise NotImplementedError()
@staticmethod @staticmethod
async def get_replication_rows( async def get_replication_rows(
instance_name: str, from_token: int, to_token: int, target_row_count: int instance_name: str, from_token: int, to_token: int, target_row_count: int
@ -607,7 +697,7 @@ class FederationSender:
# to a worker. # to a worker.
return [], 0, False return [], 0, False
async def _wake_destinations_needing_catchup(self): async def _wake_destinations_needing_catchup(self) -> None:
""" """
Wakes up destinations that need catch-up and are not currently being Wakes up destinations that need catch-up and are not currently being
backed off from. backed off from.

View File

@ -312,16 +312,16 @@ class FederationAckCommand(Command):
NAME = "FEDERATION_ACK" NAME = "FEDERATION_ACK"
def __init__(self, instance_name, token): def __init__(self, instance_name: str, token: int):
self.instance_name = instance_name self.instance_name = instance_name
self.token = token self.token = token
@classmethod @classmethod
def from_line(cls, line): def from_line(cls, line: str) -> "FederationAckCommand":
instance_name, token = line.split(" ") instance_name, token = line.split(" ")
return cls(instance_name, int(token)) return cls(instance_name, int(token))
def to_line(self): def to_line(self) -> str:
return "%s %s" % (self.instance_name, self.token) return "%s %s" % (self.instance_name, self.token)

View File

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import namedtuple from collections import namedtuple
from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Tuple
from synapse.replication.tcp.streams._base import ( from synapse.replication.tcp.streams._base import (
Stream, Stream,
@ -21,6 +22,9 @@ from synapse.replication.tcp.streams._base import (
make_http_update_function, make_http_update_function,
) )
if TYPE_CHECKING:
from synapse.server import HomeServer
class FederationStream(Stream): class FederationStream(Stream):
"""Data to be sent over federation. Only available when master has federation """Data to be sent over federation. Only available when master has federation
@ -38,7 +42,7 @@ class FederationStream(Stream):
NAME = "federation" NAME = "federation"
ROW_TYPE = FederationStreamRow ROW_TYPE = FederationStreamRow
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
if hs.config.worker_app is None: if hs.config.worker_app is None:
# master process: get updates from the FederationRemoteSendQueue. # master process: get updates from the FederationRemoteSendQueue.
# (if the master is configured to send federation itself, federation_sender # (if the master is configured to send federation itself, federation_sender
@ -48,7 +52,9 @@ class FederationStream(Stream):
current_token = current_token_without_instance( current_token = current_token_without_instance(
federation_sender.get_current_token federation_sender.get_current_token
) )
update_function = federation_sender.get_replication_rows update_function = (
federation_sender.get_replication_rows
) # type: Callable[[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]]
elif hs.should_send_federation(): elif hs.should_send_federation():
# federation sender: Query master process # federation sender: Query master process
@ -69,5 +75,7 @@ class FederationStream(Stream):
return 0 return 0
@staticmethod @staticmethod
async def _stub_update_function(instance_name, from_token, upto_token, limit): async def _stub_update_function(
instance_name: str, from_token: int, upto_token: int, limit: int
) -> Tuple[list, int, bool]:
return [], upto_token, False return [], upto_token, False

View File

@ -60,7 +60,7 @@ from synapse.federation.federation_server import (
FederationServer, FederationServer,
) )
from synapse.federation.send_queue import FederationRemoteSendQueue from synapse.federation.send_queue import FederationRemoteSendQueue
from synapse.federation.sender import FederationSender from synapse.federation.sender import AbstractFederationSender, FederationSender
from synapse.federation.transport.client import TransportLayerClient from synapse.federation.transport.client import TransportLayerClient
from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler
@ -571,7 +571,7 @@ class HomeServer(metaclass=abc.ABCMeta):
return TransportLayerClient(self) return TransportLayerClient(self)
@cache_in_self @cache_in_self
def get_federation_sender(self): def get_federation_sender(self) -> AbstractFederationSender:
if self.should_send_federation(): if self.should_send_federation():
return FederationSender(self) return FederationSender(self)
elif not self.config.worker_app: elif not self.config.worker_app: