Merge branch 'develop' of github.com:matrix-org/synapse into erikj/batch_edus

This commit is contained in:
Erik Johnston 2016-09-09 18:06:01 +01:00
commit 3265def8c7
9 changed files with 148 additions and 33 deletions

View File

@ -583,12 +583,15 @@ class Auth(object):
""" """
# Can optionally look elsewhere in the request (e.g. headers) # Can optionally look elsewhere in the request (e.g. headers)
try: try:
user_id = yield self._get_appservice_user_id(request.args) user_id = yield self._get_appservice_user_id(request)
if user_id: if user_id:
request.authenticated_entity = user_id request.authenticated_entity = user_id
defer.returnValue(synapse.types.create_requester(user_id)) defer.returnValue(synapse.types.create_requester(user_id))
access_token = request.args["access_token"][0] access_token = get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
user_info = yield self.get_user_by_access_token(access_token, rights) user_info = yield self.get_user_by_access_token(access_token, rights)
user = user_info["user"] user = user_info["user"]
token_id = user_info["token_id"] token_id = user_info["token_id"]
@ -629,17 +632,19 @@ class Auth(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_appservice_user_id(self, request_args): def _get_appservice_user_id(self, request):
app_service = yield self.store.get_app_service_by_token( app_service = yield self.store.get_app_service_by_token(
request_args["access_token"][0] get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
) )
if app_service is None: if app_service is None:
defer.returnValue(None) defer.returnValue(None)
if "user_id" not in request_args: if "user_id" not in request.args:
defer.returnValue(app_service.sender) defer.returnValue(app_service.sender)
user_id = request_args["user_id"][0] user_id = request.args["user_id"][0]
if app_service.sender == user_id: if app_service.sender == user_id:
defer.returnValue(app_service.sender) defer.returnValue(app_service.sender)
@ -833,7 +838,9 @@ class Auth(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_appservice_by_req(self, request): def get_appservice_by_req(self, request):
try: try:
token = request.args["access_token"][0] token = get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
service = yield self.store.get_app_service_by_token(token) service = yield self.store.get_app_service_by_token(token)
if not service: if not service:
logger.warn("Unrecognised appservice access token: %s" % (token,)) logger.warn("Unrecognised appservice access token: %s" % (token,))
@ -1142,3 +1149,40 @@ class Auth(object):
"This server requires you to be a moderator in the room to" "This server requires you to be a moderator in the room to"
" edit its room list entry" " edit its room list entry"
) )
def has_access_token(request):
"""Checks if the request has an access_token.
Returns:
bool: False if no access_token was given, True otherwise.
"""
query_params = request.args.get("access_token")
return bool(query_params)
def get_access_token_from_request(request, token_not_found_http_status=401):
"""Extracts the access_token from the request.
Args:
request: The http request.
token_not_found_http_status(int): The HTTP status code to set in the
AuthError if the token isn't found. This is used in some of the
legacy APIs to change the status code to 403 from the default of
401 since some of the old clients depended on auth errors returning
403.
Returns:
str: The access_token
Raises:
AuthError: If there isn't an access_token in the request.
"""
query_params = request.args.get("access_token")
# Try to get the access_token from the query params.
if not query_params:
raise AuthError(
token_not_found_http_status,
"Missing access token.",
errcode=Codes.MISSING_TOKEN
)
return query_params[0]

View File

