Merge branch 'develop' into babolivier/port_db_account_validity

This commit is contained in:
Brendan Abolivier 2019-06-04 09:13:42 +01:00
commit aeb2263320
51 changed files with 529 additions and 417 deletions

1
changelog.d/5226.misc Normal file
View File

@ -0,0 +1 @@
The base classes for the v1 and v2_alpha REST APIs have been unified.

1
changelog.d/5307.bugfix Normal file
View File

@ -0,0 +1 @@
Fix bug where a notary server would sometimes forget old keys.

1
changelog.d/5321.bugfix Normal file
View File

@ -0,0 +1 @@
Ensure that we have an up-to-date copy of the signing key when validating incoming federation requests.

1
changelog.d/5328.misc Normal file
View File

@ -0,0 +1 @@
The base classes for the v1 and v2_alpha REST APIs have been unified.

View File

@ -20,9 +20,7 @@ class CallVisitor(ast.NodeVisitor):
else:
return
if name == "client_path_patterns":
PATTERNS_V1.append(node.args[0].s)
elif name == "client_v2_patterns":
if name == "client_patterns":
PATTERNS_V2.append(node.args[0].s)

View File

@ -37,8 +37,7 @@ from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.client.v1.base import ClientV1RestServlet, client_path_patterns
from synapse.rest.client.v2_alpha._base import client_v2_patterns
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
@ -49,11 +48,11 @@ from synapse.util.versionstring import get_version_string
logger = logging.getLogger("synapse.app.frontend_proxy")
class PresenceStatusStubServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/presence/(?P<user_id>[^/]*)/status")
class PresenceStatusStubServlet(RestServlet):
PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status")
def __init__(self, hs):
super(PresenceStatusStubServlet, self).__init__(hs)
super(PresenceStatusStubServlet, self).__init__()
self.http_client = hs.get_simple_http_client()
self.auth = hs.get_auth()
self.main_uri = hs.config.worker_main_http_uri
@ -84,7 +83,7 @@ class PresenceStatusStubServlet(ClientV1RestServlet):
class KeyUploadServlet(RestServlet):
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
def __init__(self, hs):
"""

View File

@ -15,6 +15,7 @@
# limitations under the License.
import logging
from collections import defaultdict
import six
from six import raise_from
@ -70,6 +71,9 @@ class VerifyKeyRequest(object):
json_object(dict): The JSON object to verify.
minimum_valid_until_ts (int): time at which we require the signing key to
be valid. (0 implies we don't care)
deferred(Deferred[str, str, nacl.signing.VerifyKey]):
A deferred (server_name, key_id, verify_key) tuple that resolves when
a verify key has been fetched. The deferreds' callbacks are run with no
@ -82,7 +86,8 @@ class VerifyKeyRequest(object):
server_name = attr.ib()
key_ids = attr.ib()
json_object = attr.ib()
deferred = attr.ib()
minimum_valid_until_ts = attr.ib()
deferred = attr.ib(default=attr.Factory(defer.Deferred))
class KeyLookupError(ValueError):
@ -90,14 +95,16 @@ class KeyLookupError(ValueError):
class Keyring(object):
def __init__(self, hs):
def __init__(self, hs, key_fetchers=None):
self.clock = hs.get_clock()
self._key_fetchers = (
StoreKeyFetcher(hs),
PerspectivesKeyFetcher(hs),
ServerKeyFetcher(hs),
)
if key_fetchers is None:
key_fetchers = (
StoreKeyFetcher(hs),
PerspectivesKeyFetcher(hs),
ServerKeyFetcher(hs),
)
self._key_fetchers = key_fetchers
# map from server name to Deferred. Has an entry for each server with
# an ongoing key download; the Deferred completes once the download
@ -106,9 +113,25 @@ class Keyring(object):
# These are regular, logcontext-agnostic Deferreds.
self.key_downloads = {}
def verify_json_for_server(self, server_name, json_object):
def verify_json_for_server(self, server_name, json_object, validity_time):
"""Verify that a JSON object has been signed by a given server
Args:
server_name (str): name of the server which must have signed this object
json_object (dict): object to be checked
validity_time (int): timestamp at which we require the signing key to
be valid. (0 implies we don't care)
Returns:
Deferred[None]: completes if the the object was correctly signed, otherwise
errbacks with an error
"""
req = server_name, json_object, validity_time
return logcontext.make_deferred_yieldable(
self.verify_json_objects_for_server([(server_name, json_object)])[0]
self.verify_json_objects_for_server((req,))[0]
)
def verify_json_objects_for_server(self, server_and_json):
@ -116,10 +139,12 @@ class Keyring(object):
necessary.
Args:
server_and_json (list): List of pairs of (server_name, json_object)
server_and_json (iterable[Tuple[str, dict, int]):
Iterable of triplets of (server_name, json_object, validity_time)
validity_time is a timestamp at which the signing key must be valid.
Returns:
List<Deferred>: for each input pair, a deferred indicating success
List<Deferred[None]>: for each input triplet, a deferred indicating success
or failure to verify each json object's signature for the given
server_name. The deferreds run their callbacks in the sentinel
logcontext.
@ -128,12 +153,12 @@ class Keyring(object):
verify_requests = []
handle = preserve_fn(_handle_key_deferred)
def process(server_name, json_object):
def process(server_name, json_object, validity_time):
"""Process an entry in the request list
Given a (server_name, json_object) pair from the request list,
adds a key request to verify_requests, and returns a deferred which will
complete or fail (in the sentinel context) when verification completes.
Given a (server_name, json_object, validity_time) triplet from the request
list, adds a key request to verify_requests, and returns a deferred which
will complete or fail (in the sentinel context) when verification completes.
"""
key_ids = signature_ids(json_object, server_name)
@ -148,7 +173,7 @@ class Keyring(object):
# add the key request to the queue, but don't start it off yet.
verify_request = VerifyKeyRequest(
server_name, key_ids, json_object, defer.Deferred()
server_name, key_ids, json_object, validity_time
)
verify_requests.append(verify_request)
@ -160,8 +185,8 @@ class Keyring(object):
return handle(verify_request)
results = [
process(server_name, json_object)
for server_name, json_object in server_and_json
process(server_name, json_object, validity_time)
for server_name, json_object, validity_time in server_and_json
]
if verify_requests:
@ -298,8 +323,12 @@ class Keyring(object):
verify_request.deferred.errback(
SynapseError(
401,
"No key for %s with id %s"
% (verify_request.server_name, verify_request.key_ids),
"No key for %s with ids in %s (min_validity %i)"
% (
verify_request.server_name,
verify_request.key_ids,
verify_request.minimum_valid_until_ts,
),
Codes.UNAUTHORIZED,
)
)
@ -323,18 +352,28 @@ class Keyring(object):
Args:
fetcher (KeyFetcher): fetcher to use to fetch the keys
remaining_requests (set[VerifyKeyRequest]): outstanding key requests.
Any successfully-completed requests will be reomved from the list.
Any successfully-completed requests will be removed from the list.
"""
# dict[str, set(str)]: keys to fetch for each server
missing_keys = {}
# dict[str, dict[str, int]]: keys to fetch.
# server_name -> key_id -> min_valid_ts
missing_keys = defaultdict(dict)
for verify_request in remaining_requests:
# any completed requests should already have been removed
assert not verify_request.deferred.called
missing_keys.setdefault(verify_request.server_name, set()).update(
verify_request.key_ids
)
keys_for_server = missing_keys[verify_request.server_name]
results = yield fetcher.get_keys(missing_keys.items())
for key_id in verify_request.key_ids:
# If we have several requests for the same key, then we only need to
# request that key once, but we should do so with the greatest
# min_valid_until_ts of the requests, so that we can satisfy all of
# the requests.
keys_for_server[key_id] = max(
keys_for_server.get(key_id, -1),
verify_request.minimum_valid_until_ts
)
results = yield fetcher.get_keys(missing_keys)
completed = list()
for verify_request in remaining_requests:
@ -344,25 +383,34 @@ class Keyring(object):
# complete this VerifyKeyRequest.
result_keys = results.get(server_name, {})
for key_id in verify_request.key_ids:
key = result_keys.get(key_id)
if key:
with PreserveLoggingContext():
verify_request.deferred.callback(
(server_name, key_id, key.verify_key)
)
completed.append(verify_request)
break
fetch_key_result = result_keys.get(key_id)
if not fetch_key_result:
# we didn't get a result for this key
continue
if (
fetch_key_result.valid_until_ts
< verify_request.minimum_valid_until_ts
):
# key was not valid at this point
continue
with PreserveLoggingContext():
verify_request.deferred.callback(
(server_name, key_id, fetch_key_result.verify_key)
)
completed.append(verify_request)
break
remaining_requests.difference_update(completed)
class KeyFetcher(object):
def get_keys(self, server_name_and_key_ids):
def get_keys(self, keys_to_fetch):
"""
Args:
server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]):
list of (server_name, iterable[key_id]) tuples to fetch keys for
Note that the iterables may be iterated more than once.
keys_to_fetch (dict[str, dict[str, int]]):
the keys to be fetched. server_name -> key_id -> min_valid_ts
Returns:
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
@ -378,13 +426,15 @@ class StoreKeyFetcher(KeyFetcher):
self.store = hs.get_datastore()
@defer.inlineCallbacks
def get_keys(self, server_name_and_key_ids):
def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys"""
keys_to_fetch = (
(server_name, key_id)
for server_name, key_ids in server_name_and_key_ids
for key_id in key_ids
for server_name, keys_for_server in keys_to_fetch.items()
for key_id in keys_for_server.keys()
)
res = yield self.store.get_server_verify_keys(keys_to_fetch)
keys = {}
for (server_name, key_id), key in res.items():
@ -399,7 +449,7 @@ class BaseV2KeyFetcher(object):
@defer.inlineCallbacks
def process_v2_response(
self, from_server, response_json, time_added_ms, requested_ids=[]
self, from_server, response_json, time_added_ms
):
"""Parse a 'Server Keys' structure from the result of a /key request
@ -422,10 +472,6 @@ class BaseV2KeyFetcher(object):
time_added_ms (int): the timestamp to record in server_keys_json
requested_ids (iterable[str]): a list of the key IDs that were requested.
We will store the json for these key ids as well as any that are
actually in the response
Returns:
Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
"""
@ -481,11 +527,6 @@ class BaseV2KeyFetcher(object):
signed_key_json_bytes = encode_canonical_json(signed_key_json)
# for reasons I don't quite understand, we store this json for the key ids we
# requested, as well as those we got.
updated_key_ids = set(requested_ids)
updated_key_ids.update(verify_keys)
yield logcontext.make_deferred_yieldable(
defer.gatherResults(
[
@ -498,7 +539,7 @@ class BaseV2KeyFetcher(object):
ts_expires_ms=ts_valid_until_ms,
key_json_bytes=signed_key_json_bytes,
)
for key_id in updated_key_ids
for key_id in verify_keys
],
consumeErrors=True,
).addErrback(unwrapFirstError)
@ -517,14 +558,14 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
self.perspective_servers = self.config.perspectives
@defer.inlineCallbacks
def get_keys(self, server_name_and_key_ids):
def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys"""
@defer.inlineCallbacks
def get_key(perspective_name, perspective_keys):
try:
result = yield self.get_server_verify_key_v2_indirect(
server_name_and_key_ids, perspective_name, perspective_keys
keys_to_fetch, perspective_name, perspective_keys
)
defer.returnValue(result)
except KeyLookupError as e:
@ -558,13 +599,15 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
@defer.inlineCallbacks
def get_server_verify_key_v2_indirect(
self, server_names_and_key_ids, perspective_name, perspective_keys
self, keys_to_fetch, perspective_name, perspective_keys
):
"""
Args:
server_names_and_key_ids (iterable[Tuple[str, iterable[str]]]):
list of (server_name, iterable[key_id]) tuples to fetch keys for
keys_to_fetch (dict[str, dict[str, int]]):
the keys to be fetched. server_name -> key_id -> min_valid_ts
perspective_name (str): name of the notary server to query for the keys
perspective_keys (dict[str, VerifyKey]): map of key_id->key for the
notary server
@ -578,12 +621,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
"""
logger.info(
"Requesting keys %s from notary server %s",
server_names_and_key_ids,
keys_to_fetch.items(),
perspective_name,
)
# TODO(mark): Set the minimum_valid_until_ts to that needed by
# the events being validated or the current time if validating
# an incoming request.
try:
query_response = yield self.client.post_json(
destination=perspective_name,
@ -591,9 +632,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
data={
u"server_keys": {
server_name: {
key_id: {u"minimum_valid_until_ts": 0} for key_id in key_ids
key_id: {u"minimum_valid_until_ts": min_valid_ts}
for key_id, min_valid_ts in server_keys.items()
}
for server_name, key_ids in server_names_and_key_ids
for server_name, server_keys in keys_to_fetch.items()
}
},
long_retries=True,
@ -703,15 +745,18 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
self.client = hs.get_http_client()
@defer.inlineCallbacks
def get_keys(self, server_name_and_key_ids):
def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys"""
# TODO make this more resilient
results = yield logcontext.make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.get_server_verify_key_v2_direct, server_name, key_ids
self.get_server_verify_key_v2_direct,
server_name,
server_keys.keys(),
)
for server_name, key_ids in server_name_and_key_ids
for server_name, server_keys in keys_to_fetch.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError)
@ -730,6 +775,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
keys = {} # type: dict[str, FetchKeyResult]
for requested_key_id in key_ids:
# we may have found this key as a side-effect of asking for another.
if requested_key_id in keys:
continue
@ -754,7 +800,6 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
response_keys = yield self.process_v2_response(
from_server=server_name,
requested_ids=[requested_key_id],
response_json=response,
time_added_ms=time_now_ms,
)

