mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
bb4a20174c
Conflicts: synapse/federation/transport.py synapse/handlers/message.py
699 lines
22 KiB
Python
699 lines
22 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Copyright 2014 OpenMarket Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""The transport layer is responsible for both sending transactions to remote
|
|
home servers and receiving a variety of requests from other home servers.
|
|
|
|
Typically, this is done over HTTP (and all home servers are required to
|
|
support HTTP), however individual pairings of servers may decide to communicate
|
|
over a different (albeit still reliable) protocol.
|
|
"""
|
|
|
|
from twisted.internet import defer
|
|
|
|
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
|
|
from synapse.api.errors import Codes, SynapseError
|
|
from synapse.util.logutils import log_function
|
|
|
|
import logging
|
|
import json
|
|
import re
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TransportLayer(object):
|
|
"""This is a basic implementation of the transport layer that translates
|
|
transactions and other requests to/from HTTP.
|
|
|
|
Attributes:
|
|
server_name (str): Local home server host
|
|
|
|
server (synapse.http.server.HttpServer): the http server to
|
|
register listeners on
|
|
|
|
client (synapse.http.client.HttpClient): the http client used to
|
|
send requests
|
|
|
|
request_handler (TransportRequestHandler): The handler to fire when we
|
|
receive requests for data.
|
|
|
|
received_handler (TransportReceivedHandler): The handler to fire when
|
|
we receive data.
|
|
"""
|
|
|
|
def __init__(self, homeserver, server_name, server, client):
|
|
"""
|
|
Args:
|
|
server_name (str): Local home server host
|
|
server (synapse.protocol.http.HttpServer): the http server to
|
|
register listeners on
|
|
client (synapse.protocol.http.HttpClient): the http client used to
|
|
send requests
|
|
"""
|
|
self.keyring = homeserver.get_keyring()
|
|
self.server_name = server_name
|
|
self.server = server
|
|
self.client = client
|
|
self.request_handler = None
|
|
self.received_handler = None
|
|
|
|
@log_function
|
|
def get_context_state(self, destination, context, pdu_id=None,
|
|
pdu_origin=None):
|
|
""" Requests all state for a given context (i.e. room) from the
|
|
given server.
|
|
|
|
Args:
|
|
destination (str): The host name of the remote home server we want
|
|
to get the state from.
|
|
context (str): The name of the context we want the state of
|
|
|
|
Returns:
|
|
Deferred: Results in a dict received from the remote homeserver.
|
|
"""
|
|
logger.debug("get_context_state dest=%s, context=%s",
|
|
destination, context)
|
|
|
|
subpath = "/state/%s/" % context
|
|
|
|
args = {}
|
|
if pdu_id and pdu_origin:
|
|
args["pdu_id"] = pdu_id
|
|
args["pdu_origin"] = pdu_origin
|
|
|
|
return self._do_request_for_transaction(
|
|
destination, subpath, args=args
|
|
)
|
|
|
|
@log_function
|
|
def get_pdu(self, destination, pdu_origin, pdu_id):
|
|
""" Requests the pdu with give id and origin from the given server.
|
|
|
|
Args:
|
|
destination (str): The host name of the remote home server we want
|
|
to get the state from.
|
|
pdu_origin (str): The home server which created the PDU.
|
|
pdu_id (str): The id of the PDU being requested.
|
|
|
|
Returns:
|
|
Deferred: Results in a dict received from the remote homeserver.
|
|
"""
|
|
logger.debug("get_pdu dest=%s, pdu_origin=%s, pdu_id=%s",
|
|
destination, pdu_origin, pdu_id)
|
|
|
|
subpath = "/pdu/%s/%s/" % (pdu_origin, pdu_id)
|
|
|
|
return self._do_request_for_transaction(destination, subpath)
|
|
|
|
@log_function
|
|
def backfill(self, dest, context, pdu_tuples, limit):
|
|
""" Requests `limit` previous PDUs in a given context before list of
|
|
PDUs.
|
|
|
|
Args:
|
|
dest (str)
|
|
context (str)
|
|
pdu_tuples (list)
|
|
limt (int)
|
|
|
|
Returns:
|
|
Deferred: Results in a dict received from the remote homeserver.
|
|
"""
|
|
logger.debug(
|
|
"backfill dest=%s, context=%s, pdu_tuples=%s, limit=%s",
|
|
dest, context, repr(pdu_tuples), str(limit)
|
|
)
|
|
|
|
if not pdu_tuples:
|
|
return
|
|
|
|
subpath = "/backfill/%s/" % context
|
|
|
|
args = {
|
|
"v": ["%s,%s" % (i, o) for i, o in pdu_tuples],
|
|
"limit": limit,
|
|
}
|
|
|
|
return self._do_request_for_transaction(
|
|
dest,
|
|
subpath,
|
|
args=args,
|
|
)
|
|
|
|
@defer.inlineCallbacks
|
|
@log_function
|
|
def send_transaction(self, transaction, json_data_callback=None):
|
|
""" Sends the given Transaction to it's destination
|
|
|
|
Args:
|
|
transaction (Transaction)
|
|
|
|
Returns:
|
|
Deferred: Results of the deferred is a tuple in the form of
|
|
(response_code, response_body) where the response_body is a
|
|
python dict decoded from json
|
|
"""
|
|
logger.debug(
|
|
"send_data dest=%s, txid=%s",
|
|
transaction.destination, transaction.transaction_id
|
|
)
|
|
|
|
if transaction.destination == self.server_name:
|
|
raise RuntimeError("Transport layer cannot send to itself!")
|
|
|
|
# FIXME: This is only used by the tests. The actual json sent is
|
|
# generated by the json_data_callback.
|
|
json_data = transaction.get_dict()
|
|
|
|
code, response = yield self.client.put_json(
|
|
transaction.destination,
|
|
path=PREFIX + "/send/%s/" % transaction.transaction_id,
|
|
data=json_data,
|
|
json_data_callback=json_data_callback,
|
|
)
|
|
|
|
logger.debug(
|
|
"send_data dest=%s, txid=%s, got response: %d",
|
|
transaction.destination, transaction.transaction_id, code
|
|
)
|
|
|
|
defer.returnValue((code, response))
|
|
|
|
@defer.inlineCallbacks
|
|
@log_function
|
|
def make_query(self, destination, query_type, args, retry_on_dns_fail):
|
|
path = PREFIX + "/query/%s" % query_type
|
|
|
|
response = yield self.client.get_json(
|
|
destination=destination,
|
|
path=path,
|
|
args=args,
|
|
retry_on_dns_fail=retry_on_dns_fail,
|
|
)
|
|
|
|
defer.returnValue(response)
|
|
|
|
@defer.inlineCallbacks
|
|
@log_function
|
|
def make_join(self, destination, context, user_id, retry_on_dns_fail=True):
|
|
path = PREFIX + "/make_join/%s/%s" % (context, user_id,)
|
|
|
|
response = yield self.client.get_json(
|
|
destination=destination,
|
|
path=path,
|
|
retry_on_dns_fail=retry_on_dns_fail,
|
|
)
|
|
|
|
defer.returnValue(response)
|
|
|
|
@defer.inlineCallbacks
|
|
@log_function
|
|
def send_join(self, destination, context, pdu_id, origin, content):
|
|
path = PREFIX + "/send_join/%s/%s/%s" % (
|
|
context,
|
|
origin,
|
|
pdu_id,
|
|
)
|
|
|
|
code, content = yield self.client.put_json(
|
|
destination=destination,
|
|
path=path,
|
|
data=content,
|
|
)
|
|
|
|
if not 200 <= code < 300:
|
|
raise RuntimeError("Got %d from send_join", code)
|
|
|
|
defer.returnValue(json.loads(content))
|
|
|
|
@defer.inlineCallbacks
|
|
@log_function
|
|
def send_invite(self, destination, context, pdu_id, origin, content):
|
|
path = PREFIX + "/invite/%s/%s/%s" % (
|
|
context,
|
|
origin,
|
|
pdu_id,
|
|
)
|
|
|
|
code, content = yield self.client.put_json(
|
|
destination=destination,
|
|
path=path,
|
|
data=content,
|
|
)
|
|
|
|
if not 200 <= code < 300:
|
|
raise RuntimeError("Got %d from send_invite", code)
|
|
|
|
defer.returnValue(json.loads(content))
|
|
|
|
@defer.inlineCallbacks
|
|
def _authenticate_request(self, request):
|
|
json_request = {
|
|
"method": request.method,
|
|
"uri": request.uri,
|
|
"destination": self.server_name,
|
|
"signatures": {},
|
|
}
|
|
|
|
content = None
|
|
origin = None
|
|
|
|
if request.method == "PUT":
|
|
#TODO: Handle other method types? other content types?
|
|
try:
|
|
content_bytes = request.content.read()
|
|
content = json.loads(content_bytes)
|
|
json_request["content"] = content
|
|
except:
|
|
raise SynapseError(400, "Unable to parse JSON", Codes.BAD_JSON)
|
|
|
|
def parse_auth_header(header_str):
|
|
try:
|
|
params = auth.split(" ")[1].split(",")
|
|
param_dict = dict(kv.split("=") for kv in params)
|
|
def strip_quotes(value):
|
|
if value.startswith("\""):
|
|
return value[1:-1]
|
|
else:
|
|
return value
|
|
origin = strip_quotes(param_dict["origin"])
|
|
key = strip_quotes(param_dict["key"])
|
|
sig = strip_quotes(param_dict["sig"])
|
|
return (origin, key, sig)
|
|
except:
|
|
raise SynapseError(
|
|
400, "Malformed Authorization header", Codes.UNAUTHORIZED
|
|
)
|
|
|
|
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
|
|
|
|
if not auth_headers:
|
|
raise SynapseError(
|
|
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
|
|
)
|
|
|
|
for auth in auth_headers:
|
|
if auth.startswith("X-Matrix"):
|
|
(origin, key, sig) = parse_auth_header(auth)
|
|
json_request["origin"] = origin
|
|
json_request["signatures"].setdefault(origin,{})[key] = sig
|
|
|
|
if not json_request["signatures"]:
|
|
raise SynapseError(
|
|
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
|
|
)
|
|
|
|
yield self.keyring.verify_json_for_server(origin, json_request)
|
|
|
|
defer.returnValue((origin, content))
|
|
|
|
def _with_authentication(self, handler):
|
|
@defer.inlineCallbacks
|
|
def new_handler(request, *args, **kwargs):
|
|
try:
|
|
(origin, content) = yield self._authenticate_request(request)
|
|
response = yield handler(
|
|
origin, content, request.args, *args, **kwargs
|
|
)
|
|
except:
|
|
logger.exception("_authenticate_request failed")
|
|
raise
|
|
defer.returnValue(response)
|
|
return new_handler
|
|
|
|
@log_function
|
|
def register_received_handler(self, handler):
|
|
""" Register a handler that will be fired when we receive data.
|
|
|
|
Args:
|
|
handler (TransportReceivedHandler)
|
|
"""
|
|
self.received_handler = handler
|
|
|
|
# This is when someone is trying to send us a bunch of data.
|
|
self.server.register_path(
|
|
"PUT",
|
|
re.compile("^" + PREFIX + "/send/([^/]*)/$"),
|
|
self._with_authentication(self._on_send_request)
|
|
)
|
|
|
|
@log_function
|
|
def register_request_handler(self, handler):
|
|
""" Register a handler that will be fired when we get asked for data.
|
|
|
|
Args:
|
|
handler (TransportRequestHandler)
|
|
"""
|
|
self.request_handler = handler
|
|
|
|
# TODO(markjh): Namespace the federation URI paths
|
|
|
|
# This is for when someone asks us for everything since version X
|
|
self.server.register_path(
|
|
"GET",
|
|
re.compile("^" + PREFIX + "/pull/$"),
|
|
self._with_authentication(
|
|
lambda origin, content, query:
|
|
handler.on_pull_request(query["origin"][0], query["v"])
|
|
)
|
|
)
|
|
|
|
# This is when someone asks for a data item for a given server
|
|
# data_id pair.
|
|
self.server.register_path(
|
|
"GET",
|
|
re.compile("^" + PREFIX + "/pdu/([^/]*)/([^/]*)/$"),
|
|
self._with_authentication(
|
|
lambda origin, content, query, pdu_origin, pdu_id:
|
|
handler.on_pdu_request(pdu_origin, pdu_id)
|
|
)
|
|
)
|
|
|
|
# This is when someone asks for all data for a given context.
|
|
self.server.register_path(
|
|
"GET",
|
|
re.compile("^" + PREFIX + "/state/([^/]*)/$"),
|
|
self._with_authentication(
|
|
lambda origin, content, query, context:
|
|
handler.on_context_state_request(
|
|
context,
|
|
query.get("pdu_id", [None])[0],
|
|
query.get("pdu_origin", [None])[0]
|
|
)
|
|
)
|
|
)
|
|
|
|
self.server.register_path(
|
|
"GET",
|
|
re.compile("^" + PREFIX + "/backfill/([^/]*)/$"),
|
|
self._with_authentication(
|
|
lambda origin, content, query, context:
|
|
self._on_backfill_request(
|
|
context, query["v"], query["limit"]
|
|
)
|
|
)
|
|
)
|
|
|
|
self.server.register_path(
|
|
"GET",
|
|
re.compile("^" + PREFIX + "/context/([^/]*)/$"),
|
|
self._with_authentication(
|
|
lambda origin, content, query, context:
|
|
handler.on_context_pdus_request(context)
|
|
)
|
|
)
|
|
|
|
# This is when we receive a server-server Query
|
|
self.server.register_path(
|
|
"GET",
|
|
re.compile("^" + PREFIX + "/query/([^/]*)$"),
|
|
self._with_authentication(
|
|
lambda origin, content, query, query_type:
|
|
handler.on_query_request(
|
|
query_type, {k: v[0] for k, v in query.items()}
|
|
)
|
|
)
|
|
)
|
|
|
|
self.server.register_path(
|
|
"GET",
|
|
re.compile("^" + PREFIX + "/make_join/([^/]*)/([^/]*)$"),
|
|
self._with_authentication(
|
|
lambda origin, content, query, context, user_id:
|
|
self._on_make_join_request(
|
|
origin, content, query, context, user_id
|
|
)
|
|
)
|
|
)
|
|
|
|
self.server.register_path(
|
|
"PUT",
|
|
re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)/([^/]*)$"),
|
|
self._with_authentication(
|
|
lambda origin, content, query, context, pdu_origin, pdu_id:
|
|
self._on_send_join_request(
|
|
origin, content, query,
|
|
)
|
|
)
|
|
)
|
|
|
|
self.server.register_path(
|
|
"PUT",
|
|
re.compile("^" + PREFIX + "/invite/([^/]*)/([^/]*)/([^/]*)$"),
|
|
self._with_authentication(
|
|
lambda origin, content, query, context, pdu_origin, pdu_id:
|
|
self._on_invite_request(
|
|
origin, content, query,
|
|
)
|
|
)
|
|
)
|
|
|
|
@defer.inlineCallbacks
|
|
@log_function
|
|
def _on_send_request(self, origin, content, query, transaction_id):
|
|
""" Called on PUT /send/<transaction_id>/
|
|
|
|
Args:
|
|
request (twisted.web.http.Request): The HTTP request.
|
|
transaction_id (str): The transaction_id associated with this
|
|
request. This is *not* None.
|
|
|
|
Returns:
|
|
Deferred: Results in a tuple of `(code, response)`, where
|
|
`response` is a python dict to be converted into JSON that is
|
|
used as the response body.
|
|
"""
|
|
# Parse the request
|
|
try:
|
|
transaction_data = content
|
|
|
|
logger.debug(
|
|
"Decoded %s: %s",
|
|
transaction_id, str(transaction_data)
|
|
)
|
|
|
|
# We should ideally be getting this from the security layer.
|
|
# origin = body["origin"]
|
|
|
|
# Add some extra data to the transaction dict that isn't included
|
|
# in the request body.
|
|
transaction_data.update(
|
|
transaction_id=transaction_id,
|
|
destination=self.server_name
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.exception(e)
|
|
defer.returnValue((400, {"error": "Invalid transaction"}))
|
|
return
|
|
|
|
try:
|
|
code, response = yield self.received_handler.on_incoming_transaction(
|
|
transaction_data
|
|
)
|
|
except:
|
|
logger.exception("on_incoming_transaction failed")
|
|
raise
|
|
|
|
defer.returnValue((code, response))
|
|
|
|
@defer.inlineCallbacks
|
|
@log_function
|
|
def _do_request_for_transaction(self, destination, subpath, args={}):
|
|
"""
|
|
Args:
|
|
destination (str)
|
|
path (str)
|
|
args (dict): This is parsed directly to the HttpClient.
|
|
|
|
Returns:
|
|
Deferred: Results in a dict.
|
|
"""
|
|
|
|
data = yield self.client.get_json(
|
|
destination,
|
|
path=PREFIX + subpath,
|
|
args=args,
|
|
)
|
|
|
|
# Add certain keys to the JSON, ready for decoding as a Transaction
|
|
data.update(
|
|
origin=destination,
|
|
destination=self.server_name,
|
|
transaction_id=None
|
|
)
|
|
|
|
defer.returnValue(data)
|
|
|
|
@log_function
|
|
def _on_backfill_request(self, context, v_list, limits):
|
|
if not limits:
|
|
return defer.succeed(
|
|
(400, {"error": "Did not include limit param"})
|
|
)
|
|
|
|
limit = int(limits[-1])
|
|
|
|
versions = [v.split(",", 1) for v in v_list]
|
|
|
|
return self.request_handler.on_backfill_request(
|
|
context, versions, limit
|
|
)
|
|
|
|
@defer.inlineCallbacks
|
|
@log_function
|
|
def _on_make_join_request(self, origin, content, query, context, user_id):
|
|
content = yield self.request_handler.on_make_join_request(
|
|
context, user_id,
|
|
)
|
|
defer.returnValue((200, content))
|
|
|
|
@defer.inlineCallbacks
|
|
@log_function
|
|
def _on_send_join_request(self, origin, content, query):
|
|
content = yield self.request_handler.on_send_join_request(
|
|
origin, content,
|
|
)
|
|
|
|
defer.returnValue((200, content))
|
|
|
|
@defer.inlineCallbacks
|
|
@log_function
|
|
def _on_invite_request(self, origin, content, query):
|
|
content = yield self.request_handler.on_invite_request(
|
|
origin, content,
|
|
)
|
|
|
|
defer.returnValue((200, content))
|
|
|
|
|
|
class TransportReceivedHandler(object):
|
|
""" Callbacks used when we receive a transaction
|
|
"""
|
|
def on_incoming_transaction(self, transaction):
|
|
""" Called on PUT /send/<transaction_id>, or on response to a request
|
|
that we sent (e.g. a backfill request)
|
|
|
|
Args:
|
|
transaction (synapse.transaction.Transaction): The transaction that
|
|
was sent to us.
|
|
|
|
Returns:
|
|
twisted.internet.defer.Deferred: A deferred that gets fired when
|
|
the transaction has finished being processed.
|
|
|
|
The result should be a tuple in the form of
|
|
`(response_code, respond_body)`, where `response_body` is a python
|
|
dict that will get serialized to JSON.
|
|
|
|
On errors, the dict should have an `error` key with a brief message
|
|
of what went wrong.
|
|
"""
|
|
pass
|
|
|
|
|
|
class TransportRequestHandler(object):
|
|
""" Handlers used when someone want's data from us
|
|
"""
|
|
def on_pull_request(self, versions):
|
|
""" Called on GET /pull/?v=...
|
|
|
|
This is hit when a remote home server wants to get all data
|
|
after a given transaction. Mainly used when a home server comes back
|
|
online and wants to get everything it has missed.
|
|
|
|
Args:
|
|
versions (list): A list of transaction_ids that should be used to
|
|
determine what PDUs the remote side have not yet seen.
|
|
|
|
Returns:
|
|
Deferred: Resultsin a tuple in the form of
|
|
`(response_code, respond_body)`, where `response_body` is a python
|
|
dict that will get serialized to JSON.
|
|
|
|
On errors, the dict should have an `error` key with a brief message
|
|
of what went wrong.
|
|
"""
|
|
pass
|
|
|
|
def on_pdu_request(self, pdu_origin, pdu_id):
|
|
""" Called on GET /pdu/<pdu_origin>/<pdu_id>/
|
|
|
|
Someone wants a particular PDU. This PDU may or may not have originated
|
|
from us.
|
|
|
|
Args:
|
|
pdu_origin (str)
|
|
pdu_id (str)
|
|
|
|
Returns:
|
|
Deferred: Resultsin a tuple in the form of
|
|
`(response_code, respond_body)`, where `response_body` is a python
|
|
dict that will get serialized to JSON.
|
|
|
|
On errors, the dict should have an `error` key with a brief message
|
|
of what went wrong.
|
|
"""
|
|
pass
|
|
|
|
def on_context_state_request(self, context):
|
|
""" Called on GET /state/<context>/
|
|
|
|
Gets hit when someone wants all the *current* state for a given
|
|
contexts.
|
|
|
|
Args:
|
|
context (str): The name of the context that we're interested in.
|
|
|
|
Returns:
|
|
twisted.internet.defer.Deferred: A deferred that gets fired when
|
|
the transaction has finished being processed.
|
|
|
|
The result should be a tuple in the form of
|
|
`(response_code, respond_body)`, where `response_body` is a python
|
|
dict that will get serialized to JSON.
|
|
|
|
On errors, the dict should have an `error` key with a brief message
|
|
of what went wrong.
|
|
"""
|
|
pass
|
|
|
|
def on_backfill_request(self, context, versions, limit):
|
|
""" Called on GET /backfill/<context>/?v=...&limit=...
|
|
|
|
Gets hit when we want to backfill backwards on a given context from
|
|
the given point.
|
|
|
|
Args:
|
|
context (str): The context to backfill
|
|
versions (list): A list of 2-tuples representing where to backfill
|
|
from, in the form `(pdu_id, origin)`
|
|
limit (int): How many pdus to return.
|
|
|
|
Returns:
|
|
Deferred: Results in a tuple in the form of
|
|
`(response_code, respond_body)`, where `response_body` is a python
|
|
dict that will get serialized to JSON.
|
|
|
|
On errors, the dict should have an `error` key with a brief message
|
|
of what went wrong.
|
|
"""
|
|
pass
|
|
|
|
def on_query_request(self):
|
|
""" Called on a GET /query/<query_type> request. """
|