@ -32,6 +32,14 @@ HOUR_IN_MS = 60 * 60 * 1000
APP_SERVICE_PREFIX = "/_matrix/app/unstable" APP_SERVICE_PREFIX = "/_matrix/app/unstable"
def _is_valid_3pe_metadata(info):
if "instances" not in info:
return False
if not isinstance(info["instances"], list):
return False
return True
def _is_valid_3pe_result(r, field): def _is_valid_3pe_result(r, field):
if not isinstance(r, dict): if not isinstance(r, dict):
return False return False
@ -162,11 +170,18 @@ class ApplicationServiceApi(SimpleHttpClient):
urllib.quote(protocol) urllib.quote(protocol)
) )
try: try:
defer.returnValue((yield self.get_json(uri, {}))) info = yield self.get_json(uri, {})
if not _is_valid_3pe_metadata(info):
logger.warning("query_3pe_protocol to %s did not return a"
" valid result", uri)
defer.returnValue(None)
defer.returnValue(info)
except Exception as ex: except Exception as ex:
logger.warning("query_3pe_protocol to %s threw exception %s", logger.warning("query_3pe_protocol to %s threw exception %s",
uri, ex) uri, ex)
defer.returnValue({}) defer.returnValue(None)
key = (service.id, protocol) key = (service.id, protocol)
return self.protocol_meta_cache.get(key) or ( return self.protocol_meta_cache.get(key) or (

View File

@ -176,12 +176,41 @@ class ApplicationServicesHandler(object):
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_3pe_protocols(self): def get_3pe_protocols(self, only_protocol=None):
services = yield self.store.get_app_services() services = yield self.store.get_app_services()
protocols = {} protocols = {}
# Collect up all the individual protocol responses out of the ASes
for s in services: for s in services:
for p in s.protocols: for p in s.protocols:
protocols[p] = yield self.appservice_api.get_3pe_protocol(s, p) if only_protocol is not None and p != only_protocol:
continue
if p not in protocols:
protocols[p] = []
info = yield self.appservice_api.get_3pe_protocol(s, p)
if info is not None:
protocols[p].append(info)
def _merge_instances(infos):
if not infos:
return {}
# Merge the 'instances' lists of multiple results, but just take
# the other fields from the first as they ought to be identical
# copy the result so as not to corrupt the cached one
combined = dict(infos[0])
combined["instances"] = list(combined["instances"])
for info in infos[1:]:
combined["instances"].extend(info["instances"])
return combined
for p in protocols.keys():
protocols[p] = _merge_instances(protocols[p])
defer.returnValue(protocols) defer.returnValue(protocols)

View File

@ -15,7 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import AuthError, Codes from synapse.api.auth import get_access_token_from_request
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
@ -37,13 +37,7 @@ class LogoutRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
try: access_token = get_access_token_from_request(request)
access_token = request.args["access_token"][0]
except KeyError:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
errcode=Codes.MISSING_TOKEN
)
yield self.store.delete_access_token(access_token) yield self.store.delete_access_token(access_token)
defer.returnValue((200, {})) defer.returnValue((200, {}))

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.auth import get_access_token_from_request
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
@ -296,12 +297,11 @@ class RegisterRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_app_service(self, request, register_json, session): def _do_app_service(self, request, register_json, session):
if "access_token" not in request.args: as_token = get_access_token_from_request(request)
raise SynapseError(400, "Expected application service token.")
if "user" not in register_json: if "user" not in register_json:
raise SynapseError(400, "Expected 'user' key.") raise SynapseError(400, "Expected 'user' key.")
as_token = request.args["access_token"][0]
user_localpart = register_json["user"].encode("utf-8") user_localpart = register_json["user"].encode("utf-8")
handler = self.handlers.registration_handler handler = self.handlers.registration_handler
@ -390,11 +390,9 @@ class CreateUserRestServlet(ClientV1RestServlet):
def on_POST(self, request): def on_POST(self, request):
user_json = parse_json_object_from_request(request) user_json = parse_json_object_from_request(request)
if "access_token" not in request.args: access_token = get_access_token_from_request(request)
raise SynapseError(400, "Expected application service token.")
app_service = yield self.store.get_app_service_by_token( app_service = yield self.store.get_app_service_by_token(
request.args["access_token"][0] access_token
) )
if not app_service: if not app_service:
raise SynapseError(403, "Invalid application service token.") raise SynapseError(403, "Invalid application service token.")

View File

@ -22,7 +22,7 @@ from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.types import UserID, RoomID, RoomAlias from synapse.types import UserID, RoomID, RoomAlias
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event, format_event_for_client_v2
from synapse.http.servlet import parse_json_object_from_request, parse_string from synapse.http.servlet import parse_json_object_from_request, parse_string
import logging import logging
@ -120,6 +120,8 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, event_type, state_key): def on_GET(self, request, room_id, event_type, state_key):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
format = parse_string(request, "format", default="content",
allowed_values=["content", "event"])
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
data = yield msg_handler.get_room_data( data = yield msg_handler.get_room_data(
@ -134,7 +136,12 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
raise SynapseError( raise SynapseError(
404, "Event not found.", errcode=Codes.NOT_FOUND 404, "Event not found.", errcode=Codes.NOT_FOUND
) )
defer.returnValue((200, data.get_dict()["content"]))
if format == "event":
event = format_event_for_client_v2(data.get_dict())
defer.returnValue((200, event))
elif format == "content":
defer.returnValue((200, data.get_dict()["content"]))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):

View File

@ -17,6 +17,8 @@
to ensure idempotency when performing PUTs using the REST API.""" to ensure idempotency when performing PUTs using the REST API."""
import logging import logging
from synapse.api.auth import get_access_token_from_request
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -90,6 +92,6 @@ class HttpTransactionStore(object):
return response return response
def _get_key(self, request): def _get_key(self, request):
token = request.args["access_token"][0] token = get_access_token_from_request(request)
path_without_txn_id = request.path.rsplit("/", 1)[0] path_without_txn_id = request.path.rsplit("/", 1)[0]
return path_without_txn_id + "/" + token return path_without_txn_id + "/" + token

View File

@ -15,6 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.auth import get_access_token_from_request, has_access_token
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
@ -131,7 +132,7 @@ class RegisterRestServlet(RestServlet):
desired_username = body['username'] desired_username = body['username']
appservice = None appservice = None
if 'access_token' in request.args: if has_access_token(request):
appservice = yield self.auth.get_appservice_by_req(request) appservice = yield self.auth.get_appservice_by_req(request)
# fork off as soon as possible for ASes and shared secret auth which # fork off as soon as possible for ASes and shared secret auth which
@ -143,10 +144,11 @@ class RegisterRestServlet(RestServlet):
# 'user' key not 'username'). Since this is a new addition, we'll # 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one. # fallback to 'username' if they gave one.
desired_username = body.get("user", desired_username) desired_username = body.get("user", desired_username)
access_token = get_access_token_from_request(request)
if isinstance(desired_username, basestring): if isinstance(desired_username, basestring):
result = yield self._do_appservice_registration( result = yield self._do_appservice_registration(
desired_username, request.args["access_token"][0], body desired_username, access_token, body
) )
defer.returnValue((200, result)) # we throw for non 200 responses defer.returnValue((200, result)) # we throw for non 200 responses
return return

View File

@ -42,6 +42,29 @@ class ThirdPartyProtocolsServlet(RestServlet):
defer.returnValue((200, protocols)) defer.returnValue((200, protocols))
class ThirdPartyProtocolServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$",
releases=())
def __init__(self, hs):
super(ThirdPartyProtocolServlet, self).__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
@defer.inlineCallbacks
def on_GET(self, request, protocol):
yield self.auth.get_user_by_req(request)
protocols = yield self.appservice_handler.get_3pe_protocols(
only_protocol=protocol,
)
if protocol in protocols:
defer.returnValue((200, protocols[protocol]))
else:
defer.returnValue((404, {"error": "Unknown protocol"}))
class ThirdPartyUserServlet(RestServlet): class ThirdPartyUserServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$", PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$",
releases=()) releases=())
@ -57,7 +80,7 @@ class ThirdPartyUserServlet(RestServlet):
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request)
fields = request.args fields = request.args
del fields["access_token"] fields.pop("access_token", None)
results = yield self.appservice_handler.query_3pe( results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.USER, protocol, fields ThirdPartyEntityKind.USER, protocol, fields
@ -81,7 +104,7 @@ class ThirdPartyLocationServlet(RestServlet):
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request)
fields = request.args fields = request.args
del fields["access_token"] fields.pop("access_token", None)
results = yield self.appservice_handler.query_3pe( results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.LOCATION, protocol, fields ThirdPartyEntityKind.LOCATION, protocol, fields
@ -92,5 +115,6 @@ class ThirdPartyLocationServlet(RestServlet):
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
ThirdPartyProtocolsServlet(hs).register(http_server) ThirdPartyProtocolsServlet(hs).register(http_server)
ThirdPartyProtocolServlet(hs).register(http_server)
ThirdPartyUserServlet(hs).register(http_server) ThirdPartyUserServlet(hs).register(http_server)
ThirdPartyLocationServlet(hs).register(http_server) ThirdPartyLocationServlet(hs).register(http_server)