diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 1b3b55d51..69b339273 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -603,10 +603,12 @@ class Auth(object): """ # Can optionally look elsewhere in the request (e.g. headers) try: - user_id = yield self._get_appservice_user_id(request) + user_id, app_service = yield self._get_appservice_user_id(request) if user_id: request.authenticated_entity = user_id - defer.returnValue(synapse.types.create_requester(user_id)) + defer.returnValue( + synapse.types.create_requester(user_id, app_service=app_service) + ) access_token = get_access_token_from_request( request, self.TOKEN_NOT_FOUND_HTTP_STATUS @@ -644,7 +646,8 @@ class Auth(object): request.authenticated_entity = user.to_string() defer.returnValue(synapse.types.create_requester( - user, token_id, is_guest, device_id)) + user, token_id, is_guest, device_id, app_service=app_service) + ) except KeyError: raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", @@ -659,14 +662,14 @@ class Auth(object): ) ) if app_service is None: - defer.returnValue(None) + defer.returnValue((None, None)) if "user_id" not in request.args: - defer.returnValue(app_service.sender) + defer.returnValue((app_service.sender, app_service)) user_id = request.args["user_id"][0] if app_service.sender == user_id: - defer.returnValue(app_service.sender) + defer.returnValue((app_service.sender, app_service)) if not app_service.is_interested_in_user(user_id): raise AuthError( @@ -678,7 +681,7 @@ class Auth(object): 403, "Application service has not registered this user" ) - defer.returnValue(user_id) + defer.returnValue((user_id, app_service)) @defer.inlineCallbacks def get_user_by_access_token(self, token, rights="access"): @@ -1167,7 +1170,8 @@ def has_access_token(request): bool: False if no access_token was given, True otherwise. """ query_params = request.args.get("access_token") - return bool(query_params) + auth_headers = request.requestHeaders.getRawHeaders("Authorization") + return bool(query_params) or bool(auth_headers) def get_access_token_from_request(request, token_not_found_http_status=401): @@ -1185,13 +1189,40 @@ def get_access_token_from_request(request, token_not_found_http_status=401): Raises: AuthError: If there isn't an access_token in the request. """ - query_params = request.args.get("access_token") - # Try to get the access_token from the query params. - if not query_params: - raise AuthError( - token_not_found_http_status, - "Missing access token.", - errcode=Codes.MISSING_TOKEN - ) - return query_params[0] + auth_headers = request.requestHeaders.getRawHeaders("Authorization") + query_params = request.args.get("access_token") + if auth_headers: + # Try the get the access_token from a "Authorization: Bearer" + # header + if query_params is not None: + raise AuthError( + token_not_found_http_status, + "Mixing Authorization headers and access_token query parameters.", + errcode=Codes.MISSING_TOKEN, + ) + if len(auth_headers) > 1: + raise AuthError( + token_not_found_http_status, + "Too many Authorization headers.", + errcode=Codes.MISSING_TOKEN, + ) + parts = auth_headers[0].split(" ") + if parts[0] == "Bearer" and len(parts) == 2: + return parts[1] + else: + raise AuthError( + token_not_found_http_status, + "Invalid Authorization header.", + errcode=Codes.MISSING_TOKEN, + ) + else: + # Try to get the access_token from the query params. + if not query_params: + raise AuthError( + token_not_found_http_status, + "Missing access token.", + errcode=Codes.MISSING_TOKEN + ) + + return query_params[0] diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 660dfb56e..06cc8d90b 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -23,7 +23,7 @@ class Ratelimiter(object): def __init__(self): self.message_counts = collections.OrderedDict() - def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count): + def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count, update=True): """Can the user send a message? Args: user_id: The user sending a message. @@ -32,12 +32,15 @@ class Ratelimiter(object): second. burst_count: How many messages the user can send before being limited. + update (bool): Whether to update the message rates or not. This is + useful to check if a message would be allowed to be sent before + its ready to be actually sent. Returns: A pair of a bool indicating if they can send a message now and a time in seconds of when they can next send a message. """ self.prune_message_counts(time_now_s) - message_count, time_start, _ignored = self.message_counts.pop( + message_count, time_start, _ignored = self.message_counts.get( user_id, (0., time_now_s, None), ) time_delta = time_now_s - time_start @@ -52,9 +55,10 @@ class Ratelimiter(object): allowed = True message_count += 1 - self.message_counts[user_id] = ( - message_count, time_start, msg_rate_hz - ) + if update: + self.message_counts[user_id] = ( + message_count, time_start, msg_rate_hz + ) if msg_rate_hz > 0: time_allowed = ( diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 54f35900f..f27150d41 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -52,6 +52,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory from synapse.util.logcontext import LoggingContext from synapse.metrics import register_memory_metrics, get_metrics_for +from synapse.metrics.process_collector import register_process_collector from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX from synapse.federation.transport.server import TransportLayerServer @@ -337,6 +338,7 @@ def setup(config_options): hs.get_replication_layer().start_get_pdu_cache() register_memory_metrics(hs) + register_process_collector() reactor.callWhenRunning(start) diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 126a10efb..91471f7e8 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -81,7 +81,7 @@ class ApplicationService(object): NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS] def __init__(self, token, url=None, namespaces=None, hs_token=None, - sender=None, id=None, protocols=None): + sender=None, id=None, protocols=None, rate_limited=True): self.token = token self.url = url self.hs_token = hs_token @@ -95,6 +95,8 @@ class ApplicationService(object): else: self.protocols = set() + self.rate_limited = rate_limited + def _check_namespaces(self, namespaces): # Sanity check that it is of the form: # { @@ -234,5 +236,8 @@ class ApplicationService(object): def is_exclusive_room(self, room_id): return self._is_exclusive(ApplicationService.NS_ROOMS, room_id) + def is_rate_limited(self): + return self.rate_limited + def __str__(self): return "ApplicationService: %s" % (self.__dict__,) diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index d7537e8d4..82c50b824 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -110,6 +110,11 @@ def _load_appservice(hostname, as_info, config_filename): user = UserID(localpart, hostname) user_id = user.to_string() + # Rate limiting for users of this AS is on by default (excludes sender) + rate_limited = True + if isinstance(as_info.get("rate_limited"), bool): + rate_limited = as_info.get("rate_limited") + # namespace checks if not isinstance(as_info.get("namespaces"), dict): raise KeyError("Requires 'namespaces' object.") @@ -155,4 +160,5 @@ def _load_appservice(hostname, as_info, config_filename): sender=user_id, id=as_info["id"], protocols=protocols, + rate_limited=rate_limited ) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 498164316..90f96209f 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -57,10 +57,16 @@ class BaseHandler(object): time_now = self.clock.time() user_id = requester.user.to_string() + # The AS user itself is never rate limited. app_service = self.store.get_app_service_by_user_id(user_id) if app_service is not None: return # do not ratelimit app service senders + # Disable rate limiting of users belonging to any AS that is configured + # not to be rate limited in its registration file (rate_limited: true|false). + if requester.app_service and not requester.app_service.is_rate_limited(): + return + allowed, time_allowed = self.ratelimiter.send_message( user_id, time_now, msg_rate_hz=self.hs.config.rc_messages_per_second, diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index dc0fe60e1..363552123 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -611,6 +611,18 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def add_threepid(self, user_id, medium, address, validated_at): + # 'Canonicalise' email addresses down to lower case. + # We've now moving towards the Home Server being the entity that + # is responsible for validating threepids used for resetting passwords + # on accounts, so in future Synapse will gain knowledge of specific + # types (mediums) of threepid. For now, we still use the existing + # infrastructure, but this is the start of synapse gaining knowledge + # of specific types of threepid (and fixes the fact that checking + # for the presenc eof an email address during password reset was + # case sensitive). + if medium == 'email': + address = address.lower() + yield self.store.user_add_threepid( user_id, medium, address, validated_at, self.hs.get_clock().time_msec() diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 30ea9630f..abfa8c65a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -16,7 +16,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import AuthError, Codes, SynapseError +from synapse.api.errors import AuthError, Codes, SynapseError, LimitExceededError from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator @@ -82,8 +82,8 @@ class MessageHandler(BaseHandler): room_token = pagin_config.from_token.room_key else: pagin_config.from_token = ( - yield self.hs.get_event_sources().get_current_token( - direction='b' + yield self.hs.get_event_sources().get_current_token_for_room( + room_id=room_id ) ) room_token = pagin_config.from_token.room_key @@ -239,6 +239,21 @@ class MessageHandler(BaseHandler): "Tried to send member event through non-member codepath" ) + # We check here if we are currently being rate limited, so that we + # don't do unnecessary work. We check again just before we actually + # send the event. + time_now = self.clock.time() + allowed, time_allowed = self.ratelimiter.send_message( + event.sender, time_now, + msg_rate_hz=self.hs.config.rc_messages_per_second, + burst_count=self.hs.config.rc_message_burst_count, + update=False, + ) + if not allowed: + raise LimitExceededError( + retry_after_ms=int(1000 * (time_allowed - time_now)), + ) + user = UserID.from_string(event.sender) assert self.hs.is_mine(user), "User must be our own: %s" % (user,) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index a7f533f7b..59e4d1cd1 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -475,8 +475,11 @@ class RoomEventSource(object): defer.returnValue((events, end_key)) - def get_current_key(self, direction='f'): - return self.store.get_room_events_max_id(direction) + def get_current_key(self): + return self.store.get_room_events_max_id() + + def get_current_key_for_room(self, room_id): + return self.store.get_room_events_max_id(room_id) @defer.inlineCallbacks def get_pagination_rows(self, user, config, key): diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 08313417b..27ee715ff 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -88,7 +88,7 @@ class TypingHandler(object): continue until = self._member_typing_until.get(member, None) - if not until or until < now: + if not until or until <= now: logger.info("Timing out typing for: %s", member.user_id) preserve_fn(self._stopped_typing)(member) continue @@ -97,12 +97,20 @@ class TypingHandler(object): # user. if self.hs.is_mine_id(member.user_id): last_fed_poke = self._member_last_federation_poke.get(member, None) - if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL < now: + if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now: preserve_fn(self._push_remote)( member=member, typing=True ) + # Add a paranoia timer to ensure that we always have a timer for + # each person typing. + self.wheel_timer.insert( + now=now, + obj=member, + then=now + 60 * 1000, + ) + def is_typing(self, member): return member.user_id in self._room_typing.get(member.room_id, []) diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 76d5998d7..a6b868775 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -13,14 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Because otherwise 'resource' collides with synapse.metrics.resource -from __future__ import absolute_import - import logging -from resource import getrusage, RUSAGE_SELF import functools -import os -import stat import time import gc @@ -36,6 +30,7 @@ logger = logging.getLogger(__name__) all_metrics = [] +all_collectors = [] class Metrics(object): @@ -46,6 +41,9 @@ class Metrics(object): def __init__(self, name): self.name_prefix = name + def register_collector(self, func): + all_collectors.append(func) + def _register(self, metric_class, name, *args, **kwargs): full_name = "%s_%s" % (self.name_prefix, name) @@ -94,8 +92,8 @@ def get_metrics_for(pkg_name): def render_all(): strs = [] - # TODO(paul): Internal hack - update_resource_metrics() + for collector in all_collectors: + collector() for metric in all_metrics: try: @@ -109,62 +107,6 @@ def render_all(): return "\n".join(strs) -# Now register some standard process-wide state metrics, to give indications of -# process resource usage - -rusage = None - - -def update_resource_metrics(): - global rusage - rusage = getrusage(RUSAGE_SELF) - -resource_metrics = get_metrics_for("process.resource") - -# msecs -resource_metrics.register_callback("utime", lambda: rusage.ru_utime * 1000) -resource_metrics.register_callback("stime", lambda: rusage.ru_stime * 1000) - -# kilobytes -resource_metrics.register_callback("maxrss", lambda: rusage.ru_maxrss * 1024) - -TYPES = { - stat.S_IFSOCK: "SOCK", - stat.S_IFLNK: "LNK", - stat.S_IFREG: "REG", - stat.S_IFBLK: "BLK", - stat.S_IFDIR: "DIR", - stat.S_IFCHR: "CHR", - stat.S_IFIFO: "FIFO", -} - - -def _process_fds(): - counts = {(k,): 0 for k in TYPES.values()} - counts[("other",)] = 0 - - # Not every OS will have a /proc/self/fd directory - if not os.path.exists("/proc/self/fd"): - return counts - - for fd in os.listdir("/proc/self/fd"): - try: - s = os.stat("/proc/self/fd/%s" % (fd)) - fmt = stat.S_IFMT(s.st_mode) - if fmt in TYPES: - t = TYPES[fmt] - else: - t = "other" - - counts[(t,)] += 1 - except OSError: - # the dirh itself used by listdir() is usually missing by now - pass - - return counts - -get_metrics_for("process").register_callback("fds", _process_fds, labels=["type"]) - reactor_metrics = get_metrics_for("reactor") tick_time = reactor_metrics.register_distribution("tick_time") pending_calls_metric = reactor_metrics.register_distribution("pending_calls") diff --git a/synapse/metrics/metric.py b/synapse/metrics/metric.py index e81af2989..e87b2b80a 100644 --- a/synapse/metrics/metric.py +++ b/synapse/metrics/metric.py @@ -98,9 +98,9 @@ class CallbackMetric(BaseMetric): value = self.callback() if self.is_scalar(): - return ["%s %d" % (self.name, value)] + return ["%s %.12g" % (self.name, value)] - return ["%s%s %d" % (self.name, self._render_key(k), value[k]) + return ["%s%s %.12g" % (self.name, self._render_key(k), value[k]) for k in sorted(value.keys())] diff --git a/synapse/metrics/process_collector.py b/synapse/metrics/process_collector.py new file mode 100644 index 000000000..1c851d923 --- /dev/null +++ b/synapse/metrics/process_collector.py @@ -0,0 +1,181 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Because otherwise 'resource' collides with synapse.metrics.resource +from __future__ import absolute_import + +import os +import stat +from resource import getrusage, RUSAGE_SELF + +from synapse.metrics import get_metrics_for + + +TICKS_PER_SEC = 100 +BYTES_PER_PAGE = 4096 + +HAVE_PROC_STAT = os.path.exists("/proc/stat") +HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat") +HAVE_PROC_SELF_LIMITS = os.path.exists("/proc/self/limits") +HAVE_PROC_SELF_FD = os.path.exists("/proc/self/fd") + +TYPES = { + stat.S_IFSOCK: "SOCK", + stat.S_IFLNK: "LNK", + stat.S_IFREG: "REG", + stat.S_IFBLK: "BLK", + stat.S_IFDIR: "DIR", + stat.S_IFCHR: "CHR", + stat.S_IFIFO: "FIFO", +} + +# Field indexes from /proc/self/stat, taken from the proc(5) manpage +STAT_FIELDS = { + "utime": 14, + "stime": 15, + "starttime": 22, + "vsize": 23, + "rss": 24, +} + + +rusage = None +stats = {} +fd_counts = None + +# In order to report process_start_time_seconds we need to know the +# machine's boot time, because the value in /proc/self/stat is relative to +# this +boot_time = None +if HAVE_PROC_STAT: + with open("/proc/stat") as _procstat: + for line in _procstat: + if line.startswith("btime "): + boot_time = int(line.split()[1]) + + +def update_resource_metrics(): + global rusage + rusage = getrusage(RUSAGE_SELF) + + if HAVE_PROC_SELF_STAT: + global stats + with open("/proc/self/stat") as s: + line = s.read() + # line is PID (command) more stats go here ... + raw_stats = line.split(") ", 1)[1].split(" ") + + for (name, index) in STAT_FIELDS.iteritems(): + # subtract 3 from the index, because proc(5) is 1-based, and + # we've lost the first two fields in PID and COMMAND above + stats[name] = int(raw_stats[index - 3]) + + global fd_counts + fd_counts = _process_fds() + + +def _process_fds(): + counts = {(k,): 0 for k in TYPES.values()} + counts[("other",)] = 0 + + # Not every OS will have a /proc/self/fd directory + if not HAVE_PROC_SELF_FD: + return counts + + for fd in os.listdir("/proc/self/fd"): + try: + s = os.stat("/proc/self/fd/%s" % (fd)) + fmt = stat.S_IFMT(s.st_mode) + if fmt in TYPES: + t = TYPES[fmt] + else: + t = "other" + + counts[(t,)] += 1 + except OSError: + # the dirh itself used by listdir() is usually missing by now + pass + + return counts + + +def register_process_collector(): + # Legacy synapse-invented metric names + + resource_metrics = get_metrics_for("process.resource") + + resource_metrics.register_collector(update_resource_metrics) + + # msecs + resource_metrics.register_callback("utime", lambda: rusage.ru_utime * 1000) + resource_metrics.register_callback("stime", lambda: rusage.ru_stime * 1000) + + # kilobytes + resource_metrics.register_callback("maxrss", lambda: rusage.ru_maxrss * 1024) + + get_metrics_for("process").register_callback("fds", _process_fds, labels=["type"]) + + # New prometheus-standard metric names + + process_metrics = get_metrics_for("process") + + if HAVE_PROC_SELF_STAT: + process_metrics.register_callback( + "cpu_user_seconds_total", + lambda: float(stats["utime"]) / TICKS_PER_SEC + ) + process_metrics.register_callback( + "cpu_system_seconds_total", + lambda: float(stats["stime"]) / TICKS_PER_SEC + ) + process_metrics.register_callback( + "cpu_seconds_total", + lambda: (float(stats["utime"] + stats["stime"])) / TICKS_PER_SEC + ) + + process_metrics.register_callback( + "virtual_memory_bytes", + lambda: int(stats["vsize"]) + ) + process_metrics.register_callback( + "resident_memory_bytes", + lambda: int(stats["rss"]) * BYTES_PER_PAGE + ) + + process_metrics.register_callback( + "start_time_seconds", + lambda: boot_time + int(stats["starttime"]) / TICKS_PER_SEC + ) + + if HAVE_PROC_SELF_FD: + process_metrics.register_callback( + "open_fds", + lambda: sum(fd_counts.values()) + ) + + if HAVE_PROC_SELF_LIMITS: + def _get_max_fds(): + with open("/proc/self/limits") as limits: + for line in limits: + if not line.startswith("Max open files "): + continue + # Line is Max open files $SOFT $HARD + return int(line.split()[3]) + return None + + process_metrics.register_callback( + "max_fds", + lambda: _get_max_fds() + ) diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 6600c9cd5..2eb325c7c 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -150,6 +150,10 @@ class EmailPusher(object): soonest_due_at = None + if not unprocessed: + yield self.save_last_stream_ordering_and_success(self.max_stream_ordering) + return + for push_action in unprocessed: received_at = push_action['received_ts'] if received_at is None: diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 3b63c19ec..53551632b 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -372,7 +372,7 @@ class Mailer(object): state_event_id = room_state_ids[room_id][ ("m.room.member", event.sender) ] - state_event = yield self.get_event(state_event_id) + state_event = yield self.store.get_event(state_event_id) sender_name = name_from_member_event(state_event) if sender_name is not None and room_name is not None: diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index 510f8b2c7..b4084fec6 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -15,7 +15,7 @@ from twisted.internet import defer -from synapse.api.errors import AuthError, SynapseError +from synapse.api.errors import AuthError, SynapseError, StoreError, Codes from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import UserID @@ -45,7 +45,7 @@ class GetFilterRestServlet(RestServlet): raise AuthError(403, "Cannot get filters for other users") if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only get filters for local users") + raise AuthError(403, "Can only get filters for local users") try: filter_id = int(filter_id) @@ -59,8 +59,8 @@ class GetFilterRestServlet(RestServlet): ) defer.returnValue((200, filter.get_filter_json())) - except KeyError: - raise SynapseError(400, "No such filter") + except (KeyError, StoreError): + raise SynapseError(400, "No such filter", errcode=Codes.NOT_FOUND) class CreateFilterRestServlet(RestServlet): @@ -74,6 +74,7 @@ class CreateFilterRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, user_id): + target_user = UserID.from_string(user_id) requester = yield self.auth.get_user_by_req(request) @@ -81,10 +82,9 @@ class CreateFilterRestServlet(RestServlet): raise AuthError(403, "Cannot create filters for other users") if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only create filters for local users") + raise AuthError(403, "Can only create filters for local users") content = parse_json_object_from_request(request) - filter_id = yield self.filtering.add_user_filter( user_localpart=target_user.localpart, user_filter=content, diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 49fa8614f..d828d6ee1 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -85,7 +85,6 @@ class LoggingTransaction(object): sql_logger.debug("[SQL] {%s} %s", self.name, sql) sql = self.database_engine.convert_param_style(sql) - if args: try: sql_logger.debug( diff --git a/synapse/storage/schema/delta/37/user_threepids.sql b/synapse/storage/schema/delta/37/user_threepids.sql new file mode 100644 index 000000000..ef8813e72 --- /dev/null +++ b/synapse/storage/schema/delta/37/user_threepids.sql @@ -0,0 +1,23 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Update any email addresses that were stored with mixed case into all + * lowercase + */ +UPDATE user_threepids SET address = LOWER(address) where medium = 'email'; + +/* Add an index for the select we do on passwored reset */ +CREATE INDEX user_threepids_medium_address on user_threepids (medium, address); diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 07ea969d4..888b1cb35 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -521,13 +521,20 @@ class StreamStore(SQLBaseStore): ) @defer.inlineCallbacks - def get_room_events_max_id(self, direction='f'): + def get_room_events_max_id(self, room_id=None): + """Returns the current token for rooms stream. + + By default, it returns the current global stream token. Specifying a + `room_id` causes it to return the current room specific topological + token. + """ token = yield self._stream_id_gen.get_current_token() - if direction != 'b': + if room_id is None: defer.returnValue("s%d" % (token,)) else: topo = yield self.runInteraction( - "_get_max_topological_txn", self._get_max_topological_txn + "_get_max_topological_txn", self._get_max_topological_txn, + room_id, ) defer.returnValue("t%d-%d" % (topo, token)) @@ -579,11 +586,11 @@ class StreamStore(SQLBaseStore): lambda r: r[0][0] if r else 0 ) - def _get_max_topological_txn(self, txn): + def _get_max_topological_txn(self, txn, room_id): txn.execute( "SELECT MAX(topological_ordering) FROM events" - " WHERE outlier = ?", - (False,) + " WHERE room_id = ?", + (room_id,) ) rows = txn.fetchall() diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 6bf21d6f5..4d44c3d4c 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -41,13 +41,39 @@ class EventSources(object): self.store = hs.get_datastore() @defer.inlineCallbacks - def get_current_token(self, direction='f'): + def get_current_token(self): push_rules_key, _ = self.store.get_push_rules_stream_token() to_device_key = self.store.get_to_device_stream_token() token = StreamToken( room_key=( - yield self.sources["room"].get_current_key(direction) + yield self.sources["room"].get_current_key() + ), + presence_key=( + yield self.sources["presence"].get_current_key() + ), + typing_key=( + yield self.sources["typing"].get_current_key() + ), + receipt_key=( + yield self.sources["receipt"].get_current_key() + ), + account_data_key=( + yield self.sources["account_data"].get_current_key() + ), + push_rules_key=push_rules_key, + to_device_key=to_device_key, + ) + defer.returnValue(token) + + @defer.inlineCallbacks + def get_current_token_for_room(self, room_id): + push_rules_key, _ = self.store.get_push_rules_stream_token() + to_device_key = self.store.get_to_device_stream_token() + + token = StreamToken( + room_key=( + yield self.sources["room"].get_current_key_for_room(room_id) ), presence_key=( yield self.sources["presence"].get_current_key() diff --git a/synapse/types.py b/synapse/types.py index 1694af125..ffab12df0 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -18,8 +18,9 @@ from synapse.api.errors import SynapseError from collections import namedtuple -Requester = namedtuple("Requester", - ["user", "access_token_id", "is_guest", "device_id"]) +Requester = namedtuple("Requester", [ + "user", "access_token_id", "is_guest", "device_id", "app_service", +]) """ Represents the user making a request @@ -29,11 +30,12 @@ Attributes: request, or None if it came via the appservice API or similar is_guest (bool): True if the user making this request is a guest user device_id (str|None): device_id which was set at authentication time + app_service (ApplicationService|None): the AS requesting on behalf of the user """ def create_requester(user_id, access_token_id=None, is_guest=False, - device_id=None): + device_id=None, app_service=None): """ Create a new ``Requester`` object @@ -43,13 +45,14 @@ def create_requester(user_id, access_token_id=None, is_guest=False, request, or None if it came via the appservice API or similar is_guest (bool): True if the user making this request is a guest user device_id (str|None): device_id which was set at authentication time + app_service (ApplicationService|None): the AS requesting on behalf of the user Returns: Requester """ if not isinstance(user_id, UserID): user_id = UserID.from_string(user_id) - return Requester(user_id, access_token_id, is_guest, device_id) + return Requester(user_id, access_token_id, is_guest, device_id, app_service) def get_domain_from_id(string): diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index e91723ca3..2cf262bb4 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -20,7 +20,7 @@ from mock import Mock from synapse.api.auth import Auth from synapse.api.errors import AuthError from synapse.types import UserID -from tests.utils import setup_test_homeserver +from tests.utils import setup_test_homeserver, mock_getRawHeaders import pymacaroons @@ -51,7 +51,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] - request.requestHeaders.getRawHeaders = Mock(return_value=[""]) + request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request) self.assertEquals(requester.user.to_string(), self.test_user) @@ -61,7 +61,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] - request.requestHeaders.getRawHeaders = Mock(return_value=[""]) + request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) self.failureResultOf(d, AuthError) @@ -74,7 +74,7 @@ class AuthTestCase(unittest.TestCase): self.store.get_user_by_access_token = Mock(return_value=user_info) request = Mock(args={}) - request.requestHeaders.getRawHeaders = Mock(return_value=[""]) + request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) self.failureResultOf(d, AuthError) @@ -86,7 +86,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] - request.requestHeaders.getRawHeaders = Mock(return_value=[""]) + request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request) self.assertEquals(requester.user.to_string(), self.test_user) @@ -96,7 +96,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] - request.requestHeaders.getRawHeaders = Mock(return_value=[""]) + request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) self.failureResultOf(d, AuthError) @@ -106,7 +106,7 @@ class AuthTestCase(unittest.TestCase): self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) - request.requestHeaders.getRawHeaders = Mock(return_value=[""]) + request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) self.failureResultOf(d, AuthError) @@ -121,7 +121,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] request.args["user_id"] = [masquerading_user_id] - request.requestHeaders.getRawHeaders = Mock(return_value=[""]) + request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request) self.assertEquals(requester.user.to_string(), masquerading_user_id) @@ -135,7 +135,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] request.args["user_id"] = [masquerading_user_id] - request.requestHeaders.getRawHeaders = Mock(return_value=[""]) + request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) self.failureResultOf(d, AuthError) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index c3108f518..c718d1f98 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -219,7 +219,8 @@ class TypingNotificationsTestCase(unittest.TestCase): "user_id": self.u_onion.to_string(), "typing": True, } - ) + ), + federation_auth=True, ) self.on_new_event.assert_has_calls([ diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py index 44ba9ff58..a6a4e2ffe 100644 --- a/tests/rest/client/v1/test_register.py +++ b/tests/rest/client/v1/test_register.py @@ -17,6 +17,7 @@ from synapse.rest.client.v1.register import CreateUserRestServlet from twisted.internet import defer from mock import Mock from tests import unittest +from tests.utils import mock_getRawHeaders import json @@ -30,6 +31,7 @@ class CreateUserServletTestCase(unittest.TestCase): path='/_matrix/client/api/v1/createUser' ) self.request.args = {} + self.request.requestHeaders.getRawHeaders = mock_getRawHeaders() self.registration_handler = Mock() diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index d1442aafa..3d27d03cb 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -15,78 +15,125 @@ from twisted.internet import defer -from . import V2AlphaRestTestCase +from tests import unittest from synapse.rest.client.v2_alpha import filter -from synapse.api.errors import StoreError +from synapse.api.errors import Codes + +import synapse.types + +from synapse.types import UserID + +from ....utils import MockHttpResource, setup_test_homeserver + +PATH_PREFIX = "/_matrix/client/v2_alpha" -class FilterTestCase(V2AlphaRestTestCase): +class FilterTestCase(unittest.TestCase): + USER_ID = "@apple:test" + EXAMPLE_FILTER = {"type": ["m.*"]} + EXAMPLE_FILTER_JSON = '{"type": ["m.*"]}' TO_REGISTER = [filter] - def make_datastore_mock(self): - datastore = super(FilterTestCase, self).make_datastore_mock() + @defer.inlineCallbacks + def setUp(self): + self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - self._user_filters = {} + self.hs = yield setup_test_homeserver( + http_client=None, + resource_for_client=self.mock_resource, + resource_for_federation=self.mock_resource, + ) - def add_user_filter(user_localpart, definition): - filters = self._user_filters.setdefault(user_localpart, []) - filter_id = len(filters) - filters.append(definition) - return defer.succeed(filter_id) - datastore.add_user_filter = add_user_filter + self.auth = self.hs.get_auth() - def get_user_filter(user_localpart, filter_id): - if user_localpart not in self._user_filters: - raise StoreError(404, "No user") - filters = self._user_filters[user_localpart] - if filter_id >= len(filters): - raise StoreError(404, "No filter") - return defer.succeed(filters[filter_id]) - datastore.get_user_filter = get_user_filter + def get_user_by_access_token(token=None, allow_guest=False): + return { + "user": UserID.from_string(self.USER_ID), + "token_id": 1, + "is_guest": False, + } - return datastore + def get_user_by_req(request, allow_guest=False, rights="access"): + return synapse.types.create_requester( + UserID.from_string(self.USER_ID), 1, False, None) + + self.auth.get_user_by_access_token = get_user_by_access_token + self.auth.get_user_by_req = get_user_by_req + + self.store = self.hs.get_datastore() + self.filtering = self.hs.get_filtering() + + for r in self.TO_REGISTER: + r.register_servlets(self.hs, self.mock_resource) @defer.inlineCallbacks def test_add_filter(self): (code, response) = yield self.mock_resource.trigger( - "POST", "/user/%s/filter" % (self.USER_ID), '{"type": ["m.*"]}' + "POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON ) self.assertEquals(200, code) self.assertEquals({"filter_id": "0"}, response) + filter = yield self.store.get_user_filter( + user_localpart='apple', + filter_id=0, + ) + self.assertEquals(filter, self.EXAMPLE_FILTER) - self.assertIn("apple", self._user_filters) - self.assertEquals(len(self._user_filters["apple"]), 1) - self.assertEquals({"type": ["m.*"]}, self._user_filters["apple"][0]) + @defer.inlineCallbacks + def test_add_filter_for_other_user(self): + (code, response) = yield self.mock_resource.trigger( + "POST", "/user/%s/filter" % ('@watermelon:test'), self.EXAMPLE_FILTER_JSON + ) + self.assertEquals(403, code) + self.assertEquals(response['errcode'], Codes.FORBIDDEN) + + @defer.inlineCallbacks + def test_add_filter_non_local_user(self): + _is_mine = self.hs.is_mine + self.hs.is_mine = lambda target_user: False + (code, response) = yield self.mock_resource.trigger( + "POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON + ) + self.hs.is_mine = _is_mine + self.assertEquals(403, code) + self.assertEquals(response['errcode'], Codes.FORBIDDEN) @defer.inlineCallbacks def test_get_filter(self): - self._user_filters["apple"] = [ - {"type": ["m.*"]} - ] - + filter_id = yield self.filtering.add_user_filter( + user_localpart='apple', + user_filter=self.EXAMPLE_FILTER + ) (code, response) = yield self.mock_resource.trigger_get( - "/user/%s/filter/0" % (self.USER_ID) + "/user/%s/filter/%s" % (self.USER_ID, filter_id) ) self.assertEquals(200, code) - self.assertEquals({"type": ["m.*"]}, response) + self.assertEquals(self.EXAMPLE_FILTER, response) + @defer.inlineCallbacks + def test_get_filter_non_existant(self): + (code, response) = yield self.mock_resource.trigger_get( + "/user/%s/filter/12382148321" % (self.USER_ID) + ) + self.assertEquals(400, code) + self.assertEquals(response['errcode'], Codes.NOT_FOUND) + + # Currently invalid params do not have an appropriate errcode + # in errors.py + @defer.inlineCallbacks + def test_get_filter_invalid_id(self): + (code, response) = yield self.mock_resource.trigger_get( + "/user/%s/filter/foobar" % (self.USER_ID) + ) + self.assertEquals(400, code) + + # No ID also returns an invalid_id error @defer.inlineCallbacks def test_get_filter_no_id(self): - self._user_filters["apple"] = [ - {"type": ["m.*"]} - ] - (code, response) = yield self.mock_resource.trigger_get( - "/user/%s/filter/2" % (self.USER_ID) + "/user/%s/filter/" % (self.USER_ID) ) - self.assertEquals(404, code) - - @defer.inlineCallbacks - def test_get_filter_no_user(self): - (code, response) = yield self.mock_resource.trigger_get( - "/user/%s/filter/0" % (self.USER_ID) - ) - self.assertEquals(404, code) + self.assertEquals(400, code) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index e9cb416e4..b4a787c43 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -3,6 +3,7 @@ from synapse.api.errors import SynapseError from twisted.internet import defer from mock import Mock from tests import unittest +from tests.utils import mock_getRawHeaders import json @@ -16,6 +17,7 @@ class RegisterRestServletTestCase(unittest.TestCase): path='/_matrix/api/v2_alpha/register' ) self.request.args = {} + self.request.requestHeaders.getRawHeaders = mock_getRawHeaders() self.appservice = None self.auth = Mock(get_appservice_by_req=Mock( diff --git a/tests/utils.py b/tests/utils.py index f74526b6a..5929f1c72 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -116,6 +116,15 @@ def get_mock_call_args(pattern_func, mock_func): return getcallargs(pattern_func, *invoked_args, **invoked_kargs) +def mock_getRawHeaders(headers=None): + headers = headers if headers is not None else {} + + def getRawHeaders(name, default=None): + return headers.get(name, default) + + return getRawHeaders + + # This is a mock /resource/ not an entire server class MockHttpResource(HttpServer): @@ -128,7 +137,7 @@ class MockHttpResource(HttpServer): @patch('twisted.web.http.Request') @defer.inlineCallbacks - def trigger(self, http_method, path, content, mock_request): + def trigger(self, http_method, path, content, mock_request, federation_auth=False): """ Fire an HTTP event. Args: @@ -156,9 +165,10 @@ class MockHttpResource(HttpServer): mock_request.getClientIP.return_value = "-" - mock_request.requestHeaders.getRawHeaders.return_value = [ - "X-Matrix origin=test,key=,sig=" - ] + headers = {} + if federation_auth: + headers["Authorization"] = ["X-Matrix origin=test,key=,sig="] + mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers) # return the right path if the event requires it mock_request.path = path @@ -189,7 +199,7 @@ class MockHttpResource(HttpServer): ) defer.returnValue((code, response)) except CodeMessageException as e: - defer.returnValue((e.code, cs_error(e.msg))) + defer.returnValue((e.code, cs_error(e.msg, code=e.errcode))) raise KeyError("No event can handle %s" % path)