Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes

This commit is contained in:
Erik Johnston 2017-03-15 10:00:10 +00:00
commit f90eb60ae9
36 changed files with 1082 additions and 317 deletions

View file

@ -39,7 +39,9 @@ loggers:
synapse: synapse:
level: INFO level: INFO
synapse.storage: synapse.storage.SQL:
# beware: increasing this to DEBUG will make synapse log sensitive
# information such as access tokens.
level: INFO level: INFO
# example of enabling debugging for a component: # example of enabling debugging for a component:

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -44,6 +45,7 @@ class JoinRules(object):
class LoginType(object): class LoginType(object):
PASSWORD = u"m.login.password" PASSWORD = u"m.login.password"
EMAIL_IDENTITY = u"m.login.email.identity" EMAIL_IDENTITY = u"m.login.email.identity"
MSISDN = u"m.login.msisdn"
RECAPTCHA = u"m.login.recaptcha" RECAPTCHA = u"m.login.recaptcha"
DUMMY = u"m.login.dummy" DUMMY = u"m.login.dummy"

View file

@ -15,6 +15,7 @@
"""Contains exceptions and error codes.""" """Contains exceptions and error codes."""
import json
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -50,27 +51,35 @@ class Codes(object):
class CodeMessageException(RuntimeError): class CodeMessageException(RuntimeError):
"""An exception with integer code and message string attributes.""" """An exception with integer code and message string attributes.
Attributes:
code (int): HTTP error code
msg (str): string describing the error
"""
def __init__(self, code, msg): def __init__(self, code, msg):
super(CodeMessageException, self).__init__("%d: %s" % (code, msg)) super(CodeMessageException, self).__init__("%d: %s" % (code, msg))
self.code = code self.code = code
self.msg = msg self.msg = msg
self.response_code_message = None
def error_dict(self): def error_dict(self):
return cs_error(self.msg) return cs_error(self.msg)
class SynapseError(CodeMessageException): class SynapseError(CodeMessageException):
"""A base error which can be caught for all synapse events.""" """A base exception type for matrix errors which have an errcode and error
message (as well as an HTTP status code).
Attributes:
errcode (str): Matrix error code e.g 'M_FORBIDDEN'
"""
def __init__(self, code, msg, errcode=Codes.UNKNOWN): def __init__(self, code, msg, errcode=Codes.UNKNOWN):
"""Constructs a synapse error. """Constructs a synapse error.
Args: Args:
code (int): The integer error code (an HTTP response code) code (int): The integer error code (an HTTP response code)
msg (str): The human-readable error message. msg (str): The human-readable error message.
err (str): The error code e.g 'M_FORBIDDEN' errcode (str): The matrix error code e.g 'M_FORBIDDEN'
""" """
super(SynapseError, self).__init__(code, msg) super(SynapseError, self).__init__(code, msg)
self.errcode = errcode self.errcode = errcode
@ -81,6 +90,39 @@ class SynapseError(CodeMessageException):
self.errcode, self.errcode,
) )
@classmethod
def from_http_response_exception(cls, err):
"""Make a SynapseError based on an HTTPResponseException
This is useful when a proxied request has failed, and we need to
decide how to map the failure onto a matrix error to send back to the
client.
An attempt is made to parse the body of the http response as a matrix
error. If that succeeds, the errcode and error message from the body
are used as the errcode and error message in the new synapse error.
Otherwise, the errcode is set to M_UNKNOWN, and the error message is
set to the reason code from the HTTP response.
Args:
err (HttpResponseException):
Returns:
SynapseError:
"""
# try to parse the body as json, to get better errcode/msg, but
# default to M_UNKNOWN with the HTTP status as the error text
try:
j = json.loads(err.response)
except ValueError:
j = {}
errcode = j.get('errcode', Codes.UNKNOWN)
errmsg = j.get('error', err.msg)
res = SynapseError(err.code, errmsg, errcode)
return res
class RegistrationError(SynapseError): class RegistrationError(SynapseError):
"""An error raised when a registration event fails.""" """An error raised when a registration event fails."""
@ -106,13 +148,11 @@ class UnrecognizedRequestError(SynapseError):
class NotFoundError(SynapseError): class NotFoundError(SynapseError):
"""An error indicating we can't find the thing you asked for""" """An error indicating we can't find the thing you asked for"""
def __init__(self, *args, **kwargs): def __init__(self, msg="Not found", errcode=Codes.NOT_FOUND):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.NOT_FOUND
super(NotFoundError, self).__init__( super(NotFoundError, self).__init__(
404, 404,
"Not found", msg,
**kwargs errcode=errcode
) )
@ -173,7 +213,6 @@ class LimitExceededError(SynapseError):
errcode=Codes.LIMIT_EXCEEDED): errcode=Codes.LIMIT_EXCEEDED):
super(LimitExceededError, self).__init__(code, msg, errcode) super(LimitExceededError, self).__init__(code, msg, errcode)
self.retry_after_ms = retry_after_ms self.retry_after_ms = retry_after_ms
self.response_code_message = "Too Many Requests"
def error_dict(self): def error_dict(self):
return cs_error( return cs_error(
@ -243,6 +282,19 @@ class FederationError(RuntimeError):
class HttpResponseException(CodeMessageException): class HttpResponseException(CodeMessageException):
"""
Represents an HTTP-level failure of an outbound request
Attributes:
response (str): body of response
"""
def __init__(self, code, msg, response): def __init__(self, code, msg, response):
self.response = response """
Args:
code (int): HTTP status code
msg (str): reason phrase from HTTP response status line
response (str): body of response
"""
super(HttpResponseException, self).__init__(code, msg) super(HttpResponseException, self).__init__(code, msg)
self.response = response

View file

@ -157,7 +157,7 @@ def start(config_options):
assert config.worker_app == "synapse.app.appservice" assert config.worker_app == "synapse.app.appservice"
setup_logging(config.worker_log_config, config.worker_log_file) setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts events.USE_FROZEN_DICTS = config.use_frozen_dicts

View file

@ -171,7 +171,7 @@ def start(config_options):
assert config.worker_app == "synapse.app.client_reader" assert config.worker_app == "synapse.app.client_reader"
setup_logging(config.worker_log_config, config.worker_log_file) setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts events.USE_FROZEN_DICTS = config.use_frozen_dicts

View file

@ -162,7 +162,7 @@ def start(config_options):
assert config.worker_app == "synapse.app.federation_reader" assert config.worker_app == "synapse.app.federation_reader"
setup_logging(config.worker_log_config, config.worker_log_file) setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts events.USE_FROZEN_DICTS = config.use_frozen_dicts

View file

@ -160,7 +160,7 @@ def start(config_options):
assert config.worker_app == "synapse.app.federation_sender" assert config.worker_app == "synapse.app.federation_sender"
setup_logging(config.worker_log_config, config.worker_log_file) setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts events.USE_FROZEN_DICTS = config.use_frozen_dicts

View file

@ -20,6 +20,8 @@ import gc
import logging import logging
import os import os
import sys import sys
import synapse.config.logger
from synapse.config._base import ConfigError from synapse.config._base import ConfigError
from synapse.python_dependencies import ( from synapse.python_dependencies import (
@ -286,7 +288,7 @@ def setup(config_options):
# generating config files and shouldn't try to continue. # generating config files and shouldn't try to continue.
sys.exit(0) sys.exit(0)
config.setup_logging() synapse.config.logger.setup_logging(config, use_worker_options=False)
# check any extra requirements we have now we have a config # check any extra requirements we have now we have a config
check_requirements(config) check_requirements(config)

View file

@ -168,7 +168,7 @@ def start(config_options):
assert config.worker_app == "synapse.app.media_repository" assert config.worker_app == "synapse.app.media_repository"
setup_logging(config.worker_log_config, config.worker_log_file) setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts events.USE_FROZEN_DICTS = config.use_frozen_dicts

View file

@ -245,7 +245,7 @@ def start(config_options):
assert config.worker_app == "synapse.app.pusher" assert config.worker_app == "synapse.app.pusher"
setup_logging(config.worker_log_config, config.worker_log_file) setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts events.USE_FROZEN_DICTS = config.use_frozen_dicts

View file

@ -478,7 +478,7 @@ def start(config_options):
assert config.worker_app == "synapse.app.synchrotron" assert config.worker_app == "synapse.app.synchrotron"
setup_logging(config.worker_log_config, config.worker_log_file) setup_logging(config, use_worker_options=True)
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts

View file

@ -45,7 +45,6 @@ handlers:
maxBytes: 104857600 maxBytes: 104857600
backupCount: 10 backupCount: 10
filters: [context] filters: [context]
level: INFO
console: console:
class: logging.StreamHandler class: logging.StreamHandler
formatter: precise formatter: precise
@ -56,6 +55,8 @@ loggers:
level: INFO level: INFO
synapse.storage.SQL: synapse.storage.SQL:
# beware: increasing this to DEBUG will make synapse log sensitive
# information such as access tokens.
level: INFO level: INFO
root: root:
@ -68,6 +69,7 @@ class LoggingConfig(Config):
def read_config(self, config): def read_config(self, config):
self.verbosity = config.get("verbose", 0) self.verbosity = config.get("verbose", 0)
self.no_redirect_stdio = config.get("no_redirect_stdio", False)
self.log_config = self.abspath(config.get("log_config")) self.log_config = self.abspath(config.get("log_config"))
self.log_file = self.abspath(config.get("log_file")) self.log_file = self.abspath(config.get("log_file"))
@ -77,10 +79,10 @@ class LoggingConfig(Config):
os.path.join(config_dir_path, server_name + ".log.config") os.path.join(config_dir_path, server_name + ".log.config")
) )
return """ return """
# Logging verbosity level. # Logging verbosity level. Ignored if log_config is specified.
verbose: 0 verbose: 0
# File to write logging to # File to write logging to. Ignored if log_config is specified.
log_file: "%(log_file)s" log_file: "%(log_file)s"
# A yaml python logging config file # A yaml python logging config file
@ -90,6 +92,8 @@ class LoggingConfig(Config):
def read_arguments(self, args): def read_arguments(self, args):
if args.verbose is not None: if args.verbose is not None:
self.verbosity = args.verbose self.verbosity = args.verbose
if args.no_redirect_stdio is not None:
self.no_redirect_stdio = args.no_redirect_stdio
if args.log_config is not None: if args.log_config is not None:
self.log_config = args.log_config self.log_config = args.log_config
if args.log_file is not None: if args.log_file is not None:
@ -99,16 +103,22 @@ class LoggingConfig(Config):
logging_group = parser.add_argument_group("logging") logging_group = parser.add_argument_group("logging")
logging_group.add_argument( logging_group.add_argument(
'-v', '--verbose', dest="verbose", action='count', '-v', '--verbose', dest="verbose", action='count',
help="The verbosity level." help="The verbosity level. Specify multiple times to increase "
"verbosity. (Ignored if --log-config is specified.)"
) )
logging_group.add_argument( logging_group.add_argument(
'-f', '--log-file', dest="log_file", '-f', '--log-file', dest="log_file",
help="File to log to." help="File to log to. (Ignored if --log-config is specified.)"
) )
logging_group.add_argument( logging_group.add_argument(
'--log-config', dest="log_config", default=None, '--log-config', dest="log_config", default=None,
help="Python logging config file" help="Python logging config file"
) )
logging_group.add_argument(
'-n', '--no-redirect-stdio',
action='store_true', default=None,
help="Do not redirect stdout/stderr to the log"
)
def generate_files(self, config): def generate_files(self, config):
log_config = config.get("log_config") log_config = config.get("log_config")
@ -118,11 +128,22 @@ class LoggingConfig(Config):
DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"]) DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"])
) )
def setup_logging(self):
setup_logging(self.log_config, self.log_file, self.verbosity)
def setup_logging(config, use_worker_options=False):
""" Set up python logging
Args:
config (LoggingConfig | synapse.config.workers.WorkerConfig):
configuration data
use_worker_options (bool): True to use 'worker_log_config' and
'worker_log_file' options instead of 'log_config' and 'log_file'.
"""
log_config = (config.worker_log_config if use_worker_options
else config.log_config)
log_file = (config.worker_log_file if use_worker_options
else config.log_file)
def setup_logging(log_config=None, log_file=None, verbosity=None):
log_format = ( log_format = (
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s" "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
" - %(message)s" " - %(message)s"
@ -131,9 +152,9 @@ def setup_logging(log_config=None, log_file=None, verbosity=None):
level = logging.INFO level = logging.INFO
level_for_storage = logging.INFO level_for_storage = logging.INFO
if verbosity: if config.verbosity:
level = logging.DEBUG level = logging.DEBUG
if verbosity > 1: if config.verbosity > 1:
level_for_storage = logging.DEBUG level_for_storage = logging.DEBUG
# FIXME: we need a logging.WARN for a -q quiet option # FIXME: we need a logging.WARN for a -q quiet option
@ -153,14 +174,6 @@ def setup_logging(log_config=None, log_file=None, verbosity=None):
logger.info("Closing log file due to SIGHUP") logger.info("Closing log file due to SIGHUP")
handler.doRollover() handler.doRollover()
logger.info("Opened new log file due to SIGHUP") logger.info("Opened new log file due to SIGHUP")
# TODO(paul): obviously this is a terrible mechanism for
# stealing SIGHUP, because it means no other part of synapse
# can use it instead. If we want to catch SIGHUP anywhere
# else as well, I'd suggest we find a nicer way to broadcast
# it around.
if getattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, sighup)
else: else:
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setFormatter(formatter) handler.setFormatter(formatter)
@ -169,8 +182,25 @@ def setup_logging(log_config=None, log_file=None, verbosity=None):
logger.addHandler(handler) logger.addHandler(handler)
else: else:
with open(log_config, 'r') as f: def load_log_config():
logging.config.dictConfig(yaml.load(f)) with open(log_config, 'r') as f:
logging.config.dictConfig(yaml.load(f))
def sighup(signum, stack):
# it might be better to use a file watcher or something for this.
logging.info("Reloading log config from %s due to SIGHUP",
log_config)
load_log_config()
load_log_config()
# TODO(paul): obviously this is a terrible mechanism for
# stealing SIGHUP, because it means no other part of synapse
# can use it instead. If we want to catch SIGHUP anywhere
# else as well, I'd suggest we find a nicer way to broadcast
# it around.
if getattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, sighup)
# It's critical to point twisted's internal logging somewhere, otherwise it # It's critical to point twisted's internal logging somewhere, otherwise it
# stacks up and leaks kup to 64K object; # stacks up and leaks kup to 64K object;
@ -183,4 +213,7 @@ def setup_logging(log_config=None, log_file=None, verbosity=None):
# #
# However this may not be too much of a problem if we are just writing to a file. # However this may not be too much of a problem if we are just writing to a file.
observer = STDLibLogObserver() observer = STDLibLogObserver()
globalLogBeginner.beginLoggingTo([observer]) globalLogBeginner.beginLoggingTo(
[observer],
redirectStandardIO=not config.no_redirect_stdio,
)

