Merge branch 'develop' of github.com:matrix-org/synapse into release-v0.18.2

This commit is contained in:
Erik Johnston 2016-10-27 11:02:30 +01:00
commit 45bdacd9a7
27 changed files with 514 additions and 175 deletions

View File

@ -603,10 +603,12 @@ class Auth(object):
""" """
# Can optionally look elsewhere in the request (e.g. headers) # Can optionally look elsewhere in the request (e.g. headers)
try: try:
user_id = yield self._get_appservice_user_id(request) user_id, app_service = yield self._get_appservice_user_id(request)
if user_id: if user_id:
request.authenticated_entity = user_id request.authenticated_entity = user_id
defer.returnValue(synapse.types.create_requester(user_id)) defer.returnValue(
synapse.types.create_requester(user_id, app_service=app_service)
)
access_token = get_access_token_from_request( access_token = get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS request, self.TOKEN_NOT_FOUND_HTTP_STATUS
@ -644,7 +646,8 @@ class Auth(object):
request.authenticated_entity = user.to_string() request.authenticated_entity = user.to_string()
defer.returnValue(synapse.types.create_requester( defer.returnValue(synapse.types.create_requester(
user, token_id, is_guest, device_id)) user, token_id, is_guest, device_id, app_service=app_service)
)
except KeyError: except KeyError:
raise AuthError( raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
@ -659,14 +662,14 @@ class Auth(object):
) )
) )
if app_service is None: if app_service is None:
defer.returnValue(None) defer.returnValue((None, None))
if "user_id" not in request.args: if "user_id" not in request.args:
defer.returnValue(app_service.sender) defer.returnValue((app_service.sender, app_service))
user_id = request.args["user_id"][0] user_id = request.args["user_id"][0]
if app_service.sender == user_id: if app_service.sender == user_id:
defer.returnValue(app_service.sender) defer.returnValue((app_service.sender, app_service))
if not app_service.is_interested_in_user(user_id): if not app_service.is_interested_in_user(user_id):
raise AuthError( raise AuthError(
@ -678,7 +681,7 @@ class Auth(object):
403, 403,
"Application service has not registered this user" "Application service has not registered this user"
) )
defer.returnValue(user_id) defer.returnValue((user_id, app_service))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_by_access_token(self, token, rights="access"): def get_user_by_access_token(self, token, rights="access"):
@ -1167,7 +1170,8 @@ def has_access_token(request):
bool: False if no access_token was given, True otherwise. bool: False if no access_token was given, True otherwise.
""" """
query_params = request.args.get("access_token") query_params = request.args.get("access_token")
return bool(query_params) auth_headers = request.requestHeaders.getRawHeaders("Authorization")
return bool(query_params) or bool(auth_headers)
def get_access_token_from_request(request, token_not_found_http_status=401): def get_access_token_from_request(request, token_not_found_http_status=401):
@ -1185,7 +1189,34 @@ def get_access_token_from_request(request, token_not_found_http_status=401):
Raises: Raises:
AuthError: If there isn't an access_token in the request. AuthError: If there isn't an access_token in the request.
""" """
auth_headers = request.requestHeaders.getRawHeaders("Authorization")
query_params = request.args.get("access_token") query_params = request.args.get("access_token")
if auth_headers:
# Try the get the access_token from a "Authorization: Bearer"
# header
if query_params is not None:
raise AuthError(
token_not_found_http_status,
"Mixing Authorization headers and access_token query parameters.",
errcode=Codes.MISSING_TOKEN,
)
if len(auth_headers) > 1:
raise AuthError(
token_not_found_http_status,
"Too many Authorization headers.",
errcode=Codes.MISSING_TOKEN,
)
parts = auth_headers[0].split(" ")
if parts[0] == "Bearer" and len(parts) == 2:
return parts[1]
else:
raise AuthError(
token_not_found_http_status,
"Invalid Authorization header.",
errcode=Codes.MISSING_TOKEN,
)
else:
# Try to get the access_token from the query params. # Try to get the access_token from the query params.
if not query_params: if not query_params:
raise AuthError( raise AuthError(

View File

@ -23,7 +23,7 @@ class Ratelimiter(object):
def __init__(self): def __init__(self):
self.message_counts = collections.OrderedDict() self.message_counts = collections.OrderedDict()
def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count): def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count, update=True):
"""Can the user send a message? """Can the user send a message?
Args: Args:
user_id: The user sending a message. user_id: The user sending a message.
@ -32,12 +32,15 @@ class Ratelimiter(object):
second. second.
burst_count: How many messages the user can send before being burst_count: How many messages the user can send before being
limited. limited.
update (bool): Whether to update the message rates or not. This is
useful to check if a message would be allowed to be sent before
its ready to be actually sent.
Returns: Returns:
A pair of a bool indicating if they can send a message now and a A pair of a bool indicating if they can send a message now and a
time in seconds of when they can next send a message. time in seconds of when they can next send a message.
""" """
self.prune_message_counts(time_now_s) self.prune_message_counts(time_now_s)
message_count, time_start, _ignored = self.message_counts.pop( message_count, time_start, _ignored = self.message_counts.get(
user_id, (0., time_now_s, None), user_id, (0., time_now_s, None),
) )
time_delta = time_now_s - time_start time_delta = time_now_s - time_start
@ -52,6 +55,7 @@ class Ratelimiter(object):
allowed = True allowed = True
message_count += 1 message_count += 1
if update:
self.message_counts[user_id] = ( self.message_counts[user_id] = (
message_count, time_start, msg_rate_hz message_count, time_start, msg_rate_hz
) )

View File

@ -52,6 +52,7 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory from synapse.crypto import context_factory
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from synapse.metrics import register_memory_metrics, get_metrics_for from synapse.metrics import register_memory_metrics, get_metrics_for
from synapse.metrics.process_collector import register_process_collector
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
from synapse.federation.transport.server import TransportLayerServer from synapse.federation.transport.server import TransportLayerServer
@ -337,6 +338,7 @@ def setup(config_options):
hs.get_replication_layer().start_get_pdu_cache() hs.get_replication_layer().start_get_pdu_cache()
register_memory_metrics(hs) register_memory_metrics(hs)
register_process_collector()
reactor.callWhenRunning(start) reactor.callWhenRunning(start)

View File

@ -81,7 +81,7 @@ class ApplicationService(object):
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS] NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
def __init__(self, token, url=None, namespaces=None, hs_token=None, def __init__(self, token, url=None, namespaces=None, hs_token=None,
sender=None, id=None, protocols=None): sender=None, id=None, protocols=None, rate_limited=True):
self.token = token self.token = token
self.url = url self.url = url
self.hs_token = hs_token self.hs_token = hs_token
@ -95,6 +95,8 @@ class ApplicationService(object):
else: else:
self.protocols = set() self.protocols = set()
self.rate_limited = rate_limited
def _check_namespaces(self, namespaces): def _check_namespaces(self, namespaces):
# Sanity check that it is of the form: # Sanity check that it is of the form:
# { # {
@ -234,5 +236,8 @@ class ApplicationService(object):
def is_exclusive_room(self, room_id): def is_exclusive_room(self, room_id):
return self._is_exclusive(ApplicationService.NS_ROOMS, room_id) return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
def is_rate_limited(self):
return self.rate_limited
def __str__(self): def __str__(self):
return "ApplicationService: %s" % (self.__dict__,) return "ApplicationService: %s" % (self.__dict__,)

View File

@ -110,6 +110,11 @@ def _load_appservice(hostname, as_info, config_filename):
user = UserID(localpart, hostname) user = UserID(localpart, hostname)
user_id = user.to_string() user_id = user.to_string()
# Rate limiting for users of this AS is on by default (excludes sender)
rate_limited = True
if isinstance(as_info.get("rate_limited"), bool):
rate_limited = as_info.get("rate_limited")
# namespace checks # namespace checks
if not isinstance(as_info.get("namespaces"), dict): if not isinstance(as_info.get("namespaces"), dict):
raise KeyError("Requires 'namespaces' object.") raise KeyError("Requires 'namespaces' object.")
@ -155,4 +160,5 @@ def _load_appservice(hostname, as_info, config_filename):
sender=user_id, sender=user_id,
id=as_info["id"], id=as_info["id"],
protocols=protocols, protocols=protocols,
rate_limited=rate_limited
) )

View File

@ -57,10 +57,16 @@ class BaseHandler(object):
time_now = self.clock.time() time_now = self.clock.time()
user_id = requester.user.to_string() user_id = requester.user.to_string()
# The AS user itself is never rate limited.
app_service = self.store.get_app_service_by_user_id(user_id) app_service = self.store.get_app_service_by_user_id(user_id)
if app_service is not None: if app_service is not None:
return # do not ratelimit app service senders return # do not ratelimit app service senders
# Disable rate limiting of users belonging to any AS that is configured
# not to be rate limited in its registration file (rate_limited: true|false).
if requester.app_service and not requester.app_service.is_rate_limited():
return
allowed, time_allowed = self.ratelimiter.send_message( allowed, time_allowed = self.ratelimiter.send_message(
user_id, time_now, user_id, time_now,
msg_rate_hz=self.hs.config.rc_messages_per_second, msg_rate_hz=self.hs.config.rc_messages_per_second,

View File

@ -611,6 +611,18 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def add_threepid(self, user_id, medium, address, validated_at): def add_threepid(self, user_id, medium, address, validated_at):
# 'Canonicalise' email addresses down to lower case.
# We've now moving towards the Home Server being the entity that
# is responsible for validating threepids used for resetting passwords
# on accounts, so in future Synapse will gain knowledge of specific
# types (mediums) of threepid. For now, we still use the existing
# infrastructure, but this is the start of synapse gaining knowledge
# of specific types of threepid (and fixes the fact that checking
# for the presenc eof an email address during password reset was
# case sensitive).
if medium == 'email':
address = address.lower()
yield self.store.user_add_threepid( yield self.store.user_add_threepid(
user_id, medium, address, validated_at, user_id, medium, address, validated_at,
self.hs.get_clock().time_msec() self.hs.get_clock().time_msec()

View File

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError from synapse.api.errors import AuthError, Codes, SynapseError, LimitExceededError
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
@ -82,8 +82,8 @@ class MessageHandler(BaseHandler):
room_token = pagin_config.from_token.room_key room_token = pagin_config.from_token.room_key
else: else:
pagin_config.from_token = ( pagin_config.from_token = (
yield self.hs.get_event_sources().get_current_token( yield self.hs.get_event_sources().get_current_token_for_room(
direction='b' room_id=room_id
) )
) )
room_token = pagin_config.from_token.room_key room_token = pagin_config.from_token.room_key
@ -239,6 +239,21 @@ class MessageHandler(BaseHandler):
"Tried to send member event through non-member codepath" "Tried to send member event through non-member codepath"
) )
# We check here if we are currently being rate limited, so that we
# don't do unnecessary work. We check again just before we actually
# send the event.
time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.send_message(
event.sender, time_now,
msg_rate_hz=self.hs.config.rc_messages_per_second,
burst_count=self.hs.config.rc_message_burst_count,
update=False,
)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now)),
)
user = UserID.from_string(event.sender) user = UserID.from_string(event.sender)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,) assert self.hs.is_mine(user), "User must be our own: %s" % (user,)