View File

@ -265,7 +265,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
]
more_deferreds = keyring.verify_json_objects_for_server([
(p.sender_domain, p.redacted_pdu_json)
(p.sender_domain, p.redacted_pdu_json, 0)
for p in pdus_to_check_sender
])
@ -298,7 +298,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
]
more_deferreds = keyring.verify_json_objects_for_server([
(get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json)
(get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json, 0)
for p in pdus_to_check_event_id
])

View File

@ -94,6 +94,7 @@ class NoAuthenticationError(AuthenticationError):
class Authenticator(object):
def __init__(self, hs):
self._clock = hs.get_clock()
self.keyring = hs.get_keyring()
self.server_name = hs.hostname
self.store = hs.get_datastore()
@ -102,6 +103,7 @@ class Authenticator(object):
# A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks
def authenticate_request(self, request, content):
now = self._clock.time_msec()
json_request = {
"method": request.method.decode('ascii'),
"uri": request.uri.decode('ascii'),
@ -138,7 +140,7 @@ class Authenticator(object):
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
)
yield self.keyring.verify_json_for_server(origin, json_request)
yield self.keyring.verify_json_for_server(origin, json_request, now)
logger.info("Request from %s", origin)
request.authenticated_entity = origin

View File

@ -97,10 +97,11 @@ class GroupAttestationSigning(object):
# TODO: We also want to check that *new* attestations that people give
# us to store are valid for at least a little while.
if valid_until_ms < self.clock.time_msec():
now = self.clock.time_msec()
if valid_until_ms < now:
raise SynapseError(400, "Attestation expired")
yield self.keyring.verify_json_for_server(server_name, attestation)
yield self.keyring.verify_json_for_server(server_name, attestation, now)
def create_attestation(self, group_id, user_id):
"""Create an attestation for the group_id and user_id with default

View File

@ -1,65 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains base REST classes for constructing client v1 servlets.
"""
import logging
import re
from synapse.api.urls import CLIENT_API_PREFIX
from synapse.http.servlet import RestServlet
from synapse.rest.client.transactions import HttpTransactionCache
logger = logging.getLogger(__name__)
def client_path_patterns(path_regex, releases=(0,), include_in_unstable=True):
"""Creates a regex compiled client path with the correct client path
prefix.
Args:
path_regex (str): The regex string to match. This should NOT have a ^
as this will be prefixed.
Returns:
SRE_Pattern
"""
patterns = [re.compile("^" + CLIENT_API_PREFIX + "/api/v1" + path_regex)]
if include_in_unstable:
unstable_prefix = CLIENT_API_PREFIX + "/unstable"
patterns.append(re.compile("^" + unstable_prefix + path_regex))
for release in releases:
new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,)
patterns.append(re.compile("^" + new_prefix + path_regex))
return patterns
class ClientV1RestServlet(RestServlet):
"""A base Synapse REST Servlet for the client version 1 API.
"""
# This subclass was presumably created to allow the auth for the v1
# protocol version to be different, however this behaviour was removed.
# it may no longer be necessary
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
self.hs = hs
self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs)