View file

@ -52,7 +52,6 @@ class FederationServer(FederationBase):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._room_pdu_linearizer = Linearizer("fed_room_pdu")
self._server_linearizer = Linearizer("fed_server") self._server_linearizer = Linearizer("fed_server")
# We cache responses to state queries, as they take a while and often # We cache responses to state queries, as they take a while and often
@ -165,7 +164,7 @@ class FederationServer(FederationBase):
) )
try: try:
yield self._handle_new_pdu(transaction.origin, pdu) yield self._handle_received_pdu(transaction.origin, pdu)
results.append({}) results.append({})
except FederationError as e: except FederationError as e:
self.send_failure(e, transaction.origin) self.send_failure(e, transaction.origin)
@ -497,27 +496,16 @@ class FederationServer(FederationBase):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function def _handle_received_pdu(self, origin, pdu):
def _handle_new_pdu(self, origin, pdu, get_missing=True): """ Process a PDU received in a federation /send/ transaction.
# We reprocess pdus when we have seen them only as outliers Args:
existing = yield self._get_persisted_pdu( origin (str): server which sent the pdu
origin, pdu.event_id, do_auth=False pdu (FrozenEvent): received pdu
)
# FIXME: Currently we fetch an event again when we already have it
# if it has been marked as an outlier.
already_seen = (
existing and (
not existing.internal_metadata.is_outlier()
or pdu.internal_metadata.is_outlier()
)
)
if already_seen:
logger.debug("Already seen pdu %s", pdu.event_id)
return
Returns (Deferred): completes with None
Raises: FederationError if the signatures / hash do not match
"""
# Check signature. # Check signature.
try: try:
pdu = yield self._check_sigs_and_hash(pdu) pdu = yield self._check_sigs_and_hash(pdu)
@ -529,143 +517,7 @@ class FederationServer(FederationBase):
affected=pdu.event_id, affected=pdu.event_id,
) )
state = None yield self.handler.on_receive_pdu(origin, pdu, get_missing=True)
auth_chain = []
have_seen = yield self.store.have_events(
[ev for ev, _ in pdu.prev_events]
)
fetch_state = False
# Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier():
# We only backfill backwards to the min depth.
min_depth = yield self.handler.get_min_depth_for_context(
pdu.room_id
)
logger.debug(
"_handle_new_pdu min_depth for %s: %d",
pdu.room_id, min_depth
)
prevs = {e_id for e_id, _ in pdu.prev_events}
seen = set(have_seen.keys())
if min_depth and pdu.depth < min_depth:
# This is so that we don't notify the user about this
# message, to work around the fact that some events will
# reference really really old events we really don't want to
# send to the clients.
pdu.internal_metadata.outlier = True
elif min_depth and pdu.depth > min_depth:
if get_missing and prevs - seen:
# If we're missing stuff, ensure we only fetch stuff one
# at a time.
logger.info(
"Acquiring lock for room %r to fetch %d missing events: %r...",
pdu.room_id, len(prevs - seen), list(prevs - seen)[:5],
)
with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
logger.info(
"Acquired lock for room %r to fetch %d missing events",
pdu.room_id, len(prevs - seen),
)
# We recalculate seen, since it may have changed.
have_seen = yield self.store.have_events(prevs)
seen = set(have_seen.keys())
if prevs - seen:
latest = yield self.store.get_latest_event_ids_in_room(
pdu.room_id
)
# We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us
latest = set(latest)
latest |= seen
logger.info(
"Missing %d events for room %r: %r...",
len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
)
# XXX: we set timeout to 10s to help workaround
# https://github.com/matrix-org/synapse/issues/1733.
# The reason is to avoid holding the linearizer lock
# whilst processing inbound /send transactions, causing
# FDs to stack up and block other inbound transactions
# which empirically can currently take up to 30 minutes.
#
# N.B. this explicitly disables retry attempts.
#
# N.B. this also increases our chances of falling back to
# fetching fresh state for the room if the missing event
# can't be found, which slightly reduces our security.
# it may also increase our DAG extremity count for the room,
# causing additional state resolution? See #1760.
# However, fetching state doesn't hold the linearizer lock
# apparently.
#
# see https://github.com/matrix-org/synapse/pull/1744
missing_events = yield self.get_missing_events(
origin,
pdu.room_id,
earliest_events_ids=list(latest),
latest_events=[pdu],
limit=10,
min_depth=min_depth,
timeout=10000,
)
# We want to sort these by depth so we process them and
# tell clients about them in order.
missing_events.sort(key=lambda x: x.depth)
for e in missing_events:
yield self._handle_new_pdu(
origin,
e,
get_missing=False
)
have_seen = yield self.store.have_events(
[ev for ev, _ in pdu.prev_events]
)
prevs = {e_id for e_id, _ in pdu.prev_events}
seen = set(have_seen.keys())
if prevs - seen:
logger.info(
"Still missing %d events for room %r: %r...",
len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
)
fetch_state = True
if fetch_state:
# We need to get the state at this event, since we haven't
# processed all the prev events.
logger.debug(
"_handle_new_pdu getting state for %s",
pdu.room_id
)
try:
state, auth_chain = yield self.get_state_for_room(
origin, pdu.room_id, pdu.event_id,
)
except:
logger.exception("Failed to get state for event: %s", pdu.event_id)
yield self.handler.on_receive_pdu(
origin,
pdu,
state=state,
auth_chain=auth_chain,
)
def __str__(self): def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name return "<ReplicationLayer(%s)>" % self.server_name

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd # Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -47,6 +48,7 @@ class AuthHandler(BaseHandler):
LoginType.PASSWORD: self._check_password_auth, LoginType.PASSWORD: self._check_password_auth,
LoginType.RECAPTCHA: self._check_recaptcha, LoginType.RECAPTCHA: self._check_recaptcha,
LoginType.EMAIL_IDENTITY: self._check_email_identity, LoginType.EMAIL_IDENTITY: self._check_email_identity,
LoginType.MSISDN: self._check_msisdn,
LoginType.DUMMY: self._check_dummy_auth, LoginType.DUMMY: self._check_dummy_auth,
} }
self.bcrypt_rounds = hs.config.bcrypt_rounds self.bcrypt_rounds = hs.config.bcrypt_rounds
@ -307,31 +309,47 @@ class AuthHandler(BaseHandler):
defer.returnValue(True) defer.returnValue(True)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
@defer.inlineCallbacks
def _check_email_identity(self, authdict, _): def _check_email_identity(self, authdict, _):
return self._check_threepid('email', authdict)
def _check_msisdn(self, authdict, _):
return self._check_threepid('msisdn', authdict)
@defer.inlineCallbacks
def _check_dummy_auth(self, authdict, _):
yield run_on_reactor()
defer.returnValue(True)
@defer.inlineCallbacks
def _check_threepid(self, medium, authdict):
yield run_on_reactor() yield run_on_reactor()
if 'threepid_creds' not in authdict: if 'threepid_creds' not in authdict:
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM) raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
threepid_creds = authdict['threepid_creds'] threepid_creds = authdict['threepid_creds']
identity_handler = self.hs.get_handlers().identity_handler identity_handler = self.hs.get_handlers().identity_handler
logger.info("Getting validated threepid. threepidcreds: %r" % (threepid_creds,)) logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
threepid = yield identity_handler.threepid_from_creds(threepid_creds) threepid = yield identity_handler.threepid_from_creds(threepid_creds)
if not threepid: if not threepid:
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
if threepid['medium'] != medium:
raise LoginError(
401,
"Expecting threepid of type '%s', got '%s'" % (
medium, threepid['medium'],
),
errcode=Codes.UNAUTHORIZED
)
threepid['threepid_creds'] = authdict['threepid_creds'] threepid['threepid_creds'] = authdict['threepid_creds']
defer.returnValue(threepid) defer.returnValue(threepid)
@defer.inlineCallbacks
def _check_dummy_auth(self, authdict, _):
yield run_on_reactor()
defer.returnValue(True)
def _get_params_recaptcha(self): def _get_params_recaptcha(self):
return {"public_key": self.hs.config.recaptcha_public_key} return {"public_key": self.hs.config.recaptcha_public_key}

