Add typing to synapse.federation.sender (#6871)

This commit is contained in:
Erik Johnston 2020-02-07 13:56:38 +00:00 committed by GitHub
parent de2d267375
commit b08b0a22d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 138 additions and 107 deletions

1
changelog.d/6871.misc Normal file
View File

@ -0,0 +1 @@
Add typing to `synapse.federation.sender` and port to async/await.

View File

@ -294,7 +294,12 @@ class FederationServer(FederationBase):
async def _process_edu(edu_dict): async def _process_edu(edu_dict):
received_edus_counter.inc() received_edus_counter.inc()
edu = Edu(**edu_dict) edu = Edu(
origin=origin,
destination=self.server_name,
edu_type=edu_dict["edu_type"],
content=edu_dict["content"],
)
await self.registry.on_edu(edu.edu_type, origin, edu.content) await self.registry.on_edu(edu.edu_type, origin, edu.content)
await concurrently_execute( await concurrently_execute(

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict, Hashable, Iterable, List, Optional, Set
from six import itervalues from six import itervalues
@ -23,6 +24,7 @@ from twisted.internet import defer
import synapse import synapse
import synapse.metrics import synapse.metrics
from synapse.events import EventBase
from synapse.federation.sender.per_destination_queue import PerDestinationQueue from synapse.federation.sender.per_destination_queue import PerDestinationQueue
from synapse.federation.sender.transaction_manager import TransactionManager from synapse.federation.sender.transaction_manager import TransactionManager
from synapse.federation.units import Edu from synapse.federation.units import Edu
@ -39,6 +41,8 @@ from synapse.metrics import (
events_processed_counter, events_processed_counter,
) )
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.presence import UserPresenceState
from synapse.types import ReadReceipt
from synapse.util.metrics import Measure, measure_func from synapse.util.metrics import Measure, measure_func
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -68,7 +72,7 @@ class FederationSender(object):
self._transaction_manager = TransactionManager(hs) self._transaction_manager = TransactionManager(hs)
# map from destination to PerDestinationQueue # map from destination to PerDestinationQueue
self._per_destination_queues = {} # type: dict[str, PerDestinationQueue] self._per_destination_queues = {} # type: Dict[str, PerDestinationQueue]
LaterGauge( LaterGauge(
"synapse_federation_transaction_queue_pending_destinations", "synapse_federation_transaction_queue_pending_destinations",
@ -84,7 +88,7 @@ class FederationSender(object):
# Map of user_id -> UserPresenceState for all the pending presence # Map of user_id -> UserPresenceState for all the pending presence
# to be sent out by user_id. Entries here get processed and put in # to be sent out by user_id. Entries here get processed and put in
# pending_presence_by_dest # pending_presence_by_dest
self.pending_presence = {} self.pending_presence = {} # type: Dict[str, UserPresenceState]
LaterGauge( LaterGauge(
"synapse_federation_transaction_queue_pending_pdus", "synapse_federation_transaction_queue_pending_pdus",
@ -116,20 +120,17 @@ class FederationSender(object):
# and that there is a pending call to _flush_rrs_for_room in the system. # and that there is a pending call to _flush_rrs_for_room in the system.
self._queues_awaiting_rr_flush_by_room = ( self._queues_awaiting_rr_flush_by_room = (
{} {}
) # type: dict[str, set[PerDestinationQueue]] ) # type: Dict[str, Set[PerDestinationQueue]]
self._rr_txn_interval_per_room_ms = ( self._rr_txn_interval_per_room_ms = (
1000.0 / hs.get_config().federation_rr_transactions_per_room_per_second 1000.0 / hs.config.federation_rr_transactions_per_room_per_second
) )
def _get_per_destination_queue(self, destination): def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue:
"""Get or create a PerDestinationQueue for the given destination """Get or create a PerDestinationQueue for the given destination
Args: Args:
destination (str): server_name of remote server destination: server_name of remote server
Returns:
PerDestinationQueue
""" """
queue = self._per_destination_queues.get(destination) queue = self._per_destination_queues.get(destination)
if not queue: if not queue:
@ -137,7 +138,7 @@ class FederationSender(object):
self._per_destination_queues[destination] = queue self._per_destination_queues[destination] = queue
return queue return queue
def notify_new_events(self, current_id): def notify_new_events(self, current_id: int) -> None:
"""This gets called when we have some new events we might want to """This gets called when we have some new events we might want to
send out to other servers. send out to other servers.
""" """
@ -151,13 +152,12 @@ class FederationSender(object):
"process_event_queue_for_federation", self._process_event_queue_loop "process_event_queue_for_federation", self._process_event_queue_loop
) )
@defer.inlineCallbacks async def _process_event_queue_loop(self) -> None:
def _process_event_queue_loop(self):
try: try:
self._is_processing = True self._is_processing = True
while True: while True:
last_token = yield self.store.get_federation_out_pos("events") last_token = await self.store.get_federation_out_pos("events")
next_token, events = yield self.store.get_all_new_events_stream( next_token, events = await self.store.get_all_new_events_stream(
last_token, self._last_poked_id, limit=100 last_token, self._last_poked_id, limit=100
) )
@ -166,8 +166,7 @@ class FederationSender(object):
if not events and next_token >= self._last_poked_id: if not events and next_token >= self._last_poked_id:
break break
@defer.inlineCallbacks async def handle_event(event: EventBase) -> None:
def handle_event(event):
# Only send events for this server. # Only send events for this server.
send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of() send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of()
is_mine = self.is_mine_id(event.sender) is_mine = self.is_mine_id(event.sender)
@ -184,7 +183,7 @@ class FederationSender(object):
# Otherwise if the last member on a server in a room is # Otherwise if the last member on a server in a room is
# banned then it won't receive the event because it won't # banned then it won't receive the event because it won't
# be in the room after the ban. # be in the room after the ban.
destinations = yield self.state.get_hosts_in_room_at_events( destinations = await self.state.get_hosts_in_room_at_events(
event.room_id, event_ids=event.prev_event_ids() event.room_id, event_ids=event.prev_event_ids()
) )
except Exception: except Exception:
@ -206,17 +205,16 @@ class FederationSender(object):
self._send_pdu(event, destinations) self._send_pdu(event, destinations)
@defer.inlineCallbacks async def handle_room_events(events: Iterable[EventBase]) -> None:
def handle_room_events(events):
with Measure(self.clock, "handle_room_events"): with Measure(self.clock, "handle_room_events"):
for event in events: for event in events:
yield handle_event(event) await handle_event(event)
events_by_room = {} events_by_room = {} # type: Dict[str, List[EventBase]]
for event in events: for event in events:
events_by_room.setdefault(event.room_id, []).append(event) events_by_room.setdefault(event.room_id, []).append(event)
yield make_deferred_yieldable( await make_deferred_yieldable(
defer.gatherResults( defer.gatherResults(
[ [
run_in_background(handle_room_events, evs) run_in_background(handle_room_events, evs)
@ -226,11 +224,11 @@ class FederationSender(object):
) )
) )
yield self.store.update_federation_out_pos("events", next_token) await self.store.update_federation_out_pos("events", next_token)
if events: if events:
now = self.clock.time_msec() now = self.clock.time_msec()
ts = yield self.store.get_received_ts(events[-1].event_id) ts = await self.store.get_received_ts(events[-1].event_id)
synapse.metrics.event_processing_lag.labels( synapse.metrics.event_processing_lag.labels(
"federation_sender" "federation_sender"
@ -254,7 +252,7 @@ class FederationSender(object):
finally: finally:
self._is_processing = False self._is_processing = False
def _send_pdu(self, pdu, destinations): def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None:
# We loop through all destinations to see whether we already have # We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus # a transaction in progress. If we do, stick it in the pending_pdus
# table and we'll get back to it later. # table and we'll get back to it later.
@ -276,11 +274,11 @@ class FederationSender(object):
self._get_per_destination_queue(destination).send_pdu(pdu, order) self._get_per_destination_queue(destination).send_pdu(pdu, order)
@defer.inlineCallbacks @defer.inlineCallbacks
def send_read_receipt(self, receipt): def send_read_receipt(self, receipt: ReadReceipt):
"""Send a RR to any other servers in the room """Send a RR to any other servers in the room
Args: Args:
receipt (synapse.types.ReadReceipt): receipt to be sent receipt: receipt to be sent
""" """
# Some background on the rate-limiting going on here. # Some background on the rate-limiting going on here.
@ -343,7 +341,7 @@ class FederationSender(object):
else: else:
queue.flush_read_receipts_for_room(room_id) queue.flush_read_receipts_for_room(room_id)
def _schedule_rr_flush_for_room(self, room_id, n_domains): def _schedule_rr_flush_for_room(self, room_id: str, n_domains: int) -> None:
# that is going to cause approximately len(domains) transactions, so now back # that is going to cause approximately len(domains) transactions, so now back
# off for that multiplied by RR_TXN_INTERVAL_PER_ROOM # off for that multiplied by RR_TXN_INTERVAL_PER_ROOM
backoff_ms = self._rr_txn_interval_per_room_ms * n_domains backoff_ms = self._rr_txn_interval_per_room_ms * n_domains
@ -352,7 +350,7 @@ class FederationSender(object):
self.clock.call_later(backoff_ms, self._flush_rrs_for_room, room_id) self.clock.call_later(backoff_ms, self._flush_rrs_for_room, room_id)
self._queues_awaiting_rr_flush_by_room[room_id] = set() self._queues_awaiting_rr_flush_by_room[room_id] = set()
def _flush_rrs_for_room(self, room_id): def _flush_rrs_for_room(self, room_id: str) -> None:
queues = self._queues_awaiting_rr_flush_by_room.pop(room_id) queues = self._queues_awaiting_rr_flush_by_room.pop(room_id)
logger.debug("Flushing RRs in %s to %s", room_id, queues) logger.debug("Flushing RRs in %s to %s", room_id, queues)
@ -368,14 +366,11 @@ class FederationSender(object):
@preserve_fn # the caller should not yield on this @preserve_fn # the caller should not yield on this
@defer.inlineCallbacks @defer.inlineCallbacks
def send_presence(self, states): def send_presence(self, states: List[UserPresenceState]):
"""Send the new presence states to the appropriate destinations. """Send the new presence states to the appropriate destinations.
This actually queues up the presence states ready for sending and This actually queues up the presence states ready for sending and
triggers a background task to process them and send out the transactions. triggers a background task to process them and send out the transactions.
Args:
states (list(UserPresenceState))
""" """
if not self.hs.config.use_presence: if not self.hs.config.use_presence:
# No-op if presence is disabled. # No-op if presence is disabled.
@ -412,11 +407,10 @@ class FederationSender(object):
finally: finally:
self._processing_pending_presence = False self._processing_pending_presence = False
def send_presence_to_destinations(self, states, destinations): def send_presence_to_destinations(
self, states: List[UserPresenceState], destinations: List[str]
) -> None:
"""Send the given presence states to the given destinations. """Send the given presence states to the given destinations.
Args:
states (list[UserPresenceState])
destinations (list[str]) destinations (list[str])
""" """
@ -431,12 +425,9 @@ class FederationSender(object):
@measure_func("txnqueue._process_presence") @measure_func("txnqueue._process_presence")
@defer.inlineCallbacks @defer.inlineCallbacks
def _process_presence_inner(self, states): def _process_presence_inner(self, states: List[UserPresenceState]):
"""Given a list of states populate self.pending_presence_by_dest and """Given a list of states populate self.pending_presence_by_dest and
poke to send a new transaction to each destination poke to send a new transaction to each destination
Args:
states (list(UserPresenceState))
""" """
hosts_and_states = yield get_interested_remotes(self.store, states, self.state) hosts_and_states = yield get_interested_remotes(self.store, states, self.state)
@ -446,14 +437,20 @@ class FederationSender(object):
continue continue
self._get_per_destination_queue(destination).send_presence(states) self._get_per_destination_queue(destination).send_presence(states)
def build_and_send_edu(self, destination, edu_type, content, key=None): def build_and_send_edu(
self,
destination: str,
edu_type: str,
content: dict,
key: Optional[Hashable] = None,
):
"""Construct an Edu object, and queue it for sending """Construct an Edu object, and queue it for sending
Args: Args:
destination (str): name of server to send to destination: name of server to send to
edu_type (str): type of EDU to send edu_type: type of EDU to send
content (dict): content of EDU content: content of EDU
key (Any|None): clobbering key for this edu key: clobbering key for this edu
""" """
if destination == self.server_name: if destination == self.server_name:
logger.info("Not sending EDU to ourselves") logger.info("Not sending EDU to ourselves")
@ -468,12 +465,12 @@ class FederationSender(object):
self.send_edu(edu, key) self.send_edu(edu, key)
def send_edu(self, edu, key): def send_edu(self, edu: Edu, key: Optional[Hashable]):
"""Queue an EDU for sending """Queue an EDU for sending
Args: Args:
edu (Edu): edu to send edu: edu to send
key (Any|None): clobbering key for this edu key: clobbering key for this edu
""" """
queue = self._get_per_destination_queue(edu.destination) queue = self._get_per_destination_queue(edu.destination)
if key: if key:
@ -481,7 +478,7 @@ class FederationSender(object):
else: else:
queue.send_edu(edu) queue.send_edu(edu)
def send_device_messages(self, destination): def send_device_messages(self, destination: str):
if destination == self.server_name: if destination == self.server_name:
logger.warning("Not sending device update to ourselves") logger.warning("Not sending device update to ourselves")
return return
@ -501,5 +498,5 @@ class FederationSender(object):
self._get_per_destination_queue(destination).attempt_new_transaction() self._get_per_destination_queue(destination).attempt_new_transaction()
def get_current_token(self): def get_current_token(self) -> int:
return 0 return 0

View File

@ -15,11 +15,11 @@
# limitations under the License. # limitations under the License.
import datetime import datetime
import logging import logging
from typing import Dict, Hashable, Iterable, List, Tuple
from prometheus_client import Counter from prometheus_client import Counter
from twisted.internet import defer import synapse.server
from synapse.api.errors import ( from synapse.api.errors import (
FederationDeniedError, FederationDeniedError,
HttpResponseException, HttpResponseException,
@ -31,7 +31,7 @@ from synapse.handlers.presence import format_user_presence_state
from synapse.metrics import sent_transactions_counter from synapse.metrics import sent_transactions_counter
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.presence import UserPresenceState from synapse.storage.presence import UserPresenceState
from synapse.types import StateMap from synapse.types import ReadReceipt
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
# This is defined in the Matrix spec and enforced by the receiver. # This is defined in the Matrix spec and enforced by the receiver.
@ -56,13 +56,18 @@ class PerDestinationQueue(object):
Manages the per-destination transmission queues. Manages the per-destination transmission queues.
Args: Args:
hs (synapse.HomeServer): hs
transaction_sender (TransactionManager): transaction_sender
destination (str): the server_name of the destination that we are managing destination: the server_name of the destination that we are managing
transmission for. transmission for.
""" """
def __init__(self, hs, transaction_manager, destination): def __init__(
self,
hs: "synapse.server.HomeServer",
transaction_manager: "synapse.federation.sender.TransactionManager",
destination: str,
):
self._server_name = hs.hostname self._server_name = hs.hostname
self._clock = hs.get_clock() self._clock = hs.get_clock()
self._store = hs.get_datastore() self._store = hs.get_datastore()
@ -72,20 +77,20 @@ class PerDestinationQueue(object):
self.transmission_loop_running = False self.transmission_loop_running = False
# a list of tuples of (pending pdu, order) # a list of tuples of (pending pdu, order)
self._pending_pdus = [] # type: list[tuple[EventBase, int]] self._pending_pdus = [] # type: List[Tuple[EventBase, int]]
self._pending_edus = [] # type: list[Edu] self._pending_edus = [] # type: List[Edu]
# Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered # Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered
# based on their key (e.g. typing events by room_id) # based on their key (e.g. typing events by room_id)
# Map of (edu_type, key) -> Edu # Map of (edu_type, key) -> Edu
self._pending_edus_keyed = {} # type: StateMap[Edu] self._pending_edus_keyed = {} # type: Dict[Tuple[str, Hashable], Edu]
# Map of user_id -> UserPresenceState of pending presence to be sent to this # Map of user_id -> UserPresenceState of pending presence to be sent to this
# destination # destination
self._pending_presence = {} # type: dict[str, UserPresenceState] self._pending_presence = {} # type: Dict[str, UserPresenceState]
# room_id -> receipt_type -> user_id -> receipt_dict # room_id -> receipt_type -> user_id -> receipt_dict
self._pending_rrs = {} self._pending_rrs = {} # type: Dict[str, Dict[str, Dict[str, dict]]]
self._rrs_pending_flush = False self._rrs_pending_flush = False
# stream_id of last successfully sent to-device message. # stream_id of last successfully sent to-device message.
@ -95,50 +100,50 @@ class PerDestinationQueue(object):
# stream_id of last successfully sent device list update. # stream_id of last successfully sent device list update.
self._last_device_list_stream_id = 0 self._last_device_list_stream_id = 0
def __str__(self): def __str__(self) -> str:
return "PerDestinationQueue[%s]" % self._destination return "PerDestinationQueue[%s]" % self._destination
def pending_pdu_count(self): def pending_pdu_count(self) -> int:
return len(self._pending_pdus) return len(self._pending_pdus)
def pending_edu_count(self): def pending_edu_count(self) -> int:
return ( return (
len(self._pending_edus) len(self._pending_edus)
+ len(self._pending_presence) + len(self._pending_presence)
+ len(self._pending_edus_keyed) + len(self._pending_edus_keyed)
) )
def send_pdu(self, pdu, order): def send_pdu(self, pdu: EventBase, order: int) -> None:
"""Add a PDU to the queue, and start the transmission loop if neccessary """Add a PDU to the queue, and start the transmission loop if neccessary
Args: Args:
pdu (EventBase): pdu to send pdu: pdu to send
order (int): order
""" """
self._pending_pdus.append((pdu, order)) self._pending_pdus.append((pdu, order))
self.attempt_new_transaction() self.attempt_new_transaction()
def send_presence(self, states): def send_presence(self, states: Iterable[UserPresenceState]) -> None:
"""Add presence updates to the queue. Start the transmission loop if neccessary. """Add presence updates to the queue. Start the transmission loop if neccessary.
Args: Args:
states (iterable[UserPresenceState]): presence to send states: presence to send
""" """
self._pending_presence.update({state.user_id: state for state in states}) self._pending_presence.update({state.user_id: state for state in states})
self.attempt_new_transaction() self.attempt_new_transaction()
def queue_read_receipt(self, receipt): def queue_read_receipt(self, receipt: ReadReceipt) -> None:
"""Add a RR to the list to be sent. Doesn't start the transmission loop yet """Add a RR to the list to be sent. Doesn't start the transmission loop yet
(see flush_read_receipts_for_room) (see flush_read_receipts_for_room)
Args: Args:
receipt (synapse.api.receipt_info.ReceiptInfo): receipt to be queued receipt: receipt to be queued
""" """
self._pending_rrs.setdefault(receipt.room_id, {}).setdefault( self._pending_rrs.setdefault(receipt.room_id, {}).setdefault(
receipt.receipt_type, {} receipt.receipt_type, {}
)[receipt.user_id] = {"event_ids": receipt.event_ids, "data": receipt.data} )[receipt.user_id] = {"event_ids": receipt.event_ids, "data": receipt.data}
def flush_read_receipts_for_room(self, room_id): def flush_read_receipts_for_room(self, room_id: str) -> None:
# if we don't have any read-receipts for this room, it may be that we've already # if we don't have any read-receipts for this room, it may be that we've already
# sent them out, so we don't need to flush. # sent them out, so we don't need to flush.
if room_id not in self._pending_rrs: if room_id not in self._pending_rrs:
@ -146,15 +151,15 @@ class PerDestinationQueue(object):
self._rrs_pending_flush = True self._rrs_pending_flush = True
self.attempt_new_transaction() self.attempt_new_transaction()
def send_keyed_edu(self, edu, key): def send_keyed_edu(self, edu: Edu, key: Hashable) -> None:
self._pending_edus_keyed[(edu.edu_type, key)] = edu self._pending_edus_keyed[(edu.edu_type, key)] = edu
self.attempt_new_transaction() self.attempt_new_transaction()
def send_edu(self, edu): def send_edu(self, edu) -> None:
self._pending_edus.append(edu) self._pending_edus.append(edu)
self.attempt_new_transaction() self.attempt_new_transaction()
def attempt_new_transaction(self): def attempt_new_transaction(self) -> None:
"""Try to start a new transaction to this destination """Try to start a new transaction to this destination
If there is already a transaction in progress to this destination, If there is already a transaction in progress to this destination,
@ -177,23 +182,22 @@ class PerDestinationQueue(object):
self._transaction_transmission_loop, self._transaction_transmission_loop,
) )
@defer.inlineCallbacks async def _transaction_transmission_loop(self) -> None:
def _transaction_transmission_loop(self): pending_pdus = [] # type: List[Tuple[EventBase, int]]
pending_pdus = []
try: try:
self.transmission_loop_running = True self.transmission_loop_running = True
# This will throw if we wouldn't retry. We do this here so we fail # This will throw if we wouldn't retry. We do this here so we fail
# quickly, but we will later check this again in the http client, # quickly, but we will later check this again in the http client,
# hence why we throw the result away. # hence why we throw the result away.
yield get_retry_limiter(self._destination, self._clock, self._store) await get_retry_limiter(self._destination, self._clock, self._store)
pending_pdus = [] pending_pdus = []
while True: while True:
# We have to keep 2 free slots for presence and rr_edus # We have to keep 2 free slots for presence and rr_edus
limit = MAX_EDUS_PER_TRANSACTION - 2 limit = MAX_EDUS_PER_TRANSACTION - 2
device_update_edus, dev_list_id = yield self._get_device_update_edus( device_update_edus, dev_list_id = await self._get_device_update_edus(
limit limit
) )
@ -202,7 +206,7 @@ class PerDestinationQueue(object):
( (
to_device_edus, to_device_edus,
device_stream_id, device_stream_id,
) = yield self._get_to_device_message_edus(limit) ) = await self._get_to_device_message_edus(limit)
pending_edus = device_update_edus + to_device_edus pending_edus = device_update_edus + to_device_edus
@ -269,7 +273,7 @@ class PerDestinationQueue(object):
# END CRITICAL SECTION # END CRITICAL SECTION
success = yield self._transaction_manager.send_new_transaction( success = await self._transaction_manager.send_new_transaction(
self._destination, pending_pdus, pending_edus self._destination, pending_pdus, pending_edus
) )
if success: if success:
@ -280,7 +284,7 @@ class PerDestinationQueue(object):
# Remove the acknowledged device messages from the database # Remove the acknowledged device messages from the database
# Only bother if we actually sent some device messages # Only bother if we actually sent some device messages
if to_device_edus: if to_device_edus:
yield self._store.delete_device_msgs_for_remote( await self._store.delete_device_msgs_for_remote(
self._destination, device_stream_id self._destination, device_stream_id
) )
@ -289,7 +293,7 @@ class PerDestinationQueue(object):
logger.info( logger.info(
"Marking as sent %r %r", self._destination, dev_list_id "Marking as sent %r %r", self._destination, dev_list_id
) )
yield self._store.mark_as_sent_devices_by_remote( await self._store.mark_as_sent_devices_by_remote(
self._destination, dev_list_id self._destination, dev_list_id
) )
@ -334,7 +338,7 @@ class PerDestinationQueue(object):
# We want to be *very* sure we clear this after we stop processing # We want to be *very* sure we clear this after we stop processing
self.transmission_loop_running = False self.transmission_loop_running = False
def _get_rr_edus(self, force_flush): def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]:
if not self._pending_rrs: if not self._pending_rrs:
return return
if not force_flush and not self._rrs_pending_flush: if not force_flush and not self._rrs_pending_flush:
@ -351,17 +355,16 @@ class PerDestinationQueue(object):
self._rrs_pending_flush = False self._rrs_pending_flush = False
yield edu yield edu
def _pop_pending_edus(self, limit): def _pop_pending_edus(self, limit: int) -> List[Edu]:
pending_edus = self._pending_edus pending_edus = self._pending_edus
pending_edus, self._pending_edus = pending_edus[:limit], pending_edus[limit:] pending_edus, self._pending_edus = pending_edus[:limit], pending_edus[limit:]
return pending_edus return pending_edus
@defer.inlineCallbacks async def _get_device_update_edus(self, limit: int) -> Tuple[List[Edu], int]:
def _get_device_update_edus(self, limit):
last_device_list = self._last_device_list_stream_id last_device_list = self._last_device_list_stream_id
# Retrieve list of new device updates to send to the destination # Retrieve list of new device updates to send to the destination
now_stream_id, results = yield self._store.get_device_updates_by_remote( now_stream_id, results = await self._store.get_device_updates_by_remote(
self._destination, last_device_list, limit=limit self._destination, last_device_list, limit=limit
) )
edus = [ edus = [
@ -378,11 +381,10 @@ class PerDestinationQueue(object):
return (edus, now_stream_id) return (edus, now_stream_id)
@defer.inlineCallbacks async def _get_to_device_message_edus(self, limit: int) -> Tuple[List[Edu], int]:
def _get_to_device_message_edus(self, limit):
last_device_stream_id = self._last_device_stream_id last_device_stream_id = self._last_device_stream_id
to_device_stream_id = self._store.get_to_device_stream_token() to_device_stream_id = self._store.get_to_device_stream_token()
contents, stream_id = yield self._store.get_new_device_msgs_for_remote( contents, stream_id = await self._store.get_new_device_msgs_for_remote(
self._destination, last_device_stream_id, to_device_stream_id, limit self._destination, last_device_stream_id, to_device_stream_id, limit
) )
edus = [ edus = [

View File

@ -13,14 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import List
from canonicaljson import json from canonicaljson import json
from twisted.internet import defer import synapse.server
from synapse.api.errors import HttpResponseException from synapse.api.errors import HttpResponseException
from synapse.events import EventBase
from synapse.federation.persistence import TransactionActions from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Transaction from synapse.federation.units import Edu, Transaction
from synapse.logging.opentracing import ( from synapse.logging.opentracing import (
extract_text_map, extract_text_map,
set_tag, set_tag,
@ -39,7 +40,7 @@ class TransactionManager(object):
shared between PerDestinationQueue objects shared between PerDestinationQueue objects
""" """
def __init__(self, hs): def __init__(self, hs: "synapse.server.HomeServer"):
self._server_name = hs.hostname self._server_name = hs.hostname
self.clock = hs.get_clock() # nb must be called this for @measure_func self.clock = hs.get_clock() # nb must be called this for @measure_func
self._store = hs.get_datastore() self._store = hs.get_datastore()
@ -50,8 +51,9 @@ class TransactionManager(object):
self._next_txn_id = int(self.clock.time_msec()) self._next_txn_id = int(self.clock.time_msec())
@measure_func("_send_new_transaction") @measure_func("_send_new_transaction")
@defer.inlineCallbacks async def send_new_transaction(
def send_new_transaction(self, destination, pending_pdus, pending_edus): self, destination: str, pending_pdus: List[EventBase], pending_edus: List[Edu]
):
# Make a transaction-sending opentracing span. This span follows on from # Make a transaction-sending opentracing span. This span follows on from
# all the edus in that transaction. This needs to be done since there is # all the edus in that transaction. This needs to be done since there is
@ -127,7 +129,7 @@ class TransactionManager(object):
return data return data
try: try:
response = yield self._transport_layer.send_transaction( response = await self._transport_layer.send_transaction(
transaction, json_data_cb transaction, json_data_cb
) )
code = 200 code = 200

View File

@ -19,11 +19,15 @@ server protocol.
import logging import logging
import attr
from synapse.types import JsonDict
from synapse.util.jsonobject import JsonEncodedObject from synapse.util.jsonobject import JsonEncodedObject
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@attr.s(slots=True)
class Edu(JsonEncodedObject): class Edu(JsonEncodedObject):
""" An Edu represents a piece of data sent from one homeserver to another. """ An Edu represents a piece of data sent from one homeserver to another.
@ -32,11 +36,24 @@ class Edu(JsonEncodedObject):
internal ID or previous references graph. internal ID or previous references graph.
""" """
valid_keys = ["origin", "destination", "edu_type", "content"] edu_type = attr.ib(type=str)
content = attr.ib(type=dict)
origin = attr.ib(type=str)
destination = attr.ib(type=str)
required_keys = ["edu_type"] def get_dict(self) -> JsonDict:
return {
"edu_type": self.edu_type,
"content": self.content,
}
internal_keys = ["origin", "destination"] def get_internal_dict(self) -> JsonDict:
return {
"edu_type": self.edu_type,
"content": self.content,
"origin": self.origin,
"destination": self.destination,
}
def get_context(self): def get_context(self):
return getattr(self, "content", {}).get("org.matrix.opentracing_context", "{}") return getattr(self, "content", {}).get("org.matrix.opentracing_context", "{}")

View File

@ -107,3 +107,5 @@ class HomeServer(object):
self, self,
) -> synapse.replication.tcp.client.ReplicationClientHandler: ) -> synapse.replication.tcp.client.ReplicationClientHandler:
pass pass
def is_mine_id(self, domain_id: str) -> bool:
pass

View File

@ -111,7 +111,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
retry_timings_res retry_timings_res
) )
self.datastore.get_device_updates_by_remote.return_value = (0, []) self.datastore.get_device_updates_by_remote.return_value = defer.succeed(
(0, [])
)
def get_received_txn_response(*args): def get_received_txn_response(*args):
return defer.succeed(None) return defer.succeed(None)
@ -144,7 +146,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_current_state_deltas.return_value = (0, None) self.datastore.get_current_state_deltas.return_value = (0, None)
self.datastore.get_to_device_stream_token = lambda: 0 self.datastore.get_to_device_stream_token = lambda: 0
self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0) self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: defer.succeed(
([], 0)
)
self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed( self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed(
None None

View File

@ -179,6 +179,7 @@ extras = all
commands = mypy \ commands = mypy \
synapse/api \ synapse/api \
synapse/config/ \ synapse/config/ \
synapse/federation/sender \
synapse/federation/transport \ synapse/federation/transport \
synapse/handlers/sync.py \ synapse/handlers/sync.py \
synapse/handlers/ui_auth \ synapse/handlers/ui_auth \