View File

@ -19,11 +19,10 @@ import logging
from twisted.internet import defer
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import parse_json_object_from_request
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.types import RoomAlias
from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__)
@ -33,13 +32,14 @@ def register_servlets(hs, http_server):
ClientAppserviceDirectoryListServer(hs).register(http_server)
class ClientDirectoryServer(ClientV1RestServlet):
PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$")
class ClientDirectoryServer(RestServlet):
PATTERNS = client_patterns("/directory/room/(?P<room_alias>[^/]*)$", v1=True)
def __init__(self, hs):
super(ClientDirectoryServer, self).__init__(hs)
super(ClientDirectoryServer, self).__init__()
self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, room_alias):
@ -120,13 +120,14 @@ class ClientDirectoryServer(ClientV1RestServlet):
defer.returnValue((200, {}))
class ClientDirectoryListServer(ClientV1RestServlet):
PATTERNS = client_path_patterns("/directory/list/room/(?P<room_id>[^/]*)$")
class ClientDirectoryListServer(RestServlet):
PATTERNS = client_patterns("/directory/list/room/(?P<room_id>[^/]*)$", v1=True)
def __init__(self, hs):
super(ClientDirectoryListServer, self).__init__(hs)
super(ClientDirectoryListServer, self).__init__()
self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
@ -162,15 +163,16 @@ class ClientDirectoryListServer(ClientV1RestServlet):
defer.returnValue((200, {}))
class ClientAppserviceDirectoryListServer(ClientV1RestServlet):
PATTERNS = client_path_patterns(
"/directory/list/appservice/(?P<network_id>[^/]*)/(?P<room_id>[^/]*)$"
class ClientAppserviceDirectoryListServer(RestServlet):
PATTERNS = client_patterns(
"/directory/list/appservice/(?P<network_id>[^/]*)/(?P<room_id>[^/]*)$", v1=True
)
def __init__(self, hs):
super(ClientAppserviceDirectoryListServer, self).__init__(hs)
super(ClientAppserviceDirectoryListServer, self).__init__()
self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
def on_PUT(self, request, network_id, room_id):
content = parse_json_object_from_request(request)

View File

@ -19,21 +19,22 @@ import logging
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.streams.config import PaginationConfig
from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__)
class EventStreamRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/events$")
class EventStreamRestServlet(RestServlet):
PATTERNS = client_patterns("/events$", v1=True)
DEFAULT_LONGPOLL_TIME_MS = 30000
def __init__(self, hs):
super(EventStreamRestServlet, self).__init__(hs)
super(EventStreamRestServlet, self).__init__()
self.event_stream_handler = hs.get_event_stream_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request):
@ -76,11 +77,11 @@ class EventStreamRestServlet(ClientV1RestServlet):
# TODO: Unit test gets, with and without auth, with different kinds of events.
class EventRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/events/(?P<event_id>[^/]*)$")
class EventRestServlet(RestServlet):
PATTERNS = client_patterns("/events/(?P<event_id>[^/]*)$", v1=True)
def __init__(self, hs):
super(EventRestServlet, self).__init__(hs)
super(EventRestServlet, self).__init__()
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()

View File

@ -15,19 +15,19 @@
from twisted.internet import defer
from synapse.http.servlet import parse_boolean
from synapse.http.servlet import RestServlet, parse_boolean
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.streams.config import PaginationConfig
from .base import ClientV1RestServlet, client_path_patterns
# TODO: Needs unit testing
class InitialSyncRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/initialSync$")
class InitialSyncRestServlet(RestServlet):
PATTERNS = client_patterns("/initialSync$", v1=True)
def __init__(self, hs):
super(InitialSyncRestServlet, self).__init__(hs)
super(InitialSyncRestServlet, self).__init__()
self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request):

View File

@ -29,12 +29,11 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import UserID, map_username_to_mxid_localpart
from synapse.util.msisdn import phone_number_to_msisdn
from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__)
@ -81,15 +80,16 @@ def login_id_thirdparty_from_phone(identifier):
}
class LoginRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login$")
class LoginRestServlet(RestServlet):
PATTERNS = client_patterns("/login$", v1=True)
CAS_TYPE = "m.login.cas"
SSO_TYPE = "m.login.sso"
TOKEN_TYPE = "m.login.token"
JWT_TYPE = "m.login.jwt"
def __init__(self, hs):
super(LoginRestServlet, self).__init__(hs)
super(LoginRestServlet, self).__init__()
self.hs = hs
self.jwt_enabled = hs.config.jwt_enabled
self.jwt_secret = hs.config.jwt_secret
self.jwt_algorithm = hs.config.jwt_algorithm
@ -371,7 +371,7 @@ class LoginRestServlet(ClientV1RestServlet):
class CasRedirectServlet(RestServlet):
PATTERNS = client_path_patterns("/login/(cas|sso)/redirect")
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
def __init__(self, hs):
super(CasRedirectServlet, self).__init__()
@ -394,27 +394,27 @@ class CasRedirectServlet(RestServlet):
finish_request(request)
class CasTicketServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login/cas/ticket")
class CasTicketServlet(RestServlet):
PATTERNS = client_patterns("/login/cas/ticket", v1=True)
def __init__(self, hs):
super(CasTicketServlet, self).__init__(hs)
super(CasTicketServlet, self).__init__()
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
self.cas_required_attributes = hs.config.cas_required_attributes
self._sso_auth_handler = SSOAuthHandler(hs)
self._http_client = hs.get_simple_http_client()
@defer.inlineCallbacks
def on_GET(self, request):
client_redirect_url = parse_string(request, "redirectUrl", required=True)
http_client = self.hs.get_simple_http_client()
uri = self.cas_server_url + "/proxyValidate"
args = {
"ticket": parse_string(request, "ticket", required=True),
"service": self.cas_service_url
}
try:
body = yield http_client.get_raw(uri, args)
body = yield self._http_client.get_raw(uri, args)
except PartialDownloadError as pde:
# Twisted raises this error if the connection is closed,
# even if that's being used old-http style to signal end-of-data

View File

@ -17,17 +17,18 @@ import logging
from twisted.internet import defer
from .base import ClientV1RestServlet, client_path_patterns
from synapse.http.servlet import RestServlet
from synapse.rest.client.v2_alpha._base import client_patterns
logger = logging.getLogger(__name__)
class LogoutRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/logout$")
class LogoutRestServlet(RestServlet):
PATTERNS = client_patterns("/logout$", v1=True)
def __init__(self, hs):
super(LogoutRestServlet, self).__init__(hs)
self._auth = hs.get_auth()
super(LogoutRestServlet, self).__init__()
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
@ -41,7 +42,7 @@ class LogoutRestServlet(ClientV1RestServlet):
if requester.device_id is None:
# the acccess token wasn't associated with a device.
# Just delete the access token
access_token = self._auth.get_access_token_from_request(request)
access_token = self.auth.get_access_token_from_request(request)
yield self._auth_handler.delete_access_token(access_token)
else:
yield self._device_handler.delete_device(
@ -50,11 +51,11 @@ class LogoutRestServlet(ClientV1RestServlet):
defer.returnValue((200, {}))
class LogoutAllRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/logout/all$")
class LogoutAllRestServlet(RestServlet):
PATTERNS = client_patterns("/logout/all$", v1=True)
def __init__(self, hs):
super(LogoutAllRestServlet, self).__init__(hs)
super(LogoutAllRestServlet, self).__init__()
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()

View File

@ -23,21 +23,22 @@ from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError
from synapse.handlers.presence import format_user_presence_state
from synapse.http.servlet import parse_json_object_from_request
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.types import UserID
from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__)
class PresenceStatusRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/presence/(?P<user_id>[^/]*)/status")
class PresenceStatusRestServlet(RestServlet):
PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status", v1=True)
def __init__(self, hs):
super(PresenceStatusRestServlet, self).__init__(hs)
super(PresenceStatusRestServlet, self).__init__()
self.hs = hs
self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()
self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, user_id):

View File

