Verify signatures for server2server requests

This commit is contained in:
Mark Haines 2014-10-13 14:37:46 +01:00
parent 10ef8e6e4b
commit 6684855767
5 changed files with 100 additions and 25 deletions

View File

@ -22,6 +22,7 @@ from .transport import TransportLayer
def initialize_http_replication(homeserver):
transport = TransportLayer(
homeserver,
homeserver.hostname,
server=homeserver.get_resource_for_federation(),
client=homeserver.get_http_client()

View File

@ -54,7 +54,7 @@ class TransportLayer(object):
we receive data.
"""
def __init__(self, server_name, server, client):
def __init__(self, homeserver, server_name, server, client):
"""
Args:
server_name (str): Local home server host
@ -63,6 +63,7 @@ class TransportLayer(object):
client (synapse.protocol.http.HttpClient): the http client used to
send requests
"""
self.keyring = homeserver.get_keyring()
self.server_name = server_name
self.server = server
self.client = client
@ -195,6 +196,66 @@ class TransportLayer(object):
defer.returnValue(response)
@defer.inlineCallbacks
def _authenticate_request(self, request):
json_request = {
"method": request.method,
"uri": request.uri,
"destination": self.server_name,
"signatures": {},
}
content = None
origin = None
if request.method == "PUT":
#TODO: Handle other method types? other content types?
content_bytes = request.content.read()
content = json.loads(content_bytes)
json_request["content"] = content
def parse_auth_header(header_str):
params = auth.split(" ")[1].split(",")
param_dict = dict(kv.split("=") for kv in params)
def strip_quotes(value):
if value.startswith("\""):
return value[1:-1]
else:
return value
origin = strip_quotes(param_dict["origin"])
key = strip_quotes(param_dict["key"])
sig = strip_quotes(param_dict["sig"])
return (origin, key, sig)
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
if not auth_headers:
#TODO(markjh): Send a 401 response?
raise Exception("Missing auth headers")
for auth in auth_headers:
if auth.startswith("X-Matrix"):
(origin, key, sig) = parse_auth_header(auth)
json_request["origin"] = origin
json_request["signatures"].setdefault(origin,{})[key] = sig
from syutil.jsonutil import encode_canonical_json
logger.debug("Checking %s %s",
origin, encode_canonical_json(json_request))
yield self.keyring.verify_json_for_server(origin, json_request)
defer.returnValue((origin, content))
def _with_authentication(self, handler):
@defer.inlineCallbacks
def new_handler(request, *args, **kwargs):
(origin, content) = yield self._authenticate_request(request)
response = yield handler(
origin, content, request.args, *args, **kwargs
)
defer.returnValue(response)
return new_handler
@log_function
def register_received_handler(self, handler):
""" Register a handler that will be fired when we receive data.
@ -208,7 +269,7 @@ class TransportLayer(object):
self.server.register_path(
"PUT",
re.compile("^" + PREFIX + "/send/([^/]*)/$"),
self._on_send_request
self._with_authentication(self._on_send_request)
)
@log_function
@ -226,9 +287,9 @@ class TransportLayer(object):
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/pull/$"),
lambda request: handler.on_pull_request(
request.args["origin"][0],
request.args["v"]
self._with_authentication(
lambda origin, content, query:
handler.on_pull_request(query["origin"][0], query["v"])
)
)
@ -237,8 +298,9 @@ class TransportLayer(object):
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/pdu/([^/]*)/([^/]*)/$"),
lambda request, pdu_origin, pdu_id: handler.on_pdu_request(
pdu_origin, pdu_id
self._with_authentication(
lambda origin, content, query, pdu_origin, pdu_id:
handler.on_pdu_request(pdu_origin, pdu_id)
)
)
@ -246,38 +308,47 @@ class TransportLayer(object):
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/state/([^/]*)/$"),
lambda request, context: handler.on_context_state_request(
context
self._with_authentication(
lambda origin, content, query, context:
handler.on_context_state_request(context)
)
)
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/backfill/([^/]*)/$"),
lambda request, context: self._on_backfill_request(
context, request.args["v"],
request.args["limit"]
self._with_authentication(
lambda origin, content, query, context:
self._on_backfill_request(
context, query["v"], query["limit"]
)
)
)
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/context/([^/]*)/$"),
lambda request, context: handler.on_context_pdus_request(context)
self._with_authentication(
lambda origin, content, query, context:
handler.on_context_pdus_request(context)
)
)
# This is when we receive a server-server Query
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/query/([^/]*)$"),
lambda request, query_type: handler.on_query_request(
query_type, {k: v[0] for k, v in request.args.items()}
self._with_authentication(
lambda origin, content, query, query_type:
handler.on_query_request(
query_type, {k: v[0] for k, v in query.items()}
)
)
)
@defer.inlineCallbacks
@log_function
def _on_send_request(self, request, transaction_id):
def _on_send_request(self, origin, content, query, transaction_id):
""" Called on PUT /send/<transaction_id>/
Args:
@ -292,12 +363,7 @@ class TransportLayer(object):
"""
# Parse the request
try:
data = request.content.read()
l = data[:20].encode("string_escape")
logger.debug("Got data: \"%s\"", l)
transaction_data = json.loads(data)
transaction_data = content
logger.debug(
"Decoded %s: %s",

View File

@ -177,16 +177,20 @@ class MatrixHttpClient(BaseHttpClient):
request = sign_json(request, self.server_name, self.signing_key)
from syutil.jsonutil import encode_canonical_json
logger.debug("Signing " + " " * 11 + "%s %s",
self.server_name, encode_canonical_json(request))
auth_headers = []
for key,sig in request["signatures"][self.server_name].items():
auth_headers.append(
auth_headers.append(bytes(
"X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
self.server_name, key, sig,
)
)
))
headers_dict["Authorization"] = auth_headers
headers_dict[b"Authorization"] = auth_headers
@defer.inlineCallbacks
def put_json(self, destination, path, data={}, json_data_callback=None):

View File

@ -221,6 +221,7 @@ class FederationTestCase(unittest.TestCase):
json_data_callback=ANY,
)
@defer.inlineCallbacks
def test_recv_edu(self):
recv_observer = Mock()

View File

@ -76,6 +76,9 @@ class MockHttpResource(HttpServer):
mock_content.configure_mock(**config)
mock_request.content = mock_content
mock_request.method = http_method
mock_request.uri = path
# return the right path if the event requires it
mock_request.path = path