Slightly neater(?) arrangement of authentication wrapper for HTTP servlet methods

This commit is contained in:
Paul "LeoNerd" Evans 2015-03-05 20:33:16 +00:00
parent ba8ac996f9
commit 7644cb79b2

View File

@ -19,6 +19,7 @@ from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
import functools
import logging import logging
import simplejson as json import simplejson as json
import re import re
@ -30,8 +31,9 @@ logger = logging.getLogger(__name__)
class TransportLayerServer(object): class TransportLayerServer(object):
"""Handles incoming federation HTTP requests""" """Handles incoming federation HTTP requests"""
# A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks @defer.inlineCallbacks
def _authenticate_request(self, request): def authenticate_request(self, request):
json_request = { json_request = {
"method": request.method, "method": request.method,
"uri": request.uri, "uri": request.uri,
@ -93,22 +95,6 @@ class TransportLayerServer(object):
defer.returnValue((origin, content)) defer.returnValue((origin, content))
def _with_authentication(self, handler):
@defer.inlineCallbacks
def new_handler(request, *args, **kwargs):
try:
(origin, content) = yield self._authenticate_request(request)
with self.ratelimiter.ratelimit(origin) as d:
yield d
response = yield handler(
origin, content, request.args, *args, **kwargs
)
except:
logger.exception("_authenticate_request failed")
raise
defer.returnValue(response)
return new_handler
@log_function @log_function
def register_received_handler(self, handler): def register_received_handler(self, handler):
""" Register a handler that will be fired when we receive data. """ Register a handler that will be fired when we receive data.
@ -116,8 +102,10 @@ class TransportLayerServer(object):
Args: Args:
handler (TransportReceivedHandler) handler (TransportReceivedHandler)
""" """
FederationSendServlet( FederationSendServlet(handler,
handler, self._with_authentication, self.server_name authenticator=self,
ratelimiter=self.ratelimiter,
server_name=self.server_name,
).register(self.server) ).register(self.server)
@log_function @log_function
@ -140,13 +128,37 @@ class TransportLayerServer(object):
FederationQueryAuthServlet, FederationQueryAuthServlet,
FederationGetMissingEventsServlet, FederationGetMissingEventsServlet,
): ):
servletclass(handler, self._with_authentication).register(self.server) servletclass(handler,
authenticator=self,
ratelimiter=self.ratelimiter,
).register(self.server)
class BaseFederationServlet(object): class BaseFederationServlet(object):
def __init__(self, handler, wrapper): def __init__(self, handler, authenticator, ratelimiter):
self.handler = handler self.handler = handler
self.wrapper = wrapper self.authenticator = authenticator
self.ratelimiter = ratelimiter
def _wrap(self, code):
authenticator = self.authenticator
ratelimiter = self.ratelimiter
@defer.inlineCallbacks
@functools.wraps(code)
def new_code(request, *args, **kwargs):
try:
(origin, content) = yield authenticator.authenticate_request(request)
with ratelimiter.ratelimit(origin) as d:
yield d
response = yield code(
origin, content, request.args, *args, **kwargs
)
except:
logger.exception("authenticate_request failed")
raise
defer.returnValue(response)
return new_code
def register(self, server): def register(self, server):
pattern = re.compile("^" + PREFIX + self.PATH) pattern = re.compile("^" + PREFIX + self.PATH)
@ -156,14 +168,14 @@ class BaseFederationServlet(object):
if code is None: if code is None:
continue continue
server.register_path(method, pattern, self.wrapper(code)) server.register_path(method, pattern, self._wrap(code))
class FederationSendServlet(BaseFederationServlet): class FederationSendServlet(BaseFederationServlet):
PATH = "/send/([^/]*)/$" PATH = "/send/([^/]*)/$"
def __init__(self, handler, wrapper, server_name): def __init__(self, handler, server_name, **kwargs):
super(FederationSendServlet, self).__init__(handler, wrapper) super(FederationSendServlet, self).__init__(handler, **kwargs)
self.server_name = server_name self.server_name = server_name
# This is when someone is trying to send us a bunch of data. # This is when someone is trying to send us a bunch of data.