@ -16,18 +16,19 @@
""" This module contains REST servlets to do with profile: /profile/<paths> """
from twisted.internet import defer
from synapse.http.servlet import parse_json_object_from_request
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.types import UserID
from .base import ClientV1RestServlet, client_path_patterns
class ProfileDisplaynameRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname")
class ProfileDisplaynameRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/displayname", v1=True)
def __init__(self, hs):
super(ProfileDisplaynameRestServlet, self).__init__(hs)
super(ProfileDisplaynameRestServlet, self).__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
@ -71,12 +72,14 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
return (200, {})
class ProfileAvatarURLRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url")
class ProfileAvatarURLRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True)
def __init__(self, hs):
super(ProfileAvatarURLRestServlet, self).__init__(hs)
super(ProfileAvatarURLRestServlet, self).__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
@ -119,12 +122,14 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
return (200, {})
class ProfileRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)")
class ProfileRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True)
def __init__(self, hs):
super(ProfileRestServlet, self).__init__(hs)
super(ProfileRestServlet, self).__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, user_id):

View File

@ -21,22 +21,22 @@ from synapse.api.errors import (
SynapseError,
UnrecognizedRequestError,
)
from synapse.http.servlet import parse_json_value_from_request, parse_string
from synapse.http.servlet import RestServlet, parse_json_value_from_request, parse_string
from synapse.push.baserules import BASE_RULE_IDS
from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import PRIORITY_CLASS_MAP
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from .base import ClientV1RestServlet, client_path_patterns
class PushRuleRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/(?P<path>pushrules/.*)$")
class PushRuleRestServlet(RestServlet):
PATTERNS = client_patterns("/(?P<path>pushrules/.*)$", v1=True)
SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
"Unrecognised request: You probably wanted a trailing slash")
def __init__(self, hs):
super(PushRuleRestServlet, self).__init__(hs)
super(PushRuleRestServlet, self).__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker_app is not None

View File

@ -26,17 +26,18 @@ from synapse.http.servlet import (
parse_string,
)
from synapse.push import PusherConfigException
from .base import ClientV1RestServlet, client_path_patterns
from synapse.rest.client.v2_alpha._base import client_patterns
logger = logging.getLogger(__name__)
class PushersRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/pushers$")
class PushersRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers$", v1=True)
def __init__(self, hs):
super(PushersRestServlet, self).__init__(hs)
super(PushersRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request):
@ -69,11 +70,13 @@ class PushersRestServlet(ClientV1RestServlet):
return 200, {}
class PushersSetRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/pushers/set$")
class PushersSetRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers/set$", v1=True)
def __init__(self, hs):
super(PushersSetRestServlet, self).__init__(hs)
super(PushersSetRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
self.pusher_pool = self.hs.get_pusherpool()
@ -141,7 +144,7 @@ class PushersRemoveRestServlet(RestServlet):
"""
To allow pusher to be delete by clicking a link (ie. GET request)
"""
PATTERNS = client_path_patterns("/pushers/remove$")
PATTERNS = client_patterns("/pushers/remove$", v1=True)
SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>"
def __init__(self, hs):

View File

@ -28,37 +28,45 @@ from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.api.filtering import Filter
from synapse.events.utils import format_event_for_client_v2
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_integer,
parse_json_object_from_request,
parse_string,
)
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__)
class RoomCreateRestServlet(ClientV1RestServlet):
class TransactionRestServlet(RestServlet):
def __init__(self, hs):
super(TransactionRestServlet, self).__init__()
self.txns = HttpTransactionCache(hs)
class RoomCreateRestServlet(TransactionRestServlet):
# No PATTERN; we have custom dispatch rules here
def __init__(self, hs):
super(RoomCreateRestServlet, self).__init__(hs)
self._room_creation_handler = hs.get_room_creation_handler()
self.auth = hs.get_auth()
def register(self, http_server):
PATTERNS = "/createRoom"
register_txn_path(self, PATTERNS, http_server)
# define CORS for all of /rooms in RoomCreateRestServlet for simplicity
http_server.register_paths("OPTIONS",
client_path_patterns("/rooms(?:/.*)?$"),
client_patterns("/rooms(?:/.*)?$", v1=True),
self.on_OPTIONS)
# define CORS for /createRoom[/txnid]
http_server.register_paths("OPTIONS",
client_path_patterns("/createRoom(?:/.*)?$"),
client_patterns("/createRoom(?:/.*)?$", v1=True),
self.on_OPTIONS)
def on_PUT(self, request, txn_id):
@ -85,13 +93,14 @@ class RoomCreateRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(ClientV1RestServlet):
class RoomStateEventRestServlet(TransactionRestServlet):
def __init__(self, hs):
super(RoomStateEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
def register(self, http_server):
# /room/$roomid/state/$eventtype
@ -102,16 +111,16 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
"(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$")
http_server.register_paths("GET",
client_path_patterns(state_key),
client_patterns(state_key, v1=True),
self.on_GET)
http_server.register_paths("PUT",
client_path_patterns(state_key),
client_patterns(state_key, v1=True),
self.on_PUT)
http_server.register_paths("GET",
client_path_patterns(no_state_key),
client_patterns(no_state_key, v1=True),
self.on_GET_no_state_key)
http_server.register_paths("PUT",
client_path_patterns(no_state_key),
client_patterns(no_state_key, v1=True),
self.on_PUT_no_state_key)
def on_GET_no_state_key(self, request, room_id, event_type):
@ -185,11 +194,12 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for generic events + feedback
class RoomSendEventRestServlet(ClientV1RestServlet):
class RoomSendEventRestServlet(TransactionRestServlet):
def __init__(self, hs):
super(RoomSendEventRestServlet, self).__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth()
def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id]
@ -229,10 +239,11 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(ClientV1RestServlet):
class JoinRoomAliasServlet(TransactionRestServlet):
def __init__(self, hs):
super(JoinRoomAliasServlet, self).__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
def register(self, http_server):
# /join/$room_identifier[/$txn_id]
@ -291,8 +302,13 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
# TODO: Needs unit testing
class PublicRoomListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/publicRooms$")
class PublicRoomListRestServlet(TransactionRestServlet):
PATTERNS = client_patterns("/publicRooms$", v1=True)
def __init__(self, hs):
super(PublicRoomListRestServlet, self).__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request):
@ -382,12 +398,13 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing
class RoomMemberListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$")
class RoomMemberListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/members$", v1=True)
def __init__(self, hs):
super(RoomMemberListRestServlet, self).__init__(hs)
super(RoomMemberListRestServlet, self).__init__()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
@ -436,12 +453,13 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
# deprecated in favour of /members?membership=join?
# except it does custom AS logic and has a simpler return format
class JoinedRoomMemberListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$")
class JoinedRoomMemberListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$", v1=True)
def __init__(self, hs):
super(JoinedRoomMemberListRestServlet, self).__init__(hs)
super(JoinedRoomMemberListRestServlet, self).__init__()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
@ -457,12 +475,13 @@ class JoinedRoomMemberListRestServlet(ClientV1RestServlet):
# TODO: Needs better unit testing
class RoomMessageListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$")
class RoomMessageListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/messages$", v1=True)
def __init__(self, hs):
super(RoomMessageListRestServlet, self).__init__(hs)
super(RoomMessageListRestServlet, self).__init__()
self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
@ -491,12 +510,13 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing
class RoomStateRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$")
class RoomStateRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/state$", v1=True)
def __init__(self, hs):
super(RoomStateRestServlet, self).__init__(hs)
super(RoomStateRestServlet, self).__init__()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
@ -511,12 +531,13 @@ class RoomStateRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing
class RoomInitialSyncRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$")
class RoomInitialSyncRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$", v1=True)
def __init__(self, hs):
super(RoomInitialSyncRestServlet, self).__init__(hs)
super(RoomInitialSyncRestServlet, self).__init__()
self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
@ -530,16 +551,17 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
defer.returnValue((200, content))
class RoomEventServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns(
"/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$"
class RoomEventServlet(RestServlet):
PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$", v1=True
)
def __init__(self, hs):
super(RoomEventServlet, self).__init__(hs)
super(RoomEventServlet, self).__init__()
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_id):
@ -554,16 +576,17 @@ class RoomEventServlet(ClientV1RestServlet):
defer.returnValue((404, "Event not found."))
class RoomEventContextServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns(
"/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$"
class RoomEventContextServlet(RestServlet):
PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$", v1=True
)
def __init__(self, hs):
super(RoomEventContextServlet, self).__init__(hs)
super(RoomEventContextServlet, self).__init__()
self.clock = hs.get_clock()
self.room_context_handler = hs.get_room_context_handler()
self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_id):
@ -609,10 +632,11 @@ class RoomEventContextServlet(ClientV1RestServlet):
defer.returnValue((200, results))
class RoomForgetRestServlet(ClientV1RestServlet):
class RoomForgetRestServlet(TransactionRestServlet):
def __init__(self, hs):
super(RoomForgetRestServlet, self).__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
@ -639,11 +663,12 @@ class RoomForgetRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing
class RoomMembershipRestServlet(ClientV1RestServlet):
class RoomMembershipRestServlet(TransactionRestServlet):
def __init__(self, hs):
super(RoomMembershipRestServlet, self).__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
def register(self, http_server):
# /rooms/$roomid/[invite|join|leave]
@ -722,11 +747,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
)
class RoomRedactEventRestServlet(ClientV1RestServlet):
class RoomRedactEventRestServlet(TransactionRestServlet):
def __init__(self, hs):
super(RoomRedactEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth()
def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
@ -757,15 +783,16 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
)
class RoomTypingRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns(
"/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$"
class RoomTypingRestServlet(RestServlet):
PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$", v1=True
)
def __init__(self, hs):
super(RoomTypingRestServlet, self).__init__(hs)
super(RoomTypingRestServlet, self).__init__()
self.presence_handler = hs.get_presence_handler()
self.typing_handler = hs.get_typing_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_PUT(self, request, room_id, user_id):
@ -798,14 +825,13 @@ class RoomTypingRestServlet(ClientV1RestServlet):
defer.returnValue((200, {}))
class SearchRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns(
"/search$"
)
class SearchRestServlet(RestServlet):
PATTERNS = client_patterns("/search$", v1=True)
def __init__(self, hs):
super(SearchRestServlet, self).__init__(hs)
super(SearchRestServlet, self).__init__()
self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_POST(self, request):
@ -823,12 +849,13 @@ class SearchRestServlet(ClientV1RestServlet):
defer.returnValue((200, results))
class JoinedRoomsRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/joined_rooms$")
class JoinedRoomsRestServlet(RestServlet):
PATTERNS = client_patterns("/joined_rooms$", v1=True)
def __init__(self, hs):
super(JoinedRoomsRestServlet, self).__init__(hs)
super(JoinedRoomsRestServlet, self).__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request):
@ -853,18 +880,18 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
"""
http_server.register_paths(
"POST",
client_path_patterns(regex_string + "$"),
client_patterns(regex_string + "$", v1=True),
servlet.on_POST
)
http_server.register_paths(
"PUT",
client_path_patterns(regex_string + "/(?P<txn_id>[^/]*)$"),
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
servlet.on_PUT
)
if with_get:
http_server.register_paths(
"GET",
client_path_patterns(regex_string + "/(?P<txn_id>[^/]*)$"),
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
servlet.on_GET
)

View File

@ -19,11 +19,17 @@ import hmac
from twisted.internet import defer
from .base import ClientV1RestServlet, client_path_patterns
from synapse.http.servlet import RestServlet
from synapse.rest.client.v2_alpha._base import client_patterns
class VoipRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/voip/turnServer$")
class VoipRestServlet(RestServlet):
PATTERNS = client_patterns("/voip/turnServer$", v1=True)
def __init__(self, hs):
super(VoipRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request):

View File

@ -26,8 +26,7 @@ from synapse.api.urls import CLIENT_API_PREFIX
logger = logging.getLogger(__name__)
def client_v2_patterns(path_regex, releases=(0,),
unstable=True):
def client_patterns(path_regex, releases=(0,), unstable=True, v1=False):
"""Creates a regex compiled client path with the correct client path
prefix.
@ -41,6 +40,9 @@ def client_v2_patterns(path_regex, releases=(0,),
if unstable:
unstable_prefix = CLIENT_API_PREFIX + "/unstable"
patterns.append(re.compile("^" + unstable_prefix + path_regex))
if v1:
v1_prefix = CLIENT_API_PREFIX + "/api/v1"
patterns.append(re.compile("^" + v1_prefix + path_regex))
for release in releases:
new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,)
patterns.append(re.compile("^" + new_prefix + path_regex))