View file

@ -169,6 +169,40 @@ class DeviceHandler(BaseHandler):
yield self.notify_device_update(user_id, [device_id]) yield self.notify_device_update(user_id, [device_id])
@defer.inlineCallbacks
def delete_devices(self, user_id, device_ids):
""" Delete several devices
Args:
user_id (str):
device_ids (str): The list of device IDs to delete
Returns:
defer.Deferred:
"""
try:
yield self.store.delete_devices(user_id, device_ids)
except errors.StoreError, e:
if e.code == 404:
# no match
pass
else:
raise
# Delete access tokens and e2e keys for each device. Not optimised as it is not
# considered as part of a critical path.
for device_id in device_ids:
yield self.store.user_delete_access_tokens(
user_id, device_id=device_id,
delete_refresh_tokens=True,
)
yield self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id
)
yield self.notify_device_update(user_id, device_ids)
@defer.inlineCallbacks @defer.inlineCallbacks
def update_device(self, user_id, device_id, content): def update_device(self, user_id, device_id, content):
""" Update the given device """ Update the given device
@ -262,7 +296,7 @@ class DeviceHandler(BaseHandler):
# ordering: treat it the same as a new room # ordering: treat it the same as a new room
event_ids = [] event_ids = []
current_state_ids = yield self.state.get_current_state_ids(room_id) current_state_ids = yield self.store.get_current_state_ids(room_id)
# special-case for an empty prev state: include all members # special-case for an empty prev state: include all members
# in the changed list # in the changed list

View file

@ -31,7 +31,7 @@ from synapse.util.logcontext import (
) )
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor, Linearizer
from synapse.util.frozenutils import unfreeze from synapse.util.frozenutils import unfreeze
from synapse.crypto.event_signing import ( from synapse.crypto.event_signing import (
compute_event_signature, add_hashes_and_signatures, compute_event_signature, add_hashes_and_signatures,
@ -79,12 +79,204 @@ class FederationHandler(BaseHandler):
# When joining a room we need to queue any events for that room up # When joining a room we need to queue any events for that room up
self.room_queues = {} self.room_queues = {}
self._room_pdu_linearizer = Linearizer("fed_room_pdu")
@defer.inlineCallbacks
@log_function
def on_receive_pdu(self, origin, pdu, get_missing=True):
""" Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events
Args:
origin (str): server which initiated the /send/ transaction. Will
be used to fetch missing events or state.
pdu (FrozenEvent): received PDU
get_missing (bool): True if we should fetch missing prev_events
Returns (Deferred): completes with None
"""
# We reprocess pdus when we have seen them only as outliers
existing = yield self.get_persisted_pdu(
origin, pdu.event_id, do_auth=False
)
# FIXME: Currently we fetch an event again when we already have it
# if it has been marked as an outlier.
already_seen = (
existing and (
not existing.internal_metadata.is_outlier()
or pdu.internal_metadata.is_outlier()
)
)
if already_seen:
logger.debug("Already seen pdu %s", pdu.event_id)
return
state = None
auth_chain = []
have_seen = yield self.store.have_events(
[ev for ev, _ in pdu.prev_events]
)
fetch_state = False
# Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier():
# We only backfill backwards to the min depth.
min_depth = yield self.get_min_depth_for_context(
pdu.room_id
)
logger.debug(
"_handle_new_pdu min_depth for %s: %d",
pdu.room_id, min_depth
)
prevs = {e_id for e_id, _ in pdu.prev_events}
seen = set(have_seen.keys())
if min_depth and pdu.depth < min_depth:
# This is so that we don't notify the user about this
# message, to work around the fact that some events will
# reference really really old events we really don't want to
# send to the clients.
pdu.internal_metadata.outlier = True
elif min_depth and pdu.depth > min_depth:
if get_missing and prevs - seen:
# If we're missing stuff, ensure we only fetch stuff one
# at a time.
logger.info(
"Acquiring lock for room %r to fetch %d missing events: %r...",
pdu.room_id, len(prevs - seen), list(prevs - seen)[:5],
)
with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
logger.info(
"Acquired lock for room %r to fetch %d missing events",
pdu.room_id, len(prevs - seen),
)
yield self._get_missing_events_for_pdu(
origin, pdu, prevs, min_depth
)
prevs = {e_id for e_id, _ in pdu.prev_events}
seen = set(have_seen.keys())
if prevs - seen:
logger.info(
"Still missing %d events for room %r: %r...",
len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
)
fetch_state = True
if fetch_state:
# We need to get the state at this event, since we haven't
# processed all the prev events.
logger.debug(
"_handle_new_pdu getting state for %s",
pdu.room_id
)
try:
state, auth_chain = yield self.replication_layer.get_state_for_room(
origin, pdu.room_id, pdu.event_id,
)
except:
logger.exception("Failed to get state for event: %s", pdu.event_id)
yield self._process_received_pdu(
origin,
pdu,
state=state,
auth_chain=auth_chain,
)
@defer.inlineCallbacks
def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
"""
Args:
origin (str): Origin of the pdu. Will be called to get the missing events
pdu: received pdu
prevs (str[]): List of event ids which we are missing
min_depth (int): Minimum depth of events to return.
Returns:
Deferred<dict(str, str?)>: updated have_seen dictionary
"""
# We recalculate seen, since it may have changed.
have_seen = yield self.store.have_events(prevs)
seen = set(have_seen.keys())
if not prevs - seen:
# nothing left to do
defer.returnValue(have_seen)
latest = yield self.store.get_latest_event_ids_in_room(
pdu.room_id
)
# We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us
latest = set(latest)
latest |= seen
logger.info(
"Missing %d events for room %r: %r...",
len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
)
# XXX: we set timeout to 10s to help workaround
# https://github.com/matrix-org/synapse/issues/1733.
# The reason is to avoid holding the linearizer lock
# whilst processing inbound /send transactions, causing
# FDs to stack up and block other inbound transactions
# which empirically can currently take up to 30 minutes.
#
# N.B. this explicitly disables retry attempts.
#
# N.B. this also increases our chances of falling back to
# fetching fresh state for the room if the missing event
# can't be found, which slightly reduces our security.
# it may also increase our DAG extremity count for the room,
# causing additional state resolution? See #1760.
# However, fetching state doesn't hold the linearizer lock
# apparently.
#
# see https://github.com/matrix-org/synapse/pull/1744
missing_events = yield self.replication_layer.get_missing_events(
origin,
pdu.room_id,
earliest_events_ids=list(latest),
latest_events=[pdu],
limit=10,
min_depth=min_depth,
timeout=10000,
)
# We want to sort these by depth so we process them and
# tell clients about them in order.
missing_events.sort(key=lambda x: x.depth)
for e in missing_events:
yield self.on_receive_pdu(
origin,
e,
get_missing=False
)
have_seen = yield self.store.have_events(
[ev for ev, _ in pdu.prev_events]
)
defer.returnValue(have_seen)
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None): def _process_received_pdu(self, origin, pdu, state=None, auth_chain=None):
""" Called by the ReplicationLayer when we have a new pdu. We need to """ Called when we have a new pdu. We need to do auth checks and put it
do auth checks and put it through the StateHandler. through the StateHandler.
auth_chain and state are None if we already have the necessary state auth_chain and state are None if we already have the necessary state
and prev_events in the db and prev_events in the db
@ -738,7 +930,7 @@ class FederationHandler(BaseHandler):
continue continue
try: try:
self.on_receive_pdu(origin, p) self._process_received_pdu(origin, p)
except: except:
logger.exception("Couldn't handle pdu") logger.exception("Couldn't handle pdu")

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -150,7 +151,7 @@ class IdentityHandler(BaseHandler):
params.update(kwargs) params.update(kwargs)
try: try:
data = yield self.http_client.post_urlencoded_get_json( data = yield self.http_client.post_json_get_json(
"https://%s%s" % ( "https://%s%s" % (
id_server, id_server,
"/_matrix/identity/api/v1/validate/email/requestToken" "/_matrix/identity/api/v1/validate/email/requestToken"
@ -161,3 +162,37 @@ class IdentityHandler(BaseHandler):
except CodeMessageException as e: except CodeMessageException as e:
logger.info("Proxied requestToken failed: %r", e) logger.info("Proxied requestToken failed: %r", e)
raise e raise e
@defer.inlineCallbacks
def requestMsisdnToken(
self, id_server, country, phone_number,
client_secret, send_attempt, **kwargs
):
yield run_on_reactor()
if not self._should_trust_id_server(id_server):
raise SynapseError(
400, "Untrusted ID server '%s'" % id_server,
Codes.SERVER_NOT_TRUSTED
)
params = {
'country': country,
'phone_number': phone_number,
'client_secret': client_secret,
'send_attempt': send_attempt,
}
params.update(kwargs)
try:
data = yield self.http_client.post_json_get_json(
"https://%s%s" % (
id_server,
"/_matrix/identity/api/v1/validate/msisdn/requestToken"
),
params
)
defer.returnValue(data)
except CodeMessageException as e:
logger.info("Proxied requestToken failed: %r", e)
raise e

View file

@ -21,6 +21,7 @@ from synapse.api.constants import (
EventTypes, JoinRules, EventTypes, JoinRules,
) )
from synapse.util.async import concurrently_execute from synapse.util.async import concurrently_execute
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.types import ThirdPartyInstanceID from synapse.types import ThirdPartyInstanceID
@ -62,6 +63,10 @@ class RoomListHandler(BaseHandler):
appservice and network id to use an appservice specific one. appservice and network id to use an appservice specific one.
Setting to None returns all public rooms across all lists. Setting to None returns all public rooms across all lists.
""" """
logger.info(
"Getting public room list: limit=%r, since=%r, search=%r, network=%r",
limit, since_token, bool(search_filter), network_tuple,
)
if search_filter: if search_filter:
# We explicitly don't bother caching searches or requests for # We explicitly don't bother caching searches or requests for
# appservice specific lists. # appservice specific lists.
@ -91,7 +96,6 @@ class RoomListHandler(BaseHandler):
rooms_to_order_value = {} rooms_to_order_value = {}
rooms_to_num_joined = {} rooms_to_num_joined = {}
rooms_to_latest_event_ids = {}
newly_visible = [] newly_visible = []
newly_unpublished = [] newly_unpublished = []
@ -116,19 +120,26 @@ class RoomListHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_order_for_room(room_id): def get_order_for_room(room_id):
latest_event_ids = rooms_to_latest_event_ids.get(room_id, None) # Most of the rooms won't have changed between the since token and
if not latest_event_ids: # now (especially if the since token is "now"). So, we can ask what
# the current users are in a room (that will hit a cache) and then
# check if the room has changed since the since token. (We have to
# do it in that order to avoid races).
# If things have changed then fall back to getting the current state
# at the since token.
joined_users = yield self.store.get_users_in_room(room_id)
if self.store.has_room_changed_since(room_id, stream_token):
latest_event_ids = yield self.store.get_forward_extremeties_for_room( latest_event_ids = yield self.store.get_forward_extremeties_for_room(
room_id, stream_token room_id, stream_token
) )
rooms_to_latest_event_ids[room_id] = latest_event_ids
if not latest_event_ids: if not latest_event_ids:
return return
joined_users = yield self.state_handler.get_current_user_in_room(
room_id, latest_event_ids,
)
joined_users = yield self.state_handler.get_current_user_in_room(
room_id, latest_event_ids,
)
num_joined_users = len(joined_users) num_joined_users = len(joined_users)
rooms_to_num_joined[room_id] = num_joined_users rooms_to_num_joined[room_id] = num_joined_users
@ -165,19 +176,19 @@ class RoomListHandler(BaseHandler):
rooms_to_scan = rooms_to_scan[:since_token.current_limit] rooms_to_scan = rooms_to_scan[:since_token.current_limit]
rooms_to_scan.reverse() rooms_to_scan.reverse()
# Actually generate the entries. _generate_room_entry will append to # Actually generate the entries. _append_room_entry_to_chunk will append to
# chunk but will stop if len(chunk) > limit # chunk but will stop if len(chunk) > limit
chunk = [] chunk = []
if limit and not search_filter: if limit and not search_filter:
step = limit + 1 step = limit + 1
for i in xrange(0, len(rooms_to_scan), step): for i in xrange(0, len(rooms_to_scan), step):
# We iterate here because the vast majority of cases we'll stop # We iterate here because the vast majority of cases we'll stop
# at first iteration, but occaisonally _generate_room_entry # at first iteration, but occaisonally _append_room_entry_to_chunk
# won't append to the chunk and so we need to loop again. # won't append to the chunk and so we need to loop again.
# We don't want to scan over the entire range either as that # We don't want to scan over the entire range either as that
# would potentially waste a lot of work. # would potentially waste a lot of work.
yield concurrently_execute( yield concurrently_execute(
lambda r: self._generate_room_entry( lambda r: self._append_room_entry_to_chunk(
r, rooms_to_num_joined[r], r, rooms_to_num_joined[r],
chunk, limit, search_filter chunk, limit, search_filter
), ),
@ -187,7 +198,7 @@ class RoomListHandler(BaseHandler):
break break
else: else:
yield concurrently_execute( yield concurrently_execute(
lambda r: self._generate_room_entry( lambda r: self._append_room_entry_to_chunk(
r, rooms_to_num_joined[r], r, rooms_to_num_joined[r],
chunk, limit, search_filter chunk, limit, search_filter
), ),
@ -256,21 +267,35 @@ class RoomListHandler(BaseHandler):
defer.returnValue(results) defer.returnValue(results)
@defer.inlineCallbacks @defer.inlineCallbacks
def _generate_room_entry(self, room_id, num_joined_users, chunk, limit, def _append_room_entry_to_chunk(self, room_id, num_joined_users, chunk, limit,
search_filter): search_filter):
"""Generate the entry for a room in the public room list and append it
to the `chunk` if it matches the search filter
"""
if limit and len(chunk) > limit + 1: if limit and len(chunk) > limit + 1:
# We've already got enough, so lets just drop it. # We've already got enough, so lets just drop it.
return return
result = yield self._generate_room_entry(room_id, num_joined_users)
if result and _matches_room_entry(result, search_filter):
chunk.append(result)
@cachedInlineCallbacks(num_args=1, cache_context=True)
def _generate_room_entry(self, room_id, num_joined_users, cache_context):
"""Returns the entry for a room
"""
result = { result = {
"room_id": room_id, "room_id": room_id,
"num_joined_members": num_joined_users, "num_joined_members": num_joined_users,
} }
current_state_ids = yield self.state_handler.get_current_state_ids(room_id) current_state_ids = yield self.store.get_current_state_ids(
room_id, on_invalidate=cache_context.invalidate,
)
event_map = yield self.store.get_events([ event_map = yield self.store.get_events([
event_id for key, event_id in current_state_ids.items() event_id for key, event_id in current_state_ids.iteritems()
if key[0] in ( if key[0] in (
EventTypes.JoinRules, EventTypes.JoinRules,
EventTypes.Name, EventTypes.Name,
@ -294,7 +319,9 @@ class RoomListHandler(BaseHandler):
if join_rule and join_rule != JoinRules.PUBLIC: if join_rule and join_rule != JoinRules.PUBLIC:
defer.returnValue(None) defer.returnValue(None)
aliases = yield self.store.get_aliases_for_room(room_id) aliases = yield self.store.get_aliases_for_room(
room_id, on_invalidate=cache_context.invalidate
)
if aliases: if aliases:
result["aliases"] = aliases result["aliases"] = aliases
@ -334,8 +361,7 @@ class RoomListHandler(BaseHandler):
if avatar_url: if avatar_url:
result["avatar_url"] = avatar_url result["avatar_url"] = avatar_url
if _matches_room_entry(result, search_filter): defer.returnValue(result)
chunk.append(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_remote_public_room_list(self, server_name, limit=None, since_token=None, def get_remote_public_room_list(self, server_name, limit=None, since_token=None,

View file

@ -108,6 +108,12 @@ class MatrixFederationHttpClient(object):
query_bytes=b"", retry_on_dns_fail=True, query_bytes=b"", retry_on_dns_fail=True,
timeout=None, long_retries=False): timeout=None, long_retries=False):
""" Creates and sends a request to the given url """ Creates and sends a request to the given url
Returns:
Deferred: resolves with the http response object on success.
Fails with ``HTTPRequestException``: if we get an HTTP response
code >= 300.
""" """
headers_dict[b"User-Agent"] = [self.version_string] headers_dict[b"User-Agent"] = [self.version_string]
headers_dict[b"Host"] = [destination] headers_dict[b"Host"] = [destination]
@ -408,8 +414,11 @@ class MatrixFederationHttpClient(object):
output_stream (file): File to write the response body to. output_stream (file): File to write the response body to.
args (dict): Optional dictionary used to create the query string. args (dict): Optional dictionary used to create the query string.
Returns: Returns:
A (int,dict) tuple of the file length and a dict of the response Deferred: resolves with an (int,dict) tuple of the file length and
headers. a dict of the response headers.
Fails with ``HTTPRequestException`` if we get an HTTP response code
>= 300
""" """
encoded_args = {} encoded_args = {}
@ -419,7 +428,7 @@ class MatrixFederationHttpClient(object):
encoded_args[k] = [v.encode("UTF-8") for v in vs] encoded_args[k] = [v.encode("UTF-8") for v in vs]
query_bytes = urllib.urlencode(encoded_args, True) query_bytes = urllib.urlencode(encoded_args, True)
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail) logger.debug("Query bytes: %s Retry DNS: %s", query_bytes, retry_on_dns_fail)
def body_callback(method, url_bytes, headers_dict): def body_callback(method, url_bytes, headers_dict):
self.sign_request(destination, method, url_bytes, headers_dict) self.sign_request(destination, method, url_bytes, headers_dict)

View file

@ -192,6 +192,16 @@ def parse_json_object_from_request(request):
return content return content
def assert_params_in_request(body, required):
absent = []
for k in required:
if k not in body:
absent.append(k)
if len(absent) > 0:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
class RestServlet(object): class RestServlet(object):
""" A Synapse REST Servlet. """ A Synapse REST Servlet.

View file

@ -73,6 +73,13 @@ class _NotifierUserStream(object):
self.user_id = user_id self.user_id = user_id
self.rooms = set(rooms) self.rooms = set(rooms)
self.current_token = current_token self.current_token = current_token
# The last token for which we should wake up any streams that have a
# token that comes before it. This gets updated everytime we get poked.
# We start it at the current token since if we get any streams
# that have a token from before we have no idea whether they should be
# woken up or not, so lets just wake them up.
self.last_notified_token = current_token
self.last_notified_ms = time_now_ms self.last_notified_ms = time_now_ms
with PreserveLoggingContext(): with PreserveLoggingContext():
@ -89,6 +96,7 @@ class _NotifierUserStream(object):
self.current_token = self.current_token.copy_and_advance( self.current_token = self.current_token.copy_and_advance(
stream_key, stream_id stream_key, stream_id
) )
self.last_notified_token = self.current_token
self.last_notified_ms = time_now_ms self.last_notified_ms = time_now_ms
noify_deferred = self.notify_deferred noify_deferred = self.notify_deferred
@ -113,8 +121,14 @@ class _NotifierUserStream(object):
def new_listener(self, token): def new_listener(self, token):
"""Returns a deferred that is resolved when there is a new token """Returns a deferred that is resolved when there is a new token
greater than the given token. greater than the given token.
Args:
token: The token from which we are streaming from, i.e. we shouldn't
notify for things that happened before this.
""" """
if self.current_token.is_after(token): # Immediately wake up stream if something has already since happened
# since their last token.
if self.last_notified_token.is_after(token):
return _NotificationListener(defer.succeed(self.current_token)) return _NotificationListener(defer.succeed(self.current_token))
else: else:
return _NotificationListener(self.notify_deferred.observe()) return _NotificationListener(self.notify_deferred.observe())
@ -294,40 +308,44 @@ class Notifier(object):
self._register_with_keys(user_stream) self._register_with_keys(user_stream)
result = None result = None
prev_token = from_token
if timeout: if timeout:
end_time = self.clock.time_msec() + timeout end_time = self.clock.time_msec() + timeout
prev_token = from_token
while not result: while not result:
try: try:
current_token = user_stream.current_token
result = yield callback(prev_token, current_token)
if result:
break
now = self.clock.time_msec() now = self.clock.time_msec()
if end_time <= now: if end_time <= now:
break break
# Now we wait for the _NotifierUserStream to be told there # Now we wait for the _NotifierUserStream to be told there
# is a new token. # is a new token.
# We need to supply the token we supplied to callback so
# that we don't miss any current_token updates.
prev_token = current_token
listener = user_stream.new_listener(prev_token) listener = user_stream.new_listener(prev_token)
with PreserveLoggingContext(): with PreserveLoggingContext():
yield self.clock.time_bound_deferred( yield self.clock.time_bound_deferred(
listener.deferred, listener.deferred,
time_out=(end_time - now) / 1000. time_out=(end_time - now) / 1000.
) )
current_token = user_stream.current_token
result = yield callback(prev_token, current_token)
if result:
break
# Update the prev_token to the current_token since nothing
# has happened between the old prev_token and the current_token
prev_token = current_token
except DeferredTimedOutError: except DeferredTimedOutError:
break break
except defer.CancelledError: except defer.CancelledError:
break break
else:
if result is None:
# This happened if there was no timeout or if the timeout had
# already expired.
current_token = user_stream.current_token current_token = user_stream.current_token
result = yield callback(from_token, current_token) result = yield callback(prev_token, current_token)
defer.returnValue(result) defer.returnValue(result)

View file

@ -139,7 +139,7 @@ class Mailer(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _fetch_room_state(room_id): def _fetch_room_state(room_id):
room_state = yield self.state_handler.get_current_state_ids(room_id) room_state = yield self.store.get_current_state_ids(room_id)
state_by_room[room_id] = room_state state_by_room[room_id] = room_state
# Run at most 3 of these at once: sync does 10 at a time but email # Run at most 3 of these at once: sync does 10 at a time but email

View file

@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -37,6 +38,7 @@ REQUIREMENTS = {
"pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"], "pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"],
"pymacaroons-pynacl": ["pymacaroons"], "pymacaroons-pynacl": ["pymacaroons"],
"msgpack-python>=0.3.0": ["msgpack"], "msgpack-python>=0.3.0": ["msgpack"],
"phonenumbers>=8.2.0": ["phonenumbers"],
} }
CONDITIONAL_REQUIREMENTS = { CONDITIONAL_REQUIREMENTS = {
"web_client": { "web_client": {

View file

@ -109,6 +109,10 @@ class SlavedEventStore(BaseSlavedStore):
get_recent_event_ids_for_room = ( get_recent_event_ids_for_room = (
StreamStore.__dict__["get_recent_event_ids_for_room"] StreamStore.__dict__["get_recent_event_ids_for_room"]
) )
get_current_state_ids = (
StateStore.__dict__["get_current_state_ids"]
)
has_room_changed_since = DataStore.has_room_changed_since.__func__
get_unread_push_actions_for_user_in_range_for_http = ( get_unread_push_actions_for_user_in_range_for_http = (
DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__ DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__

View file

@ -19,6 +19,7 @@ from synapse.api.errors import SynapseError, LoginError, Codes
from synapse.types import UserID from synapse.types import UserID
from synapse.http.server import finish_request from synapse.http.server import finish_request
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.util.msisdn import phone_number_to_msisdn
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
@ -37,6 +38,49 @@ import xml.etree.ElementTree as ET
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def login_submission_legacy_convert(submission):
"""
If the input login submission is an old style object
(ie. with top-level user / medium / address) convert it
to a typed object.
"""
if "user" in submission:
submission["identifier"] = {
"type": "m.id.user",
"user": submission["user"],
}
del submission["user"]
if "medium" in submission and "address" in submission:
submission["identifier"] = {
"type": "m.id.thirdparty",
"medium": submission["medium"],
"address": submission["address"],
}
del submission["medium"]
del submission["address"]
def login_id_thirdparty_from_phone(identifier):
"""
Convert a phone login identifier type to a generic threepid identifier
Args:
identifier(dict): Login identifier dict of type 'm.id.phone'
Returns: Login identifier dict of type 'm.id.threepid'
"""
if "country" not in identifier or "number" not in identifier:
raise SynapseError(400, "Invalid phone-type identifier")
msisdn = phone_number_to_msisdn(identifier["country"], identifier["number"])
return {
"type": "m.id.thirdparty",
"medium": "msisdn",
"address": msisdn,
}
class LoginRestServlet(ClientV1RestServlet): class LoginRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login$") PATTERNS = client_path_patterns("/login$")
PASS_TYPE = "m.login.password" PASS_TYPE = "m.login.password"
@ -117,20 +161,52 @@ class LoginRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def do_password_login(self, login_submission): def do_password_login(self, login_submission):
if 'medium' in login_submission and 'address' in login_submission: if "password" not in login_submission:
address = login_submission['address'] raise SynapseError(400, "Missing parameter: password")
if login_submission['medium'] == 'email':
login_submission_legacy_convert(login_submission)
if "identifier" not in login_submission:
raise SynapseError(400, "Missing param: identifier")
identifier = login_submission["identifier"]
if "type" not in identifier:
raise SynapseError(400, "Login identifier has no type")
# convert phone type identifiers to generic threepids
if identifier["type"] == "m.id.phone":
identifier = login_id_thirdparty_from_phone(identifier)
# convert threepid identifiers to user IDs
if identifier["type"] == "m.id.thirdparty":
if 'medium' not in identifier or 'address' not in identifier:
raise SynapseError(400, "Invalid thirdparty identifier")
address = identifier['address']
if identifier['medium'] == 'email':
# For emails, transform the address to lowercase. # For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB. # We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py) # (See add_threepid in synapse/handlers/auth.py)
address = address.lower() address = address.lower()
user_id = yield self.hs.get_datastore().get_user_id_by_threepid( user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
login_submission['medium'], address identifier['medium'], address
) )
if not user_id: if not user_id:
raise LoginError(403, "", errcode=Codes.FORBIDDEN) raise LoginError(403, "", errcode=Codes.FORBIDDEN)
else:
user_id = login_submission['user'] identifier = {
"type": "m.id.user",
"user": user_id,
}
# by this point, the identifier should be an m.id.user: if it's anything
# else, we haven't understood it.
if identifier["type"] != "m.id.user":
raise SynapseError(400, "Unknown login identifier type")
if "user" not in identifier:
raise SynapseError(400, "User identifier is missing 'user' key")
user_id = identifier["user"]
if not user_id.startswith('@'): if not user_id.startswith('@'):
user_id = UserID.create( user_id = UserID.create(

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -17,8 +18,11 @@ from twisted.internet import defer
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import LoginError, SynapseError, Codes from synapse.api.errors import LoginError, SynapseError, Codes
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import (
RestServlet, parse_json_object_from_request, assert_params_in_request
)
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.msisdn import phone_number_to_msisdn
from ._base import client_v2_patterns from ._base import client_v2_patterns
@ -28,11 +32,11 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PasswordRequestTokenRestServlet(RestServlet): class EmailPasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password/email/requestToken$") PATTERNS = client_v2_patterns("/account/password/email/requestToken$")
def __init__(self, hs): def __init__(self, hs):
super(PasswordRequestTokenRestServlet, self).__init__() super(EmailPasswordRequestTokenRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
@ -40,14 +44,9 @@ class PasswordRequestTokenRestServlet(RestServlet):
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
required = ['id_server', 'client_secret', 'email', 'send_attempt'] assert_params_in_request(body, [
absent = [] 'id_server', 'client_secret', 'email', 'send_attempt'
for k in required: ])
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email'] 'email', body['email']
@ -60,6 +59,37 @@ class PasswordRequestTokenRestServlet(RestServlet):
defer.returnValue((200, ret)) defer.returnValue((200, ret))
class MsisdnPasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password/msisdn/requestToken$")
def __init__(self, hs):
super(MsisdnPasswordRequestTokenRestServlet, self).__init__()
self.hs = hs
self.datastore = self.hs.get_datastore()
self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
assert_params_in_request(body, [
'id_server', 'client_secret',
'country', 'phone_number', 'send_attempt',
])
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
existingUid = yield self.datastore.get_user_id_by_threepid(
'msisdn', msisdn
)
if existingUid is None:
raise SynapseError(400, "MSISDN not found", Codes.THREEPID_NOT_FOUND)
ret = yield self.identity_handler.requestMsisdnToken(**body)
defer.returnValue((200, ret))
class PasswordRestServlet(RestServlet): class PasswordRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password$") PATTERNS = client_v2_patterns("/account/password$")
@ -68,6 +98,7 @@ class PasswordRestServlet(RestServlet):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.datastore = self.hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -77,7 +108,8 @@ class PasswordRestServlet(RestServlet):
authed, result, params, _ = yield self.auth_handler.check_auth([ authed, result, params, _ = yield self.auth_handler.check_auth([
[LoginType.PASSWORD], [LoginType.PASSWORD],
[LoginType.EMAIL_IDENTITY] [LoginType.EMAIL_IDENTITY],
[LoginType.MSISDN],
], body, self.hs.get_ip_from_request(request)) ], body, self.hs.get_ip_from_request(request))
if not authed: if not authed:
@ -102,7 +134,7 @@ class PasswordRestServlet(RestServlet):
# (See add_threepid in synapse/handlers/auth.py) # (See add_threepid in synapse/handlers/auth.py)
threepid['address'] = threepid['address'].lower() threepid['address'] = threepid['address'].lower()
# if using email, we must know about the email they're authing with! # if using email, we must know about the email they're authing with!
threepid_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( threepid_user_id = yield self.datastore.get_user_id_by_threepid(
threepid['medium'], threepid['address'] threepid['medium'], threepid['address']
) )
if not threepid_user_id: if not threepid_user_id:
@ -169,13 +201,14 @@ class DeactivateAccountRestServlet(RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
class ThreepidRequestTokenRestServlet(RestServlet): class EmailThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$") PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$")
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
super(ThreepidRequestTokenRestServlet, self).__init__() super(EmailThreepidRequestTokenRestServlet, self).__init__()
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
self.datastore = self.hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -190,7 +223,7 @@ class ThreepidRequestTokenRestServlet(RestServlet):
if absent: if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( existingUid = yield self.datastore.get_user_id_by_threepid(
'email', body['email'] 'email', body['email']
) )
@ -201,6 +234,44 @@ class ThreepidRequestTokenRestServlet(RestServlet):
defer.returnValue((200, ret)) defer.returnValue((200, ret))
class MsisdnThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/msisdn/requestToken$")
def __init__(self, hs):
self.hs = hs
super(MsisdnThreepidRequestTokenRestServlet, self).__init__()
self.identity_handler = hs.get_handlers().identity_handler
self.datastore = self.hs.get_datastore()
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
required = [
'id_server', 'client_secret',
'country', 'phone_number', 'send_attempt',
]
absent = []
for k in required:
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
existingUid = yield self.datastore.get_user_id_by_threepid(
'msisdn', msisdn
)
if existingUid is not None:
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret))
class ThreepidRestServlet(RestServlet): class ThreepidRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid$") PATTERNS = client_v2_patterns("/account/3pid$")
@ -210,6 +281,7 @@ class ThreepidRestServlet(RestServlet):
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.datastore = self.hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
@ -217,7 +289,7 @@ class ThreepidRestServlet(RestServlet):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
threepids = yield self.hs.get_datastore().user_get_threepids( threepids = yield self.datastore.user_get_threepids(
requester.user.to_string() requester.user.to_string()
) )
@ -258,7 +330,7 @@ class ThreepidRestServlet(RestServlet):
if 'bind' in body and body['bind']: if 'bind' in body and body['bind']:
logger.debug( logger.debug(
"Binding emails %s to %s", "Binding threepid %s to %s",
threepid, user_id threepid, user_id
) )
yield self.identity_handler.bind_threepid( yield self.identity_handler.bind_threepid(
@ -302,9 +374,11 @@ class ThreepidDeleteRestServlet(RestServlet):
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
PasswordRequestTokenRestServlet(hs).register(http_server) EmailPasswordRequestTokenRestServlet(hs).register(http_server)
MsisdnPasswordRequestTokenRestServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server)
ThreepidRequestTokenRestServlet(hs).register(http_server) EmailThreepidRequestTokenRestServlet(hs).register(http_server)
MsisdnThreepidRequestTokenRestServlet(hs).register(http_server)
ThreepidRestServlet(hs).register(http_server) ThreepidRestServlet(hs).register(http_server)
ThreepidDeleteRestServlet(hs).register(http_server) ThreepidDeleteRestServlet(hs).register(http_server)

View file

@ -46,6 +46,52 @@ class DevicesRestServlet(servlet.RestServlet):
defer.returnValue((200, {"devices": devices})) defer.returnValue((200, {"devices": devices}))
class DeleteDevicesRestServlet(servlet.RestServlet):
"""
API for bulk deletion of devices. Accepts a JSON object with a devices
key which lists the device_ids to delete. Requires user interactive auth.
"""
PATTERNS = client_v2_patterns("/delete_devices", releases=[], v2_alpha=False)
def __init__(self, hs):
super(DeleteDevicesRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler()
@defer.inlineCallbacks
def on_POST(self, request):
try:
body = servlet.parse_json_object_from_request(request)
except errors.SynapseError as e:
if e.errcode == errors.Codes.NOT_JSON:
# deal with older clients which didn't pass a J*DELETESON dict
# the same as those that pass an empty dict
body = {}
else:
raise e
if 'devices' not in body:
raise errors.SynapseError(
400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
)
authed, result, params, _ = yield self.auth_handler.check_auth([
[constants.LoginType.PASSWORD],
], body, self.hs.get_ip_from_request(request))
if not authed:
defer.returnValue((401, result))
requester = yield self.auth.get_user_by_req(request)
yield self.device_handler.delete_devices(
requester.user.to_string(),
body['devices'],
)
defer.returnValue((200, {}))
class DeviceRestServlet(servlet.RestServlet): class DeviceRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
releases=[], v2_alpha=False) releases=[], v2_alpha=False)
@ -111,5 +157,6 @@ class DeviceRestServlet(servlet.RestServlet):
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
DeleteDevicesRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server) DevicesRestServlet(hs).register(http_server)
DeviceRestServlet(hs).register(http_server) DeviceRestServlet(hs).register(http_server)

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 - 2016 OpenMarket Ltd # Copyright 2015 - 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -19,7 +20,10 @@ import synapse
from synapse.api.auth import get_access_token_from_request, has_access_token from synapse.api.auth import get_access_token_from_request, has_access_token
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import (
RestServlet, parse_json_object_from_request, assert_params_in_request
)
from synapse.util.msisdn import phone_number_to_msisdn
from ._base import client_v2_patterns from ._base import client_v2_patterns
@ -43,7 +47,7 @@ else:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RegisterRequestTokenRestServlet(RestServlet): class EmailRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register/email/requestToken$") PATTERNS = client_v2_patterns("/register/email/requestToken$")
def __init__(self, hs): def __init__(self, hs):
@ -51,7 +55,7 @@ class RegisterRequestTokenRestServlet(RestServlet):
Args: Args:
hs (synapse.server.HomeServer): server hs (synapse.server.HomeServer): server
""" """
super(RegisterRequestTokenRestServlet, self).__init__() super(EmailRegisterRequestTokenRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
@ -59,14 +63,9 @@ class RegisterRequestTokenRestServlet(RestServlet):
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
required = ['id_server', 'client_secret', 'email', 'send_attempt'] assert_params_in_request(body, [
absent = [] 'id_server', 'client_secret', 'email', 'send_attempt'
for k in required: ])
if k not in body:
absent.append(k)
if len(absent) > 0:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email'] 'email', body['email']
@ -79,6 +78,43 @@ class RegisterRequestTokenRestServlet(RestServlet):
defer.returnValue((200, ret)) defer.returnValue((200, ret))
class MsisdnRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register/msisdn/requestToken$")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(MsisdnRegisterRequestTokenRestServlet, self).__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
assert_params_in_request(body, [
'id_server', 'client_secret',
'country', 'phone_number',
'send_attempt',
])
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'msisdn', msisdn
)
if existingUid is not None:
raise SynapseError(
400, "Phone number is already in use", Codes.THREEPID_IN_USE
)
ret = yield self.identity_handler.requestMsisdnToken(**body)
defer.returnValue((200, ret))
class RegisterRestServlet(RestServlet): class RegisterRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register$") PATTERNS = client_v2_patterns("/register$")
@ -200,16 +236,37 @@ class RegisterRestServlet(RestServlet):
assigned_user_id=registered_user_id, assigned_user_id=registered_user_id,
) )
# Only give msisdn flows if the x_show_msisdn flag is given:
# this is a hack to work around the fact that clients were shipped
# that use fallback registration if they see any flows that they don't
# recognise, which means we break registration for these clients if we
# advertise msisdn flows. Once usage of Riot iOS <=0.3.9 and Riot
# Android <=0.6.9 have fallen below an acceptable threshold, this
# parameter should go away and we should always advertise msisdn flows.
show_msisdn = False
if 'x_show_msisdn' in body and body['x_show_msisdn']:
show_msisdn = True
if self.hs.config.enable_registration_captcha: if self.hs.config.enable_registration_captcha:
flows = [ flows = [
[LoginType.RECAPTCHA], [LoginType.RECAPTCHA],
[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA] [LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
] ]
if show_msisdn:
flows.extend([
[LoginType.MSISDN, LoginType.RECAPTCHA],
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
])
else: else:
flows = [ flows = [
[LoginType.DUMMY], [LoginType.DUMMY],
[LoginType.EMAIL_IDENTITY] [LoginType.EMAIL_IDENTITY],
] ]
if show_msisdn:
flows.extend([
[LoginType.MSISDN],
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
])
authed, auth_result, params, session_id = yield self.auth_handler.check_auth( authed, auth_result, params, session_id = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request) flows, body, self.hs.get_ip_from_request(request)
@ -224,8 +281,9 @@ class RegisterRestServlet(RestServlet):
"Already registered user ID %r for this session", "Already registered user ID %r for this session",
registered_user_id registered_user_id
) )
# don't re-register the email address # don't re-register the threepids
add_email = False add_email = False
add_msisdn = False
else: else:
# NB: This may be from the auth handler and NOT from the POST # NB: This may be from the auth handler and NOT from the POST
if 'password' not in params: if 'password' not in params:
@ -250,6 +308,7 @@ class RegisterRestServlet(RestServlet):
) )
add_email = True add_email = True
add_msisdn = True
return_dict = yield self._create_registration_details( return_dict = yield self._create_registration_details(
registered_user_id, params registered_user_id, params
@ -262,6 +321,13 @@ class RegisterRestServlet(RestServlet):
params.get("bind_email") params.get("bind_email")
) )
if add_msisdn and auth_result and LoginType.MSISDN in auth_result:
threepid = auth_result[LoginType.MSISDN]
yield self._register_msisdn_threepid(
registered_user_id, threepid, return_dict["access_token"],
params.get("bind_msisdn")
)
defer.returnValue((200, return_dict)) defer.returnValue((200, return_dict))
def on_OPTIONS(self, _): def on_OPTIONS(self, _):
@ -323,8 +389,9 @@ class RegisterRestServlet(RestServlet):
""" """
reqd = ('medium', 'address', 'validated_at') reqd = ('medium', 'address', 'validated_at')
if any(x not in threepid for x in reqd): if any(x not in threepid for x in reqd):
# This will only happen if the ID server returns a malformed response
logger.info("Can't add incomplete 3pid") logger.info("Can't add incomplete 3pid")
defer.returnValue() return
yield self.auth_handler.add_threepid( yield self.auth_handler.add_threepid(
user_id, user_id,
@ -371,6 +438,43 @@ class RegisterRestServlet(RestServlet):
else: else:
logger.info("bind_email not specified: not binding email") logger.info("bind_email not specified: not binding email")
@defer.inlineCallbacks
def _register_msisdn_threepid(self, user_id, threepid, token, bind_msisdn):
"""Add a phone number as a 3pid identifier
Also optionally binds msisdn to the given user_id on the identity server
Args:
user_id (str): id of user
threepid (object): m.login.msisdn auth response
token (str): access_token for the user
bind_email (bool): true if the client requested the email to be
bound at the identity server
Returns:
defer.Deferred:
"""
reqd = ('medium', 'address', 'validated_at')
if any(x not in threepid for x in reqd):
# This will only happen if the ID server returns a malformed response
logger.info("Can't add incomplete 3pid")
defer.returnValue()
yield self.auth_handler.add_threepid(
user_id,
threepid['medium'],
threepid['address'],
threepid['validated_at'],
)
if bind_msisdn:
logger.info("bind_msisdn specified: binding")
logger.debug("Binding msisdn %s to %s", threepid, user_id)
yield self.identity_handler.bind_threepid(
threepid['threepid_creds'], user_id
)
else:
logger.info("bind_msisdn not specified: not binding msisdn")
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_registration_details(self, user_id, params): def _create_registration_details(self, user_id, params):
"""Complete registration of newly-registered user """Complete registration of newly-registered user
@ -449,5 +553,6 @@ class RegisterRestServlet(RestServlet):
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
RegisterRequestTokenRestServlet(hs).register(http_server) EmailRegisterRequestTokenRestServlet(hs).register(http_server)
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
RegisterRestServlet(hs).register(http_server) RegisterRestServlet(hs).register(http_server)

View file

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import synapse.http.servlet
from ._base import parse_media_id, respond_with_file, respond_404 from ._base import parse_media_id, respond_with_file, respond_404
from twisted.web.resource import Resource from twisted.web.resource import Resource
@ -81,6 +82,17 @@ class DownloadResource(Resource):
@defer.inlineCallbacks @defer.inlineCallbacks
def _respond_remote_file(self, request, server_name, media_id, name): def _respond_remote_file(self, request, server_name, media_id, name):
# don't forward requests for remote media if allow_remote is false
allow_remote = synapse.http.servlet.parse_boolean(
request, "allow_remote", default=True)
if not allow_remote:
logger.info(
"Rejecting request for remote media %s/%s due to allow_remote",
server_name, media_id,
)
respond_404(request)
return
media_info = yield self.media_repo.get_remote_media(server_name, media_id) media_info = yield self.media_repo.get_remote_media(server_name, media_id)
media_type = media_info["media_type"] media_type = media_info["media_type"]

View file

@ -13,22 +13,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer, threads
import twisted.internet.error
import twisted.web.http
from twisted.web.resource import Resource
from .upload_resource import UploadResource from .upload_resource import UploadResource
from .download_resource import DownloadResource from .download_resource import DownloadResource
from .thumbnail_resource import ThumbnailResource from .thumbnail_resource import ThumbnailResource
from .identicon_resource import IdenticonResource from .identicon_resource import IdenticonResource
from .preview_url_resource import PreviewUrlResource from .preview_url_resource import PreviewUrlResource
from .filepath import MediaFilePaths from .filepath import MediaFilePaths
from twisted.web.resource import Resource
from .thumbnailer import Thumbnailer from .thumbnailer import Thumbnailer
from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError, HttpResponseException, \
NotFoundError
from twisted.internet import defer, threads
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from synapse.util.stringutils import is_ascii from synapse.util.stringutils import is_ascii
@ -157,11 +158,34 @@ class MediaRepository(object):
try: try:
length, headers = yield self.client.get_file( length, headers = yield self.client.get_file(
server_name, request_path, output_stream=f, server_name, request_path, output_stream=f,
max_size=self.max_upload_size, max_size=self.max_upload_size, args={
# tell the remote server to 404 if it doesn't
# recognise the server_name, to make sure we don't
# end up with a routing loop.
"allow_remote": "false",
}
) )
except Exception as e: except twisted.internet.error.DNSLookupError as e:
logger.warn("Failed to fetch remoted media %r", e) logger.warn("HTTP error fetching remote media %s/%s: %r",
raise SynapseError(502, "Failed to fetch remoted media") server_name, media_id, e)
raise NotFoundError()
except HttpResponseException as e:
logger.warn("HTTP error fetching remote media %s/%s: %s",
server_name, media_id, e.response)
if e.code == twisted.web.http.NOT_FOUND:
raise SynapseError.from_http_response_exception(e)
raise SynapseError(502, "Failed to fetch remote media")
except SynapseError:
logger.exception("Failed to fetch remote media %s/%s",
server_name, media_id)
raise
except Exception:
logger.exception("Failed to fetch remote media %s/%s",
server_name, media_id)
raise SynapseError(502, "Failed to fetch remote media")
media_type = headers["Content-Type"][0] media_type = headers["Content-Type"][0]
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()

View file

@ -840,6 +840,47 @@ class SQLBaseStore(object):
return txn.execute(sql, keyvalues.values()) return txn.execute(sql, keyvalues.values())
def _simple_delete_many(self, table, column, iterable, keyvalues, desc):
return self.runInteraction(
desc, self._simple_delete_many_txn, table, column, iterable, keyvalues
)
@staticmethod
def _simple_delete_many_txn(txn, table, column, iterable, keyvalues):
"""Executes a DELETE query on the named table.
Filters rows by if value of `column` is in `iterable`.
Args:
txn : Transaction object
table : string giving the table name
column : column name to test for inclusion against `iterable`
iterable : list
keyvalues : dict of column names and values to select the rows with
"""
if not iterable:
return
sql = "DELETE FROM %s" % table
clauses = []
values = []
clauses.append(
"%s IN (%s)" % (column, ",".join("?" for _ in iterable))
)
values.extend(iterable)
for key, value in keyvalues.items():
clauses.append("%s = ?" % (key,))
values.append(value)
if clauses:
sql = "%s WHERE %s" % (
sql,
" AND ".join(clauses),
)
return txn.execute(sql, values)
def _get_cache_dict(self, db_conn, table, entity_column, stream_column, def _get_cache_dict(self, db_conn, table, entity_column, stream_column,
max_value, limit=100000): max_value, limit=100000):
# Fetch a mapping of room_id -> max stream position for "recent" rooms. # Fetch a mapping of room_id -> max stream position for "recent" rooms.

View file

@ -108,6 +108,23 @@ class DeviceStore(SQLBaseStore):
desc="delete_device", desc="delete_device",
) )
def delete_devices(self, user_id, device_ids):
"""Deletes several devices.
Args:
user_id (str): The ID of the user which owns the devices
device_ids (list): The IDs of the devices to delete
Returns:
defer.Deferred
"""
return self._simple_delete_many(
table="devices",
column="device_id",
iterable=device_ids,
keyvalues={"user_id": user_id},
desc="delete_devices",
)
def update_device(self, user_id, device_id, new_display_name=None): def update_device(self, user_id, device_id, new_display_name=None):
"""Update a device. """Update a device.

View file

@ -433,23 +433,43 @@ class EventsStore(SQLBaseStore):
if not new_latest_event_ids: if not new_latest_event_ids:
current_state = {} current_state = {}
elif was_updated: elif was_updated:
# We work out the current state by passing the state sets to the
# state resolution algorithm. It may ask for some events, including
# the events we have yet to persist, so we need a slightly more
# complicated event lookup function than simply looking the events
# up in the db.
events_map = {ev.event_id: ev for ev, _ in events_context}
@defer.inlineCallbacks
def get_events(ev_ids):
# We get the events by first looking at the list of events we
# are trying to persist, and then fetching the rest from the DB.
db = []
to_return = {}
for ev_id in ev_ids:
ev = events_map.get(ev_id, None)
if ev:
to_return[ev_id] = ev
else:
db.append(ev_id)
if db:
evs = yield self.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
)
to_return.update(evs)
defer.returnValue(to_return)
current_state = yield resolve_events( current_state = yield resolve_events(
state_sets, state_sets,
state_map_factory=lambda ev_ids: self.get_events( state_map_factory=get_events,
ev_ids, get_prev_content=False, check_redacted=False,
),
) )
else: else:
return return
existing_state_rows = yield self._simple_select_list( existing_state = yield self.get_current_state_ids(room_id)
table="current_state_events",
keyvalues={"room_id": room_id},
retcols=["event_id", "type", "state_key"],
desc="_calculate_state_delta",
)
existing_events = set(row["event_id"] for row in existing_state_rows) existing_events = set(existing_state.itervalues())
new_events = set(ev_id for ev_id in current_state.itervalues()) new_events = set(ev_id for ev_id in current_state.itervalues())
changed_events = existing_events ^ new_events changed_events = existing_events ^ new_events
@ -457,9 +477,8 @@ class EventsStore(SQLBaseStore):
return return
to_delete = { to_delete = {
(row["type"], row["state_key"]): row["event_id"] key: ev_id for key, ev_id in existing_state.iteritems()
for row in existing_state_rows if ev_id in changed_events
if row["event_id"] in changed_events
} }
events_to_insert = (new_events - existing_events) events_to_insert = (new_events - existing_events)
to_insert = { to_insert = {
@ -585,6 +604,10 @@ class EventsStore(SQLBaseStore):
txn, self.get_users_in_room, (room_id,) txn, self.get_users_in_room, (room_id,)
) )
self._invalidate_cache_and_stream(
txn, self.get_current_state_ids, (room_id,)
)
for room_id, new_extrem in new_forward_extremeties.items(): for room_id, new_extrem in new_forward_extremeties.items():
self._simple_delete_txn( self._simple_delete_txn(
txn, txn,

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
from synapse.util.caches import intern_string from synapse.util.caches import intern_string
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
@ -69,6 +69,18 @@ class StateStore(SQLBaseStore):
where_clause="type='m.room.member'", where_clause="type='m.room.member'",
) )
@cachedInlineCallbacks(max_entries=100000, iterable=True)
def get_current_state_ids(self, room_id):
rows = yield self._simple_select_list(
table="current_state_events",
keyvalues={"room_id": room_id},
retcols=["event_id", "type", "state_key"],
desc="_calculate_state_delta",
)
defer.returnValue({
(r["type"], r["state_key"]): r["event_id"] for r in rows
})
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_groups_ids(self, room_id, event_ids): def get_state_groups_ids(self, room_id, event_ids):
if not event_ids: if not event_ids:

View file

@ -829,3 +829,6 @@ class StreamStore(SQLBaseStore):
updatevalues={"stream_id": stream_id}, updatevalues={"stream_id": stream_id},
desc="update_federation_out_pos", desc="update_federation_out_pos",
) )
def has_room_changed_since(self, room_id, stream_id):
return self._events_stream_cache.has_entity_changed(room_id, stream_id)

40
synapse/util/msisdn.py Normal file
View file

@ -0,0 +1,40 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import phonenumbers
from synapse.api.errors import SynapseError
def phone_number_to_msisdn(country, number):
"""
Takes an ISO-3166-1 2 letter country code and phone number and
returns an msisdn representing the canonical version of that
phone number.
Args:
country (str): ISO-3166-1 2 letter country code
number (str): Phone number in a national or international format
Returns:
(str) The canonical form of the phone number, as an msisdn
Raises:
SynapseError if the number could not be parsed.
"""
try:
phoneNumber = phonenumbers.parse(number, country)
except phonenumbers.NumberParseException:
raise SynapseError(400, "Unable to parse phone number")
return phonenumbers.format_number(
phoneNumber, phonenumbers.PhoneNumberFormat.E164
)[1:]