mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
ef771cc4c2
Broadly three things here: * disable W504 which seems a bit whacko * remove a bunch of `as e` expressions from exception handlers that don't use them * use `r""` for strings which include backslashes Also, we don't use pep8 any more, so we can get rid of the duplicate config there.
148 lines
4.8 KiB
Python
148 lines
4.8 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Copyright 2014-2016 OpenMarket Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import logging
|
|
|
|
from canonicaljson import json
|
|
|
|
from twisted.internet import defer, reactor
|
|
from twisted.internet.error import ConnectError
|
|
from twisted.internet.protocol import Factory
|
|
from twisted.names.error import DomainError
|
|
from twisted.web.http import HTTPClient
|
|
|
|
from synapse.http.endpoint import matrix_federation_endpoint
|
|
from synapse.util import logcontext
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
KEY_API_V1 = b"/_matrix/key/v1/"
|
|
|
|
|
|
@defer.inlineCallbacks
|
|
def fetch_server_key(server_name, tls_client_options_factory, path=KEY_API_V1):
|
|
"""Fetch the keys for a remote server."""
|
|
|
|
factory = SynapseKeyClientFactory()
|
|
factory.path = path
|
|
factory.host = server_name
|
|
endpoint = matrix_federation_endpoint(
|
|
reactor, server_name, tls_client_options_factory, timeout=30
|
|
)
|
|
|
|
for i in range(5):
|
|
try:
|
|
with logcontext.PreserveLoggingContext():
|
|
protocol = yield endpoint.connect(factory)
|
|
server_response, server_certificate = yield protocol.remote_key
|
|
defer.returnValue((server_response, server_certificate))
|
|
except SynapseKeyClientError as e:
|
|
logger.warn("Error getting key for %r: %s", server_name, e)
|
|
if e.status.startswith(b"4"):
|
|
# Don't retry for 4xx responses.
|
|
raise IOError("Cannot get key for %r" % server_name)
|
|
except (ConnectError, DomainError) as e:
|
|
logger.warn("Error getting key for %r: %s", server_name, e)
|
|
except Exception:
|
|
logger.exception("Error getting key for %r", server_name)
|
|
raise IOError("Cannot get key for %r" % server_name)
|
|
|
|
|
|
class SynapseKeyClientError(Exception):
|
|
"""The key wasn't retrieved from the remote server."""
|
|
status = None
|
|
pass
|
|
|
|
|
|
class SynapseKeyClientProtocol(HTTPClient):
|
|
"""Low level HTTPS client which retrieves an application/json response from
|
|
the server and extracts the X.509 certificate for the remote peer from the
|
|
SSL connection."""
|
|
|
|
timeout = 30
|
|
|
|
def __init__(self):
|
|
self.remote_key = defer.Deferred()
|
|
self.host = None
|
|
self._peer = None
|
|
|
|
def connectionMade(self):
|
|
self._peer = self.transport.getPeer()
|
|
logger.debug("Connected to %s", self._peer)
|
|
|
|
if not isinstance(self.path, bytes):
|
|
self.path = self.path.encode('ascii')
|
|
|
|
if not isinstance(self.host, bytes):
|
|
self.host = self.host.encode('ascii')
|
|
|
|
self.sendCommand(b"GET", self.path)
|
|
if self.host:
|
|
self.sendHeader(b"Host", self.host)
|
|
self.endHeaders()
|
|
self.timer = reactor.callLater(
|
|
self.timeout,
|
|
self.on_timeout
|
|
)
|
|
|
|
def errback(self, error):
|
|
if not self.remote_key.called:
|
|
self.remote_key.errback(error)
|
|
|
|
def callback(self, result):
|
|
if not self.remote_key.called:
|
|
self.remote_key.callback(result)
|
|
|
|
def handleStatus(self, version, status, message):
|
|
if status != b"200":
|
|
# logger.info("Non-200 response from %s: %s %s",
|
|
# self.transport.getHost(), status, message)
|
|
error = SynapseKeyClientError(
|
|
"Non-200 response %r from %r" % (status, self.host)
|
|
)
|
|
error.status = status
|
|
self.errback(error)
|
|
self.transport.abortConnection()
|
|
|
|
def handleResponse(self, response_body_bytes):
|
|
try:
|
|
json_response = json.loads(response_body_bytes)
|
|
except ValueError:
|
|
# logger.info("Invalid JSON response from %s",
|
|
# self.transport.getHost())
|
|
self.transport.abortConnection()
|
|
return
|
|
|
|
certificate = self.transport.getPeerCertificate()
|
|
self.callback((json_response, certificate))
|
|
self.transport.abortConnection()
|
|
self.timer.cancel()
|
|
|
|
def on_timeout(self):
|
|
logger.debug(
|
|
"Timeout waiting for response from %s: %s",
|
|
self.host, self._peer,
|
|
)
|
|
self.errback(IOError("Timeout waiting for response"))
|
|
self.transport.abortConnection()
|
|
|
|
|
|
class SynapseKeyClientFactory(Factory):
|
|
def protocol(self):
|
|
protocol = SynapseKeyClientProtocol()
|
|
protocol.path = self.path
|
|
protocol.host = self.host
|
|
return protocol
|