View File

@ -30,13 +30,13 @@ from synapse.http.servlet import (
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import check_3pid_allowed
from ._base import client_v2_patterns, interactive_auth_handler
from ._base import client_patterns, interactive_auth_handler
logger = logging.getLogger(__name__)
class EmailPasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password/email/requestToken$")
PATTERNS = client_patterns("/account/password/email/requestToken$")
def __init__(self, hs):
super(EmailPasswordRequestTokenRestServlet, self).__init__()
@ -70,7 +70,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
class MsisdnPasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password/msisdn/requestToken$")
PATTERNS = client_patterns("/account/password/msisdn/requestToken$")
def __init__(self, hs):
super(MsisdnPasswordRequestTokenRestServlet, self).__init__()
@ -108,7 +108,7 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
class PasswordRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password$")
PATTERNS = client_patterns("/account/password$")
def __init__(self, hs):
super(PasswordRestServlet, self).__init__()
@ -180,7 +180,7 @@ class PasswordRestServlet(RestServlet):
class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/deactivate$")
PATTERNS = client_patterns("/account/deactivate$")
def __init__(self, hs):
super(DeactivateAccountRestServlet, self).__init__()
@ -228,7 +228,7 @@ class DeactivateAccountRestServlet(RestServlet):
class EmailThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$")
PATTERNS = client_patterns("/account/3pid/email/requestToken$")
def __init__(self, hs):
self.hs = hs
@ -263,7 +263,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
class MsisdnThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/msisdn/requestToken$")
PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$")
def __init__(self, hs):
self.hs = hs
@ -300,7 +300,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
class ThreepidRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid$")
PATTERNS = client_patterns("/account/3pid$")
def __init__(self, hs):
super(ThreepidRestServlet, self).__init__()
@ -364,7 +364,7 @@ class ThreepidRestServlet(RestServlet):
class ThreepidDeleteRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/delete$")
PATTERNS = client_patterns("/account/3pid/delete$")
def __init__(self, hs):
super(ThreepidDeleteRestServlet, self).__init__()
@ -401,7 +401,7 @@ class ThreepidDeleteRestServlet(RestServlet):
class WhoamiRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/whoami$")
PATTERNS = client_patterns("/account/whoami$")
def __init__(self, hs):
super(WhoamiRestServlet, self).__init__()

View File

@ -20,7 +20,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_v2_patterns
from ._base import client_patterns
logger = logging.getLogger(__name__)
@ -30,7 +30,7 @@ class AccountDataServlet(RestServlet):
PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1
GET /user/{user_id}/account_data/{account_dataType} HTTP/1.1
"""
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)"
)
@ -79,7 +79,7 @@ class RoomAccountDataServlet(RestServlet):
PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1
GET /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1
"""
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)"
"/rooms/(?P<room_id>[^/]*)"
"/account_data/(?P<account_data_type>[^/]*)"

View File

@ -21,13 +21,13 @@ from synapse.api.errors import AuthError, SynapseError
from synapse.http.server import finish_request
from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns
from ._base import client_patterns
logger = logging.getLogger(__name__)
class AccountValidityRenewServlet(RestServlet):
PATTERNS = client_v2_patterns("/account_validity/renew$")
PATTERNS = client_patterns("/account_validity/renew$")
SUCCESS_HTML = b"<html><body>Your account has been successfully renewed.</body><html>"
def __init__(self, hs):
@ -60,7 +60,7 @@ class AccountValidityRenewServlet(RestServlet):
class AccountValiditySendMailServlet(RestServlet):
PATTERNS = client_v2_patterns("/account_validity/send_mail$")
PATTERNS = client_patterns("/account_validity/send_mail$")
def __init__(self, hs):
"""

View File

@ -23,7 +23,7 @@ from synapse.api.urls import CLIENT_API_PREFIX
from synapse.http.server import finish_request
from synapse.http.servlet import RestServlet, parse_string
from ._base import client_v2_patterns
from ._base import client_patterns
logger = logging.getLogger(__name__)
@ -122,7 +122,7 @@ class AuthRestServlet(RestServlet):
cannot be handled in the normal flow (with requests to the same endpoint).
Current use is for web fallback auth.
"""
PATTERNS = client_v2_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
def __init__(self, hs):
super(AuthRestServlet, self).__init__()

