Convert the federation agent and related code to async/await. (#7874)

This commit is contained in:
Patrick Cloke 2020-07-23 07:05:57 -04:00 committed by GitHub
parent 13d77464c9
commit 68cd935826
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 51 additions and 53 deletions

View file

@ -15,6 +15,7 @@
import logging
import urllib
from typing import List
from netaddr import AddrFormatError, IPAddress
from zope.interface import implementer
@ -236,11 +237,10 @@ class MatrixHostnameEndpoint(object):
return run_in_background(self._do_connect, protocol_factory)
@defer.inlineCallbacks
def _do_connect(self, protocol_factory):
async def _do_connect(self, protocol_factory):
first_exception = None
server_list = yield self._resolve_server()
server_list = await self._resolve_server()
for server in server_list:
host = server.host
@ -251,7 +251,7 @@ class MatrixHostnameEndpoint(object):
endpoint = HostnameEndpoint(self._reactor, host, port)
if self._tls_options:
endpoint = wrapClientTLS(self._tls_options, endpoint)
result = yield make_deferred_yieldable(
result = await make_deferred_yieldable(
endpoint.connect(protocol_factory)
)
@ -271,13 +271,9 @@ class MatrixHostnameEndpoint(object):
# to try and if that doesn't work then we'll have an exception.
raise Exception("Failed to resolve server %r" % (self._parsed_uri.netloc,))
@defer.inlineCallbacks
def _resolve_server(self):
async def _resolve_server(self) -> List[Server]:
"""Resolves the server name to a list of hosts and ports to attempt to
connect to.
Returns:
Deferred[list[Server]]
"""
if self._parsed_uri.scheme != b"matrix":
@ -298,7 +294,7 @@ class MatrixHostnameEndpoint(object):
if port or _is_ip_literal(host):
return [Server(host, port or 8448)]
server_list = yield self._srv_resolver.resolve_service(b"_matrix._tcp." + host)
server_list = await self._srv_resolver.resolve_service(b"_matrix._tcp." + host)
if server_list:
return server_list

View file

@ -17,10 +17,10 @@
import logging
import random
import time
from typing import List
import attr
from twisted.internet import defer
from twisted.internet.error import ConnectError
from twisted.names import client, dns
from twisted.names.error import DNSNameError, DomainError
@ -113,16 +113,14 @@ class SrvResolver(object):
self._cache = cache
self._get_time = get_time
@defer.inlineCallbacks
def resolve_service(self, service_name):
async def resolve_service(self, service_name: bytes) -> List[Server]:
"""Look up a SRV record
Args:
service_name (bytes): record to look up
Returns:
Deferred[list[Server]]:
a list of the SRV records, or an empty list if none found
a list of the SRV records, or an empty list if none found
"""
now = int(self._get_time())
@ -136,7 +134,7 @@ class SrvResolver(object):
return _sort_server_list(servers)
try:
answers, _, _ = yield make_deferred_yieldable(
answers, _, _ = await make_deferred_yieldable(
self._dns_client.lookupService(service_name)
)
except DNSNameError: