Convert replication code to async/await. (#7987)

This commit is contained in:
Patrick Cloke 2020-08-03 07:12:55 -04:00 committed by GitHub
parent db5970ac6d
commit 3b415e23a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 29 additions and 38 deletions

1
changelog.d/7987.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@ -548,7 +548,7 @@ class RegistrationHandler(BaseHandler):
address (str|None): the IP address used to perform the registration. address (str|None): the IP address used to perform the registration.
Returns: Returns:
Deferred Awaitable
""" """
if self.hs.config.worker_app: if self.hs.config.worker_app:
return self._register_client( return self._register_client(

View File

@ -20,8 +20,6 @@ import urllib
from inspect import signature from inspect import signature
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
from twisted.internet import defer
from synapse.api.errors import ( from synapse.api.errors import (
CodeMessageException, CodeMessageException,
HttpResponseException, HttpResponseException,
@ -101,7 +99,7 @@ class ReplicationEndpoint(object):
assert self.METHOD in ("PUT", "POST", "GET") assert self.METHOD in ("PUT", "POST", "GET")
@abc.abstractmethod @abc.abstractmethod
def _serialize_payload(**kwargs): async def _serialize_payload(**kwargs):
"""Static method that is called when creating a request. """Static method that is called when creating a request.
Concrete implementations should have explicit parameters (rather than Concrete implementations should have explicit parameters (rather than
@ -110,9 +108,8 @@ class ReplicationEndpoint(object):
argument list. argument list.
Returns: Returns:
Deferred[dict]|dict: If POST/PUT request then dictionary must be dict: If POST/PUT request then dictionary must be JSON serialisable,
JSON serialisable, otherwise must be appropriate for adding as otherwise must be appropriate for adding as query args.
query args.
""" """
return {} return {}
@ -144,8 +141,7 @@ class ReplicationEndpoint(object):
instance_map = hs.config.worker.instance_map instance_map = hs.config.worker.instance_map
@trace(opname="outgoing_replication_request") @trace(opname="outgoing_replication_request")
@defer.inlineCallbacks async def send_request(instance_name="master", **kwargs):
def send_request(instance_name="master", **kwargs):
if instance_name == local_instance_name: if instance_name == local_instance_name:
raise Exception("Trying to send HTTP request to self") raise Exception("Trying to send HTTP request to self")
if instance_name == "master": if instance_name == "master":
@ -159,7 +155,7 @@ class ReplicationEndpoint(object):
"Instance %r not in 'instance_map' config" % (instance_name,) "Instance %r not in 'instance_map' config" % (instance_name,)
) )
data = yield cls._serialize_payload(**kwargs) data = await cls._serialize_payload(**kwargs)
url_args = [ url_args = [
urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
@ -197,7 +193,7 @@ class ReplicationEndpoint(object):
headers = {} # type: Dict[bytes, List[bytes]] headers = {} # type: Dict[bytes, List[bytes]]
inject_active_span_byte_dict(headers, None, check_destination=False) inject_active_span_byte_dict(headers, None, check_destination=False)
try: try:
result = yield request_func(uri, data, headers=headers) result = await request_func(uri, data, headers=headers)
break break
except CodeMessageException as e: except CodeMessageException as e:
if e.code != 504 or not cls.RETRY_ON_TIMEOUT: if e.code != 504 or not cls.RETRY_ON_TIMEOUT:
@ -207,7 +203,7 @@ class ReplicationEndpoint(object):
# If we timed out we probably don't need to worry about backing # If we timed out we probably don't need to worry about backing
# off too much, but lets just wait a little anyway. # off too much, but lets just wait a little anyway.
yield clock.sleep(1) await clock.sleep(1)
except HttpResponseException as e: except HttpResponseException as e:
# We convert to SynapseError as we know that it was a SynapseError # We convert to SynapseError as we know that it was a SynapseError
# on the master process that we should send to the client. (And # on the master process that we should send to the client. (And

View File

@ -60,7 +60,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock() self.clock = hs.get_clock()
@staticmethod @staticmethod
def _serialize_payload(user_id): async def _serialize_payload(user_id):
return {} return {}
async def _handle_request(self, request, user_id): async def _handle_request(self, request, user_id):

View File

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict from synapse.events import make_event_from_dict
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
@ -67,8 +65,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
self.federation_handler = hs.get_handlers().federation_handler self.federation_handler = hs.get_handlers().federation_handler
@staticmethod @staticmethod
@defer.inlineCallbacks async def _serialize_payload(store, event_and_contexts, backfilled):
def _serialize_payload(store, event_and_contexts, backfilled):
""" """
Args: Args:
store store
@ -78,9 +75,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
""" """
event_payloads = [] event_payloads = []
for event, context in event_and_contexts: for event, context in event_and_contexts:
serialized_context = yield defer.ensureDeferred( serialized_context = await context.serialize(event, store)
context.serialize(event, store)
)
event_payloads.append( event_payloads.append(
{ {
@ -156,7 +151,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
self.registry = hs.get_federation_registry() self.registry = hs.get_federation_registry()
@staticmethod @staticmethod
def _serialize_payload(edu_type, origin, content): async def _serialize_payload(edu_type, origin, content):
return {"origin": origin, "content": content} return {"origin": origin, "content": content}
async def _handle_request(self, request, edu_type): async def _handle_request(self, request, edu_type):
@ -199,7 +194,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
self.registry = hs.get_federation_registry() self.registry = hs.get_federation_registry()
@staticmethod @staticmethod
def _serialize_payload(query_type, args): async def _serialize_payload(query_type, args):
""" """
Args: Args:
query_type (str) query_type (str)
@ -240,7 +235,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@staticmethod @staticmethod
def _serialize_payload(room_id, args): async def _serialize_payload(room_id, args):
""" """
Args: Args:
room_id (str) room_id (str)
@ -275,7 +270,7 @@ class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@staticmethod @staticmethod
def _serialize_payload(room_id, room_version): async def _serialize_payload(room_id, room_version):
return {"room_version": room_version.identifier} return {"room_version": room_version.identifier}
async def _handle_request(self, request, room_id): async def _handle_request(self, request, room_id):

View File

@ -36,7 +36,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
@staticmethod @staticmethod
def _serialize_payload(user_id, device_id, initial_display_name, is_guest): async def _serialize_payload(user_id, device_id, initial_display_name, is_guest):
""" """
Args: Args:
device_id (str|None): Device ID to use, if None a new one is device_id (str|None): Device ID to use, if None a new one is

View File

@ -52,7 +52,9 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock() self.clock = hs.get_clock()
@staticmethod @staticmethod
def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content): async def _serialize_payload(
requester, room_id, user_id, remote_room_hosts, content
):
""" """
Args: Args:
requester(Requester) requester(Requester)
@ -112,7 +114,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
self.member_handler = hs.get_room_member_handler() self.member_handler = hs.get_room_member_handler()
@staticmethod @staticmethod
def _serialize_payload( # type: ignore async def _serialize_payload( # type: ignore
invite_event_id: str, invite_event_id: str,
txn_id: Optional[str], txn_id: Optional[str],
requester: Requester, requester: Requester,
@ -174,7 +176,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
self.distributor = hs.get_distributor() self.distributor = hs.get_distributor()
@staticmethod @staticmethod
def _serialize_payload(room_id, user_id, change): async def _serialize_payload(room_id, user_id, change):
""" """
Args: Args:
room_id (str) room_id (str)

View File

@ -50,7 +50,7 @@ class ReplicationBumpPresenceActiveTime(ReplicationEndpoint):
self._presence_handler = hs.get_presence_handler() self._presence_handler = hs.get_presence_handler()
@staticmethod @staticmethod
def _serialize_payload(user_id): async def _serialize_payload(user_id):
return {} return {}
async def _handle_request(self, request, user_id): async def _handle_request(self, request, user_id):
@ -92,7 +92,7 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
self._presence_handler = hs.get_presence_handler() self._presence_handler = hs.get_presence_handler()
@staticmethod @staticmethod
def _serialize_payload(user_id, state, ignore_status_msg=False): async def _serialize_payload(user_id, state, ignore_status_msg=False):
return { return {
"state": state, "state": state,
"ignore_status_msg": ignore_status_msg, "ignore_status_msg": ignore_status_msg,

View File

@ -34,7 +34,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
@staticmethod @staticmethod
def _serialize_payload( async def _serialize_payload(
user_id, user_id,
password_hash, password_hash,
was_guest, was_guest,
@ -105,7 +105,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
@staticmethod @staticmethod
def _serialize_payload(user_id, auth_result, access_token): async def _serialize_payload(user_id, auth_result, access_token):
""" """
Args: Args:
user_id (str): The user ID that consented user_id (str): The user ID that consented

View File

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict from synapse.events import make_event_from_dict
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
@ -62,8 +60,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock() self.clock = hs.get_clock()
@staticmethod @staticmethod
@defer.inlineCallbacks async def _serialize_payload(
def _serialize_payload(
event_id, store, event, context, requester, ratelimit, extra_users event_id, store, event, context, requester, ratelimit, extra_users
): ):
""" """
@ -77,7 +74,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
extra_users (list(UserID)): Any extra users to notify about event extra_users (list(UserID)): Any extra users to notify about event
""" """
serialized_context = yield defer.ensureDeferred(context.serialize(event, store)) serialized_context = await context.serialize(event, store)
payload = { payload = {
"event": event.get_pdu_json(), "event": event.get_pdu_json(),

View File

@ -54,7 +54,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
self.streams = hs.get_replication_streams() self.streams = hs.get_replication_streams()
@staticmethod @staticmethod
def _serialize_payload(stream_name, from_token, upto_token): async def _serialize_payload(stream_name, from_token, upto_token):
return {"from_token": from_token, "upto_token": upto_token} return {"from_token": from_token, "upto_token": upto_token}
async def _handle_request(self, request, stream_name): async def _handle_request(self, request, stream_name):