View File

@ -19,7 +19,7 @@ from twisted.internet import defer
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns
from ._base import client_patterns
logger = logging.getLogger(__name__)
@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
class CapabilitiesRestServlet(RestServlet):
"""End point to expose the capabilities of the server."""
PATTERNS = client_v2_patterns("/capabilities$")
PATTERNS = client_patterns("/capabilities$")
def __init__(self, hs):
"""

View File

@ -24,13 +24,13 @@ from synapse.http.servlet import (
parse_json_object_from_request,
)
from ._base import client_v2_patterns, interactive_auth_handler
from ._base import client_patterns, interactive_auth_handler
logger = logging.getLogger(__name__)
class DevicesRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/devices$")
PATTERNS = client_patterns("/devices$")
def __init__(self, hs):
"""
@ -56,7 +56,7 @@ class DeleteDevicesRestServlet(RestServlet):
API for bulk deletion of devices. Accepts a JSON object with a devices
key which lists the device_ids to delete. Requires user interactive auth.
"""
PATTERNS = client_v2_patterns("/delete_devices")
PATTERNS = client_patterns("/delete_devices")
def __init__(self, hs):
super(DeleteDevicesRestServlet, self).__init__()
@ -95,7 +95,7 @@ class DeleteDevicesRestServlet(RestServlet):
class DeviceRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$")
PATTERNS = client_patterns("/devices/(?P<device_id>[^/]*)$")
def __init__(self, hs):
"""

View File

@ -21,13 +21,13 @@ from synapse.api.errors import AuthError, Codes, StoreError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID
from ._base import client_v2_patterns, set_timeline_upper_limit
from ._base import client_patterns, set_timeline_upper_limit
logger = logging.getLogger(__name__)
class GetFilterRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
def __init__(self, hs):
super(GetFilterRestServlet, self).__init__()
@ -63,7 +63,7 @@ class GetFilterRestServlet(RestServlet):
class CreateFilterRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/user/(?P<user_id>[^/]*)/filter")
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter")
def __init__(self, hs):
super(CreateFilterRestServlet, self).__init__()

View File

@ -21,7 +21,7 @@ from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import GroupID
from ._base import client_v2_patterns
from ._base import client_patterns
logger = logging.getLogger(__name__)
@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class GroupServlet(RestServlet):
"""Get the group profile
"""
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/profile$")
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$")
def __init__(self, hs):
super(GroupServlet, self).__init__()
@ -65,7 +65,7 @@ class GroupServlet(RestServlet):
class GroupSummaryServlet(RestServlet):
"""Get the full group summary
"""
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/summary$")
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$")
def __init__(self, hs):
super(GroupSummaryServlet, self).__init__()
@ -93,7 +93,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
- /groups/:group/summary/rooms/:room_id
- /groups/:group/summary/categories/:category/rooms/:room_id
"""
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/summary"
"(/categories/(?P<category_id>[^/]+))?"
"/rooms/(?P<room_id>[^/]*)$"
@ -137,7 +137,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
class GroupCategoryServlet(RestServlet):
"""Get/add/update/delete a group category
"""
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
)
@ -189,7 +189,7 @@ class GroupCategoryServlet(RestServlet):
class GroupCategoriesServlet(RestServlet):
"""Get all group categories
"""
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/categories/$"
)
@ -214,7 +214,7 @@ class GroupCategoriesServlet(RestServlet):
class GroupRoleServlet(RestServlet):
"""Get/add/update/delete a group role
"""
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$"
)
@ -266,7 +266,7 @@ class GroupRoleServlet(RestServlet):
class GroupRolesServlet(RestServlet):
"""Get all group roles
"""
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/roles/$"
)
@ -295,7 +295,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
- /groups/:group/summary/users/:room_id
- /groups/:group/summary/roles/:role/users/:user_id
"""
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/summary"
"(/roles/(?P<role_id>[^/]+))?"
"/users/(?P<user_id>[^/]*)$"
@ -339,7 +339,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
class GroupRoomServlet(RestServlet):
"""Get all rooms in a group
"""
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
def __init__(self, hs):
super(GroupRoomServlet, self).__init__()
@ -360,7 +360,7 @@ class GroupRoomServlet(RestServlet):
class GroupUsersServlet(RestServlet):
"""Get all users in a group
"""
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/users$")
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$")
def __init__(self, hs):
super(GroupUsersServlet, self).__init__()
@ -381,7 +381,7 @@ class GroupUsersServlet(RestServlet):
class GroupInvitedUsersServlet(RestServlet):
"""Get users invited to a group
"""
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
def __init__(self, hs):
super(GroupInvitedUsersServlet, self).__init__()
@ -405,7 +405,7 @@ class GroupInvitedUsersServlet(RestServlet):
class GroupSettingJoinPolicyServlet(RestServlet):
"""Set group join policy
"""
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
def __init__(self, hs):
super(GroupSettingJoinPolicyServlet, self).__init__()
@ -431,7 +431,7 @@ class GroupSettingJoinPolicyServlet(RestServlet):
class GroupCreateServlet(RestServlet):
"""Create a group
"""
PATTERNS = client_v2_patterns("/create_group$")
PATTERNS = client_patterns("/create_group$")
def __init__(self, hs):
super(GroupCreateServlet, self).__init__()
@ -462,7 +462,7 @@ class GroupCreateServlet(RestServlet):
class GroupAdminRoomsServlet(RestServlet):
"""Add a room to the group
"""
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$"
)
@ -499,7 +499,7 @@ class GroupAdminRoomsServlet(RestServlet):
class GroupAdminRoomsConfigServlet(RestServlet):
"""Update the config of a room in a group
"""
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)"
"/config/(?P<config_key>[^/]*)$"
)
@ -526,7 +526,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
class GroupAdminUsersInviteServlet(RestServlet):
"""Invite a user to the group
"""
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$"
)
@ -555,7 +555,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
class GroupAdminUsersKickServlet(RestServlet):
"""Kick a user from the group
"""
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$"
)
@ -581,7 +581,7 @@ class GroupAdminUsersKickServlet(RestServlet):
class GroupSelfLeaveServlet(RestServlet):
"""Leave a joined group
"""
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/self/leave$"
)
@ -607,7 +607,7 @@ class GroupSelfLeaveServlet(RestServlet):
class GroupSelfJoinServlet(RestServlet):
"""Attempt to join a group, or knock
"""
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/self/join$"
)
@ -633,7 +633,7 @@ class GroupSelfJoinServlet(RestServlet):
class GroupSelfAcceptInviteServlet(RestServlet):
"""Accept a group invite
"""
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/self/accept_invite$"
)
@ -659,7 +659,7 @@ class GroupSelfAcceptInviteServlet(RestServlet):
class GroupSelfUpdatePublicityServlet(RestServlet):
"""Update whether we publicise a users membership of a group
"""
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/self/update_publicity$"
)
@ -686,7 +686,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
class PublicisedGroupsForUserServlet(RestServlet):
"""Get the list of groups a user is advertising
"""
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
"/publicised_groups/(?P<user_id>[^/]*)$"
)
@ -711,7 +711,7 @@ class PublicisedGroupsForUserServlet(RestServlet):
class PublicisedGroupsForUsersServlet(RestServlet):
"""Get the list of groups a user is advertising
"""
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
"/publicised_groups$"
)
@ -739,7 +739,7 @@ class PublicisedGroupsForUsersServlet(RestServlet):
class GroupsForUserServlet(RestServlet):
"""Get all groups the logged in user is joined to
"""
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
"/joined_groups$"
)

View File