View File

@ -475,8 +475,11 @@ class RoomEventSource(object):
defer.returnValue((events, end_key)) defer.returnValue((events, end_key))
def get_current_key(self, direction='f'): def get_current_key(self):
return self.store.get_room_events_max_id(direction) return self.store.get_room_events_max_id()
def get_current_key_for_room(self, room_id):
return self.store.get_room_events_max_id(room_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_pagination_rows(self, user, config, key): def get_pagination_rows(self, user, config, key):

View File

@ -88,7 +88,7 @@ class TypingHandler(object):
continue continue
until = self._member_typing_until.get(member, None) until = self._member_typing_until.get(member, None)
if not until or until < now: if not until or until <= now:
logger.info("Timing out typing for: %s", member.user_id) logger.info("Timing out typing for: %s", member.user_id)
preserve_fn(self._stopped_typing)(member) preserve_fn(self._stopped_typing)(member)
continue continue
@ -97,12 +97,20 @@ class TypingHandler(object):
# user. # user.
if self.hs.is_mine_id(member.user_id): if self.hs.is_mine_id(member.user_id):
last_fed_poke = self._member_last_federation_poke.get(member, None) last_fed_poke = self._member_last_federation_poke.get(member, None)
if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL < now: if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now:
preserve_fn(self._push_remote)( preserve_fn(self._push_remote)(
member=member, member=member,
typing=True typing=True
) )
# Add a paranoia timer to ensure that we always have a timer for
# each person typing.
self.wheel_timer.insert(
now=now,
obj=member,
then=now + 60 * 1000,
)
def is_typing(self, member): def is_typing(self, member):
return member.user_id in self._room_typing.get(member.room_id, []) return member.user_id in self._room_typing.get(member.room_id, [])

View File

@ -13,14 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Because otherwise 'resource' collides with synapse.metrics.resource
from __future__ import absolute_import
import logging import logging
from resource import getrusage, RUSAGE_SELF
import functools import functools
import os
import stat
import time import time
import gc import gc
@ -36,6 +30,7 @@ logger = logging.getLogger(__name__)
all_metrics = [] all_metrics = []
all_collectors = []
class Metrics(object): class Metrics(object):
@ -46,6 +41,9 @@ class Metrics(object):
def __init__(self, name): def __init__(self, name):
self.name_prefix = name self.name_prefix = name
def register_collector(self, func):
all_collectors.append(func)
def _register(self, metric_class, name, *args, **kwargs): def _register(self, metric_class, name, *args, **kwargs):
full_name = "%s_%s" % (self.name_prefix, name) full_name = "%s_%s" % (self.name_prefix, name)
@ -94,8 +92,8 @@ def get_metrics_for(pkg_name):
def render_all(): def render_all():
strs = [] strs = []
# TODO(paul): Internal hack for collector in all_collectors:
update_resource_metrics() collector()
for metric in all_metrics: for metric in all_metrics:
try: try:
@ -109,62 +107,6 @@ def render_all():
return "\n".join(strs) return "\n".join(strs)
# Now register some standard process-wide state metrics, to give indications of
# process resource usage
rusage = None
def update_resource_metrics():
global rusage
rusage = getrusage(RUSAGE_SELF)
resource_metrics = get_metrics_for("process.resource")
# msecs
resource_metrics.register_callback("utime", lambda: rusage.ru_utime * 1000)
resource_metrics.register_callback("stime", lambda: rusage.ru_stime * 1000)
# kilobytes
resource_metrics.register_callback("maxrss", lambda: rusage.ru_maxrss * 1024)
TYPES = {
stat.S_IFSOCK: "SOCK",
stat.S_IFLNK: "LNK",
stat.S_IFREG: "REG",
stat.S_IFBLK: "BLK",
stat.S_IFDIR: "DIR",
stat.S_IFCHR: "CHR",
stat.S_IFIFO: "FIFO",
}
def _process_fds():
counts = {(k,): 0 for k in TYPES.values()}
counts[("other",)] = 0
# Not every OS will have a /proc/self/fd directory
if not os.path.exists("/proc/self/fd"):
return counts
for fd in os.listdir("/proc/self/fd"):
try:
s = os.stat("/proc/self/fd/%s" % (fd))
fmt = stat.S_IFMT(s.st_mode)
if fmt in TYPES:
t = TYPES[fmt]
else:
t = "other"
counts[(t,)] += 1
except OSError:
# the dirh itself used by listdir() is usually missing by now
pass
return counts
get_metrics_for("process").register_callback("fds", _process_fds, labels=["type"])
reactor_metrics = get_metrics_for("reactor") reactor_metrics = get_metrics_for("reactor")
tick_time = reactor_metrics.register_distribution("tick_time") tick_time = reactor_metrics.register_distribution("tick_time")
pending_calls_metric = reactor_metrics.register_distribution("pending_calls") pending_calls_metric = reactor_metrics.register_distribution("pending_calls")

View File

@ -98,9 +98,9 @@ class CallbackMetric(BaseMetric):
value = self.callback() value = self.callback()
if self.is_scalar(): if self.is_scalar():
return ["%s %d" % (self.name, value)] return ["%s %.12g" % (self.name, value)]
return ["%s%s %d" % (self.name, self._render_key(k), value[k]) return ["%s%s %.12g" % (self.name, self._render_key(k), value[k])
for k in sorted(value.keys())] for k in sorted(value.keys())]

View File

@ -0,0 +1,181 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Because otherwise 'resource' collides with synapse.metrics.resource
from __future__ import absolute_import
import os
import stat
from resource import getrusage, RUSAGE_SELF
from synapse.metrics import get_metrics_for
TICKS_PER_SEC = 100
BYTES_PER_PAGE = 4096
HAVE_PROC_STAT = os.path.exists("/proc/stat")
HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
HAVE_PROC_SELF_LIMITS = os.path.exists("/proc/self/limits")
HAVE_PROC_SELF_FD = os.path.exists("/proc/self/fd")
TYPES = {
stat.S_IFSOCK: "SOCK",
stat.S_IFLNK: "LNK",
stat.S_IFREG: "REG",
stat.S_IFBLK: "BLK",
stat.S_IFDIR: "DIR",
stat.S_IFCHR: "CHR",
stat.S_IFIFO: "FIFO",
}
# Field indexes from /proc/self/stat, taken from the proc(5) manpage
STAT_FIELDS = {
"utime": 14,
"stime": 15,
"starttime": 22,
"vsize": 23,
"rss": 24,
}
rusage = None
stats = {}
fd_counts = None
# In order to report process_start_time_seconds we need to know the
# machine's boot time, because the value in /proc/self/stat is relative to
# this
boot_time = None
if HAVE_PROC_STAT:
with open("/proc/stat") as _procstat:
for line in _procstat:
if line.startswith("btime "):
boot_time = int(line.split()[1])
def update_resource_metrics():
global rusage
rusage = getrusage(RUSAGE_SELF)
if HAVE_PROC_SELF_STAT:
global stats
with open("/proc/self/stat") as s:
line = s.read()
# line is PID (command) more stats go here ...
raw_stats = line.split(") ", 1)[1].split(" ")
for (name, index) in STAT_FIELDS.iteritems():
# subtract 3 from the index, because proc(5) is 1-based, and
# we've lost the first two fields in PID and COMMAND above
stats[name] = int(raw_stats[index - 3])
global fd_counts
fd_counts = _process_fds()
def _process_fds():
counts = {(k,): 0 for k in TYPES.values()}
counts[("other",)] = 0
# Not every OS will have a /proc/self/fd directory
if not HAVE_PROC_SELF_FD:
return counts
for fd in os.listdir("/proc/self/fd"):
try:
s = os.stat("/proc/self/fd/%s" % (fd))
fmt = stat.S_IFMT(s.st_mode)
if fmt in TYPES:
t = TYPES[fmt]
else:
t = "other"
counts[(t,)] += 1
except OSError:
# the dirh itself used by listdir() is usually missing by now
pass
return counts
def register_process_collector():
# Legacy synapse-invented metric names
resource_metrics = get_metrics_for("process.resource")
resource_metrics.register_collector(update_resource_metrics)
# msecs
resource_metrics.register_callback("utime", lambda: rusage.ru_utime * 1000)
resource_metrics.register_callback("stime", lambda: rusage.ru_stime * 1000)
# kilobytes
resource_metrics.register_callback("maxrss", lambda: rusage.ru_maxrss * 1024)
get_metrics_for("process").register_callback("fds", _process_fds, labels=["type"])
# New prometheus-standard metric names
process_metrics = get_metrics_for("process")
if HAVE_PROC_SELF_STAT:
process_metrics.register_callback(
"cpu_user_seconds_total",
lambda: float(stats["utime"]) / TICKS_PER_SEC
)
process_metrics.register_callback(
"cpu_system_seconds_total",
lambda: float(stats["stime"]) / TICKS_PER_SEC
)
process_metrics.register_callback(
"cpu_seconds_total",
lambda: (float(stats["utime"] + stats["stime"])) / TICKS_PER_SEC
)
process_metrics.register_callback(
"virtual_memory_bytes",
lambda: int(stats["vsize"])
)
process_metrics.register_callback(
"resident_memory_bytes",
lambda: int(stats["rss"]) * BYTES_PER_PAGE
)
process_metrics.register_callback(
"start_time_seconds",
lambda: boot_time + int(stats["starttime"]) / TICKS_PER_SEC
)
if HAVE_PROC_SELF_FD:
process_metrics.register_callback(
"open_fds",
lambda: sum(fd_counts.values())
)
if HAVE_PROC_SELF_LIMITS:
def _get_max_fds():
with open("/proc/self/limits") as limits:
for line in limits:
if not line.startswith("Max open files "):
continue
# Line is Max open files $SOFT $HARD
return int(line.split()[3])
return None
process_metrics.register_callback(
"max_fds",
lambda: _get_max_fds()
)

View File

@ -150,6 +150,10 @@ class EmailPusher(object):
soonest_due_at = None soonest_due_at = None
if not unprocessed:
yield self.save_last_stream_ordering_and_success(self.max_stream_ordering)
return
for push_action in unprocessed: for push_action in unprocessed:
received_at = push_action['received_ts'] received_at = push_action['received_ts']
if received_at is None: if received_at is None:

View File

@ -372,7 +372,7 @@ class Mailer(object):
state_event_id = room_state_ids[room_id][ state_event_id = room_state_ids[room_id][
("m.room.member", event.sender) ("m.room.member", event.sender)
] ]
state_event = yield self.get_event(state_event_id) state_event = yield self.store.get_event(state_event_id)
sender_name = name_from_member_event(state_event) sender_name = name_from_member_event(state_event)
if sender_name is not None and room_name is not None: if sender_name is not None and room_name is not None:

View File

@ -15,7 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError, StoreError, Codes
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID from synapse.types import UserID
@ -45,7 +45,7 @@ class GetFilterRestServlet(RestServlet):
raise AuthError(403, "Cannot get filters for other users") raise AuthError(403, "Cannot get filters for other users")
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only get filters for local users") raise AuthError(403, "Can only get filters for local users")
try: try:
filter_id = int(filter_id) filter_id = int(filter_id)
@ -59,8 +59,8 @@ class GetFilterRestServlet(RestServlet):
) )
defer.returnValue((200, filter.get_filter_json())) defer.returnValue((200, filter.get_filter_json()))
except KeyError: except (KeyError, StoreError):
raise SynapseError(400, "No such filter") raise SynapseError(400, "No such filter", errcode=Codes.NOT_FOUND)
class CreateFilterRestServlet(RestServlet): class CreateFilterRestServlet(RestServlet):
@ -74,6 +74,7 @@ class CreateFilterRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id): def on_POST(self, request, user_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
@ -81,10 +82,9 @@ class CreateFilterRestServlet(RestServlet):
raise AuthError(403, "Cannot create filters for other users") raise AuthError(403, "Cannot create filters for other users")
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only create filters for local users") raise AuthError(403, "Can only create filters for local users")
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
filter_id = yield self.filtering.add_user_filter( filter_id = yield self.filtering.add_user_filter(
user_localpart=target_user.localpart, user_localpart=target_user.localpart,
user_filter=content, user_filter=content,

View File

@ -85,7 +85,6 @@ class LoggingTransaction(object):
sql_logger.debug("[SQL] {%s} %s", self.name, sql) sql_logger.debug("[SQL] {%s} %s", self.name, sql)
sql = self.database_engine.convert_param_style(sql) sql = self.database_engine.convert_param_style(sql)
if args: if args:
try: try:
sql_logger.debug( sql_logger.debug(

View File

@ -0,0 +1,23 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Update any email addresses that were stored with mixed case into all
* lowercase
*/
UPDATE user_threepids SET address = LOWER(address) where medium = 'email';
/* Add an index for the select we do on passwored reset */
CREATE INDEX user_threepids_medium_address on user_threepids (medium, address);

View File

@ -521,13 +521,20 @@ class StreamStore(SQLBaseStore):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_events_max_id(self, direction='f'): def get_room_events_max_id(self, room_id=None):
"""Returns the current token for rooms stream.
By default, it returns the current global stream token. Specifying a
`room_id` causes it to return the current room specific topological
token.
"""
token = yield self._stream_id_gen.get_current_token() token = yield self._stream_id_gen.get_current_token()
if direction != 'b': if room_id is None:
defer.returnValue("s%d" % (token,)) defer.returnValue("s%d" % (token,))
else: else:
topo = yield self.runInteraction( topo = yield self.runInteraction(
"_get_max_topological_txn", self._get_max_topological_txn "_get_max_topological_txn", self._get_max_topological_txn,
room_id,
) )
defer.returnValue("t%d-%d" % (topo, token)) defer.returnValue("t%d-%d" % (topo, token))
@ -579,11 +586,11 @@ class StreamStore(SQLBaseStore):
lambda r: r[0][0] if r else 0 lambda r: r[0][0] if r else 0
) )
def _get_max_topological_txn(self, txn): def _get_max_topological_txn(self, txn, room_id):
txn.execute( txn.execute(
"SELECT MAX(topological_ordering) FROM events" "SELECT MAX(topological_ordering) FROM events"
" WHERE outlier = ?", " WHERE room_id = ?",
(False,) (room_id,)
) )
rows = txn.fetchall() rows = txn.fetchall()

View File

@ -41,13 +41,39 @@ class EventSources(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_token(self, direction='f'): def get_current_token(self):
push_rules_key, _ = self.store.get_push_rules_stream_token() push_rules_key, _ = self.store.get_push_rules_stream_token()
to_device_key = self.store.get_to_device_stream_token() to_device_key = self.store.get_to_device_stream_token()
token = StreamToken( token = StreamToken(
room_key=( room_key=(
yield self.sources["room"].get_current_key(direction) yield self.sources["room"].get_current_key()
),
presence_key=(
yield self.sources["presence"].get_current_key()
),
typing_key=(
yield self.sources["typing"].get_current_key()
),
receipt_key=(
yield self.sources["receipt"].get_current_key()
),
account_data_key=(
yield self.sources["account_data"].get_current_key()
),
push_rules_key=push_rules_key,
to_device_key=to_device_key,
)
defer.returnValue(token)
@defer.inlineCallbacks
def get_current_token_for_room(self, room_id):
push_rules_key, _ = self.store.get_push_rules_stream_token()
to_device_key = self.store.get_to_device_stream_token()
token = StreamToken(
room_key=(
yield self.sources["room"].get_current_key_for_room(room_id)
), ),
presence_key=( presence_key=(
yield self.sources["presence"].get_current_key() yield self.sources["presence"].get_current_key()

View File

@ -18,8 +18,9 @@ from synapse.api.errors import SynapseError
from collections import namedtuple from collections import namedtuple
Requester = namedtuple("Requester", Requester = namedtuple("Requester", [
["user", "access_token_id", "is_guest", "device_id"]) "user", "access_token_id", "is_guest", "device_id", "app_service",
])
""" """
Represents the user making a request Represents the user making a request
@ -29,11 +30,12 @@ Attributes:
request, or None if it came via the appservice API or similar request, or None if it came via the appservice API or similar
is_guest (bool): True if the user making this request is a guest user is_guest (bool): True if the user making this request is a guest user
device_id (str|None): device_id which was set at authentication time device_id (str|None): device_id which was set at authentication time
app_service (ApplicationService|None): the AS requesting on behalf of the user
""" """
def create_requester(user_id, access_token_id=None, is_guest=False, def create_requester(user_id, access_token_id=None, is_guest=False,
device_id=None): device_id=None, app_service=None):
""" """
Create a new ``Requester`` object Create a new ``Requester`` object
@ -43,13 +45,14 @@ def create_requester(user_id, access_token_id=None, is_guest=False,
request, or None if it came via the appservice API or similar request, or None if it came via the appservice API or similar
is_guest (bool): True if the user making this request is a guest user is_guest (bool): True if the user making this request is a guest user
device_id (str|None): device_id which was set at authentication time device_id (str|None): device_id which was set at authentication time
app_service (ApplicationService|None): the AS requesting on behalf of the user
Returns: Returns:
Requester Requester
""" """
if not isinstance(user_id, UserID): if not isinstance(user_id, UserID):
user_id = UserID.from_string(user_id) user_id = UserID.from_string(user_id)
return Requester(user_id, access_token_id, is_guest, device_id) return Requester(user_id, access_token_id, is_guest, device_id, app_service)
def get_domain_from_id(string): def get_domain_from_id(string):

View File

@ -20,7 +20,7 @@ from mock import Mock
from synapse.api.auth import Auth from synapse.api.auth import Auth
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.types import UserID from synapse.types import UserID
from tests.utils import setup_test_homeserver from tests.utils import setup_test_homeserver, mock_getRawHeaders
import pymacaroons import pymacaroons
@ -51,7 +51,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), self.test_user) self.assertEquals(requester.user.to_string(), self.test_user)
@ -61,7 +61,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)
@ -74,7 +74,7 @@ class AuthTestCase(unittest.TestCase):
self.store.get_user_by_access_token = Mock(return_value=user_info) self.store.get_user_by_access_token = Mock(return_value=user_info)
request = Mock(args={}) request = Mock(args={})
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)
@ -86,7 +86,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), self.test_user) self.assertEquals(requester.user.to_string(), self.test_user)
@ -96,7 +96,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)
@ -106,7 +106,7 @@ class AuthTestCase(unittest.TestCase):
self.store.get_user_by_access_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={}) request = Mock(args={})
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)
@ -121,7 +121,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id] request.args["user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), masquerading_user_id) self.assertEquals(requester.user.to_string(), masquerading_user_id)
@ -135,7 +135,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id] request.args["user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)

View File

@ -219,7 +219,8 @@ class TypingNotificationsTestCase(unittest.TestCase):
"user_id": self.u_onion.to_string(), "user_id": self.u_onion.to_string(),
"typing": True, "typing": True,
} }
) ),
federation_auth=True,
) )
self.on_new_event.assert_has_calls([ self.on_new_event.assert_has_calls([

View File

@ -17,6 +17,7 @@ from synapse.rest.client.v1.register import CreateUserRestServlet
from twisted.internet import defer from twisted.internet import defer
from mock import Mock from mock import Mock
from tests import unittest from tests import unittest
from tests.utils import mock_getRawHeaders
import json import json
@ -30,6 +31,7 @@ class CreateUserServletTestCase(unittest.TestCase):
path='/_matrix/client/api/v1/createUser' path='/_matrix/client/api/v1/createUser'
) )
self.request.args = {} self.request.args = {}
self.request.requestHeaders.getRawHeaders = mock_getRawHeaders()
self.registration_handler = Mock() self.registration_handler = Mock()

View File

@ -15,78 +15,125 @@
from twisted.internet import defer from twisted.internet import defer
from . import V2AlphaRestTestCase from tests import unittest
from synapse.rest.client.v2_alpha import filter from synapse.rest.client.v2_alpha import filter
from synapse.api.errors import StoreError from synapse.api.errors import Codes
import synapse.types
from synapse.types import UserID
from ....utils import MockHttpResource, setup_test_homeserver
PATH_PREFIX = "/_matrix/client/v2_alpha"
class FilterTestCase(V2AlphaRestTestCase): class FilterTestCase(unittest.TestCase):
USER_ID = "@apple:test" USER_ID = "@apple:test"
EXAMPLE_FILTER = {"type": ["m.*"]}
EXAMPLE_FILTER_JSON = '{"type": ["m.*"]}'
TO_REGISTER = [filter] TO_REGISTER = [filter]
def make_datastore_mock(self): @defer.inlineCallbacks
datastore = super(FilterTestCase, self).make_datastore_mock() def setUp(self):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
self._user_filters = {} self.hs = yield setup_test_homeserver(
http_client=None,
resource_for_client=self.mock_resource,
resource_for_federation=self.mock_resource,
)
def add_user_filter(user_localpart, definition): self.auth = self.hs.get_auth()
filters = self._user_filters.setdefault(user_localpart, [])
filter_id = len(filters)
filters.append(definition)
return defer.succeed(filter_id)
datastore.add_user_filter = add_user_filter
def get_user_filter(user_localpart, filter_id): def get_user_by_access_token(token=None, allow_guest=False):
if user_localpart not in self._user_filters: return {
raise StoreError(404, "No user") "user": UserID.from_string(self.USER_ID),
filters = self._user_filters[user_localpart] "token_id": 1,
if filter_id >= len(filters): "is_guest": False,
raise StoreError(404, "No filter") }
return defer.succeed(filters[filter_id])
datastore.get_user_filter = get_user_filter
return datastore def get_user_by_req(request, allow_guest=False, rights="access"):
return synapse.types.create_requester(
UserID.from_string(self.USER_ID), 1, False, None)
self.auth.get_user_by_access_token = get_user_by_access_token
self.auth.get_user_by_req = get_user_by_req
self.store = self.hs.get_datastore()
self.filtering = self.hs.get_filtering()
for r in self.TO_REGISTER:
r.register_servlets(self.hs, self.mock_resource)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_add_filter(self): def test_add_filter(self):
(code, response) = yield self.mock_resource.trigger( (code, response) = yield self.mock_resource.trigger(
"POST", "/user/%s/filter" % (self.USER_ID), '{"type": ["m.*"]}' "POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON
) )
self.assertEquals(200, code) self.assertEquals(200, code)
self.assertEquals({"filter_id": "0"}, response) self.assertEquals({"filter_id": "0"}, response)
filter = yield self.store.get_user_filter(
user_localpart='apple',
filter_id=0,
)
self.assertEquals(filter, self.EXAMPLE_FILTER)
self.assertIn("apple", self._user_filters) @defer.inlineCallbacks
self.assertEquals(len(self._user_filters["apple"]), 1) def test_add_filter_for_other_user(self):
self.assertEquals({"type": ["m.*"]}, self._user_filters["apple"][0]) (code, response) = yield self.mock_resource.trigger(
"POST", "/user/%s/filter" % ('@watermelon:test'), self.EXAMPLE_FILTER_JSON
)
self.assertEquals(403, code)
self.assertEquals(response['errcode'], Codes.FORBIDDEN)
@defer.inlineCallbacks
def test_add_filter_non_local_user(self):
_is_mine = self.hs.is_mine
self.hs.is_mine = lambda target_user: False
(code, response) = yield self.mock_resource.trigger(
"POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON
)
self.hs.is_mine = _is_mine
self.assertEquals(403, code)
self.assertEquals(response['errcode'], Codes.FORBIDDEN)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_filter(self): def test_get_filter(self):
self._user_filters["apple"] = [ filter_id = yield self.filtering.add_user_filter(
{"type": ["m.*"]} user_localpart='apple',
] user_filter=self.EXAMPLE_FILTER
)
(code, response) = yield self.mock_resource.trigger_get( (code, response) = yield self.mock_resource.trigger_get(
"/user/%s/filter/0" % (self.USER_ID) "/user/%s/filter/%s" % (self.USER_ID, filter_id)
) )
self.assertEquals(200, code) self.assertEquals(200, code)
self.assertEquals({"type": ["m.*"]}, response) self.assertEquals(self.EXAMPLE_FILTER, response)
@defer.inlineCallbacks
def test_get_filter_non_existant(self):
(code, response) = yield self.mock_resource.trigger_get(
"/user/%s/filter/12382148321" % (self.USER_ID)
)
self.assertEquals(400, code)
self.assertEquals(response['errcode'], Codes.NOT_FOUND)
# Currently invalid params do not have an appropriate errcode
# in errors.py
@defer.inlineCallbacks
def test_get_filter_invalid_id(self):
(code, response) = yield self.mock_resource.trigger_get(
"/user/%s/filter/foobar" % (self.USER_ID)
)
self.assertEquals(400, code)
# No ID also returns an invalid_id error
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_filter_no_id(self): def test_get_filter_no_id(self):
self._user_filters["apple"] = [
{"type": ["m.*"]}
]
(code, response) = yield self.mock_resource.trigger_get( (code, response) = yield self.mock_resource.trigger_get(
"/user/%s/filter/2" % (self.USER_ID) "/user/%s/filter/" % (self.USER_ID)
) )
self.assertEquals(404, code) self.assertEquals(400, code)
@defer.inlineCallbacks
def test_get_filter_no_user(self):
(code, response) = yield self.mock_resource.trigger_get(
"/user/%s/filter/0" % (self.USER_ID)
)
self.assertEquals(404, code)

View File

@ -3,6 +3,7 @@ from synapse.api.errors import SynapseError
from twisted.internet import defer from twisted.internet import defer
from mock import Mock from mock import Mock
from tests import unittest from tests import unittest
from tests.utils import mock_getRawHeaders
import json import json
@ -16,6 +17,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
path='/_matrix/api/v2_alpha/register' path='/_matrix/api/v2_alpha/register'
) )
self.request.args = {} self.request.args = {}
self.request.requestHeaders.getRawHeaders = mock_getRawHeaders()
self.appservice = None self.appservice = None
self.auth = Mock(get_appservice_by_req=Mock( self.auth = Mock(get_appservice_by_req=Mock(

View File

@ -116,6 +116,15 @@ def get_mock_call_args(pattern_func, mock_func):
return getcallargs(pattern_func, *invoked_args, **invoked_kargs) return getcallargs(pattern_func, *invoked_args, **invoked_kargs)
def mock_getRawHeaders(headers=None):
headers = headers if headers is not None else {}
def getRawHeaders(name, default=None):
return headers.get(name, default)
return getRawHeaders
# This is a mock /resource/ not an entire server # This is a mock /resource/ not an entire server
class MockHttpResource(HttpServer): class MockHttpResource(HttpServer):
@ -128,7 +137,7 @@ class MockHttpResource(HttpServer):
@patch('twisted.web.http.Request') @patch('twisted.web.http.Request')
@defer.inlineCallbacks @defer.inlineCallbacks
def trigger(self, http_method, path, content, mock_request): def trigger(self, http_method, path, content, mock_request, federation_auth=False):
""" Fire an HTTP event. """ Fire an HTTP event.
Args: Args:
@ -156,9 +165,10 @@ class MockHttpResource(HttpServer):
mock_request.getClientIP.return_value = "-" mock_request.getClientIP.return_value = "-"
mock_request.requestHeaders.getRawHeaders.return_value = [ headers = {}
"X-Matrix origin=test,key=,sig=" if federation_auth:
] headers["Authorization"] = ["X-Matrix origin=test,key=,sig="]
mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
# return the right path if the event requires it # return the right path if the event requires it
mock_request.path = path mock_request.path = path
@ -189,7 +199,7 @@ class MockHttpResource(HttpServer):
) )
defer.returnValue((code, response)) defer.returnValue((code, response))
except CodeMessageException as e: except CodeMessageException as e:
defer.returnValue((e.code, cs_error(e.msg))) defer.returnValue((e.code, cs_error(e.msg, code=e.errcode)))
raise KeyError("No event can handle %s" % path) raise KeyError("No event can handle %s" % path)