@ -26,7 +26,7 @@ from synapse.http.servlet import (
)
from synapse.types import StreamToken
from ._base import client_v2_patterns
from ._base import client_patterns
logger = logging.getLogger(__name__)
@ -56,7 +56,7 @@ class KeyUploadServlet(RestServlet):
},
}
"""
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
def __init__(self, hs):
"""
@ -130,7 +130,7 @@ class KeyQueryServlet(RestServlet):
} } } } } }
"""
PATTERNS = client_v2_patterns("/keys/query$")
PATTERNS = client_patterns("/keys/query$")
def __init__(self, hs):
"""
@ -159,7 +159,7 @@ class KeyChangesServlet(RestServlet):
200 OK
{ "changed": ["@foo:example.com"] }
"""
PATTERNS = client_v2_patterns("/keys/changes$")
PATTERNS = client_patterns("/keys/changes$")
def __init__(self, hs):
"""
@ -209,7 +209,7 @@ class OneTimeKeyServlet(RestServlet):
} } } }
"""
PATTERNS = client_v2_patterns("/keys/claim$")
PATTERNS = client_patterns("/keys/claim$")
def __init__(self, hs):
super(OneTimeKeyServlet, self).__init__()

View File

@ -20,13 +20,13 @@ from twisted.internet import defer
from synapse.events.utils import format_event_for_client_v2_without_room_id
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from ._base import client_v2_patterns
from ._base import client_patterns
logger = logging.getLogger(__name__)
class NotificationsServlet(RestServlet):
PATTERNS = client_v2_patterns("/notifications$")
PATTERNS = client_patterns("/notifications$")
def __init__(self, hs):
super(NotificationsServlet, self).__init__()

View File

@ -22,7 +22,7 @@ from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.util.stringutils import random_string
from ._base import client_v2_patterns
from ._base import client_patterns
logger = logging.getLogger(__name__)
@ -56,7 +56,7 @@ class IdTokenServlet(RestServlet):
"expires_in": 3600,
}
"""
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/openid/request_token"
)

View File

@ -19,13 +19,13 @@ from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_v2_patterns
from ._base import client_patterns
logger = logging.getLogger(__name__)
class ReadMarkerRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$")
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$")
def __init__(self, hs):
super(ReadMarkerRestServlet, self).__init__()

View File

@ -20,13 +20,13 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns
from ._base import client_patterns
logger = logging.getLogger(__name__)
class ReceiptRestServlet(RestServlet):
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)"
"/receipt/(?P<receipt_type>[^/]*)"
"/(?P<event_id>[^/]*)$"

View File

@ -43,7 +43,7 @@ from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.threepids import check_3pid_allowed
from ._base import client_v2_patterns, interactive_auth_handler
from ._base import client_patterns, interactive_auth_handler
# We ought to be using hmac.compare_digest() but on older pythons it doesn't
# exist. It's a _really minor_ security flaw to use plain string comparison
@ -60,7 +60,7 @@ logger = logging.getLogger(__name__)
class EmailRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register/email/requestToken$")
PATTERNS = client_patterns("/register/email/requestToken$")
def __init__(self, hs):
"""
@ -98,7 +98,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
class MsisdnRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register/msisdn/requestToken$")
PATTERNS = client_patterns("/register/msisdn/requestToken$")
def __init__(self, hs):
"""
@ -142,7 +142,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
class UsernameAvailabilityRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register/available")
PATTERNS = client_patterns("/register/available")
def __init__(self, hs):
"""
@ -182,7 +182,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
class RegisterRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register$")
PATTERNS = client_patterns("/register$")
def __init__(self, hs):
"""

View File

@ -34,7 +34,7 @@ from synapse.http.servlet import (
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.storage.relations import AggregationPaginationToken, RelationPaginationToken
from ._base import client_v2_patterns
from ._base import client_patterns
logger = logging.getLogger(__name__)
@ -66,12 +66,12 @@ class RelationSendServlet(RestServlet):
def register(self, http_server):
http_server.register_paths(
"POST",
client_v2_patterns(self.PATTERN + "$", releases=()),
client_patterns(self.PATTERN + "$", releases=()),
self.on_PUT_or_POST,
)
http_server.register_paths(
"PUT",
client_v2_patterns(self.PATTERN + "/(?P<txn_id>[^/]*)$", releases=()),
client_patterns(self.PATTERN + "/(?P<txn_id>[^/]*)$", releases=()),
self.on_PUT,
)
@ -120,7 +120,7 @@ class RelationPaginationServlet(RestServlet):
filtered by relation type and event type.
"""
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/relations/(?P<parent_id>[^/]*)"
"(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$",
releases=(),
@ -197,7 +197,7 @@ class RelationAggregationPaginationServlet(RestServlet):
}
"""
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)"
"(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$",
releases=(),
@ -269,7 +269,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
}
"""
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)"
"/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)/(?P<key>[^/]*)$",
releases=(),

View File

@ -27,13 +27,13 @@ from synapse.http.servlet import (
parse_json_object_from_request,
)
from ._base import client_v2_patterns
from ._base import client_patterns
logger = logging.getLogger(__name__)
class ReportEventRestServlet(RestServlet):
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$"
)

View File

@ -24,13 +24,13 @@ from synapse.http.servlet import (
parse_string,
)
from ._base import client_v2_patterns
from ._base import client_patterns
logger = logging.getLogger(__name__)
class RoomKeysServlet(RestServlet):
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
"/room_keys/keys(/(?P<room_id>[^/]+))?(/(?P<session_id>[^/]+))?$"
)
@ -256,7 +256,7 @@ class RoomKeysServlet(RestServlet):
class RoomKeysNewVersionServlet(RestServlet):
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
"/room_keys/version$"
)
@ -314,7 +314,7 @@ class RoomKeysNewVersionServlet(RestServlet):
class RoomKeysVersionServlet(RestServlet):
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
"/room_keys/version(/(?P<version>[^/]+))?$"
)

View File

@ -25,7 +25,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
)
from ._base import client_v2_patterns
from ._base import client_patterns
logger = logging.getLogger(__name__)
@ -47,7 +47,7 @@ class RoomUpgradeRestServlet(RestServlet):
Args:
hs (synapse.server.HomeServer):
"""
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
# /rooms/$roomid/upgrade
"/rooms/(?P<room_id>[^/]*)/upgrade$",
)

View File

@ -21,13 +21,13 @@ from synapse.http import servlet
from synapse.http.servlet import parse_json_object_from_request
from synapse.rest.client.transactions import HttpTransactionCache
from ._base import client_v2_patterns
from ._base import client_patterns
logger = logging.getLogger(__name__)
class SendToDeviceRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
"/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$",
)

View File

@ -32,7 +32,7 @@ from synapse.handlers.sync import SyncConfig
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.types import StreamToken
from ._base import client_v2_patterns, set_timeline_upper_limit
from ._base import client_patterns, set_timeline_upper_limit
logger = logging.getLogger(__name__)
@ -73,7 +73,7 @@ class SyncRestServlet(RestServlet):
}
"""
PATTERNS = client_v2_patterns("/sync$")
PATTERNS = client_patterns("/sync$")
ALLOWED_PRESENCE = set(["online", "offline", "unavailable"])
def __init__(self, hs):

View File

@ -20,7 +20,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_v2_patterns
from ._base import client_patterns
logger = logging.getLogger(__name__)
@ -29,7 +29,7 @@ class TagListServlet(RestServlet):
"""
GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1
"""
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags"
)
@ -54,7 +54,7 @@ class TagServlet(RestServlet):
PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1
DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1
"""
PATTERNS = client_v2_patterns(
PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)"
)

View File

@ -21,13 +21,13 @@ from twisted.internet import defer
from synapse.api.constants import ThirdPartyEntityKind
from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns
from ._base import client_patterns
logger = logging.getLogger(__name__)
class ThirdPartyProtocolsServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/protocols")
PATTERNS = client_patterns("/thirdparty/protocols")
def __init__(self, hs):
super(ThirdPartyProtocolsServlet, self).__init__()
@ -44,7 +44,7 @@ class ThirdPartyProtocolsServlet(RestServlet):
class ThirdPartyProtocolServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$")
PATTERNS = client_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$")
def __init__(self, hs):
super(ThirdPartyProtocolServlet, self).__init__()
@ -66,7 +66,7 @@ class ThirdPartyProtocolServlet(RestServlet):
class ThirdPartyUserServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$")
PATTERNS = client_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$")
def __init__(self, hs):
super(ThirdPartyUserServlet, self).__init__()
@ -89,7 +89,7 @@ class ThirdPartyUserServlet(RestServlet):
class ThirdPartyLocationServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$")
PATTERNS = client_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$")
def __init__(self, hs):
super(ThirdPartyLocationServlet, self).__init__()

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns
from ._base import client_patterns
class TokenRefreshRestServlet(RestServlet):
@ -26,7 +26,7 @@ class TokenRefreshRestServlet(RestServlet):
Exchanges refresh tokens for a pair of an access token and a new refresh
token.
"""
PATTERNS = client_v2_patterns("/tokenrefresh")
PATTERNS = client_patterns("/tokenrefresh")
def __init__(self, hs):
super(TokenRefreshRestServlet, self).__init__()

View File

@ -20,13 +20,13 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_v2_patterns
from ._base import client_patterns
logger = logging.getLogger(__name__)
class UserDirectorySearchRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/user_directory/search$")
PATTERNS = client_patterns("/user_directory/search$")
def __init__(self, hs):
"""

View File

@ -21,4 +21,4 @@ import tests.patch_inline_callbacks
# attempt to do the patch before we load any synapse code
tests.patch_inline_callbacks.do_patch()
util.DEFAULT_TIMEOUT_DURATION = 10
util.DEFAULT_TIMEOUT_DURATION = 20

View File

@ -19,6 +19,7 @@ from mock import Mock
import canonicaljson
import signedjson.key
import signedjson.sign
from signedjson.key import get_verify_key
from twisted.internet import defer
@ -137,7 +138,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
context_11.request = "11"
res_deferreds = kr.verify_json_objects_for_server(
[("server10", json1), ("server11", {})]
[("server10", json1, 0), ("server11", {}, 0)]
)
# the unsigned json should be rejected pretty quickly
@ -174,7 +175,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.http_client.post_json.return_value = defer.Deferred()
res_deferreds_2 = kr.verify_json_objects_for_server(
[("server10", json1)]
[("server10", json1, 0)]
)
res_deferreds_2[0].addBoth(self.check_context, None)
yield logcontext.make_deferred_yieldable(res_deferreds_2[0])
@ -197,31 +198,108 @@ class KeyringTestCase(unittest.HomeserverTestCase):
kr = keyring.Keyring(self.hs)
key1 = signedjson.key.generate_signing_key(1)
key1_id = "%s:%s" % (key1.alg, key1.version)
r = self.hs.datastore.store_server_verify_keys(
"server9",
time.time() * 1000,
[
(
"server9",
key1_id,
FetchKeyResult(signedjson.key.get_verify_key(key1), 1000),
),
],
[("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
)
self.get_success(r)
json1 = {}
signedjson.sign.sign_json(json1, "server9", key1)
# should fail immediately on an unsigned object
d = _verify_json_for_server(kr, "server9", {})
d = _verify_json_for_server(kr, "server9", {}, 0)
self.failureResultOf(d, SynapseError)
d = _verify_json_for_server(kr, "server9", json1)
self.assertFalse(d.called)
# should suceed on a signed object
d = _verify_json_for_server(kr, "server9", json1, 500)
# self.assertFalse(d.called)
self.get_success(d)
def test_verify_json_dedupes_key_requests(self):
"""Two requests for the same key should be deduped."""
key1 = signedjson.key.generate_signing_key(1)
def get_keys(keys_to_fetch):
# there should only be one request object (with the max validity)
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
return defer.succeed(
{
"server1": {
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
}
}
)
mock_fetcher = keyring.KeyFetcher()
mock_fetcher.get_keys = Mock(side_effect=get_keys)
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
json1 = {}
signedjson.sign.sign_json(json1, "server1", key1)
# the first request should succeed; the second should fail because the key
# has expired
results = kr.verify_json_objects_for_server(
[("server1", json1, 500), ("server1", json1, 1500)]
)
self.assertEqual(len(results), 2)
self.get_success(results[0])
e = self.get_failure(results[1], SynapseError).value
self.assertEqual(e.errcode, "M_UNAUTHORIZED")
self.assertEqual(e.code, 401)
# there should have been a single call to the fetcher
mock_fetcher.get_keys.assert_called_once()
def test_verify_json_falls_back_to_other_fetchers(self):
"""If the first fetcher cannot provide a recent enough key, we fall back"""
key1 = signedjson.key.generate_signing_key(1)
def get_keys1(keys_to_fetch):
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
return defer.succeed(
{
"server1": {
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)
}
}
)
def get_keys2(keys_to_fetch):
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
return defer.succeed(
{
"server1": {
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
}
}
)
mock_fetcher1 = keyring.KeyFetcher()
mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
mock_fetcher2 = keyring.KeyFetcher()
mock_fetcher2.get_keys = Mock(side_effect=get_keys2)
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2))
json1 = {}
signedjson.sign.sign_json(json1, "server1", key1)
results = kr.verify_json_objects_for_server(
[("server1", json1, 1200), ("server1", json1, 1500)]
)
self.assertEqual(len(results), 2)
self.get_success(results[0])
e = self.get_failure(results[1], SynapseError).value
self.assertEqual(e.errcode, "M_UNAUTHORIZED")
self.assertEqual(e.code, 401)
# there should have been a single call to each fetcher
mock_fetcher1.get_keys.assert_called_once()
mock_fetcher2.get_keys.assert_called_once()
class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
@ -260,8 +338,8 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
self.http_client.get_json.side_effect = get_json
server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
keys = self.get_success(fetcher.get_keys(server_name_and_key_ids))
keys_to_fetch = {SERVER_NAME: {"key1": 0}}
keys = self.get_success(fetcher.get_keys(keys_to_fetch))
k = keys[SERVER_NAME][testverifykey_id]
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
self.assertEqual(k.verify_key, testverifykey)
@ -288,9 +366,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
# change the server name: it should cause a rejection
response["server_name"] = "OTHER_SERVER"
self.get_failure(
fetcher.get_keys(server_name_and_key_ids), KeyLookupError
)
self.get_failure(fetcher.get_keys(keys_to_fetch), KeyLookupError)
class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
@ -342,8 +418,8 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.http_client.post_json.side_effect = post_json
server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
keys = self.get_success(fetcher.get_keys(server_name_and_key_ids))
keys_to_fetch = {SERVER_NAME: {"key1": 0}}
keys = self.get_success(fetcher.get_keys(keys_to_fetch))
self.assertIn(SERVER_NAME, keys)
k = keys[SERVER_NAME][testverifykey_id]
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
@ -401,7 +477,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
def get_key_from_perspectives(response):
fetcher = PerspectivesKeyFetcher(self.hs)
server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
keys_to_fetch = {SERVER_NAME: {"key1": 0}}
def post_json(destination, path, data, **kwargs):
self.assertEqual(destination, self.mock_perspective_server.server_name)
@ -410,9 +486,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.http_client.post_json.side_effect = post_json
return self.get_success(
fetcher.get_keys(server_name_and_key_ids)
)
return self.get_success(fetcher.get_keys(keys_to_fetch))
# start with a valid response so we can check we are testing the right thing
response = build_response()
@ -435,6 +509,11 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig")
def get_key_id(key):
"""Get the matrix ID tag for a given SigningKey or VerifyKey"""
return "%s:%s" % (key.alg, key.version)
@defer.inlineCallbacks
def run_in_context(f, *args, **kwargs):
with LoggingContext("testctx") as ctx:
@ -445,14 +524,16 @@ def run_in_context(f, *args, **kwargs):
defer.returnValue(rv)
def _verify_json_for_server(keyring, server_name, json_object):
def _verify_json_for_server(keyring, server_name, json_object, validity_time):
"""thin wrapper around verify_json_for_server which makes sure it is wrapped
with the patched defer.inlineCallbacks.
"""
@defer.inlineCallbacks
def v():
rv1 = yield keyring.verify_json_for_server(server_name, json_object)
rv1 = yield keyring.verify_json_for_server(
server_name, json_object, validity_time
)
defer.returnValue(rv1)
return run_in_context(v)

View File

@ -408,7 +408,6 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
users_in_room = self.get_success(self.store.get_users_in_room(room_id))
self.assertEqual([], users_in_room)
@unittest.DEBUG
def test_shutdown_room_block_peek(self):
"""Test that a world_readable room can no longer be peeked into after
it has been shut down.

View File

@ -30,7 +30,7 @@ from tests import unittest
from ....utils import MockHttpResource, setup_test_homeserver
myid = "@1234ABCD:test"
PATH_PREFIX = "/_matrix/client/api/v1"
PATH_PREFIX = "/_matrix/client/r0"
class MockHandlerProfileTestCase(unittest.TestCase):