Merge branch 'develop' into matthew/gin_work_mem

This commit is contained in:
Richard van der Hoff 2018-02-03 22:40:28 +00:00
commit bb9f0f3cdb
87 changed files with 2822 additions and 1122 deletions

View File

@ -1,3 +1,56 @@
Unreleased
==========
synctl no longer starts the main synapse when using ``-a`` option with workers.
A new worker file should be added with ``worker_app: synapse.app.homeserver``.
This release also begins the process of renaming a number of the metrics
reported to prometheus. See `docs/metrics-howto.rst <docs/metrics-howto.rst#block-and-response-metrics-renamed-for-0-27-0>`_.
Changes in synapse v0.26.0 (2018-01-05)
=======================================
No changes since v0.26.0-rc1
Changes in synapse v0.26.0-rc1 (2017-12-13)
===========================================
Features:
* Add ability for ASes to publicise groups for their users (PR #2686)
* Add all local users to the user_directory and optionally search them (PR
#2723)
* Add support for custom login types for validating users (PR #2729)
Changes:
* Update example Prometheus config to new format (PR #2648) Thanks to
@krombel!
* Rename redact_content option to include_content in Push API (PR #2650)
* Declare support for r0.3.0 (PR #2677)
* Improve upserts (PR #2684, #2688, #2689, #2713)
* Improve documentation of workers (PR #2700)
* Improve tracebacks on exceptions (PR #2705)
* Allow guest access to group APIs for reading (PR #2715)
* Support for posting content in federation_client script (PR #2716)
* Delete devices and pushers on logouts etc (PR #2722)
Bug fixes:
* Fix database port script (PR #2673)
* Fix internal server error on login with ldap_auth_provider (PR #2678) Thanks
to @jkolo!
* Fix error on sqlite 3.7 (PR #2697)
* Fix OPTIONS on preview_url (PR #2707)
* Fix error handling on dns lookup (PR #2711)
* Fix wrong avatars when inviting multiple users when creating room (PR #2717)
* Fix 500 when joining matrix-dev (PR #2719)
Changes in synapse v0.25.1 (2017-11-17) Changes in synapse v0.25.1 (2017-11-17)
======================================= =======================================

View File

@ -632,6 +632,11 @@ largest boxes pause for thought.)
Troubleshooting Troubleshooting
--------------- ---------------
You can use the federation tester to check if your homeserver is all set:
``https://matrix.org/federationtester/api/report?server_name=<your_server_name>``
If any of the attributes under "checks" is false, federation won't work.
The typical failure mode with federation is that when you try to join a room, The typical failure mode with federation is that when you try to join a room,
it is rejected with "401: Unauthorized". Generally this means that other it is rejected with "401: Unauthorized". Generally this means that other
servers in the room couldn't access yours. (Joining a room over federation is a servers in the room couldn't access yours. (Joining a room over federation is a

View File

@ -16,7 +16,7 @@ How to monitor Synapse metrics using Prometheus
metrics_port: 9092 metrics_port: 9092
Also ensure that ``enable_metrics`` is set to ``True``. Also ensure that ``enable_metrics`` is set to ``True``.
Restart synapse. Restart synapse.
3. Add a prometheus target for synapse. 3. Add a prometheus target for synapse.
@ -28,11 +28,58 @@ How to monitor Synapse metrics using Prometheus
static_configs: static_configs:
- targets: ["my.server.here:9092"] - targets: ["my.server.here:9092"]
If your prometheus is older than 1.5.2, you will need to replace If your prometheus is older than 1.5.2, you will need to replace
``static_configs`` in the above with ``target_groups``. ``static_configs`` in the above with ``target_groups``.
Restart prometheus. Restart prometheus.
Block and response metrics renamed for 0.27.0
---------------------------------------------
Synapse 0.27.0 begins the process of rationalising the duplicate ``*:count``
metrics reported for the resource tracking for code blocks and HTTP requests.
At the same time, the corresponding ``*:total`` metrics are being renamed, as
the ``:total`` suffix no longer makes sense in the absence of a corresponding
``:count`` metric.
To enable a graceful migration path, this release just adds new names for the
metrics being renamed. A future release will remove the old ones.
The following table shows the new metrics, and the old metrics which they are
replacing.
==================================================== ===================================================
New name Old name
==================================================== ===================================================
synapse_util_metrics_block_count synapse_util_metrics_block_timer:count
synapse_util_metrics_block_count synapse_util_metrics_block_ru_utime:count
synapse_util_metrics_block_count synapse_util_metrics_block_ru_stime:count
synapse_util_metrics_block_count synapse_util_metrics_block_db_txn_count:count
synapse_util_metrics_block_count synapse_util_metrics_block_db_txn_duration:count
synapse_util_metrics_block_time_seconds synapse_util_metrics_block_timer:total
synapse_util_metrics_block_ru_utime_seconds synapse_util_metrics_block_ru_utime:total
synapse_util_metrics_block_ru_stime_seconds synapse_util_metrics_block_ru_stime:total
synapse_util_metrics_block_db_txn_count synapse_util_metrics_block_db_txn_count:total
synapse_util_metrics_block_db_txn_duration_seconds synapse_util_metrics_block_db_txn_duration:total
synapse_http_server_response_count synapse_http_server_requests
synapse_http_server_response_count synapse_http_server_response_time:count
synapse_http_server_response_count synapse_http_server_response_ru_utime:count
synapse_http_server_response_count synapse_http_server_response_ru_stime:count
synapse_http_server_response_count synapse_http_server_response_db_txn_count:count
synapse_http_server_response_count synapse_http_server_response_db_txn_duration:count
synapse_http_server_response_time_seconds synapse_http_server_response_time:total
synapse_http_server_response_ru_utime_seconds synapse_http_server_response_ru_utime:total
synapse_http_server_response_ru_stime_seconds synapse_http_server_response_ru_stime:total
synapse_http_server_response_db_txn_count synapse_http_server_response_db_txn_count:total
synapse_http_server_response_db_txn_duration_seconds synapse_http_server_response_db_txn_duration:total
==================================================== ===================================================
Standard Metric Names Standard Metric Names
--------------------- ---------------------
@ -42,7 +89,7 @@ have been changed to seconds, from miliseconds.
================================== ============================= ================================== =============================
New name Old name New name Old name
---------------------------------- ----------------------------- ================================== =============================
process_cpu_user_seconds_total process_resource_utime / 1000 process_cpu_user_seconds_total process_resource_utime / 1000
process_cpu_system_seconds_total process_resource_stime / 1000 process_cpu_system_seconds_total process_resource_stime / 1000
process_open_fds (no 'type' label) process_fds process_open_fds (no 'type' label) process_fds
@ -52,8 +99,8 @@ The python-specific counts of garbage collector performance have been renamed.
=========================== ====================== =========================== ======================
New name Old name New name Old name
--------------------------- ---------------------- =========================== ======================
python_gc_time reactor_gc_time python_gc_time reactor_gc_time
python_gc_unreachable_total reactor_gc_unreachable python_gc_unreachable_total reactor_gc_unreachable
python_gc_counts reactor_gc_counts python_gc_counts reactor_gc_counts
=========================== ====================== =========================== ======================
@ -62,7 +109,7 @@ The twisted-specific reactor metrics have been renamed.
==================================== ===================== ==================================== =====================
New name Old name New name Old name
------------------------------------ --------------------- ==================================== =====================
python_twisted_reactor_pending_calls reactor_pending_calls python_twisted_reactor_pending_calls reactor_pending_calls
python_twisted_reactor_tick_time reactor_tick_time python_twisted_reactor_tick_time reactor_tick_time
==================================== ===================== ==================================== =====================

View File

@ -0,0 +1,133 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Moves a list of remote media from one media store to another.
The input should be a list of media files to be moved, one per line. Each line
should be formatted::
<origin server>|<file id>
This can be extracted from postgres with::
psql --tuples-only -A -c "select media_origin, filesystem_id from
matrix.remote_media_cache where ..."
To use, pipe the above into::
PYTHON_PATH=. ./scripts/move_remote_media_to_new_store.py <source repo> <dest repo>
"""
from __future__ import print_function
import argparse
import logging
import sys
import os
import shutil
from synapse.rest.media.v1.filepath import MediaFilePaths
logger = logging.getLogger()
def main(src_repo, dest_repo):
src_paths = MediaFilePaths(src_repo)
dest_paths = MediaFilePaths(dest_repo)
for line in sys.stdin:
line = line.strip()
parts = line.split('|')
if len(parts) != 2:
print("Unable to parse input line %s" % line, file=sys.stderr)
exit(1)
move_media(parts[0], parts[1], src_paths, dest_paths)
def move_media(origin_server, file_id, src_paths, dest_paths):
"""Move the given file, and any thumbnails, to the dest repo
Args:
origin_server (str):
file_id (str):
src_paths (MediaFilePaths):
dest_paths (MediaFilePaths):
"""
logger.info("%s/%s", origin_server, file_id)
# check that the original exists
original_file = src_paths.remote_media_filepath(origin_server, file_id)
if not os.path.exists(original_file):
logger.warn(
"Original for %s/%s (%s) does not exist",
origin_server, file_id, original_file,
)
else:
mkdir_and_move(
original_file,
dest_paths.remote_media_filepath(origin_server, file_id),
)
# now look for thumbnails
original_thumb_dir = src_paths.remote_media_thumbnail_dir(
origin_server, file_id,
)
if not os.path.exists(original_thumb_dir):
return
mkdir_and_move(
original_thumb_dir,
dest_paths.remote_media_thumbnail_dir(origin_server, file_id)
)
def mkdir_and_move(original_file, dest_file):
dirname = os.path.dirname(dest_file)
if not os.path.exists(dirname):
logger.debug("mkdir %s", dirname)
os.makedirs(dirname)
logger.debug("mv %s %s", original_file, dest_file)
shutil.move(original_file, dest_file)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class = argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"-v", action='store_true', help='enable debug logging')
parser.add_argument(
"src_repo",
help="Path to source content repo",
)
parser.add_argument(
"dest_repo",
help="Path to source content repo",
)
args = parser.parse_args()
logging_config = {
"level": logging.DEBUG if args.v else logging.INFO,
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s"
}
logging.basicConfig(**logging_config)
main(args.src_repo, args.dest_repo)

View File

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.25.1" __version__ = "0.26.0"

View File

@ -46,6 +46,7 @@ class Codes(object):
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED" THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
THREEPID_IN_USE = "M_THREEPID_IN_USE" THREEPID_IN_USE = "M_THREEPID_IN_USE"
THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND" THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND"
THREEPID_DENIED = "M_THREEPID_DENIED"
INVALID_USERNAME = "M_INVALID_USERNAME" INVALID_USERNAME = "M_INVALID_USERNAME"
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED" SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
@ -140,6 +141,32 @@ class RegistrationError(SynapseError):
pass pass
class FederationDeniedError(SynapseError):
"""An error raised when the server tries to federate with a server which
is not on its federation whitelist.
Attributes:
destination (str): The destination which has been denied
"""
def __init__(self, destination):
"""Raised by federation client or server to indicate that we are
are deliberately not attempting to contact a given server because it is
not on our federation whitelist.
Args:
destination (str): the domain in question
"""
self.destination = destination
super(FederationDeniedError, self).__init__(
code=403,
msg="Federation denied with %s." % (self.destination,),
errcode=Codes.FORBIDDEN,
)
class InteractiveAuthIncompleteError(Exception): class InteractiveAuthIncompleteError(Exception):
"""An error raised when UI auth is not yet complete """An error raised when UI auth is not yet complete

View File

@ -49,19 +49,6 @@ class AppserviceSlaveStore(
class AppserviceServer(HomeServer): class AppserviceServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self): def setup(self):
logger.info("Setting up.") logger.info("Setting up.")
self.datastore = AppserviceSlaveStore(self.get_db_conn(), self) self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)

View File

@ -64,19 +64,6 @@ class ClientReaderSlavedStore(
class ClientReaderServer(HomeServer): class ClientReaderServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self): def setup(self):
logger.info("Setting up.") logger.info("Setting up.")
self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self) self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self)

View File

@ -58,19 +58,6 @@ class FederationReaderSlavedStore(
class FederationReaderServer(HomeServer): class FederationReaderServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self): def setup(self):
logger.info("Setting up.") logger.info("Setting up.")
self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self) self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)

View File

@ -76,19 +76,6 @@ class FederationSenderSlaveStore(
class FederationSenderServer(HomeServer): class FederationSenderServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self): def setup(self):
logger.info("Setting up.") logger.info("Setting up.")
self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self) self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self)

View File

@ -118,19 +118,6 @@ class FrontendProxySlavedStore(
class FrontendProxyServer(HomeServer): class FrontendProxyServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self): def setup(self):
logger.info("Setting up.") logger.info("Setting up.")
self.datastore = FrontendProxySlavedStore(self.get_db_conn(), self) self.datastore = FrontendProxySlavedStore(self.get_db_conn(), self)

View File

@ -266,19 +266,6 @@ class SynapseHomeServer(HomeServer):
except IncorrectDatabaseSetup as e: except IncorrectDatabaseSetup as e:
quit_with_error(e.message) quit_with_error(e.message)
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(config_options): def setup(config_options):
""" """

View File

@ -60,19 +60,6 @@ class MediaRepositorySlavedStore(
class MediaRepositoryServer(HomeServer): class MediaRepositoryServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self): def setup(self):
logger.info("Setting up.") logger.info("Setting up.")
self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self) self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)

View File

@ -81,19 +81,6 @@ class PusherSlaveStore(
class PusherServer(HomeServer): class PusherServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self): def setup(self):
logger.info("Setting up.") logger.info("Setting up.")
self.datastore = PusherSlaveStore(self.get_db_conn(), self) self.datastore = PusherSlaveStore(self.get_db_conn(), self)

View File

@ -246,19 +246,6 @@ class SynchrotronApplicationService(object):
class SynchrotronServer(HomeServer): class SynchrotronServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self): def setup(self):
logger.info("Setting up.") logger.info("Setting up.")
self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self) self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self)

View File

@ -184,6 +184,9 @@ def main():
worker_configfiles.append(worker_configfile) worker_configfiles.append(worker_configfile)
if options.all_processes: if options.all_processes:
# To start the main synapse with -a you need to add a worker file
# with worker_app == "synapse.app.homeserver"
start_stop_synapse = False
worker_configdir = options.all_processes worker_configdir = options.all_processes
if not os.path.isdir(worker_configdir): if not os.path.isdir(worker_configdir):
write( write(
@ -200,11 +203,29 @@ def main():
with open(worker_configfile) as stream: with open(worker_configfile) as stream:
worker_config = yaml.load(stream) worker_config = yaml.load(stream)
worker_app = worker_config["worker_app"] worker_app = worker_config["worker_app"]
worker_pidfile = worker_config["worker_pid_file"] if worker_app == "synapse.app.homeserver":
worker_daemonize = worker_config["worker_daemonize"] # We need to special case all of this to pick up options that may
assert worker_daemonize, "In config %r: expected '%s' to be True" % ( # be set in the main config file or in this worker config file.
worker_configfile, "worker_daemonize") worker_pidfile = (
worker_cache_factor = worker_config.get("synctl_cache_factor") worker_config.get("pid_file")
or pidfile
)
worker_cache_factor = worker_config.get("synctl_cache_factor") or cache_factor
daemonize = worker_config.get("daemonize") or config.get("daemonize")
assert daemonize, "Main process must have daemonize set to true"
# The master process doesn't support using worker_* config.
for key in worker_config:
if key == "worker_app": # But we allow worker_app
continue
assert not key.startswith("worker_"), \
"Main process cannot use worker_* config"
else:
worker_pidfile = worker_config["worker_pid_file"]
worker_daemonize = worker_config["worker_daemonize"]
assert worker_daemonize, "In config %r: expected '%s' to be True" % (
worker_configfile, "worker_daemonize")
worker_cache_factor = worker_config.get("synctl_cache_factor")
workers.append(Worker( workers.append(Worker(
worker_app, worker_configfile, worker_pidfile, worker_cache_factor, worker_app, worker_configfile, worker_pidfile, worker_cache_factor,
)) ))

View File

@ -92,19 +92,6 @@ class UserDirectorySlaveStore(
class UserDirectoryServer(HomeServer): class UserDirectoryServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self): def setup(self):
logger.info("Setting up.") logger.info("Setting up.")
self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self) self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self)

View File

@ -28,27 +28,27 @@ DEFAULT_LOG_CONFIG = Template("""
version: 1 version: 1
formatters: formatters:
precise: precise:
format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s\ format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - \
- %(message)s' %(request)s - %(message)s'
filters: filters:
context: context:
(): synapse.util.logcontext.LoggingContextFilter (): synapse.util.logcontext.LoggingContextFilter
request: "" request: ""
handlers: handlers:
file: file:
class: logging.handlers.RotatingFileHandler class: logging.handlers.RotatingFileHandler
formatter: precise formatter: precise
filename: ${log_file} filename: ${log_file}
maxBytes: 104857600 maxBytes: 104857600
backupCount: 10 backupCount: 10
filters: [context] filters: [context]
console: console:
class: logging.StreamHandler class: logging.StreamHandler
formatter: precise formatter: precise
filters: [context] filters: [context]
loggers: loggers:
synapse: synapse:
@ -74,17 +74,10 @@ class LoggingConfig(Config):
self.log_file = self.abspath(config.get("log_file")) self.log_file = self.abspath(config.get("log_file"))
def default_config(self, config_dir_path, server_name, **kwargs): def default_config(self, config_dir_path, server_name, **kwargs):
log_file = self.abspath("homeserver.log")
log_config = self.abspath( log_config = self.abspath(
os.path.join(config_dir_path, server_name + ".log.config") os.path.join(config_dir_path, server_name + ".log.config")
) )
return """ return """
# Logging verbosity level. Ignored if log_config is specified.
verbose: 0
# File to write logging to. Ignored if log_config is specified.
log_file: "%(log_file)s"
# A yaml python logging config file # A yaml python logging config file
log_config: "%(log_config)s" log_config: "%(log_config)s"
""" % locals() """ % locals()
@ -123,9 +116,10 @@ class LoggingConfig(Config):
def generate_files(self, config): def generate_files(self, config):
log_config = config.get("log_config") log_config = config.get("log_config")
if log_config and not os.path.exists(log_config): if log_config and not os.path.exists(log_config):
log_file = self.abspath("homeserver.log")
with open(log_config, "wb") as log_config_file: with open(log_config, "wb") as log_config_file:
log_config_file.write( log_config_file.write(
DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"]) DEFAULT_LOG_CONFIG.substitute(log_file=log_file)
) )
@ -150,6 +144,9 @@ def setup_logging(config, use_worker_options=False):
) )
if log_config is None: if log_config is None:
# We don't have a logfile, so fall back to the 'verbosity' param from
# the config or cmdline. (Note that we generate a log config for new
# installs, so this will be an unusual case)
level = logging.INFO level = logging.INFO
level_for_storage = logging.INFO level_for_storage = logging.INFO
if config.verbosity: if config.verbosity:
@ -157,11 +154,10 @@ def setup_logging(config, use_worker_options=False):
if config.verbosity > 1: if config.verbosity > 1:
level_for_storage = logging.DEBUG level_for_storage = logging.DEBUG
# FIXME: we need a logging.WARN for a -q quiet option
logger = logging.getLogger('') logger = logging.getLogger('')
logger.setLevel(level) logger.setLevel(level)
logging.getLogger('synapse.storage').setLevel(level_for_storage) logging.getLogger('synapse.storage.SQL').setLevel(level_for_storage)
formatter = logging.Formatter(log_format) formatter = logging.Formatter(log_format)
if log_file: if log_file:

View File

@ -31,6 +31,8 @@ class RegistrationConfig(Config):
strtobool(str(config["disable_registration"])) strtobool(str(config["disable_registration"]))
) )
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
self.registration_shared_secret = config.get("registration_shared_secret") self.registration_shared_secret = config.get("registration_shared_secret")
self.bcrypt_rounds = config.get("bcrypt_rounds", 12) self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
@ -52,6 +54,23 @@ class RegistrationConfig(Config):
# Enable registration for new users. # Enable registration for new users.
enable_registration: False enable_registration: False
# The user must provide all of the below types of 3PID when registering.
#
# registrations_require_3pid:
# - email
# - msisdn
# Mandate that users are only allowed to associate certain formats of
# 3PIDs with accounts on this server.
#
# allowed_local_3pids:
# - medium: email
# pattern: ".*@matrix\\.org"
# - medium: email
# pattern: ".*@vector\\.im"
# - medium: msisdn
# pattern: "\\+44"
# If set, allows registration by anyone who also has the shared # If set, allows registration by anyone who also has the shared
# secret, even if registration is otherwise disabled. # secret, even if registration is otherwise disabled.
registration_shared_secret: "%(registration_shared_secret)s" registration_shared_secret: "%(registration_shared_secret)s"

View File

@ -16,6 +16,8 @@
from ._base import Config, ConfigError from ._base import Config, ConfigError
from collections import namedtuple from collections import namedtuple
from synapse.util.module_loader import load_module
MISSING_NETADDR = ( MISSING_NETADDR = (
"Missing netaddr library. This is required for URL preview API." "Missing netaddr library. This is required for URL preview API."
@ -36,6 +38,14 @@ ThumbnailRequirement = namedtuple(
"ThumbnailRequirement", ["width", "height", "method", "media_type"] "ThumbnailRequirement", ["width", "height", "method", "media_type"]
) )
MediaStorageProviderConfig = namedtuple(
"MediaStorageProviderConfig", (
"store_local", # Whether to store newly uploaded local files
"store_remote", # Whether to store newly downloaded remote files
"store_synchronous", # Whether to wait for successful storage for local uploads
),
)
def parse_thumbnail_requirements(thumbnail_sizes): def parse_thumbnail_requirements(thumbnail_sizes):
""" Takes a list of dictionaries with "width", "height", and "method" keys """ Takes a list of dictionaries with "width", "height", and "method" keys
@ -73,16 +83,61 @@ class ContentRepositoryConfig(Config):
self.media_store_path = self.ensure_directory(config["media_store_path"]) self.media_store_path = self.ensure_directory(config["media_store_path"])
self.backup_media_store_path = config.get("backup_media_store_path") backup_media_store_path = config.get("backup_media_store_path")
if self.backup_media_store_path:
self.backup_media_store_path = self.ensure_directory(
self.backup_media_store_path
)
self.synchronous_backup_media_store = config.get( synchronous_backup_media_store = config.get(
"synchronous_backup_media_store", False "synchronous_backup_media_store", False
) )
storage_providers = config.get("media_storage_providers", [])
if backup_media_store_path:
if storage_providers:
raise ConfigError(
"Cannot use both 'backup_media_store_path' and 'storage_providers'"
)
storage_providers = [{
"module": "file_system",
"store_local": True,
"store_synchronous": synchronous_backup_media_store,
"store_remote": True,
"config": {
"directory": backup_media_store_path,
}
}]
# This is a list of config that can be used to create the storage
# providers. The entries are tuples of (Class, class_config,
# MediaStorageProviderConfig), where Class is the class of the provider,
# the class_config the config to pass to it, and
# MediaStorageProviderConfig are options for StorageProviderWrapper.
#
# We don't create the storage providers here as not all workers need
# them to be started.
self.media_storage_providers = []
for provider_config in storage_providers:
# We special case the module "file_system" so as not to need to
# expose FileStorageProviderBackend
if provider_config["module"] == "file_system":
provider_config["module"] = (
"synapse.rest.media.v1.storage_provider"
".FileStorageProviderBackend"
)
provider_class, parsed_config = load_module(provider_config)
wrapper_config = MediaStorageProviderConfig(
provider_config.get("store_local", False),
provider_config.get("store_remote", False),
provider_config.get("store_synchronous", False),
)
self.media_storage_providers.append(
(provider_class, parsed_config, wrapper_config,)
)
self.uploads_path = self.ensure_directory(config["uploads_path"]) self.uploads_path = self.ensure_directory(config["uploads_path"])
self.dynamic_thumbnails = config["dynamic_thumbnails"] self.dynamic_thumbnails = config["dynamic_thumbnails"]
self.thumbnail_requirements = parse_thumbnail_requirements( self.thumbnail_requirements = parse_thumbnail_requirements(
@ -127,13 +182,19 @@ class ContentRepositoryConfig(Config):
# Directory where uploaded images and attachments are stored. # Directory where uploaded images and attachments are stored.
media_store_path: "%(media_store)s" media_store_path: "%(media_store)s"
# A secondary directory where uploaded images and attachments are # Media storage providers allow media to be stored in different
# stored as a backup. # locations.
# backup_media_store_path: "%(media_store)s" # media_storage_providers:
# - module: file_system
# Whether to wait for successful write to backup media store before # # Whether to write new local files.
# returning successfully. # store_local: false
# synchronous_backup_media_store: false # # Whether to write new remote media
# store_remote: false
# # Whether to block upload requests waiting for write to this
# # provider to complete
# store_synchronous: false
# config:
# directory: /mnt/some/other/directory
# Directory where in-progress uploads are stored. # Directory where in-progress uploads are stored.
uploads_path: "%(uploads_path)s" uploads_path: "%(uploads_path)s"

View File

@ -55,6 +55,17 @@ class ServerConfig(Config):
"block_non_admin_invites", False, "block_non_admin_invites", False,
) )
# FIXME: federation_domain_whitelist needs sytests
self.federation_domain_whitelist = None
federation_domain_whitelist = config.get(
"federation_domain_whitelist", None
)
# turn the whitelist into a hash for speed of lookup
if federation_domain_whitelist is not None:
self.federation_domain_whitelist = {}
for domain in federation_domain_whitelist:
self.federation_domain_whitelist[domain] = True
if self.public_baseurl is not None: if self.public_baseurl is not None:
if self.public_baseurl[-1] != '/': if self.public_baseurl[-1] != '/':
self.public_baseurl += '/' self.public_baseurl += '/'
@ -210,6 +221,17 @@ class ServerConfig(Config):
# (except those sent by local server admins). The default is False. # (except those sent by local server admins). The default is False.
# block_non_admin_invites: True # block_non_admin_invites: True
# Restrict federation to the following whitelist of domains.
# N.B. we recommend also firewalling your federation listener to limit
# inbound federation traffic as early as possible, rather than relying
# purely on this application-layer restriction. If not specified, the
# default is to whitelist everything.
#
# federation_domain_whitelist:
# - lon.example.com
# - nyc.example.com
# - syd.example.com
# List of ports that Synapse should listen on, their purpose and their # List of ports that Synapse should listen on, their purpose and their
# configuration. # configuration.
listeners: listeners:

View File

@ -96,7 +96,7 @@ class TlsConfig(Config):
# certificates returned by this server match one of the fingerprints. # certificates returned by this server match one of the fingerprints.
# #
# Synapse automatically adds the fingerprint of its own certificate # Synapse automatically adds the fingerprint of its own certificate
# to the list. So if federation traffic is handle directly by synapse # to the list. So if federation traffic is handled directly by synapse
# then no modification to the list is required. # then no modification to the list is required.
# #
# If synapse is run behind a load balancer that handles the TLS then it # If synapse is run behind a load balancer that handles the TLS then it

View File

@ -23,6 +23,11 @@ class WorkerConfig(Config):
def read_config(self, config): def read_config(self, config):
self.worker_app = config.get("worker_app") self.worker_app = config.get("worker_app")
# Canonicalise worker_app so that master always has None
if self.worker_app == "synapse.app.homeserver":
self.worker_app = None
self.worker_listeners = config.get("worker_listeners") self.worker_listeners = config.get("worker_listeners")
self.worker_daemonize = config.get("worker_daemonize") self.worker_daemonize = config.get("worker_daemonize")
self.worker_pid_file = config.get("worker_pid_file") self.worker_pid_file = config.get("worker_pid_file")

View File

@ -319,7 +319,7 @@ def _is_membership_change_allowed(event, auth_events):
# TODO (erikj): Implement kicks. # TODO (erikj): Implement kicks.
if target_banned and user_level < ban_level: if target_banned and user_level < ban_level:
raise AuthError( raise AuthError(
403, "You cannot unban user &s." % (target_user_id,) 403, "You cannot unban user %s." % (target_user_id,)
) )
elif target_user_id != event.user_id: elif target_user_id != event.user_id:
kick_level = _get_named_level(auth_events, "kick", 50) kick_level = _get_named_level(auth_events, "kick", 50)

View File

@ -16,7 +16,9 @@ import logging
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.crypto.event_signing import check_event_content_hash from synapse.crypto.event_signing import check_event_content_hash
from synapse.events import FrozenEvent
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.http.servlet import assert_params_in_request
from synapse.util import unwrapFirstError, logcontext from synapse.util import unwrapFirstError, logcontext
from twisted.internet import defer from twisted.internet import defer
@ -169,3 +171,28 @@ class FederationBase(object):
) )
return deferreds return deferreds
def event_from_pdu_json(pdu_json, outlier=False):
"""Construct a FrozenEvent from an event json received over federation
Args:
pdu_json (object): pdu as received over federation
outlier (bool): True to mark this event as an outlier
Returns:
FrozenEvent
Raises:
SynapseError: if the pdu is missing required fields
"""
# we could probably enforce a bunch of other fields here (room_id, sender,
# origin, etc etc)
assert_params_in_request(pdu_json, ('event_id', 'type'))
event = FrozenEvent(
pdu_json
)
event.internal_metadata.outlier = outlier
return event

View File

@ -14,28 +14,28 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from .federation_base import FederationBase
from synapse.api.constants import Membership
from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError,
)
from synapse.util import unwrapFirstError, logcontext
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.events import FrozenEvent, builder
import synapse.metrics
from synapse.util.retryutils import NotRetryingDestination
import copy import copy
import itertools import itertools
import logging import logging
import random import random
from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError, FederationDeniedError
)
from synapse.events import builder
from synapse.federation.federation_base import (
FederationBase,
event_from_pdu_json,
)
import synapse.metrics
from synapse.util import logcontext, unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.util.logutils import log_function
from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -184,7 +184,7 @@ class FederationClient(FederationBase):
logger.debug("backfill transaction_data=%s", repr(transaction_data)) logger.debug("backfill transaction_data=%s", repr(transaction_data))
pdus = [ pdus = [
self.event_from_pdu_json(p, outlier=False) event_from_pdu_json(p, outlier=False)
for p in transaction_data["pdus"] for p in transaction_data["pdus"]
] ]
@ -244,7 +244,7 @@ class FederationClient(FederationBase):
logger.debug("transaction_data %r", transaction_data) logger.debug("transaction_data %r", transaction_data)
pdu_list = [ pdu_list = [
self.event_from_pdu_json(p, outlier=outlier) event_from_pdu_json(p, outlier=outlier)
for p in transaction_data["pdus"] for p in transaction_data["pdus"]
] ]
@ -266,6 +266,9 @@ class FederationClient(FederationBase):
except NotRetryingDestination as e: except NotRetryingDestination as e:
logger.info(e.message) logger.info(e.message)
continue continue
except FederationDeniedError as e:
logger.info(e.message)
continue
except Exception as e: except Exception as e:
pdu_attempts[destination] = now pdu_attempts[destination] = now
@ -336,11 +339,11 @@ class FederationClient(FederationBase):
) )
pdus = [ pdus = [
self.event_from_pdu_json(p, outlier=True) for p in result["pdus"] event_from_pdu_json(p, outlier=True) for p in result["pdus"]
] ]
auth_chain = [ auth_chain = [
self.event_from_pdu_json(p, outlier=True) event_from_pdu_json(p, outlier=True)
for p in result.get("auth_chain", []) for p in result.get("auth_chain", [])
] ]
@ -441,7 +444,7 @@ class FederationClient(FederationBase):
) )
auth_chain = [ auth_chain = [
self.event_from_pdu_json(p, outlier=True) event_from_pdu_json(p, outlier=True)
for p in res["auth_chain"] for p in res["auth_chain"]
] ]
@ -570,12 +573,12 @@ class FederationClient(FederationBase):
logger.debug("Got content: %s", content) logger.debug("Got content: %s", content)
state = [ state = [
self.event_from_pdu_json(p, outlier=True) event_from_pdu_json(p, outlier=True)
for p in content.get("state", []) for p in content.get("state", [])
] ]
auth_chain = [ auth_chain = [
self.event_from_pdu_json(p, outlier=True) event_from_pdu_json(p, outlier=True)
for p in content.get("auth_chain", []) for p in content.get("auth_chain", [])
] ]
@ -650,7 +653,7 @@ class FederationClient(FederationBase):
logger.debug("Got response to send_invite: %s", pdu_dict) logger.debug("Got response to send_invite: %s", pdu_dict)
pdu = self.event_from_pdu_json(pdu_dict) pdu = event_from_pdu_json(pdu_dict)
# Check signatures are correct. # Check signatures are correct.
pdu = yield self._check_sigs_and_hash(pdu) pdu = yield self._check_sigs_and_hash(pdu)
@ -740,7 +743,7 @@ class FederationClient(FederationBase):
) )
auth_chain = [ auth_chain = [
self.event_from_pdu_json(e) event_from_pdu_json(e)
for e in content["auth_chain"] for e in content["auth_chain"]
] ]
@ -788,7 +791,7 @@ class FederationClient(FederationBase):
) )
events = [ events = [
self.event_from_pdu_json(e) event_from_pdu_json(e)
for e in content.get("events", []) for e in content.get("events", [])
] ]
@ -805,15 +808,6 @@ class FederationClient(FederationBase):
defer.returnValue(signed_events) defer.returnValue(signed_events)
def event_from_pdu_json(self, pdu_json, outlier=False):
event = FrozenEvent(
pdu_json
)
event.internal_metadata.outlier = outlier
return event
@defer.inlineCallbacks @defer.inlineCallbacks
def forward_third_party_invite(self, destinations, room_id, event_dict): def forward_third_party_invite(self, destinations, room_id, event_dict):
for destination in destinations: for destination in destinations:

View File

@ -12,25 +12,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer import logging
from .federation_base import FederationBase
from .units import Transaction, Edu
from synapse.util import async
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.util.logutils import log_function
from synapse.util.caches.response_cache import ResponseCache
from synapse.events import FrozenEvent
from synapse.types import get_domain_from_id
import synapse.metrics
from synapse.api.errors import AuthError, FederationError, SynapseError
from synapse.crypto.event_signing import compute_event_signature
import simplejson as json import simplejson as json
import logging from twisted.internet import defer
from synapse.api.errors import AuthError, FederationError, SynapseError
from synapse.crypto.event_signing import compute_event_signature
from synapse.federation.federation_base import (
FederationBase,
event_from_pdu_json,
)
from synapse.federation.units import Edu, Transaction
import synapse.metrics
from synapse.types import get_domain_from_id
from synapse.util import async
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.util.logutils import log_function
# when processing incoming transactions, we try to handle multiple rooms in # when processing incoming transactions, we try to handle multiple rooms in
# parallel, up to this limit. # parallel, up to this limit.
@ -172,7 +171,7 @@ class FederationServer(FederationBase):
p["age_ts"] = request_time - int(p["age"]) p["age_ts"] = request_time - int(p["age"])
del p["age"] del p["age"]
event = self.event_from_pdu_json(p) event = event_from_pdu_json(p)
room_id = event.room_id room_id = event.room_id
pdus_by_room.setdefault(room_id, []).append(event) pdus_by_room.setdefault(room_id, []).append(event)
@ -346,7 +345,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_invite_request(self, origin, content): def on_invite_request(self, origin, content):
pdu = self.event_from_pdu_json(content) pdu = event_from_pdu_json(content)
ret_pdu = yield self.handler.on_invite_request(origin, pdu) ret_pdu = yield self.handler.on_invite_request(origin, pdu)
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)})) defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
@ -354,7 +353,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_send_join_request(self, origin, content): def on_send_join_request(self, origin, content):
logger.debug("on_send_join_request: content: %s", content) logger.debug("on_send_join_request: content: %s", content)
pdu = self.event_from_pdu_json(content) pdu = event_from_pdu_json(content)
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
res_pdus = yield self.handler.on_send_join_request(origin, pdu) res_pdus = yield self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
@ -374,7 +373,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_send_leave_request(self, origin, content): def on_send_leave_request(self, origin, content):
logger.debug("on_send_leave_request: content: %s", content) logger.debug("on_send_leave_request: content: %s", content)
pdu = self.event_from_pdu_json(content) pdu = event_from_pdu_json(content)
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures) logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
yield self.handler.on_send_leave_request(origin, pdu) yield self.handler.on_send_leave_request(origin, pdu)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -411,7 +410,7 @@ class FederationServer(FederationBase):
""" """
with (yield self._server_linearizer.queue((origin, room_id))): with (yield self._server_linearizer.queue((origin, room_id))):
auth_chain = [ auth_chain = [
self.event_from_pdu_json(e) event_from_pdu_json(e)
for e in content["auth_chain"] for e in content["auth_chain"]
] ]
@ -586,15 +585,6 @@ class FederationServer(FederationBase):
def __str__(self): def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name return "<ReplicationLayer(%s)>" % self.server_name
def event_from_pdu_json(self, pdu_json, outlier=False):
event = FrozenEvent(
pdu_json
)
event.internal_metadata.outlier = outlier
return event
@defer.inlineCallbacks @defer.inlineCallbacks
def exchange_third_party_invite( def exchange_third_party_invite(
self, self,

View File

@ -19,7 +19,7 @@ from twisted.internet import defer
from .persistence import TransactionActions from .persistence import TransactionActions
from .units import Transaction, Edu from .units import Transaction, Edu
from synapse.api.errors import HttpResponseException from synapse.api.errors import HttpResponseException, FederationDeniedError
from synapse.util import logcontext, PreserveLoggingContext from synapse.util import logcontext, PreserveLoggingContext
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
@ -42,6 +42,8 @@ sent_edus_counter = client_metrics.register_counter("sent_edus")
sent_transactions_counter = client_metrics.register_counter("sent_transactions") sent_transactions_counter = client_metrics.register_counter("sent_transactions")
events_processed_counter = client_metrics.register_counter("events_processed")
class TransactionQueue(object): class TransactionQueue(object):
"""This class makes sure we only have one transaction in flight at """This class makes sure we only have one transaction in flight at
@ -205,6 +207,8 @@ class TransactionQueue(object):
self._send_pdu(event, destinations) self._send_pdu(event, destinations)
events_processed_counter.inc_by(len(events))
yield self.store.update_federation_out_pos( yield self.store.update_federation_out_pos(
"events", next_token "events", next_token
) )
@ -486,6 +490,8 @@ class TransactionQueue(object):
(e.retry_last_ts + e.retry_interval) / 1000.0 (e.retry_last_ts + e.retry_interval) / 1000.0
), ),
) )
except FederationDeniedError as e:
logger.info(e)
except Exception as e: except Exception as e:
logger.warn( logger.warn(
"TX [%s] Failed to send transaction: %s", "TX [%s] Failed to send transaction: %s",

View File

@ -212,6 +212,9 @@ class TransportLayerClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server. to retry this server.
Fails with ``FederationDeniedError`` if the remote destination
is not in our federation whitelist
""" """
valid_memberships = {Membership.JOIN, Membership.LEAVE} valid_memberships = {Membership.JOIN, Membership.LEAVE}
if membership not in valid_memberships: if membership not in valid_memberships:

View File

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError, FederationDeniedError
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.http.servlet import ( from synapse.http.servlet import (
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args, parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
@ -81,6 +81,7 @@ class Authenticator(object):
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
self.server_name = hs.hostname self.server_name = hs.hostname
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
# A method just so we can pass 'self' as the authenticator to the Servlets # A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks @defer.inlineCallbacks
@ -92,6 +93,12 @@ class Authenticator(object):
"signatures": {}, "signatures": {},
} }
if (
self.federation_domain_whitelist is not None and
self.server_name not in self.federation_domain_whitelist
):
raise FederationDeniedError(self.server_name)
if content is not None: if content is not None:
json_request["content"] = content json_request["content"] = content

View File

@ -15,6 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
import synapse
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
@ -23,6 +24,10 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__)
events_processed_counter = metrics.register_counter("events_processed")
def log_failure(failure): def log_failure(failure):
logger.error( logger.error(
@ -103,6 +108,8 @@ class ApplicationServicesHandler(object):
service, event service, event
) )
events_processed_counter.inc_by(len(events))
yield self.store.set_appservice_last_pos(upper_bound) yield self.store.set_appservice_last_pos(upper_bound)
finally: finally:
self.is_processing = False self.is_processing = False

View File

@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer, threads
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
@ -25,6 +25,7 @@ from synapse.module_api import ModuleApi
from synapse.types import UserID from synapse.types import UserID
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logcontext import make_deferred_yieldable
from twisted.web.client import PartialDownloadError from twisted.web.client import PartialDownloadError
@ -714,7 +715,7 @@ class AuthHandler(BaseHandler):
if not lookupres: if not lookupres:
defer.returnValue(None) defer.returnValue(None)
(user_id, password_hash) = lookupres (user_id, password_hash) = lookupres
result = self.validate_hash(password, password_hash) result = yield self.validate_hash(password, password_hash)
if not result: if not result:
logger.warn("Failed password login for user %s", user_id) logger.warn("Failed password login for user %s", user_id)
defer.returnValue(None) defer.returnValue(None)
@ -842,10 +843,13 @@ class AuthHandler(BaseHandler):
password (str): Password to hash. password (str): Password to hash.
Returns: Returns:
Hashed password (str). Deferred(str): Hashed password.
""" """
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper, def _do_hash():
bcrypt.gensalt(self.bcrypt_rounds)) return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
bcrypt.gensalt(self.bcrypt_rounds))
return make_deferred_yieldable(threads.deferToThread(_do_hash))
def validate_hash(self, password, stored_hash): def validate_hash(self, password, stored_hash):
"""Validates that self.hash(password) == stored_hash. """Validates that self.hash(password) == stored_hash.
@ -855,13 +859,17 @@ class AuthHandler(BaseHandler):
stored_hash (str): Expected hash value. stored_hash (str): Expected hash value.
Returns: Returns:
Whether self.hash(password) == stored_hash (bool). Deferred(bool): Whether self.hash(password) == stored_hash.
""" """
if stored_hash:
def _do_validate_hash():
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper, return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
stored_hash.encode('utf8')) == stored_hash stored_hash.encode('utf8')) == stored_hash
if stored_hash:
return make_deferred_yieldable(threads.deferToThread(_do_validate_hash))
else: else:
return False return defer.succeed(False)
class MacaroonGeneartor(object): class MacaroonGeneartor(object):

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from synapse.api import errors from synapse.api import errors
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import FederationDeniedError
from synapse.util import stringutils from synapse.util import stringutils
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
@ -513,6 +514,9 @@ class DeviceListEduUpdater(object):
# This makes it more likely that the device lists will # This makes it more likely that the device lists will
# eventually become consistent. # eventually become consistent.
return return
except FederationDeniedError as e:
logger.info(e)
return
except Exception: except Exception:
# TODO: Remember that we are now out of sync and try again # TODO: Remember that we are now out of sync and try again
# later # later

View File

@ -17,7 +17,8 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.types import get_domain_from_id from synapse.api.errors import SynapseError
from synapse.types import get_domain_from_id, UserID
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
@ -33,7 +34,7 @@ class DeviceMessageHandler(object):
""" """
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id self.is_mine = hs.is_mine
self.federation = hs.get_federation_sender() self.federation = hs.get_federation_sender()
hs.get_replication_layer().register_edu_handler( hs.get_replication_layer().register_edu_handler(
@ -52,6 +53,12 @@ class DeviceMessageHandler(object):
message_type = content["type"] message_type = content["type"]
message_id = content["message_id"] message_id = content["message_id"]
for user_id, by_device in content["messages"].items(): for user_id, by_device in content["messages"].items():
# we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
logger.warning("Request for keys for non-local user %s",
user_id)
raise SynapseError(400, "Not a user here")
messages_by_device = { messages_by_device = {
device_id: { device_id: {
"content": message_content, "content": message_content,
@ -77,7 +84,8 @@ class DeviceMessageHandler(object):
local_messages = {} local_messages = {}
remote_messages = {} remote_messages = {}
for user_id, by_device in messages.items(): for user_id, by_device in messages.items():
if self.is_mine_id(user_id): # we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
messages_by_device = { messages_by_device = {
device_id: { device_id: {
"content": message_content, "content": message_content,

View File

@ -19,8 +19,10 @@ import logging
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError, CodeMessageException from synapse.api.errors import (
from synapse.types import get_domain_from_id SynapseError, CodeMessageException, FederationDeniedError,
)
from synapse.types import get_domain_from_id, UserID
from synapse.util.logcontext import preserve_fn, make_deferred_yieldable from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
@ -32,7 +34,7 @@ class E2eKeysHandler(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.is_mine_id = hs.is_mine_id self.is_mine = hs.is_mine
self.clock = hs.get_clock() self.clock = hs.get_clock()
# doesn't really work as part of the generic query API, because the # doesn't really work as part of the generic query API, because the
@ -70,7 +72,8 @@ class E2eKeysHandler(object):
remote_queries = {} remote_queries = {}
for user_id, device_ids in device_keys_query.items(): for user_id, device_ids in device_keys_query.items():
if self.is_mine_id(user_id): # we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
local_query[user_id] = device_ids local_query[user_id] = device_ids
else: else:
remote_queries[user_id] = device_ids remote_queries[user_id] = device_ids
@ -139,6 +142,10 @@ class E2eKeysHandler(object):
failures[destination] = { failures[destination] = {
"status": 503, "message": "Not ready for retry", "status": 503, "message": "Not ready for retry",
} }
except FederationDeniedError as e:
failures[destination] = {
"status": 403, "message": "Federation Denied",
}
except Exception as e: except Exception as e:
# include ConnectionRefused and other errors # include ConnectionRefused and other errors
failures[destination] = { failures[destination] = {
@ -170,7 +177,8 @@ class E2eKeysHandler(object):
result_dict = {} result_dict = {}
for user_id, device_ids in query.items(): for user_id, device_ids in query.items():
if not self.is_mine_id(user_id): # we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
logger.warning("Request for keys for non-local user %s", logger.warning("Request for keys for non-local user %s",
user_id) user_id)
raise SynapseError(400, "Not a user here") raise SynapseError(400, "Not a user here")
@ -213,7 +221,8 @@ class E2eKeysHandler(object):
remote_queries = {} remote_queries = {}
for user_id, device_keys in query.get("one_time_keys", {}).items(): for user_id, device_keys in query.get("one_time_keys", {}).items():
if self.is_mine_id(user_id): # we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
for device_id, algorithm in device_keys.items(): for device_id, algorithm in device_keys.items():
local_query.append((user_id, device_id, algorithm)) local_query.append((user_id, device_id, algorithm))
else: else:

View File

@ -22,6 +22,7 @@ from ._base import BaseHandler
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, FederationError, StoreError, CodeMessageException, SynapseError, AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
FederationDeniedError,
) )
from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
@ -782,6 +783,9 @@ class FederationHandler(BaseHandler):
except NotRetryingDestination as e: except NotRetryingDestination as e:
logger.info(e.message) logger.info(e.message)
continue continue
except FederationDeniedError as e:
logger.info(e)
continue
except Exception as e: except Exception as e:
logger.exception( logger.exception(
"Failed to backfill from %s because %s", "Failed to backfill from %s because %s",

View File

@ -383,11 +383,12 @@ class GroupsLocalHandler(object):
defer.returnValue({"groups": result}) defer.returnValue({"groups": result})
else: else:
result = yield self.transport_client.get_publicised_groups_for_user( bulk_result = yield self.transport_client.bulk_get_publicised_groups(
get_domain_from_id(user_id), user_id get_domain_from_id(user_id), [user_id],
) )
result = bulk_result.get("users", {}).get(user_id)
# TODO: Verify attestations # TODO: Verify attestations
defer.returnValue(result) defer.returnValue({"groups": result})
@defer.inlineCallbacks @defer.inlineCallbacks
def bulk_get_publicised_groups(self, user_ids, proxy=True): def bulk_get_publicised_groups(self, user_ids, proxy=True):

View File

@ -25,6 +25,7 @@ from synapse.http.client import CaptchaServerHttpClient
from synapse import types from synapse import types
from synapse.types import UserID from synapse.types import UserID
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.threepids import check_3pid_allowed
from ._base import BaseHandler from ._base import BaseHandler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -131,7 +132,7 @@ class RegistrationHandler(BaseHandler):
yield run_on_reactor() yield run_on_reactor()
password_hash = None password_hash = None
if password: if password:
password_hash = self.auth_handler().hash(password) password_hash = yield self.auth_handler().hash(password)
if localpart: if localpart:
yield self.check_username(localpart, guest_access_token=guest_access_token) yield self.check_username(localpart, guest_access_token=guest_access_token)
@ -293,7 +294,7 @@ class RegistrationHandler(BaseHandler):
""" """
for c in threepidCreds: for c in threepidCreds:
logger.info("validating theeepidcred sid %s on id server %s", logger.info("validating threepidcred sid %s on id server %s",
c['sid'], c['idServer']) c['sid'], c['idServer'])
try: try:
identity_handler = self.hs.get_handlers().identity_handler identity_handler = self.hs.get_handlers().identity_handler
@ -307,6 +308,11 @@ class RegistrationHandler(BaseHandler):
logger.info("got threepid with medium '%s' and address '%s'", logger.info("got threepid with medium '%s' and address '%s'",
threepid['medium'], threepid['address']) threepid['medium'], threepid['address'])
if not check_3pid_allowed(self.hs, threepid['medium'], threepid['address']):
raise RegistrationError(
403, "Third party identifier is not allowed"
)
@defer.inlineCallbacks @defer.inlineCallbacks
def bind_emails(self, user_id, threepidCreds): def bind_emails(self, user_id, threepidCreds):
"""Links emails with a user ID and informs an identity server. """Links emails with a user ID and informs an identity server.

View File

@ -203,7 +203,8 @@ class RoomListHandler(BaseHandler):
if limit: if limit:
step = limit + 1 step = limit + 1
else: else:
step = len(rooms_to_scan) # step cannot be zero
step = len(rooms_to_scan) if len(rooms_to_scan) != 0 else 1
chunk = [] chunk = []
for i in xrange(0, len(rooms_to_scan), step): for i in xrange(0, len(rooms_to_scan), step):

View File

@ -31,7 +31,7 @@ class SetPasswordHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def set_password(self, user_id, newpassword, requester=None): def set_password(self, user_id, newpassword, requester=None):
password_hash = self._auth_handler.hash(newpassword) password_hash = yield self._auth_handler.hash(newpassword)
except_device_id = requester.device_id if requester else None except_device_id = requester.device_id if requester else None
except_access_token_id = requester.access_token_id if requester else None except_access_token_id = requester.access_token_id if requester else None

View File

@ -18,6 +18,7 @@ from OpenSSL.SSL import VERIFY_NONE
from synapse.api.errors import ( from synapse.api.errors import (
CodeMessageException, MatrixCodeMessageException, SynapseError, Codes, CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
) )
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.logcontext import make_deferred_yieldable from synapse.util.logcontext import make_deferred_yieldable
from synapse.util import logcontext from synapse.util import logcontext
import synapse.metrics import synapse.metrics
@ -30,6 +31,7 @@ from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.web.client import ( from twisted.web.client import (
BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent, BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent,
readBody, PartialDownloadError, readBody, PartialDownloadError,
HTTPConnectionPool,
) )
from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer
from twisted.web.http import PotentialDataLoss from twisted.web.http import PotentialDataLoss
@ -64,13 +66,23 @@ class SimpleHttpClient(object):
""" """
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
pool = HTTPConnectionPool(reactor)
# the pusher makes lots of concurrent SSL connections to sygnal, and
# tends to do so in batches, so we need to allow the pool to keep lots
# of idle connections around.
pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5))
pool.cachedConnectionTimeout = 2 * 60
# The default context factory in Twisted 14.0.0 (which we require) is # The default context factory in Twisted 14.0.0 (which we require) is
# BrowserLikePolicyForHTTPS which will do regular cert validation # BrowserLikePolicyForHTTPS which will do regular cert validation
# 'like a browser' # 'like a browser'
self.agent = Agent( self.agent = Agent(
reactor, reactor,
connectTimeout=15, connectTimeout=15,
contextFactory=hs.get_http_client_context_factory() contextFactory=hs.get_http_client_context_factory(),
pool=pool,
) )
self.user_agent = hs.version_string self.user_agent = hs.version_string
self.clock = hs.get_clock() self.clock = hs.get_clock()

View File

@ -357,8 +357,7 @@ def _get_hosts_for_srv_record(dns_client, host):
def eb(res, record_type): def eb(res, record_type):
if res.check(DNSNameError): if res.check(DNSNameError):
return [] return []
logger.warn("Error looking up %s for %s: %s", logger.warn("Error looking up %s for %s: %s", record_type, host, res)
record_type, host, res, res.value)
return res return res
# no logcontexts here, so we can safely fire these off and gatherResults # no logcontexts here, so we can safely fire these off and gatherResults

View File

@ -27,7 +27,7 @@ import synapse.metrics
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from synapse.api.errors import ( from synapse.api.errors import (
SynapseError, Codes, HttpResponseException, SynapseError, Codes, HttpResponseException, FederationDeniedError,
) )
from signedjson.sign import sign_json from signedjson.sign import sign_json
@ -123,11 +123,22 @@ class MatrixFederationHttpClient(object):
Fails with ``HTTPRequestException``: if we get an HTTP response Fails with ``HTTPRequestException``: if we get an HTTP response
code >= 300. code >= 300.
Fails with ``NotRetryingDestination`` if we are not yet ready Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server. to retry this server.
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
(May also fail with plenty of other Exceptions for things like DNS (May also fail with plenty of other Exceptions for things like DNS
failures, connection failures, SSL failures.) failures, connection failures, SSL failures.)
""" """
if (
self.hs.config.federation_domain_whitelist and
destination not in self.hs.config.federation_domain_whitelist
):
raise FederationDeniedError(destination)
limiter = yield synapse.util.retryutils.get_retry_limiter( limiter = yield synapse.util.retryutils.get_retry_limiter(
destination, destination,
self.clock, self.clock,
@ -308,6 +319,9 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server. to retry this server.
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
""" """
if not json_data_callback: if not json_data_callback:
@ -368,6 +382,9 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server. to retry this server.
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
""" """
def body_callback(method, url_bytes, headers_dict): def body_callback(method, url_bytes, headers_dict):
@ -422,6 +439,9 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server. to retry this server.
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
""" """
logger.debug("get_json args: %s", args) logger.debug("get_json args: %s", args)
@ -475,6 +495,9 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server. to retry this server.
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
""" """
response = yield self._request( response = yield self._request(
@ -518,6 +541,9 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server. to retry this server.
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
""" """
encoded_args = {} encoded_args = {}

View File

@ -42,36 +42,70 @@ logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__) metrics = synapse.metrics.get_metrics_for(__name__)
incoming_requests_counter = metrics.register_counter( # total number of responses served, split by method/servlet/tag
"requests", response_count = metrics.register_counter(
"response_count",
labels=["method", "servlet", "tag"], labels=["method", "servlet", "tag"],
alternative_names=(
# the following are all deprecated aliases for the same metric
metrics.name_prefix + x for x in (
"_requests",
"_response_time:count",
"_response_ru_utime:count",
"_response_ru_stime:count",
"_response_db_txn_count:count",
"_response_db_txn_duration:count",
)
)
) )
outgoing_responses_counter = metrics.register_counter( outgoing_responses_counter = metrics.register_counter(
"responses", "responses",
labels=["method", "code"], labels=["method", "code"],
) )
response_timer = metrics.register_distribution( response_timer = metrics.register_counter(
"response_time", "response_time_seconds",
labels=["method", "servlet", "tag"] labels=["method", "servlet", "tag"],
alternative_names=(
metrics.name_prefix + "_response_time:total",
),
) )
response_ru_utime = metrics.register_distribution( response_ru_utime = metrics.register_counter(
"response_ru_utime", labels=["method", "servlet", "tag"] "response_ru_utime_seconds", labels=["method", "servlet", "tag"],
alternative_names=(
metrics.name_prefix + "_response_ru_utime:total",
),
) )
response_ru_stime = metrics.register_distribution( response_ru_stime = metrics.register_counter(
"response_ru_stime", labels=["method", "servlet", "tag"] "response_ru_stime_seconds", labels=["method", "servlet", "tag"],
alternative_names=(
metrics.name_prefix + "_response_ru_stime:total",
),
) )
response_db_txn_count = metrics.register_distribution( response_db_txn_count = metrics.register_counter(
"response_db_txn_count", labels=["method", "servlet", "tag"] "response_db_txn_count", labels=["method", "servlet", "tag"],
alternative_names=(
metrics.name_prefix + "_response_db_txn_count:total",
),
) )
response_db_txn_duration = metrics.register_distribution( # seconds spent waiting for db txns, excluding scheduling time, when processing
"response_db_txn_duration", labels=["method", "servlet", "tag"] # this request
response_db_txn_duration = metrics.register_counter(
"response_db_txn_duration_seconds", labels=["method", "servlet", "tag"],
alternative_names=(
metrics.name_prefix + "_response_db_txn_duration:total",
),
) )
# seconds spent waiting for a db connection, when processing this request
response_db_sched_duration = metrics.register_counter(
"response_db_sched_duration_seconds", labels=["method", "servlet", "tag"]
)
_next_request_id = 0 _next_request_id = 0
@ -107,6 +141,10 @@ def wrap_request_handler(request_handler, include_metrics=False):
with LoggingContext(request_id) as request_context: with LoggingContext(request_id) as request_context:
with Measure(self.clock, "wrapped_request_handler"): with Measure(self.clock, "wrapped_request_handler"):
request_metrics = RequestMetrics() request_metrics = RequestMetrics()
# we start the request metrics timer here with an initial stab
# at the servlet name. For most requests that name will be
# JsonResource (or a subclass), and JsonResource._async_render
# will update it once it picks a servlet.
request_metrics.start(self.clock, name=self.__class__.__name__) request_metrics.start(self.clock, name=self.__class__.__name__)
request_context.request = request_id request_context.request = request_id
@ -249,12 +287,23 @@ class JsonResource(HttpServer, resource.Resource):
if not m: if not m:
continue continue
# We found a match! Trigger callback and then return the # We found a match! First update the metrics object to indicate
# returned response. We pass both the request and any # which servlet is handling the request.
# matched groups from the regex to the callback.
callback = path_entry.callback callback = path_entry.callback
servlet_instance = getattr(callback, "__self__", None)
if servlet_instance is not None:
servlet_classname = servlet_instance.__class__.__name__
else:
servlet_classname = "%r" % callback
request_metrics.name = servlet_classname
# Now trigger the callback. If it returns a response, we send it
# here. If it throws an exception, that is handled by the wrapper
# installed by @request_handler.
kwargs = intern_dict({ kwargs = intern_dict({
name: urllib.unquote(value).decode("UTF-8") if value else value name: urllib.unquote(value).decode("UTF-8") if value else value
for name, value in m.groupdict().items() for name, value in m.groupdict().items()
@ -265,30 +314,14 @@ class JsonResource(HttpServer, resource.Resource):
code, response = callback_return code, response = callback_return
self._send_response(request, code, response) self._send_response(request, code, response)
servlet_instance = getattr(callback, "__self__", None)
if servlet_instance is not None:
servlet_classname = servlet_instance.__class__.__name__
else:
servlet_classname = "%r" % callback
request_metrics.name = servlet_classname
return return
# Huh. No one wanted to handle that? Fiiiiiine. Send 400. # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
request_metrics.name = self.__class__.__name__ + ".UnrecognizedRequest"
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
def _send_response(self, request, code, response_json_object, def _send_response(self, request, code, response_json_object,
response_code_message=None): response_code_message=None):
# could alternatively use request.notifyFinish() and flip a flag when
# the Deferred fires, but since the flag is RIGHT THERE it seems like
# a waste.
if request._disconnected:
logger.warn(
"Not sending response to request %s, already disconnected.",
request)
return
outgoing_responses_counter.inc(request.method, str(code)) outgoing_responses_counter.inc(request.method, str(code))
# TODO: Only enable CORS for the requests that need it. # TODO: Only enable CORS for the requests that need it.
@ -322,7 +355,7 @@ class RequestMetrics(object):
) )
return return
incoming_requests_counter.inc(request.method, self.name, tag) response_count.inc(request.method, self.name, tag)
response_timer.inc_by( response_timer.inc_by(
clock.time_msec() - self.start, request.method, clock.time_msec() - self.start, request.method,
@ -341,7 +374,10 @@ class RequestMetrics(object):
context.db_txn_count, request.method, self.name, tag context.db_txn_count, request.method, self.name, tag
) )
response_db_txn_duration.inc_by( response_db_txn_duration.inc_by(
context.db_txn_duration, request.method, self.name, tag context.db_txn_duration_ms / 1000., request.method, self.name, tag
)
response_db_sched_duration.inc_by(
context.db_sched_duration_ms / 1000., request.method, self.name, tag
) )
@ -364,6 +400,15 @@ class RootRedirect(resource.Resource):
def respond_with_json(request, code, json_object, send_cors=False, def respond_with_json(request, code, json_object, send_cors=False,
response_code_message=None, pretty_print=False, response_code_message=None, pretty_print=False,
version_string="", canonical_json=True): version_string="", canonical_json=True):
# could alternatively use request.notifyFinish() and flip a flag when
# the Deferred fires, but since the flag is RIGHT THERE it seems like
# a waste.
if request._disconnected:
logger.warn(
"Not sending response to request %s, already disconnected.",
request)
return
if pretty_print: if pretty_print:
json_bytes = encode_pretty_printed_json(json_object) + "\n" json_bytes = encode_pretty_printed_json(json_object) + "\n"
else: else:

View File

@ -66,14 +66,15 @@ class SynapseRequest(Request):
context = LoggingContext.current_context() context = LoggingContext.current_context()
ru_utime, ru_stime = context.get_resource_usage() ru_utime, ru_stime = context.get_resource_usage()
db_txn_count = context.db_txn_count db_txn_count = context.db_txn_count
db_txn_duration = context.db_txn_duration db_txn_duration_ms = context.db_txn_duration_ms
db_sched_duration_ms = context.db_sched_duration_ms
except Exception: except Exception:
ru_utime, ru_stime = (0, 0) ru_utime, ru_stime = (0, 0)
db_txn_count, db_txn_duration = (0, 0) db_txn_count, db_txn_duration_ms = (0, 0)
self.site.access_logger.info( self.site.access_logger.info(
"%s - %s - {%s}" "%s - %s - {%s}"
" Processed request: %dms (%dms, %dms) (%dms/%d)" " Processed request: %dms (%dms, %dms) (%dms/%dms/%d)"
" %sB %s \"%s %s %s\" \"%s\"", " %sB %s \"%s %s %s\" \"%s\"",
self.getClientIP(), self.getClientIP(),
self.site.site_tag, self.site.site_tag,
@ -81,7 +82,8 @@ class SynapseRequest(Request):
int(time.time() * 1000) - self.start_time, int(time.time() * 1000) - self.start_time,
int(ru_utime * 1000), int(ru_utime * 1000),
int(ru_stime * 1000), int(ru_stime * 1000),
int(db_txn_duration * 1000), db_sched_duration_ms,
db_txn_duration_ms,
int(db_txn_count), int(db_txn_count),
self.sentLength, self.sentLength,
self.code, self.code,

View File

@ -146,10 +146,15 @@ def runUntilCurrentTimer(func):
num_pending += 1 num_pending += 1
num_pending += len(reactor.threadCallQueue) num_pending += len(reactor.threadCallQueue)
start = time.time() * 1000 start = time.time() * 1000
ret = func(*args, **kwargs) ret = func(*args, **kwargs)
end = time.time() * 1000 end = time.time() * 1000
# record the amount of wallclock time spent running pending calls.
# This is a proxy for the actual amount of time between reactor polls,
# since about 25% of time is actually spent running things triggered by
# I/O events, but that is harder to capture without rewriting half the
# reactor.
tick_time.inc_by(end - start) tick_time.inc_by(end - start)
pending_calls_metric.inc_by(num_pending) pending_calls_metric.inc_by(num_pending)

View File

@ -15,18 +15,38 @@
from itertools import chain from itertools import chain
import logging
logger = logging.getLogger(__name__)
# TODO(paul): I can't believe Python doesn't have one of these def flatten(items):
def map_concat(func, items): """Flatten a list of lists
# flatten a list-of-lists
return list(chain.from_iterable(map(func, items))) Args:
items: iterable[iterable[X]]
Returns:
list[X]: flattened list
"""
return list(chain.from_iterable(items))
class BaseMetric(object): class BaseMetric(object):
"""Base class for metrics which report a single value per label set
"""
def __init__(self, name, labels=[]): def __init__(self, name, labels=[], alternative_names=[]):
self.name = name """
Args:
name (str): principal name for this metric
labels (list(str)): names of the labels which will be reported
for this metric
alternative_names (iterable(str)): list of alternative names for
this metric. This can be useful to provide a migration path
when renaming metrics.
"""
self._names = [name] + list(alternative_names)
self.labels = labels # OK not to clone as we never write it self.labels = labels # OK not to clone as we never write it
def dimension(self): def dimension(self):
@ -36,7 +56,7 @@ class BaseMetric(object):
return not len(self.labels) return not len(self.labels)
def _render_labelvalue(self, value): def _render_labelvalue(self, value):
# TODO: some kind of value escape # TODO: escape backslashes, quotes and newlines
return '"%s"' % (value) return '"%s"' % (value)
def _render_key(self, values): def _render_key(self, values):
@ -47,19 +67,60 @@ class BaseMetric(object):
for k, v in zip(self.labels, values)]) for k, v in zip(self.labels, values)])
) )
def _render_for_labels(self, label_values, value):
"""Render this metric for a single set of labels
Args:
label_values (list[str]): values for each of the labels
value: value of the metric at with these labels
Returns:
iterable[str]: rendered metric
"""
rendered_labels = self._render_key(label_values)
return (
"%s%s %.12g" % (name, rendered_labels, value)
for name in self._names
)
def render(self):
"""Render this metric
Each metric is rendered as:
name{label1="val1",label2="val2"} value
https://prometheus.io/docs/instrumenting/exposition_formats/#text-format-details
Returns:
iterable[str]: rendered metrics
"""
raise NotImplementedError()
class CounterMetric(BaseMetric): class CounterMetric(BaseMetric):
"""The simplest kind of metric; one that stores a monotonically-increasing """The simplest kind of metric; one that stores a monotonically-increasing
integer that counts events.""" value that counts events or running totals.
Example use cases for Counters:
- Number of requests processed
- Number of items that were inserted into a queue
- Total amount of data that a system has processed
Counters can only go up (and be reset when the process restarts).
"""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(CounterMetric, self).__init__(*args, **kwargs) super(CounterMetric, self).__init__(*args, **kwargs)
# dict[list[str]]: value for each set of label values. the keys are the
# label values, in the same order as the labels in self.labels.
#
# (if the metric is a scalar, the (single) key is the empty list).
self.counts = {} self.counts = {}
# Scalar metrics are never empty # Scalar metrics are never empty
if self.is_scalar(): if self.is_scalar():
self.counts[()] = 0 self.counts[()] = 0.
def inc_by(self, incr, *values): def inc_by(self, incr, *values):
if len(values) != self.dimension(): if len(values) != self.dimension():
@ -77,11 +138,11 @@ class CounterMetric(BaseMetric):
def inc(self, *values): def inc(self, *values):
self.inc_by(1, *values) self.inc_by(1, *values)
def render_item(self, k):
return ["%s%s %d" % (self.name, self._render_key(k), self.counts[k])]
def render(self): def render(self):
return map_concat(self.render_item, sorted(self.counts.keys())) return flatten(
self._render_for_labels(k, self.counts[k])
for k in sorted(self.counts.keys())
)
class CallbackMetric(BaseMetric): class CallbackMetric(BaseMetric):
@ -95,13 +156,19 @@ class CallbackMetric(BaseMetric):
self.callback = callback self.callback = callback
def render(self): def render(self):
value = self.callback() try:
value = self.callback()
except Exception:
logger.exception("Failed to render %s", self.name)
return ["# FAILED to render " + self.name]
if self.is_scalar(): if self.is_scalar():
return ["%s %.12g" % (self.name, value)] return list(self._render_for_labels([], value))
return ["%s%s %.12g" % (self.name, self._render_key(k), value[k]) return flatten(
for k in sorted(value.keys())] self._render_for_labels(k, value[k])
for k in sorted(value.keys())
)
class DistributionMetric(object): class DistributionMetric(object):

View File

@ -13,21 +13,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from synapse.push import PusherConfigException
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.error import AlreadyCalled, AlreadyCancelled from twisted.internet.error import AlreadyCalled, AlreadyCancelled
import logging
import push_rule_evaluator import push_rule_evaluator
import push_tools import push_tools
import synapse
from synapse.push import PusherConfigException
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__)
http_push_processed_counter = metrics.register_counter(
"http_pushes_processed",
)
http_push_failed_counter = metrics.register_counter(
"http_pushes_failed",
)
class HttpPusher(object): class HttpPusher(object):
INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes
@ -152,9 +161,16 @@ class HttpPusher(object):
self.user_id, self.last_stream_ordering, self.max_stream_ordering self.user_id, self.last_stream_ordering, self.max_stream_ordering
) )
logger.info(
"Processing %i unprocessed push actions for %s starting at "
"stream_ordering %s",
len(unprocessed), self.name, self.last_stream_ordering,
)
for push_action in unprocessed: for push_action in unprocessed:
processed = yield self._process_one(push_action) processed = yield self._process_one(push_action)
if processed: if processed:
http_push_processed_counter.inc()
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action['stream_ordering'] self.last_stream_ordering = push_action['stream_ordering']
yield self.store.update_pusher_last_stream_ordering_and_success( yield self.store.update_pusher_last_stream_ordering_and_success(
@ -169,6 +185,7 @@ class HttpPusher(object):
self.failing_since self.failing_since
) )
else: else:
http_push_failed_counter.inc()
if not self.failing_since: if not self.failing_since:
self.failing_since = self.clock.time_msec() self.failing_since = self.clock.time_msec()
yield self.store.update_pusher_failing_since( yield self.store.update_pusher_failing_since(
@ -316,7 +333,10 @@ class HttpPusher(object):
try: try:
resp = yield self.http_client.post_json_get_json(self.url, notification_dict) resp = yield self.http_client.post_json_get_json(self.url, notification_dict)
except Exception: except Exception:
logger.warn("Failed to push %s ", self.url) logger.warn(
"Failed to push event %s to %s",
event.event_id, self.name, exc_info=True,
)
defer.returnValue(False) defer.returnValue(False)
rejected = [] rejected = []
if 'rejected' in resp: if 'rejected' in resp:
@ -325,7 +345,7 @@ class HttpPusher(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _send_badge(self, badge): def _send_badge(self, badge):
logger.info("Sending updated badge count %d to %r", badge, self.user_id) logger.info("Sending updated badge count %d to %s", badge, self.name)
d = { d = {
'notification': { 'notification': {
'id': '', 'id': '',
@ -347,7 +367,10 @@ class HttpPusher(object):
try: try:
resp = yield self.http_client.post_json_get_json(self.url, d) resp = yield self.http_client.post_json_get_json(self.url, d)
except Exception: except Exception:
logger.exception("Failed to push %s ", self.url) logger.warn(
"Failed to send badge count to %s",
self.name, exc_info=True,
)
defer.returnValue(False) defer.returnValue(False)
rejected = [] rejected = []
if 'rejected' in resp: if 'rejected' in resp:

View File

@ -517,25 +517,28 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.send_error("Wrong remote") self.send_error("Wrong remote")
def on_RDATA(self, cmd): def on_RDATA(self, cmd):
stream_name = cmd.stream_name
inbound_rdata_count.inc(stream_name)
try: try:
row = STREAMS_MAP[cmd.stream_name].ROW_TYPE(*cmd.row) row = STREAMS_MAP[stream_name].ROW_TYPE(*cmd.row)
except Exception: except Exception:
logger.exception( logger.exception(
"[%s] Failed to parse RDATA: %r %r", "[%s] Failed to parse RDATA: %r %r",
self.id(), cmd.stream_name, cmd.row self.id(), stream_name, cmd.row
) )
raise raise
if cmd.token is None: if cmd.token is None:
# I.e. this is part of a batch of updates for this stream. Batch # I.e. this is part of a batch of updates for this stream. Batch
# until we get an update for the stream with a non None token # until we get an update for the stream with a non None token
self.pending_batches.setdefault(cmd.stream_name, []).append(row) self.pending_batches.setdefault(stream_name, []).append(row)
else: else:
# Check if this is the last of a batch of updates # Check if this is the last of a batch of updates
rows = self.pending_batches.pop(cmd.stream_name, []) rows = self.pending_batches.pop(stream_name, [])
rows.append(row) rows.append(row)
self.handler.on_rdata(cmd.stream_name, cmd.token, rows) self.handler.on_rdata(stream_name, cmd.token, rows)
def on_POSITION(self, cmd): def on_POSITION(self, cmd):
self.handler.on_position(cmd.stream_name, cmd.token) self.handler.on_position(cmd.stream_name, cmd.token)
@ -644,3 +647,9 @@ metrics.register_callback(
}, },
labels=["command", "name", "conn_id"], labels=["command", "name", "conn_id"],
) )
# number of updates received for each RDATA stream
inbound_rdata_count = metrics.register_counter(
"inbound_rdata_count",
labels=["stream_name"],
)

View File

@ -191,19 +191,25 @@ class LoginRestServlet(ClientV1RestServlet):
# convert threepid identifiers to user IDs # convert threepid identifiers to user IDs
if identifier["type"] == "m.id.thirdparty": if identifier["type"] == "m.id.thirdparty":
if 'medium' not in identifier or 'address' not in identifier: address = identifier.get('address')
medium = identifier.get('medium')
if medium is None or address is None:
raise SynapseError(400, "Invalid thirdparty identifier") raise SynapseError(400, "Invalid thirdparty identifier")
address = identifier['address'] if medium == 'email':
if identifier['medium'] == 'email':
# For emails, transform the address to lowercase. # For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB. # We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py) # (See add_threepid in synapse/handlers/auth.py)
address = address.lower() address = address.lower()
user_id = yield self.hs.get_datastore().get_user_id_by_threepid( user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
identifier['medium'], address medium, address,
) )
if not user_id: if not user_id:
logger.warn(
"unknown 3pid identifier medium %s, address %r",
medium, address,
)
raise LoginError(403, "", errcode=Codes.FORBIDDEN) raise LoginError(403, "", errcode=Codes.FORBIDDEN)
identifier = { identifier = {

View File

@ -70,10 +70,15 @@ class RegisterRestServlet(ClientV1RestServlet):
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
def on_GET(self, request): def on_GET(self, request):
require_email = 'email' in self.hs.config.registrations_require_3pid
require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
flows = []
if self.hs.config.enable_registration_captcha: if self.hs.config.enable_registration_captcha:
return ( # only support the email-only flow if we don't require MSISDN 3PIDs
200, if not require_msisdn:
{"flows": [ flows.extend([
{ {
"type": LoginType.RECAPTCHA, "type": LoginType.RECAPTCHA,
"stages": [ "stages": [
@ -82,27 +87,34 @@ class RegisterRestServlet(ClientV1RestServlet):
LoginType.PASSWORD LoginType.PASSWORD
] ]
}, },
])
# only support 3PIDless registration if no 3PIDs are required
if not require_email and not require_msisdn:
flows.extend([
{ {
"type": LoginType.RECAPTCHA, "type": LoginType.RECAPTCHA,
"stages": [LoginType.RECAPTCHA, LoginType.PASSWORD] "stages": [LoginType.RECAPTCHA, LoginType.PASSWORD]
} }
]} ])
)
else: else:
return ( # only support the email-only flow if we don't require MSISDN 3PIDs
200, if require_email or not require_msisdn:
{"flows": [ flows.extend([
{ {
"type": LoginType.EMAIL_IDENTITY, "type": LoginType.EMAIL_IDENTITY,
"stages": [ "stages": [
LoginType.EMAIL_IDENTITY, LoginType.PASSWORD LoginType.EMAIL_IDENTITY, LoginType.PASSWORD
] ]
}, }
])
# only support 3PIDless registration if no 3PIDs are required
if not require_email and not require_msisdn:
flows.extend([
{ {
"type": LoginType.PASSWORD "type": LoginType.PASSWORD
} }
]} ])
) return (200, {"flows": flows})
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):

View File

@ -195,15 +195,20 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
event_dict = {
"type": event_type,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
}
if 'ts' in request.args and requester.app_service:
event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
event = yield msg_handler.create_and_send_nonmember_event( event = yield msg_handler.create_and_send_nonmember_event(
requester, requester,
{ event_dict,
"type": event_type,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
},
txn_id=txn_id, txn_id=txn_id,
) )
@ -487,13 +492,35 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
defer.returnValue((200, content)) defer.returnValue((200, content))
class RoomEventContext(ClientV1RestServlet): class RoomEventServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns(
"/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$"
)
def __init__(self, hs):
super(RoomEventServlet, self).__init__(hs)
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_id):
requester = yield self.auth.get_user_by_req(request)
event = yield self.event_handler.get_event(requester.user, event_id)
time_now = self.clock.time_msec()
if event:
defer.returnValue((200, serialize_event(event, time_now)))
else:
defer.returnValue((404, "Event not found."))
class RoomEventContextServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns( PATTERNS = client_path_patterns(
"/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$" "/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$"
) )
def __init__(self, hs): def __init__(self, hs):
super(RoomEventContext, self).__init__(hs) super(RoomEventContextServlet, self).__init__(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
@ -803,4 +830,5 @@ def register_servlets(hs, http_server):
RoomTypingRestServlet(hs).register(http_server) RoomTypingRestServlet(hs).register(http_server)
SearchRestServlet(hs).register(http_server) SearchRestServlet(hs).register(http_server)
JoinedRoomsRestServlet(hs).register(http_server) JoinedRoomsRestServlet(hs).register(http_server)
RoomEventContext(hs).register(http_server) RoomEventServlet(hs).register(http_server)
RoomEventContextServlet(hs).register(http_server)

View File

@ -26,6 +26,7 @@ from synapse.http.servlet import (
) )
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import check_3pid_allowed
from ._base import client_v2_patterns, interactive_auth_handler from ._base import client_v2_patterns, interactive_auth_handler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -47,6 +48,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
'id_server', 'client_secret', 'email', 'send_attempt' 'id_server', 'client_secret', 'email', 'send_attempt'
]) ])
if not check_3pid_allowed(self.hs, "email", body['email']):
raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email'] 'email', body['email']
) )
@ -78,6 +84,11 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
)
existingUid = yield self.datastore.get_user_id_by_threepid( existingUid = yield self.datastore.get_user_id_by_threepid(
'msisdn', msisdn 'msisdn', msisdn
) )
@ -217,6 +228,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
if absent: if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
if not check_3pid_allowed(self.hs, "email", body['email']):
raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
)
existingUid = yield self.datastore.get_user_id_by_threepid( existingUid = yield self.datastore.get_user_id_by_threepid(
'email', body['email'] 'email', body['email']
) )
@ -255,6 +271,11 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
)
existingUid = yield self.datastore.get_user_id_by_threepid( existingUid = yield self.datastore.get_user_id_by_threepid(
'msisdn', msisdn 'msisdn', msisdn
) )

View File

@ -26,6 +26,7 @@ from synapse.http.servlet import (
RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string
) )
from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import check_3pid_allowed
from ._base import client_v2_patterns, interactive_auth_handler from ._base import client_v2_patterns, interactive_auth_handler
@ -70,6 +71,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
'id_server', 'client_secret', 'email', 'send_attempt' 'id_server', 'client_secret', 'email', 'send_attempt'
]) ])
if not check_3pid_allowed(self.hs, "email", body['email']):
raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email'] 'email', body['email']
) )
@ -105,6 +111,11 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'msisdn', msisdn 'msisdn', msisdn
) )
@ -305,31 +316,67 @@ class RegisterRestServlet(RestServlet):
if 'x_show_msisdn' in body and body['x_show_msisdn']: if 'x_show_msisdn' in body and body['x_show_msisdn']:
show_msisdn = True show_msisdn = True
# FIXME: need a better error than "no auth flow found" for scenarios
# where we required 3PID for registration but the user didn't give one
require_email = 'email' in self.hs.config.registrations_require_3pid
require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
flows = []
if self.hs.config.enable_registration_captcha: if self.hs.config.enable_registration_captcha:
flows = [ # only support 3PIDless registration if no 3PIDs are required
[LoginType.RECAPTCHA], if not require_email and not require_msisdn:
[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA], flows.extend([[LoginType.RECAPTCHA]])
] # only support the email-only flow if we don't require MSISDN 3PIDs
if not require_msisdn:
flows.extend([[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]])
if show_msisdn: if show_msisdn:
# only support the MSISDN-only flow if we don't require email 3PIDs
if not require_email:
flows.extend([[LoginType.MSISDN, LoginType.RECAPTCHA]])
# always let users provide both MSISDN & email
flows.extend([ flows.extend([
[LoginType.MSISDN, LoginType.RECAPTCHA],
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA], [LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
]) ])
else: else:
flows = [ # only support 3PIDless registration if no 3PIDs are required
[LoginType.DUMMY], if not require_email and not require_msisdn:
[LoginType.EMAIL_IDENTITY], flows.extend([[LoginType.DUMMY]])
] # only support the email-only flow if we don't require MSISDN 3PIDs
if not require_msisdn:
flows.extend([[LoginType.EMAIL_IDENTITY]])
if show_msisdn: if show_msisdn:
# only support the MSISDN-only flow if we don't require email 3PIDs
if not require_email or require_msisdn:
flows.extend([[LoginType.MSISDN]])
# always let users provide both MSISDN & email
flows.extend([ flows.extend([
[LoginType.MSISDN], [LoginType.MSISDN, LoginType.EMAIL_IDENTITY]
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
]) ])
auth_result, params, session_id = yield self.auth_handler.check_auth( auth_result, params, session_id = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request) flows, body, self.hs.get_ip_from_request(request)
) )
# Check that we're not trying to register a denied 3pid.
#
# the user-facing checks will probably already have happened in
# /register/email/requestToken when we requested a 3pid, but that's not
# guaranteed.
if auth_result:
for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
if login_type in auth_result:
medium = auth_result[login_type]['medium']
address = auth_result[login_type]['address']
if not check_3pid_allowed(self.hs, medium, address):
raise SynapseError(
403, "Third party identifier is not allowed",
Codes.THREEPID_DENIED,
)
if registered_user_id is not None: if registered_user_id is not None:
logger.info( logger.info(
"Already registered user ID %r for this session", "Already registered user ID %r for this session",

View File

@ -93,6 +93,7 @@ class RemoteKey(Resource):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.version_string = hs.version_string self.version_string = hs.version_string
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
def render_GET(self, request): def render_GET(self, request):
self.async_render_GET(request) self.async_render_GET(request)
@ -137,6 +138,13 @@ class RemoteKey(Resource):
logger.info("Handling query for keys %r", query) logger.info("Handling query for keys %r", query)
store_queries = [] store_queries = []
for server_name, key_ids in query.items(): for server_name, key_ids in query.items():
if (
self.federation_domain_whitelist is not None and
server_name not in self.federation_domain_whitelist
):
logger.debug("Federation denied with %s", server_name)
continue
if not key_ids: if not key_ids:
key_ids = (None,) key_ids = (None,)
for key_id in key_ids: for key_id in key_ids:

View File

@ -70,38 +70,11 @@ def respond_with_file(request, media_type, file_path,
logger.debug("Responding with %r", file_path) logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path): if os.path.isfile(file_path):
request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
if upload_name:
if is_ascii(upload_name):
request.setHeader(
b"Content-Disposition",
b"inline; filename=%s" % (
urllib.quote(upload_name.encode("utf-8")),
),
)
else:
request.setHeader(
b"Content-Disposition",
b"inline; filename*=utf-8''%s" % (
urllib.quote(upload_name.encode("utf-8")),
),
)
# cache for at least a day.
# XXX: we might want to turn this off for data we don't want to
# recommend caching as it's sensitive or private - or at least
# select private. don't bother setting Expires as all our
# clients are smart enough to be happy with Cache-Control
request.setHeader(
b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
)
if file_size is None: if file_size is None:
stat = os.stat(file_path) stat = os.stat(file_path)
file_size = stat.st_size file_size = stat.st_size
request.setHeader( add_file_headers(request, media_type, file_size, upload_name)
b"Content-Length", b"%d" % (file_size,)
)
with open(file_path, "rb") as f: with open(file_path, "rb") as f:
yield logcontext.make_deferred_yieldable( yield logcontext.make_deferred_yieldable(
@ -111,3 +84,118 @@ def respond_with_file(request, media_type, file_path,
finish_request(request) finish_request(request)
else: else:
respond_404(request) respond_404(request)
def add_file_headers(request, media_type, file_size, upload_name):
"""Adds the correct response headers in preparation for responding with the
media.
Args:
request (twisted.web.http.Request)
media_type (str): The media/content type.
file_size (int): Size in bytes of the media, if known.
upload_name (str): The name of the requested file, if any.
"""
request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
if upload_name:
if is_ascii(upload_name):
request.setHeader(
b"Content-Disposition",
b"inline; filename=%s" % (
urllib.quote(upload_name.encode("utf-8")),
),
)
else:
request.setHeader(
b"Content-Disposition",
b"inline; filename*=utf-8''%s" % (
urllib.quote(upload_name.encode("utf-8")),
),
)
# cache for at least a day.
# XXX: we might want to turn this off for data we don't want to
# recommend caching as it's sensitive or private - or at least
# select private. don't bother setting Expires as all our
# clients are smart enough to be happy with Cache-Control
request.setHeader(
b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
)
request.setHeader(
b"Content-Length", b"%d" % (file_size,)
)
@defer.inlineCallbacks
def respond_with_responder(request, responder, media_type, file_size, upload_name=None):
"""Responds to the request with given responder. If responder is None then
returns 404.
Args:
request (twisted.web.http.Request)
responder (Responder|None)
media_type (str): The media/content type.
file_size (int|None): Size in bytes of the media. If not known it should be None
upload_name (str|None): The name of the requested file, if any.
"""
if not responder:
respond_404(request)
return
add_file_headers(request, media_type, file_size, upload_name)
with responder:
yield responder.write_to_consumer(request)
finish_request(request)
class Responder(object):
"""Represents a response that can be streamed to the requester.
Responder is a context manager which *must* be used, so that any resources
held can be cleaned up.
"""
def write_to_consumer(self, consumer):
"""Stream response into consumer
Args:
consumer (IConsumer)
Returns:
Deferred: Resolves once the response has finished being written
"""
pass
def __enter__(self):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
pass
class FileInfo(object):
"""Details about a requested/uploaded file.
Attributes:
server_name (str): The server name where the media originated from,
or None if local.
file_id (str): The local ID of the file. For local files this is the
same as the media_id
url_cache (bool): If the file is for the url preview cache
thumbnail (bool): Whether the file is a thumbnail or not.
thumbnail_width (int)
thumbnail_height (int)
thumbnail_method (str)
thumbnail_type (str): Content type of thumbnail, e.g. image/png
"""
def __init__(self, server_name, file_id, url_cache=False,
thumbnail=False, thumbnail_width=None, thumbnail_height=None,
thumbnail_method=None, thumbnail_type=None):
self.server_name = server_name
self.file_id = file_id
self.url_cache = url_cache
self.thumbnail = thumbnail
self.thumbnail_width = thumbnail_width
self.thumbnail_height = thumbnail_height
self.thumbnail_method = thumbnail_method
self.thumbnail_type = thumbnail_type

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import synapse.http.servlet import synapse.http.servlet
from ._base import parse_media_id, respond_with_file, respond_404 from ._base import parse_media_id, respond_404
from twisted.web.resource import Resource from twisted.web.resource import Resource
from synapse.http.server import request_handler, set_cors_headers from synapse.http.server import request_handler, set_cors_headers
@ -32,12 +32,12 @@ class DownloadResource(Resource):
def __init__(self, hs, media_repo): def __init__(self, hs, media_repo):
Resource.__init__(self) Resource.__init__(self)
self.filepaths = media_repo.filepaths
self.media_repo = media_repo self.media_repo = media_repo
self.server_name = hs.hostname self.server_name = hs.hostname
self.store = hs.get_datastore()
self.version_string = hs.version_string # Both of these are expected by @request_handler()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.version_string = hs.version_string
def render_GET(self, request): def render_GET(self, request):
self._async_render_GET(request) self._async_render_GET(request)
@ -57,59 +57,16 @@ class DownloadResource(Resource):
) )
server_name, media_id, name = parse_media_id(request) server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name: if server_name == self.server_name:
yield self._respond_local_file(request, media_id, name) yield self.media_repo.get_local_media(request, media_id, name)
else: else:
yield self._respond_remote_file( allow_remote = synapse.http.servlet.parse_boolean(
request, server_name, media_id, name request, "allow_remote", default=True)
) if not allow_remote:
logger.info(
"Rejecting request for remote media %s/%s due to allow_remote",
server_name, media_id,
)
respond_404(request)
return
@defer.inlineCallbacks yield self.media_repo.get_remote_media(request, server_name, media_id, name)
def _respond_local_file(self, request, media_id, name):
media_info = yield self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]:
respond_404(request)
return
media_type = media_info["media_type"]
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
if media_info["url_cache"]:
# TODO: Check the file still exists, if it doesn't we can redownload
# it from the url `media_info["url_cache"]`
file_path = self.filepaths.url_cache_filepath(media_id)
else:
file_path = self.filepaths.local_media_filepath(media_id)
yield respond_with_file(
request, media_type, file_path, media_length,
upload_name=upload_name,
)
@defer.inlineCallbacks
def _respond_remote_file(self, request, server_name, media_id, name):
# don't forward requests for remote media if allow_remote is false
allow_remote = synapse.http.servlet.parse_boolean(
request, "allow_remote", default=True)
if not allow_remote:
logger.info(
"Rejecting request for remote media %s/%s due to allow_remote",
server_name, media_id,
)
respond_404(request)
return
media_info = yield self.media_repo.get_remote_media(server_name, media_id)
media_type = media_info["media_type"]
media_length = media_info["media_length"]
filesystem_id = media_info["filesystem_id"]
upload_name = name if name else media_info["upload_name"]
file_path = self.filepaths.remote_media_filepath(
server_name, filesystem_id
)
yield respond_with_file(
request, media_type, file_path, media_length,
upload_name=upload_name,
)

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -18,6 +19,7 @@ import twisted.internet.error
import twisted.web.http import twisted.web.http
from twisted.web.resource import Resource from twisted.web.resource import Resource
from ._base import respond_404, FileInfo, respond_with_responder
from .upload_resource import UploadResource from .upload_resource import UploadResource
from .download_resource import DownloadResource from .download_resource import DownloadResource
from .thumbnail_resource import ThumbnailResource from .thumbnail_resource import ThumbnailResource
@ -25,15 +27,18 @@ from .identicon_resource import IdenticonResource
from .preview_url_resource import PreviewUrlResource from .preview_url_resource import PreviewUrlResource
from .filepath import MediaFilePaths from .filepath import MediaFilePaths
from .thumbnailer import Thumbnailer from .thumbnailer import Thumbnailer
from .storage_provider import StorageProviderWrapper
from .media_storage import MediaStorage
from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from synapse.api.errors import SynapseError, HttpResponseException, \ from synapse.api.errors import (
NotFoundError SynapseError, HttpResponseException, NotFoundError, FederationDeniedError,
)
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from synapse.util.stringutils import is_ascii from synapse.util.stringutils import is_ascii
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn from synapse.util.logcontext import make_deferred_yieldable
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
import os import os
@ -47,7 +52,7 @@ import urlparse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000 UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
class MediaRepository(object): class MediaRepository(object):
@ -63,96 +68,62 @@ class MediaRepository(object):
self.primary_base_path = hs.config.media_store_path self.primary_base_path = hs.config.media_store_path
self.filepaths = MediaFilePaths(self.primary_base_path) self.filepaths = MediaFilePaths(self.primary_base_path)
self.backup_base_path = hs.config.backup_media_store_path
self.synchronous_backup_media_store = hs.config.synchronous_backup_media_store
self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements self.thumbnail_requirements = hs.config.thumbnail_requirements
self.remote_media_linearizer = Linearizer(name="media_remote") self.remote_media_linearizer = Linearizer(name="media_remote")
self.recently_accessed_remotes = set() self.recently_accessed_remotes = set()
self.recently_accessed_locals = set()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
# List of StorageProviders where we should search for media and
# potentially upload to.
storage_providers = []
for clz, provider_config, wrapper_config in hs.config.media_storage_providers:
backend = clz(hs, provider_config)
provider = StorageProviderWrapper(
backend,
store_local=wrapper_config.store_local,
store_remote=wrapper_config.store_remote,
store_synchronous=wrapper_config.store_synchronous,
)
storage_providers.append(provider)
self.media_storage = MediaStorage(
self.primary_base_path, self.filepaths, storage_providers,
)
self.clock.looping_call( self.clock.looping_call(
self._update_recently_accessed_remotes, self._update_recently_accessed,
UPDATE_RECENTLY_ACCESSED_REMOTES_TS UPDATE_RECENTLY_ACCESSED_TS,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _update_recently_accessed_remotes(self): def _update_recently_accessed(self):
media = self.recently_accessed_remotes remote_media = self.recently_accessed_remotes
self.recently_accessed_remotes = set() self.recently_accessed_remotes = set()
local_media = self.recently_accessed_locals
self.recently_accessed_locals = set()
yield self.store.update_cached_last_access_time( yield self.store.update_cached_last_access_time(
media, self.clock.time_msec() local_media, remote_media, self.clock.time_msec()
) )
@staticmethod def mark_recently_accessed(self, server_name, media_id):
def _makedirs(filepath): """Mark the given media as recently accessed.
dirname = os.path.dirname(filepath)
if not os.path.exists(dirname):
os.makedirs(dirname)
@staticmethod
def _write_file_synchronously(source, fname):
"""Write `source` to the path `fname` synchronously. Should be called
from a thread.
Args: Args:
source: A file like object to be written server_name (str|None): Origin server of media, or None if local
fname (str): Path to write to media_id (str): The media ID of the content
""" """
MediaRepository._makedirs(fname) if server_name:
source.seek(0) # Ensure we read from the start of the file self.recently_accessed_remotes.add((server_name, media_id))
with open(fname, "wb") as f: else:
shutil.copyfileobj(source, f) self.recently_accessed_locals.add(media_id)
@defer.inlineCallbacks
def write_to_file_and_backup(self, source, path):
"""Write `source` to the on disk media store, and also the backup store
if configured.
Args:
source: A file like object that should be written
path (str): Relative path to write file to
Returns:
Deferred[str]: the file path written to in the primary media store
"""
fname = os.path.join(self.primary_base_path, path)
# Write to the main repository
yield make_deferred_yieldable(threads.deferToThread(
self._write_file_synchronously, source, fname,
))
# Write to backup repository
yield self.copy_to_backup(path)
defer.returnValue(fname)
@defer.inlineCallbacks
def copy_to_backup(self, path):
"""Copy a file from the primary to backup media store, if configured.
Args:
path(str): Relative path to write file to
"""
if self.backup_base_path:
primary_fname = os.path.join(self.primary_base_path, path)
backup_fname = os.path.join(self.backup_base_path, path)
# We can either wait for successful writing to the backup repository
# or write in the background and immediately return
if self.synchronous_backup_media_store:
yield make_deferred_yieldable(threads.deferToThread(
shutil.copyfile, primary_fname, backup_fname,
))
else:
preserve_fn(threads.deferToThread)(
shutil.copyfile, primary_fname, backup_fname,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def create_content(self, media_type, upload_name, content, content_length, def create_content(self, media_type, upload_name, content, content_length,
@ -171,10 +142,13 @@ class MediaRepository(object):
""" """
media_id = random_string(24) media_id = random_string(24)
fname = yield self.write_to_file_and_backup( file_info = FileInfo(
content, self.filepaths.local_media_filepath_rel(media_id) server_name=None,
file_id=media_id,
) )
fname = yield self.media_storage.store_file(content, file_info)
logger.info("Stored local media in file %r", fname) logger.info("Stored local media in file %r", fname)
yield self.store.store_local_media( yield self.store.store_local_media(
@ -185,134 +159,275 @@ class MediaRepository(object):
media_length=content_length, media_length=content_length,
user_id=auth_user, user_id=auth_user,
) )
media_info = {
"media_type": media_type,
"media_length": content_length,
}
yield self._generate_thumbnails(None, media_id, media_info) yield self._generate_thumbnails(
None, media_id, media_id, media_type,
)
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_remote_media(self, server_name, media_id): def get_local_media(self, request, media_id, name):
"""Responds to reqests for local media, if exists, or returns 404.
Args:
request(twisted.web.http.Request)
media_id (str): The media ID of the content. (This is the same as
the file_id for local content.)
name (str|None): Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response.
Returns:
Deferred: Resolves once a response has successfully been written
to request
"""
media_info = yield self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]:
respond_404(request)
return
self.mark_recently_accessed(None, media_id)
media_type = media_info["media_type"]
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
url_cache = media_info["url_cache"]
file_info = FileInfo(
None, media_id,
url_cache=url_cache,
)
responder = yield self.media_storage.fetch_media(file_info)
yield respond_with_responder(
request, responder, media_type, media_length, upload_name,
)
@defer.inlineCallbacks
def get_remote_media(self, request, server_name, media_id, name):
"""Respond to requests for remote media.
Args:
request(twisted.web.http.Request)
server_name (str): Remote server_name where the media originated.
media_id (str): The media ID of the content (as defined by the
remote server).
name (str|None): Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response.
Returns:
Deferred: Resolves once a response has successfully been written
to request
"""
if (
self.federation_domain_whitelist is not None and
server_name not in self.federation_domain_whitelist
):
raise FederationDeniedError(server_name)
self.mark_recently_accessed(server_name, media_id)
# We linearize here to ensure that we don't try and download remote
# media multiple times concurrently
key = (server_name, media_id) key = (server_name, media_id)
with (yield self.remote_media_linearizer.queue(key)): with (yield self.remote_media_linearizer.queue(key)):
media_info = yield self._get_remote_media_impl(server_name, media_id) responder, media_info = yield self._get_remote_media_impl(
server_name, media_id,
)
# We deliberately stream the file outside the lock
if responder:
media_type = media_info["media_type"]
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
yield respond_with_responder(
request, responder, media_type, media_length, upload_name,
)
else:
respond_404(request)
@defer.inlineCallbacks
def get_remote_media_info(self, server_name, media_id):
"""Gets the media info associated with the remote file, downloading
if necessary.
Args:
server_name (str): Remote server_name where the media originated.
media_id (str): The media ID of the content (as defined by the
remote server).
Returns:
Deferred[dict]: The media_info of the file
"""
if (
self.federation_domain_whitelist is not None and
server_name not in self.federation_domain_whitelist
):
raise FederationDeniedError(server_name)
# We linearize here to ensure that we don't try and download remote
# media multiple times concurrently
key = (server_name, media_id)
with (yield self.remote_media_linearizer.queue(key)):
responder, media_info = yield self._get_remote_media_impl(
server_name, media_id,
)
# Ensure we actually use the responder so that it releases resources
if responder:
with responder:
pass
defer.returnValue(media_info) defer.returnValue(media_info)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id): def _get_remote_media_impl(self, server_name, media_id):
"""Looks for media in local cache, if not there then attempt to
download from remote server.
Args:
server_name (str): Remote server_name where the media originated.
media_id (str): The media ID of the content (as defined by the
remote server).
Returns:
Deferred[(Responder, media_info)]
"""
media_info = yield self.store.get_cached_remote_media( media_info = yield self.store.get_cached_remote_media(
server_name, media_id server_name, media_id
) )
if not media_info:
media_info = yield self._download_remote_file( # file_id is the ID we use to track the file locally. If we've already
server_name, media_id # seen the file then reuse the existing ID, otherwise genereate a new
) # one.
elif media_info["quarantined_by"]: if media_info:
raise NotFoundError() file_id = media_info["filesystem_id"]
else: else:
self.recently_accessed_remotes.add((server_name, media_id)) file_id = random_string(24)
yield self.store.update_cached_last_access_time(
[(server_name, media_id)], self.clock.time_msec() file_info = FileInfo(server_name, file_id)
)
defer.returnValue(media_info) # If we have an entry in the DB, try and look for it
if media_info:
if media_info["quarantined_by"]:
logger.info("Media is quarantined")
raise NotFoundError()
responder = yield self.media_storage.fetch_media(file_info)
if responder:
defer.returnValue((responder, media_info))
# Failed to find the file anywhere, lets download it.
media_info = yield self._download_remote_file(
server_name, media_id, file_id
)
responder = yield self.media_storage.fetch_media(file_info)
defer.returnValue((responder, media_info))
@defer.inlineCallbacks @defer.inlineCallbacks
def _download_remote_file(self, server_name, media_id): def _download_remote_file(self, server_name, media_id, file_id):
file_id = random_string(24) """Attempt to download the remote file from the given server name,
using the given file_id as the local id.
fpath = self.filepaths.remote_media_filepath_rel( Args:
server_name, file_id server_name (str): Originating server
media_id (str): The media ID of the content (as defined by the
remote server). This is different than the file_id, which is
locally generated.
file_id (str): Local file ID
Returns:
Deferred[MediaInfo]
"""
file_info = FileInfo(
server_name=server_name,
file_id=file_id,
) )
fname = os.path.join(self.primary_base_path, fpath)
self._makedirs(fname)
try: with self.media_storage.store_into_file(file_info) as (f, fname, finish):
with open(fname, "wb") as f: request_path = "/".join((
request_path = "/".join(( "/_matrix/media/v1/download", server_name, media_id,
"/_matrix/media/v1/download", server_name, media_id, ))
)) try:
length, headers = yield self.client.get_file(
server_name, request_path, output_stream=f,
max_size=self.max_upload_size, args={
# tell the remote server to 404 if it doesn't
# recognise the server_name, to make sure we don't
# end up with a routing loop.
"allow_remote": "false",
}
)
except twisted.internet.error.DNSLookupError as e:
logger.warn("HTTP error fetching remote media %s/%s: %r",
server_name, media_id, e)
raise NotFoundError()
except HttpResponseException as e:
logger.warn("HTTP error fetching remote media %s/%s: %s",
server_name, media_id, e.response)
if e.code == twisted.web.http.NOT_FOUND:
raise SynapseError.from_http_response_exception(e)
raise SynapseError(502, "Failed to fetch remote media")
except SynapseError:
logger.exception("Failed to fetch remote media %s/%s",
server_name, media_id)
raise
except NotRetryingDestination:
logger.warn("Not retrying destination %r", server_name)
raise SynapseError(502, "Failed to fetch remote media")
except Exception:
logger.exception("Failed to fetch remote media %s/%s",
server_name, media_id)
raise SynapseError(502, "Failed to fetch remote media")
yield finish()
media_type = headers["Content-Type"][0]
time_now_ms = self.clock.time_msec()
content_disposition = headers.get("Content-Disposition", None)
if content_disposition:
_, params = cgi.parse_header(content_disposition[0],)
upload_name = None
# First check if there is a valid UTF-8 filename
upload_name_utf8 = params.get("filename*", None)
if upload_name_utf8:
if upload_name_utf8.lower().startswith("utf-8''"):
upload_name = upload_name_utf8[7:]
# If there isn't check for an ascii name.
if not upload_name:
upload_name_ascii = params.get("filename", None)
if upload_name_ascii and is_ascii(upload_name_ascii):
upload_name = upload_name_ascii
if upload_name:
upload_name = urlparse.unquote(upload_name)
try: try:
length, headers = yield self.client.get_file( upload_name = upload_name.decode("utf-8")
server_name, request_path, output_stream=f, except UnicodeDecodeError:
max_size=self.max_upload_size, args={ upload_name = None
# tell the remote server to 404 if it doesn't else:
# recognise the server_name, to make sure we don't upload_name = None
# end up with a routing loop.
"allow_remote": "false",
}
)
except twisted.internet.error.DNSLookupError as e:
logger.warn("HTTP error fetching remote media %s/%s: %r",
server_name, media_id, e)
raise NotFoundError()
except HttpResponseException as e: logger.info("Stored remote media in file %r", fname)
logger.warn("HTTP error fetching remote media %s/%s: %s",
server_name, media_id, e.response)
if e.code == twisted.web.http.NOT_FOUND:
raise SynapseError.from_http_response_exception(e)
raise SynapseError(502, "Failed to fetch remote media")
except SynapseError: yield self.store.store_cached_remote_media(
logger.exception("Failed to fetch remote media %s/%s", origin=server_name,
server_name, media_id) media_id=media_id,
raise media_type=media_type,
except NotRetryingDestination: time_now_ms=self.clock.time_msec(),
logger.warn("Not retrying destination %r", server_name) upload_name=upload_name,
raise SynapseError(502, "Failed to fetch remote media") media_length=length,
except Exception: filesystem_id=file_id,
logger.exception("Failed to fetch remote media %s/%s", )
server_name, media_id)
raise SynapseError(502, "Failed to fetch remote media")
yield self.copy_to_backup(fpath)
media_type = headers["Content-Type"][0]
time_now_ms = self.clock.time_msec()
content_disposition = headers.get("Content-Disposition", None)
if content_disposition:
_, params = cgi.parse_header(content_disposition[0],)
upload_name = None
# First check if there is a valid UTF-8 filename
upload_name_utf8 = params.get("filename*", None)
if upload_name_utf8:
if upload_name_utf8.lower().startswith("utf-8''"):
upload_name = upload_name_utf8[7:]
# If there isn't check for an ascii name.
if not upload_name:
upload_name_ascii = params.get("filename", None)
if upload_name_ascii and is_ascii(upload_name_ascii):
upload_name = upload_name_ascii
if upload_name:
upload_name = urlparse.unquote(upload_name)
try:
upload_name = upload_name.decode("utf-8")
except UnicodeDecodeError:
upload_name = None
else:
upload_name = None
logger.info("Stored remote media in file %r", fname)
yield self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
)
except Exception:
os.remove(fname)
raise
media_info = { media_info = {
"media_type": media_type, "media_type": media_type,
@ -323,7 +438,7 @@ class MediaRepository(object):
} }
yield self._generate_thumbnails( yield self._generate_thumbnails(
server_name, media_id, media_info server_name, media_id, file_id, media_type,
) )
defer.returnValue(media_info) defer.returnValue(media_info)
@ -368,11 +483,18 @@ class MediaRepository(object):
if t_byte_source: if t_byte_source:
try: try:
output_path = yield self.write_to_file_and_backup( file_info = FileInfo(
t_byte_source, server_name=None,
self.filepaths.local_media_thumbnail_rel( file_id=media_id,
media_id, t_width, t_height, t_type, t_method thumbnail=True,
) thumbnail_width=t_width,
thumbnail_height=t_height,
thumbnail_method=t_method,
thumbnail_type=t_type,
)
output_path = yield self.media_storage.store_file(
t_byte_source, file_info,
) )
finally: finally:
t_byte_source.close() t_byte_source.close()
@ -400,11 +522,18 @@ class MediaRepository(object):
if t_byte_source: if t_byte_source:
try: try:
output_path = yield self.write_to_file_and_backup( file_info = FileInfo(
t_byte_source, server_name=server_name,
self.filepaths.remote_media_thumbnail_rel( file_id=media_id,
server_name, file_id, t_width, t_height, t_type, t_method thumbnail=True,
) thumbnail_width=t_width,
thumbnail_height=t_height,
thumbnail_method=t_method,
thumbnail_type=t_type,
)
output_path = yield self.media_storage.store_file(
t_byte_source, file_info,
) )
finally: finally:
t_byte_source.close() t_byte_source.close()
@ -421,21 +550,22 @@ class MediaRepository(object):
defer.returnValue(output_path) defer.returnValue(output_path)
@defer.inlineCallbacks @defer.inlineCallbacks
def _generate_thumbnails(self, server_name, media_id, media_info, url_cache=False): def _generate_thumbnails(self, server_name, media_id, file_id, media_type,
url_cache=False):
"""Generate and store thumbnails for an image. """Generate and store thumbnails for an image.
Args: Args:
server_name(str|None): The server name if remote media, else None if local server_name (str|None): The server name if remote media, else None if local
media_id(str) media_id (str): The media ID of the content. (This is the same as
media_info(dict) the file_id for local content)
url_cache(bool): If we are thumbnailing images downloaded for the URL cache, file_id (str): Local file ID
media_type (str): The content type of the file
url_cache (bool): If we are thumbnailing images downloaded for the URL cache,
used exclusively by the url previewer used exclusively by the url previewer
Returns: Returns:
Deferred[dict]: Dict with "width" and "height" keys of original image Deferred[dict]: Dict with "width" and "height" keys of original image
""" """
media_type = media_info["media_type"]
file_id = media_info.get("filesystem_id")
requirements = self._get_thumbnail_requirements(media_type) requirements = self._get_thumbnail_requirements(media_type)
if not requirements: if not requirements:
return return
@ -472,20 +602,6 @@ class MediaRepository(object):
# Now we generate the thumbnails for each dimension, store it # Now we generate the thumbnails for each dimension, store it
for (t_width, t_height, t_type), t_method in thumbnails.iteritems(): for (t_width, t_height, t_type), t_method in thumbnails.iteritems():
# Work out the correct file name for thumbnail
if server_name:
file_path = self.filepaths.remote_media_thumbnail_rel(
server_name, file_id, t_width, t_height, t_type, t_method
)
elif url_cache:
file_path = self.filepaths.url_cache_thumbnail_rel(
media_id, t_width, t_height, t_type, t_method
)
else:
file_path = self.filepaths.local_media_thumbnail_rel(
media_id, t_width, t_height, t_type, t_method
)
# Generate the thumbnail # Generate the thumbnail
if t_method == "crop": if t_method == "crop":
t_byte_source = yield make_deferred_yieldable(threads.deferToThread( t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
@ -505,9 +621,19 @@ class MediaRepository(object):
continue continue
try: try:
# Write to disk file_info = FileInfo(
output_path = yield self.write_to_file_and_backup( server_name=server_name,
t_byte_source, file_path, file_id=file_id,
thumbnail=True,
thumbnail_width=t_width,
thumbnail_height=t_height,
thumbnail_method=t_method,
thumbnail_type=t_type,
url_cache=url_cache,
)
output_path = yield self.media_storage.store_file(
t_byte_source, file_info,
) )
finally: finally:
t_byte_source.close() t_byte_source.close()
@ -620,7 +746,11 @@ class MediaRepositoryResource(Resource):
self.putChild("upload", UploadResource(hs, media_repo)) self.putChild("upload", UploadResource(hs, media_repo))
self.putChild("download", DownloadResource(hs, media_repo)) self.putChild("download", DownloadResource(hs, media_repo))
self.putChild("thumbnail", ThumbnailResource(hs, media_repo)) self.putChild("thumbnail", ThumbnailResource(
hs, media_repo, media_repo.media_storage,
))
self.putChild("identicon", IdenticonResource()) self.putChild("identicon", IdenticonResource())
if hs.config.url_preview_enabled: if hs.config.url_preview_enabled:
self.putChild("preview_url", PreviewUrlResource(hs, media_repo)) self.putChild("preview_url", PreviewUrlResource(
hs, media_repo, media_repo.media_storage,
))

View File

@ -0,0 +1,236 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vecotr Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer, threads
from twisted.protocols.basic import FileSender
from ._base import Responder
from synapse.util.logcontext import make_deferred_yieldable
import contextlib
import os
import logging
import shutil
import sys
logger = logging.getLogger(__name__)
class MediaStorage(object):
"""Responsible for storing/fetching files from local sources.
Args:
local_media_directory (str): Base path where we store media on disk
filepaths (MediaFilePaths)
storage_providers ([StorageProvider]): List of StorageProvider that are
used to fetch and store files.
"""
def __init__(self, local_media_directory, filepaths, storage_providers):
self.local_media_directory = local_media_directory
self.filepaths = filepaths
self.storage_providers = storage_providers
@defer.inlineCallbacks
def store_file(self, source, file_info):
"""Write `source` to the on disk media store, and also any other
configured storage providers
Args:
source: A file like object that should be written
file_info (FileInfo): Info about the file to store
Returns:
Deferred[str]: the file path written to in the primary media store
"""
path = self._file_info_to_path(file_info)
fname = os.path.join(self.local_media_directory, path)
dirname = os.path.dirname(fname)
if not os.path.exists(dirname):
os.makedirs(dirname)
# Write to the main repository
yield make_deferred_yieldable(threads.deferToThread(
_write_file_synchronously, source, fname,
))
defer.returnValue(fname)
@contextlib.contextmanager
def store_into_file(self, file_info):
"""Context manager used to get a file like object to write into, as
described by file_info.
Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
like object that can be written to, fname is the absolute path of file
on disk, and finish_cb is a function that returns a Deferred.
fname can be used to read the contents from after upload, e.g. to
generate thumbnails.
finish_cb must be called and waited on after the file has been
successfully been written to. Should not be called if there was an
error.
Args:
file_info (FileInfo): Info about the file to store
Example:
with media_storage.store_into_file(info) as (f, fname, finish_cb):
# .. write into f ...
yield finish_cb()
"""
path = self._file_info_to_path(file_info)
fname = os.path.join(self.local_media_directory, path)
dirname = os.path.dirname(fname)
if not os.path.exists(dirname):
os.makedirs(dirname)
finished_called = [False]
@defer.inlineCallbacks
def finish():
for provider in self.storage_providers:
yield provider.store_file(path, file_info)
finished_called[0] = True
try:
with open(fname, "wb") as f:
yield f, fname, finish
except Exception:
t, v, tb = sys.exc_info()
try:
os.remove(fname)
except Exception:
pass
raise t, v, tb
if not finished_called:
raise Exception("Finished callback not called")
@defer.inlineCallbacks
def fetch_media(self, file_info):
"""Attempts to fetch media described by file_info from the local cache
and configured storage providers.
Args:
file_info (FileInfo)
Returns:
Deferred[Responder|None]: Returns a Responder if the file was found,
otherwise None.
"""
path = self._file_info_to_path(file_info)
local_path = os.path.join(self.local_media_directory, path)
if os.path.exists(local_path):
defer.returnValue(FileResponder(open(local_path, "rb")))
for provider in self.storage_providers:
res = yield provider.fetch(path, file_info)
if res:
defer.returnValue(res)
defer.returnValue(None)
def _file_info_to_path(self, file_info):
"""Converts file_info into a relative path.
The path is suitable for storing files under a directory, e.g. used to
store files on local FS under the base media repository directory.
Args:
file_info (FileInfo)
Returns:
str
"""
if file_info.url_cache:
if file_info.thumbnail:
return self.filepaths.url_cache_thumbnail_rel(
media_id=file_info.file_id,
width=file_info.thumbnail_width,
height=file_info.thumbnail_height,
content_type=file_info.thumbnail_type,
method=file_info.thumbnail_method,
)
return self.filepaths.url_cache_filepath_rel(file_info.file_id)
if file_info.server_name:
if file_info.thumbnail:
return self.filepaths.remote_media_thumbnail_rel(
server_name=file_info.server_name,
file_id=file_info.file_id,
width=file_info.thumbnail_width,
height=file_info.thumbnail_height,
content_type=file_info.thumbnail_type,
method=file_info.thumbnail_method
)
return self.filepaths.remote_media_filepath_rel(
file_info.server_name, file_info.file_id,
)
if file_info.thumbnail:
return self.filepaths.local_media_thumbnail_rel(
media_id=file_info.file_id,
width=file_info.thumbnail_width,
height=file_info.thumbnail_height,
content_type=file_info.thumbnail_type,
method=file_info.thumbnail_method
)
return self.filepaths.local_media_filepath_rel(
file_info.file_id,
)
def _write_file_synchronously(source, fname):
"""Write `source` to the path `fname` synchronously. Should be called
from a thread.
Args:
source: A file like object to be written
fname (str): Path to write to
"""
dirname = os.path.dirname(fname)
if not os.path.exists(dirname):
os.makedirs(dirname)
source.seek(0) # Ensure we read from the start of the file
with open(fname, "wb") as f:
shutil.copyfileobj(source, f)
class FileResponder(Responder):
"""Wraps an open file that can be sent to a request.
Args:
open_file (file): A file like object to be streamed ot the client,
is closed when finished streaming.
"""
def __init__(self, open_file):
self.open_file = open_file
@defer.inlineCallbacks
def write_to_consumer(self, consumer):
yield FileSender().beginFileTransfer(self.open_file, consumer)
def __exit__(self, exc_type, exc_val, exc_tb):
self.open_file.close()

View File

@ -17,6 +17,8 @@ from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer from twisted.internet import defer
from twisted.web.resource import Resource from twisted.web.resource import Resource
from ._base import FileInfo
from synapse.api.errors import ( from synapse.api.errors import (
SynapseError, Codes, SynapseError, Codes,
) )
@ -49,7 +51,7 @@ logger = logging.getLogger(__name__)
class PreviewUrlResource(Resource): class PreviewUrlResource(Resource):
isLeaf = True isLeaf = True
def __init__(self, hs, media_repo): def __init__(self, hs, media_repo, media_storage):
Resource.__init__(self) Resource.__init__(self)
self.auth = hs.get_auth() self.auth = hs.get_auth()
@ -62,6 +64,7 @@ class PreviewUrlResource(Resource):
self.client = SpiderHttpClient(hs) self.client = SpiderHttpClient(hs)
self.media_repo = media_repo self.media_repo = media_repo
self.primary_base_path = media_repo.primary_base_path self.primary_base_path = media_repo.primary_base_path
self.media_storage = media_storage
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
@ -182,8 +185,10 @@ class PreviewUrlResource(Resource):
logger.debug("got media_info of '%s'" % media_info) logger.debug("got media_info of '%s'" % media_info)
if _is_media(media_info['media_type']): if _is_media(media_info['media_type']):
file_id = media_info['filesystem_id']
dims = yield self.media_repo._generate_thumbnails( dims = yield self.media_repo._generate_thumbnails(
None, media_info['filesystem_id'], media_info, url_cache=True, None, file_id, file_id, media_info["media_type"],
url_cache=True,
) )
og = { og = {
@ -228,8 +233,10 @@ class PreviewUrlResource(Resource):
if _is_media(image_info['media_type']): if _is_media(image_info['media_type']):
# TODO: make sure we don't choke on white-on-transparent images # TODO: make sure we don't choke on white-on-transparent images
file_id = image_info['filesystem_id']
dims = yield self.media_repo._generate_thumbnails( dims = yield self.media_repo._generate_thumbnails(
None, image_info['filesystem_id'], image_info, url_cache=True, None, file_id, file_id, image_info["media_type"],
url_cache=True,
) )
if dims: if dims:
og["og:image:width"] = dims['width'] og["og:image:width"] = dims['width']
@ -273,19 +280,21 @@ class PreviewUrlResource(Resource):
file_id = datetime.date.today().isoformat() + '_' + random_string(16) file_id = datetime.date.today().isoformat() + '_' + random_string(16)
fpath = self.filepaths.url_cache_filepath_rel(file_id) file_info = FileInfo(
fname = os.path.join(self.primary_base_path, fpath) server_name=None,
self.media_repo._makedirs(fname) file_id=file_id,
url_cache=True,
)
try: try:
with open(fname, "wb") as f: with self.media_storage.store_into_file(file_info) as (f, fname, finish):
logger.debug("Trying to get url '%s'" % url) logger.debug("Trying to get url '%s'" % url)
length, headers, uri, code = yield self.client.get_file( length, headers, uri, code = yield self.client.get_file(
url, output_stream=f, max_size=self.max_spider_size, url, output_stream=f, max_size=self.max_spider_size,
) )
# FIXME: pass through 404s and other error messages nicely # FIXME: pass through 404s and other error messages nicely
yield self.media_repo.copy_to_backup(fpath) yield finish()
media_type = headers["Content-Type"][0] media_type = headers["Content-Type"][0]
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
@ -327,7 +336,6 @@ class PreviewUrlResource(Resource):
) )
except Exception as e: except Exception as e:
os.remove(fname)
raise SynapseError( raise SynapseError(
500, ("Failed to download content: %s" % e), 500, ("Failed to download content: %s" % e),
Codes.UNKNOWN Codes.UNKNOWN

View File

@ -0,0 +1,140 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer, threads
from .media_storage import FileResponder
from synapse.config._base import Config
from synapse.util.logcontext import preserve_fn
import logging
import os
import shutil
logger = logging.getLogger(__name__)
class StorageProvider(object):
"""A storage provider is a service that can store uploaded media and
retrieve them.
"""
def store_file(self, path, file_info):
"""Store the file described by file_info. The actual contents can be
retrieved by reading the file in file_info.upload_path.
Args:
path (str): Relative path of file in local cache
file_info (FileInfo)
Returns:
Deferred
"""
pass
def fetch(self, path, file_info):
"""Attempt to fetch the file described by file_info and stream it
into writer.
Args:
path (str): Relative path of file in local cache
file_info (FileInfo)
Returns:
Deferred(Responder): Returns a Responder if the provider has the file,
otherwise returns None.
"""
pass
class StorageProviderWrapper(StorageProvider):
"""Wraps a storage provider and provides various config options
Args:
backend (StorageProvider)
store_local (bool): Whether to store new local files or not.
store_synchronous (bool): Whether to wait for file to be successfully
uploaded, or todo the upload in the backgroud.
store_remote (bool): Whether remote media should be uploaded
"""
def __init__(self, backend, store_local, store_synchronous, store_remote):
self.backend = backend
self.store_local = store_local
self.store_synchronous = store_synchronous
self.store_remote = store_remote
def store_file(self, path, file_info):
if not file_info.server_name and not self.store_local:
return defer.succeed(None)
if file_info.server_name and not self.store_remote:
return defer.succeed(None)
if self.store_synchronous:
return self.backend.store_file(path, file_info)
else:
# TODO: Handle errors.
preserve_fn(self.backend.store_file)(path, file_info)
return defer.succeed(None)
def fetch(self, path, file_info):
return self.backend.fetch(path, file_info)
class FileStorageProviderBackend(StorageProvider):
"""A storage provider that stores files in a directory on a filesystem.
Args:
hs (HomeServer)
config: The config returned by `parse_config`.
"""
def __init__(self, hs, config):
self.cache_directory = hs.config.media_store_path
self.base_directory = config
def store_file(self, path, file_info):
"""See StorageProvider.store_file"""
primary_fname = os.path.join(self.cache_directory, path)
backup_fname = os.path.join(self.base_directory, path)
dirname = os.path.dirname(backup_fname)
if not os.path.exists(dirname):
os.makedirs(dirname)
return threads.deferToThread(
shutil.copyfile, primary_fname, backup_fname,
)
def fetch(self, path, file_info):
"""See StorageProvider.fetch"""
backup_fname = os.path.join(self.base_directory, path)
if os.path.isfile(backup_fname):
return FileResponder(open(backup_fname, "rb"))
@staticmethod
def parse_config(config):
"""Called on startup to parse config supplied. This should parse
the config and raise if there is a problem.
The returned value is passed into the constructor.
In this case we only care about a single param, the directory, so let's
just pull that out.
"""
return Config.ensure_directory(config["directory"])

View File

@ -14,7 +14,10 @@
# limitations under the License. # limitations under the License.
from ._base import parse_media_id, respond_404, respond_with_file from ._base import (
parse_media_id, respond_404, respond_with_file, FileInfo,
respond_with_responder,
)
from twisted.web.resource import Resource from twisted.web.resource import Resource
from synapse.http.servlet import parse_string, parse_integer from synapse.http.servlet import parse_string, parse_integer
from synapse.http.server import request_handler, set_cors_headers from synapse.http.server import request_handler, set_cors_headers
@ -30,12 +33,12 @@ logger = logging.getLogger(__name__)
class ThumbnailResource(Resource): class ThumbnailResource(Resource):
isLeaf = True isLeaf = True
def __init__(self, hs, media_repo): def __init__(self, hs, media_repo, media_storage):
Resource.__init__(self) Resource.__init__(self)
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.filepaths = media_repo.filepaths
self.media_repo = media_repo self.media_repo = media_repo
self.media_storage = media_storage
self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.server_name = hs.hostname self.server_name = hs.hostname
self.version_string = hs.version_string self.version_string = hs.version_string
@ -64,6 +67,7 @@ class ThumbnailResource(Resource):
yield self._respond_local_thumbnail( yield self._respond_local_thumbnail(
request, media_id, width, height, method, m_type request, media_id, width, height, method, m_type
) )
self.media_repo.mark_recently_accessed(None, media_id)
else: else:
if self.dynamic_thumbnails: if self.dynamic_thumbnails:
yield self._select_or_generate_remote_thumbnail( yield self._select_or_generate_remote_thumbnail(
@ -75,20 +79,20 @@ class ThumbnailResource(Resource):
request, server_name, media_id, request, server_name, media_id,
width, height, method, m_type width, height, method, m_type
) )
self.media_repo.mark_recently_accessed(server_name, media_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def _respond_local_thumbnail(self, request, media_id, width, height, def _respond_local_thumbnail(self, request, media_id, width, height,
method, m_type): method, m_type):
media_info = yield self.store.get_local_media(media_id) media_info = yield self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]: if not media_info:
respond_404(request)
return
if media_info["quarantined_by"]:
logger.info("Media is quarantined")
respond_404(request) respond_404(request)
return return
# if media_info["media_type"] == "image/svg+xml":
# file_path = self.filepaths.local_media_filepath(media_id)
# yield respond_with_file(request, media_info["media_type"], file_path)
# return
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id) thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
@ -96,42 +100,39 @@ class ThumbnailResource(Resource):
thumbnail_info = self._select_thumbnail( thumbnail_info = self._select_thumbnail(
width, height, method, m_type, thumbnail_infos width, height, method, m_type, thumbnail_infos
) )
t_width = thumbnail_info["thumbnail_width"]
t_height = thumbnail_info["thumbnail_height"]
t_type = thumbnail_info["thumbnail_type"]
t_method = thumbnail_info["thumbnail_method"]
if media_info["url_cache"]: file_info = FileInfo(
# TODO: Check the file still exists, if it doesn't we can redownload server_name=None, file_id=media_id,
# it from the url `media_info["url_cache"]` url_cache=media_info["url_cache"],
file_path = self.filepaths.url_cache_thumbnail( thumbnail=True,
media_id, t_width, t_height, t_type, t_method, thumbnail_width=thumbnail_info["thumbnail_width"],
) thumbnail_height=thumbnail_info["thumbnail_height"],
else: thumbnail_type=thumbnail_info["thumbnail_type"],
file_path = self.filepaths.local_media_thumbnail( thumbnail_method=thumbnail_info["thumbnail_method"],
media_id, t_width, t_height, t_type, t_method,
)
yield respond_with_file(request, t_type, file_path)
else:
yield self._respond_default_thumbnail(
request, media_info, width, height, method, m_type,
) )
t_type = file_info.thumbnail_type
t_length = thumbnail_info["thumbnail_length"]
responder = yield self.media_storage.fetch_media(file_info)
yield respond_with_responder(request, responder, t_type, t_length)
else:
logger.info("Couldn't find any generated thumbnails")
respond_404(request)
@defer.inlineCallbacks @defer.inlineCallbacks
def _select_or_generate_local_thumbnail(self, request, media_id, desired_width, def _select_or_generate_local_thumbnail(self, request, media_id, desired_width,
desired_height, desired_method, desired_height, desired_method,
desired_type): desired_type):
media_info = yield self.store.get_local_media(media_id) media_info = yield self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]: if not media_info:
respond_404(request)
return
if media_info["quarantined_by"]:
logger.info("Media is quarantined")
respond_404(request) respond_404(request)
return return
# if media_info["media_type"] == "image/svg+xml":
# file_path = self.filepaths.local_media_filepath(media_id)
# yield respond_with_file(request, media_info["media_type"], file_path)
# return
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id) thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
for info in thumbnail_infos: for info in thumbnail_infos:
@ -141,22 +142,25 @@ class ThumbnailResource(Resource):
t_type = info["thumbnail_type"] == desired_type t_type = info["thumbnail_type"] == desired_type
if t_w and t_h and t_method and t_type: if t_w and t_h and t_method and t_type:
if media_info["url_cache"]: file_info = FileInfo(
# TODO: Check the file still exists, if it doesn't we can redownload server_name=None, file_id=media_id,
# it from the url `media_info["url_cache"]` url_cache=media_info["url_cache"],
file_path = self.filepaths.url_cache_thumbnail( thumbnail=True,
media_id, desired_width, desired_height, desired_type, thumbnail_width=info["thumbnail_width"],
desired_method, thumbnail_height=info["thumbnail_height"],
) thumbnail_type=info["thumbnail_type"],
else: thumbnail_method=info["thumbnail_method"],
file_path = self.filepaths.local_media_thumbnail( )
media_id, desired_width, desired_height, desired_type,
desired_method,
)
yield respond_with_file(request, desired_type, file_path)
return
logger.debug("We don't have a local thumbnail of that size. Generating") t_type = file_info.thumbnail_type
t_length = info["thumbnail_length"]
responder = yield self.media_storage.fetch_media(file_info)
if responder:
yield respond_with_responder(request, responder, t_type, t_length)
return
logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one. # Okay, so we generate one.
file_path = yield self.media_repo.generate_local_exact_thumbnail( file_path = yield self.media_repo.generate_local_exact_thumbnail(
@ -166,21 +170,14 @@ class ThumbnailResource(Resource):
if file_path: if file_path:
yield respond_with_file(request, desired_type, file_path) yield respond_with_file(request, desired_type, file_path)
else: else:
yield self._respond_default_thumbnail( logger.warn("Failed to generate thumbnail")
request, media_info, desired_width, desired_height, respond_404(request)
desired_method, desired_type,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _select_or_generate_remote_thumbnail(self, request, server_name, media_id, def _select_or_generate_remote_thumbnail(self, request, server_name, media_id,
desired_width, desired_height, desired_width, desired_height,
desired_method, desired_type): desired_method, desired_type):
media_info = yield self.media_repo.get_remote_media(server_name, media_id) media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
# if media_info["media_type"] == "image/svg+xml":
# file_path = self.filepaths.remote_media_filepath(server_name, media_id)
# yield respond_with_file(request, media_info["media_type"], file_path)
# return
thumbnail_infos = yield self.store.get_remote_media_thumbnails( thumbnail_infos = yield self.store.get_remote_media_thumbnails(
server_name, media_id, server_name, media_id,
@ -195,14 +192,24 @@ class ThumbnailResource(Resource):
t_type = info["thumbnail_type"] == desired_type t_type = info["thumbnail_type"] == desired_type
if t_w and t_h and t_method and t_type: if t_w and t_h and t_method and t_type:
file_path = self.filepaths.remote_media_thumbnail( file_info = FileInfo(
server_name, file_id, desired_width, desired_height, server_name=server_name, file_id=media_info["filesystem_id"],
desired_type, desired_method, thumbnail=True,
thumbnail_width=info["thumbnail_width"],
thumbnail_height=info["thumbnail_height"],
thumbnail_type=info["thumbnail_type"],
thumbnail_method=info["thumbnail_method"],
) )
yield respond_with_file(request, desired_type, file_path)
return
logger.debug("We don't have a local thumbnail of that size. Generating") t_type = file_info.thumbnail_type
t_length = info["thumbnail_length"]
responder = yield self.media_storage.fetch_media(file_info)
if responder:
yield respond_with_responder(request, responder, t_type, t_length)
return
logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one. # Okay, so we generate one.
file_path = yield self.media_repo.generate_remote_exact_thumbnail( file_path = yield self.media_repo.generate_remote_exact_thumbnail(
@ -213,22 +220,16 @@ class ThumbnailResource(Resource):
if file_path: if file_path:
yield respond_with_file(request, desired_type, file_path) yield respond_with_file(request, desired_type, file_path)
else: else:
yield self._respond_default_thumbnail( logger.warn("Failed to generate thumbnail")
request, media_info, desired_width, desired_height, respond_404(request)
desired_method, desired_type,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _respond_remote_thumbnail(self, request, server_name, media_id, width, def _respond_remote_thumbnail(self, request, server_name, media_id, width,
height, method, m_type): height, method, m_type):
# TODO: Don't download the whole remote file # TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead. # We should proxy the thumbnail from the remote server instead of
media_info = yield self.media_repo.get_remote_media(server_name, media_id) # downloading the remote file and generating our own thumbnails.
media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
# if media_info["media_type"] == "image/svg+xml":
# file_path = self.filepaths.remote_media_filepath(server_name, media_id)
# yield respond_with_file(request, media_info["media_type"], file_path)
# return
thumbnail_infos = yield self.store.get_remote_media_thumbnails( thumbnail_infos = yield self.store.get_remote_media_thumbnails(
server_name, media_id, server_name, media_id,
@ -238,59 +239,23 @@ class ThumbnailResource(Resource):
thumbnail_info = self._select_thumbnail( thumbnail_info = self._select_thumbnail(
width, height, method, m_type, thumbnail_infos width, height, method, m_type, thumbnail_infos
) )
t_width = thumbnail_info["thumbnail_width"] file_info = FileInfo(
t_height = thumbnail_info["thumbnail_height"] server_name=server_name, file_id=media_info["filesystem_id"],
t_type = thumbnail_info["thumbnail_type"] thumbnail=True,
t_method = thumbnail_info["thumbnail_method"] thumbnail_width=thumbnail_info["thumbnail_width"],
file_id = thumbnail_info["filesystem_id"] thumbnail_height=thumbnail_info["thumbnail_height"],
thumbnail_type=thumbnail_info["thumbnail_type"],
thumbnail_method=thumbnail_info["thumbnail_method"],
)
t_type = file_info.thumbnail_type
t_length = thumbnail_info["thumbnail_length"] t_length = thumbnail_info["thumbnail_length"]
file_path = self.filepaths.remote_media_thumbnail( responder = yield self.media_storage.fetch_media(file_info)
server_name, file_id, t_width, t_height, t_type, t_method, yield respond_with_responder(request, responder, t_type, t_length)
)
yield respond_with_file(request, t_type, file_path, t_length)
else: else:
yield self._respond_default_thumbnail( logger.info("Failed to find any generated thumbnails")
request, media_info, width, height, method, m_type,
)
@defer.inlineCallbacks
def _respond_default_thumbnail(self, request, media_info, width, height,
method, m_type):
# XXX: how is this meant to work? store.get_default_thumbnails
# appears to always return [] so won't this always 404?
media_type = media_info["media_type"]
top_level_type = media_type.split("/")[0]
sub_type = media_type.split("/")[-1].split(";")[0]
thumbnail_infos = yield self.store.get_default_thumbnails(
top_level_type, sub_type,
)
if not thumbnail_infos:
thumbnail_infos = yield self.store.get_default_thumbnails(
top_level_type, "_default",
)
if not thumbnail_infos:
thumbnail_infos = yield self.store.get_default_thumbnails(
"_default", "_default",
)
if not thumbnail_infos:
respond_404(request) respond_404(request)
return
thumbnail_info = self._select_thumbnail(
width, height, "crop", m_type, thumbnail_infos
)
t_width = thumbnail_info["thumbnail_width"]
t_height = thumbnail_info["thumbnail_height"]
t_type = thumbnail_info["thumbnail_type"]
t_method = thumbnail_info["thumbnail_method"]
t_length = thumbnail_info["thumbnail_length"]
file_path = self.filepaths.default_thumbnail(
top_level_type, sub_type, t_width, t_height, t_type, t_method,
)
yield respond_with_file(request, t_type, file_path, t_length)
def _select_thumbnail(self, desired_width, desired_height, desired_method, def _select_thumbnail(self, desired_width, desired_height, desired_method,
desired_type, thumbnail_infos): desired_type, thumbnail_infos):

View File

@ -307,6 +307,23 @@ class HomeServer(object):
**self.db_config.get("args", {}) **self.db_config.get("args", {})
) )
def get_db_conn(self, run_new_connection=True):
"""Makes a new connection to the database, skipping the db pool
Returns:
Connection: a connection object implementing the PEP-249 spec
"""
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def build_media_repository_resource(self): def build_media_repository_resource(self):
# build the media repo resource. This indirects through the HomeServer # build the media repo resource. This indirects through the HomeServer
# to ensure that we only have a single instance of # to ensure that we only have a single instance of

View File

@ -146,8 +146,20 @@ class StateHandler(object):
defer.returnValue(state) defer.returnValue(state)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_state_ids(self, room_id, event_type=None, state_key="", def get_current_state_ids(self, room_id, latest_event_ids=None):
latest_event_ids=None): """Get the current state, or the state at a set of events, for a room
Args:
room_id (str):
latest_event_ids (iterable[str]|None): if given, the forward
extremities to resolve. If None, we look them up from the
database (via a cache)
Returns:
Deferred[dict[(str, str), str)]]: the state dict, mapping from
(event_type, state_key) -> event_id
"""
if not latest_event_ids: if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
@ -155,10 +167,6 @@ class StateHandler(object):
ret = yield self.resolve_state_groups(room_id, latest_event_ids) ret = yield self.resolve_state_groups(room_id, latest_event_ids)
state = ret.state state = ret.state
if event_type:
defer.returnValue(state.get((event_type, state_key)))
return
defer.returnValue(state) defer.returnValue(state)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -341,7 +349,7 @@ class StateHandler(object):
if conflicted_state: if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id) logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events( new_state = yield resolve_events_with_factory(
state_groups_ids.values(), state_groups_ids.values(),
state_map_factory=lambda ev_ids: self.store.get_events( state_map_factory=lambda ev_ids: self.store.get_events(
ev_ids, get_prev_content=False, check_redacted=False, ev_ids, get_prev_content=False, check_redacted=False,
@ -404,7 +412,7 @@ class StateHandler(object):
} }
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
new_state = resolve_events(state_set_ids, state_map) new_state = resolve_events_with_state_map(state_set_ids, state_map)
new_state = { new_state = {
key: state_map[ev_id] for key, ev_id in new_state.items() key: state_map[ev_id] for key, ev_id in new_state.items()
@ -420,19 +428,17 @@ def _ordered_events(events):
return sorted(events, key=key_func) return sorted(events, key=key_func)
def resolve_events(state_sets, state_map_factory): def resolve_events_with_state_map(state_sets, state_map):
""" """
Args: Args:
state_sets(list): List of dicts of (type, state_key) -> event_id, state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve. which are the different state groups to resolve.
state_map_factory(dict|callable): If callable, then will be called state_map(dict): a dict from event_id to event, for all events in
with a list of event_ids that are needed, and should return with state_sets.
a Deferred of dict of event_id to event. Otherwise, should be
a dict from event_id to event of all events in state_sets.
Returns Returns
dict[(str, str), synapse.events.FrozenEvent] is a map from dict[(str, str), synapse.events.FrozenEvent]:
(type, state_key) to event. a map from (type, state_key) to event.
""" """
if len(state_sets) == 1: if len(state_sets) == 1:
return state_sets[0] return state_sets[0]
@ -441,13 +447,6 @@ def resolve_events(state_sets, state_map_factory):
state_sets, state_sets,
) )
if callable(state_map_factory):
return _resolve_with_state_fac(
unconflicted_state, conflicted_state, state_map_factory
)
state_map = state_map_factory
auth_events = _create_auth_events_from_maps( auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map unconflicted_state, conflicted_state, state_map
) )
@ -491,8 +490,26 @@ def _seperate(state_sets):
@defer.inlineCallbacks @defer.inlineCallbacks
def _resolve_with_state_fac(unconflicted_state, conflicted_state, def resolve_events_with_factory(state_sets, state_map_factory):
state_map_factory): """
Args:
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
state_map_factory(func): will be called
with a list of event_ids that are needed, and should return with
a Deferred of dict of event_id to event.
Returns
Deferred[dict[(str, str), synapse.events.FrozenEvent]]:
a map from (type, state_key) to event.
"""
if len(state_sets) == 1:
defer.returnValue(state_sets[0])
unconflicted_state, conflicted_state = _seperate(
state_sets,
)
needed_events = set( needed_events = set(
event_id event_id
for event_ids in conflicted_state.itervalues() for event_ids in conflicted_state.itervalues()

View File

@ -291,33 +291,33 @@ class SQLBaseStore(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def runInteraction(self, desc, func, *args, **kwargs): def runInteraction(self, desc, func, *args, **kwargs):
"""Wraps the .runInteraction() method on the underlying db_pool.""" """Starts a transaction on the database and runs a given function
current_context = LoggingContext.current_context()
start_time = time.time() * 1000 Arguments:
desc (str): description of the transaction, for logging and metrics
func (func): callback function, which will be called with a
database transaction (twisted.enterprise.adbapi.Transaction) as
its first argument, followed by `args` and `kwargs`.
args (list): positional args to pass to `func`
kwargs (dict): named args to pass to `func`
Returns:
Deferred: The result of func
"""
current_context = LoggingContext.current_context()
after_callbacks = [] after_callbacks = []
final_callbacks = [] final_callbacks = []
def inner_func(conn, *args, **kwargs): def inner_func(conn, *args, **kwargs):
with LoggingContext("runInteraction") as context: return self._new_transaction(
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time) conn, desc, after_callbacks, final_callbacks, current_context,
func, *args, **kwargs
if self.database_engine.is_connection_closed(conn): )
logger.debug("Reconnecting closed database connection")
conn.reconnect()
current_context.copy_to(context)
return self._new_transaction(
conn, desc, after_callbacks, final_callbacks, current_context,
func, *args, **kwargs
)
try: try:
with PreserveLoggingContext(): result = yield self.runWithConnection(inner_func, *args, **kwargs)
result = yield self._db_pool.runWithConnection(
inner_func, *args, **kwargs
)
for after_callback, after_args, after_kwargs in after_callbacks: for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs) after_callback(*after_args, **after_kwargs)
@ -329,14 +329,27 @@ class SQLBaseStore(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def runWithConnection(self, func, *args, **kwargs): def runWithConnection(self, func, *args, **kwargs):
"""Wraps the .runInteraction() method on the underlying db_pool.""" """Wraps the .runWithConnection() method on the underlying db_pool.
Arguments:
func (func): callback function, which will be called with a
database connection (twisted.enterprise.adbapi.Connection) as
its first argument, followed by `args` and `kwargs`.
args (list): positional args to pass to `func`
kwargs (dict): named args to pass to `func`
Returns:
Deferred: The result of func
"""
current_context = LoggingContext.current_context() current_context = LoggingContext.current_context()
start_time = time.time() * 1000 start_time = time.time() * 1000
def inner_func(conn, *args, **kwargs): def inner_func(conn, *args, **kwargs):
with LoggingContext("runWithConnection") as context: with LoggingContext("runWithConnection") as context:
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time) sched_duration_ms = time.time() * 1000 - start_time
sql_scheduling_timer.inc_by(sched_duration_ms)
current_context.add_database_scheduled(sched_duration_ms)
if self.database_engine.is_connection_closed(conn): if self.database_engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection") logger.debug("Reconnecting closed database connection")

View File

@ -27,7 +27,7 @@ from synapse.util.logutils import log_function
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.state import resolve_events from synapse.state import resolve_events_with_factory
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
@ -110,7 +110,7 @@ class _EventPeristenceQueue(object):
end_item.events_and_contexts.extend(events_and_contexts) end_item.events_and_contexts.extend(events_and_contexts)
return end_item.deferred.observe() return end_item.deferred.observe()
deferred = ObservableDeferred(defer.Deferred()) deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
queue.append(self._EventPersistQueueItem( queue.append(self._EventPersistQueueItem(
events_and_contexts=events_and_contexts, events_and_contexts=events_and_contexts,
@ -146,18 +146,25 @@ class _EventPeristenceQueue(object):
try: try:
queue = self._get_drainining_queue(room_id) queue = self._get_drainining_queue(room_id)
for item in queue: for item in queue:
# handle_queue_loop runs in the sentinel logcontext, so
# there is no need to preserve_fn when running the
# callbacks on the deferred.
try: try:
ret = yield per_item_callback(item) ret = yield per_item_callback(item)
item.deferred.callback(ret) item.deferred.callback(ret)
except Exception as e: except Exception:
item.deferred.errback(e) item.deferred.errback()
finally: finally:
queue = self._event_persist_queues.pop(room_id, None) queue = self._event_persist_queues.pop(room_id, None)
if queue: if queue:
self._event_persist_queues[room_id] = queue self._event_persist_queues[room_id] = queue
self._currently_persisting_rooms.discard(room_id) self._currently_persisting_rooms.discard(room_id)
preserve_fn(handle_queue_loop)() # set handle_queue_loop off on the background. We don't want to
# attribute work done in it to the current request, so we drop the
# logcontext altogether.
with PreserveLoggingContext():
handle_queue_loop()
def _get_drainining_queue(self, room_id): def _get_drainining_queue(self, room_id):
queue = self._event_persist_queues.setdefault(room_id, deque()) queue = self._event_persist_queues.setdefault(room_id, deque())
@ -528,6 +535,12 @@ class EventsStore(SQLBaseStore):
# the events we have yet to persist, so we need a slightly more # the events we have yet to persist, so we need a slightly more
# complicated event lookup function than simply looking the events # complicated event lookup function than simply looking the events
# up in the db. # up in the db.
logger.info(
"Resolving state for %s with %i state sets",
room_id, len(state_sets),
)
events_map = {ev.event_id: ev for ev, _ in events_context} events_map = {ev.event_id: ev for ev, _ in events_context}
@defer.inlineCallbacks @defer.inlineCallbacks
@ -550,7 +563,7 @@ class EventsStore(SQLBaseStore):
to_return.update(evs) to_return.update(evs)
defer.returnValue(to_return) defer.returnValue(to_return)
current_state = yield resolve_events( current_state = yield resolve_events_with_factory(
state_sets, state_sets,
state_map_factory=get_events, state_map_factory=get_events,
) )

View File

@ -29,9 +29,6 @@ class MediaRepositoryStore(BackgroundUpdateStore):
where_clause='url_cache IS NOT NULL', where_clause='url_cache IS NOT NULL',
) )
def get_default_thumbnails(self, top_level_type, sub_type):
return []
def get_local_media(self, media_id): def get_local_media(self, media_id):
"""Get the metadata for a local piece of media """Get the metadata for a local piece of media
Returns: Returns:
@ -176,7 +173,14 @@ class MediaRepositoryStore(BackgroundUpdateStore):
desc="store_cached_remote_media", desc="store_cached_remote_media",
) )
def update_cached_last_access_time(self, origin_id_tuples, time_ts): def update_cached_last_access_time(self, local_media, remote_media, time_ms):
"""Updates the last access time of the given media
Args:
local_media (iterable[str]): Set of media_ids
remote_media (iterable[(str, str)]): Set of (server_name, media_id)
time_ms: Current time in milliseconds
"""
def update_cache_txn(txn): def update_cache_txn(txn):
sql = ( sql = (
"UPDATE remote_media_cache SET last_access_ts = ?" "UPDATE remote_media_cache SET last_access_ts = ?"
@ -184,8 +188,18 @@ class MediaRepositoryStore(BackgroundUpdateStore):
) )
txn.executemany(sql, ( txn.executemany(sql, (
(time_ts, media_origin, media_id) (time_ms, media_origin, media_id)
for media_origin, media_id in origin_id_tuples for media_origin, media_id in remote_media
))
sql = (
"UPDATE local_media_repository SET last_access_ts = ?"
" WHERE media_id = ?"
)
txn.executemany(sql, (
(time_ms, media_id)
for media_id in local_media
)) ))
return self.runInteraction("update_cached_last_access_time", update_cache_txn) return self.runInteraction("update_cached_last_access_time", update_cache_txn)

View File

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 46 SCHEMA_VERSION = 47
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View File

@ -591,7 +591,7 @@ class RoomStore(SQLBaseStore):
""" """
UPDATE remote_media_cache UPDATE remote_media_cache
SET quarantined_by = ? SET quarantined_by = ?
WHERE media_origin AND media_id = ? WHERE media_origin = ? AND media_id = ?
""", """,
( (
(quarantined_by, origin, media_id) (quarantined_by, origin, media_id)

View File

@ -0,0 +1,16 @@
/* Copyright 2018 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE local_media_repository ADD COLUMN last_access_ts BIGINT;

View File

@ -641,8 +641,12 @@ class UserDirectoryStore(SQLBaseStore):
""" """
if self.hs.config.user_directory_search_all_users: if self.hs.config.user_directory_search_all_users:
join_clause = "" # make s.user_id null to keep the ordering algorithm happy
where_clause = "?<>''" # naughty hack to keep the same number of binds join_clause = """
CROSS JOIN (SELECT NULL as user_id) AS s
"""
join_args = ()
where_clause = "1=1"
else: else:
join_clause = """ join_clause = """
LEFT JOIN users_in_public_rooms AS p USING (user_id) LEFT JOIN users_in_public_rooms AS p USING (user_id)
@ -651,6 +655,7 @@ class UserDirectoryStore(SQLBaseStore):
WHERE user_id = ? AND share_private WHERE user_id = ? AND share_private
) AS s USING (user_id) ) AS s USING (user_id)
""" """
join_args = (user_id,)
where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)" where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)"
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
@ -692,7 +697,7 @@ class UserDirectoryStore(SQLBaseStore):
avatar_url IS NULL avatar_url IS NULL
LIMIT ? LIMIT ?
""" % (join_clause, where_clause) """ % (join_clause, where_clause)
args = (user_id, full_query, exact_query, prefix_query, limit + 1,) args = join_args + (full_query, exact_query, prefix_query, limit + 1,)
elif isinstance(self.database_engine, Sqlite3Engine): elif isinstance(self.database_engine, Sqlite3Engine):
search_query = _parse_query_sqlite(search_term) search_query = _parse_query_sqlite(search_term)
@ -710,7 +715,7 @@ class UserDirectoryStore(SQLBaseStore):
avatar_url IS NULL avatar_url IS NULL
LIMIT ? LIMIT ?
""" % (join_clause, where_clause) """ % (join_clause, where_clause)
args = (user_id, search_query, limit + 1) args = join_args + (search_query, limit + 1)
else: else:
# This should be unreachable. # This should be unreachable.
raise Exception("Unrecognized database engine") raise Exception("Unrecognized database engine")

View File

@ -0,0 +1,139 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import threads, reactor
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
import Queue
class BackgroundFileConsumer(object):
"""A consumer that writes to a file like object. Supports both push
and pull producers
Args:
file_obj (file): The file like object to write to. Closed when
finished.
"""
# For PushProducers pause if we have this many unwritten slices
_PAUSE_ON_QUEUE_SIZE = 5
# And resume once the size of the queue is less than this
_RESUME_ON_QUEUE_SIZE = 2
def __init__(self, file_obj):
self._file_obj = file_obj
# Producer we're registered with
self._producer = None
# True if PushProducer, false if PullProducer
self.streaming = False
# For PushProducers, indicates whether we've paused the producer and
# need to call resumeProducing before we get more data.
self._paused_producer = False
# Queue of slices of bytes to be written. When producer calls
# unregister a final None is sent.
self._bytes_queue = Queue.Queue()
# Deferred that is resolved when finished writing
self._finished_deferred = None
# If the _writer thread throws an exception it gets stored here.
self._write_exception = None
def registerProducer(self, producer, streaming):
"""Part of IConsumer interface
Args:
producer (IProducer)
streaming (bool): True if push based producer, False if pull
based.
"""
if self._producer:
raise Exception("registerProducer called twice")
self._producer = producer
self.streaming = streaming
self._finished_deferred = preserve_fn(threads.deferToThread)(self._writer)
if not streaming:
self._producer.resumeProducing()
def unregisterProducer(self):
"""Part of IProducer interface
"""
self._producer = None
if not self._finished_deferred.called:
self._bytes_queue.put_nowait(None)
def write(self, bytes):
"""Part of IProducer interface
"""
if self._write_exception:
raise self._write_exception
if self._finished_deferred.called:
raise Exception("consumer has closed")
self._bytes_queue.put_nowait(bytes)
# If this is a PushProducer and the queue is getting behind
# then we pause the producer.
if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE:
self._paused_producer = True
self._producer.pauseProducing()
def _writer(self):
"""This is run in a background thread to write to the file.
"""
try:
while self._producer or not self._bytes_queue.empty():
# If we've paused the producer check if we should resume the
# producer.
if self._producer and self._paused_producer:
if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE:
reactor.callFromThread(self._resume_paused_producer)
bytes = self._bytes_queue.get()
# If we get a None (or empty list) then that's a signal used
# to indicate we should check if we should stop.
if bytes:
self._file_obj.write(bytes)
# If its a pull producer then we need to explicitly ask for
# more stuff.
if not self.streaming and self._producer:
reactor.callFromThread(self._producer.resumeProducing)
except Exception as e:
self._write_exception = e
raise
finally:
self._file_obj.close()
def wait(self):
"""Returns a deferred that resolves when finished writing to file
"""
return make_deferred_yieldable(self._finished_deferred)
def _resume_paused_producer(self):
"""Gets called if we should resume producing after being paused
"""
if self._paused_producer and self._producer:
self._paused_producer = False
self._producer.resumeProducing()

View File

@ -52,13 +52,17 @@ except Exception:
class LoggingContext(object): class LoggingContext(object):
"""Additional context for log formatting. Contexts are scoped within a """Additional context for log formatting. Contexts are scoped within a
"with" block. "with" block.
Args: Args:
name (str): Name for the context for debugging. name (str): Name for the context for debugging.
""" """
__slots__ = [ __slots__ = [
"previous_context", "name", "usage_start", "usage_end", "main_thread", "previous_context", "name", "ru_stime", "ru_utime",
"__dict__", "tag", "alive", "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
"usage_start", "usage_end",
"main_thread", "alive",
"request", "tag",
] ]
thread_local = threading.local() thread_local = threading.local()
@ -83,6 +87,9 @@ class LoggingContext(object):
def add_database_transaction(self, duration_ms): def add_database_transaction(self, duration_ms):
pass pass
def add_database_scheduled(self, sched_ms):
pass
def __nonzero__(self): def __nonzero__(self):
return False return False
@ -94,9 +101,17 @@ class LoggingContext(object):
self.ru_stime = 0. self.ru_stime = 0.
self.ru_utime = 0. self.ru_utime = 0.
self.db_txn_count = 0 self.db_txn_count = 0
self.db_txn_duration = 0.
# ms spent waiting for db txns, excluding scheduling time
self.db_txn_duration_ms = 0
# ms spent waiting for db txns to be scheduled
self.db_sched_duration_ms = 0
self.usage_start = None self.usage_start = None
self.usage_end = None
self.main_thread = threading.current_thread() self.main_thread = threading.current_thread()
self.request = None
self.tag = "" self.tag = ""
self.alive = True self.alive = True
@ -105,7 +120,11 @@ class LoggingContext(object):
@classmethod @classmethod
def current_context(cls): def current_context(cls):
"""Get the current logging context from thread local storage""" """Get the current logging context from thread local storage
Returns:
LoggingContext: the current logging context
"""
return getattr(cls.thread_local, "current_context", cls.sentinel) return getattr(cls.thread_local, "current_context", cls.sentinel)
@classmethod @classmethod
@ -155,11 +174,13 @@ class LoggingContext(object):
self.alive = False self.alive = False
def copy_to(self, record): def copy_to(self, record):
"""Copy fields from this context to the record""" """Copy logging fields from this context to a log record or
for key, value in self.__dict__.items(): another LoggingContext
setattr(record, key, value) """
record.ru_utime, record.ru_stime = self.get_resource_usage() # 'request' is the only field we currently use in the logger, so that's
# all we need to copy
record.request = self.request
def start(self): def start(self):
if threading.current_thread() is not self.main_thread: if threading.current_thread() is not self.main_thread:
@ -194,7 +215,16 @@ class LoggingContext(object):
def add_database_transaction(self, duration_ms): def add_database_transaction(self, duration_ms):
self.db_txn_count += 1 self.db_txn_count += 1
self.db_txn_duration += duration_ms / 1000. self.db_txn_duration_ms += duration_ms
def add_database_scheduled(self, sched_ms):
"""Record a use of the database pool
Args:
sched_ms (int): number of milliseconds it took us to get a
connection
"""
self.db_sched_duration_ms += sched_ms
class LoggingContextFilter(logging.Filter): class LoggingContextFilter(logging.Filter):

View File

@ -27,25 +27,62 @@ logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__) metrics = synapse.metrics.get_metrics_for(__name__)
block_timer = metrics.register_distribution( # total number of times we have hit this block
"block_timer", block_counter = metrics.register_counter(
labels=["block_name"] "block_count",
labels=["block_name"],
alternative_names=(
# the following are all deprecated aliases for the same metric
metrics.name_prefix + x for x in (
"_block_timer:count",
"_block_ru_utime:count",
"_block_ru_stime:count",
"_block_db_txn_count:count",
"_block_db_txn_duration:count",
)
)
) )
block_ru_utime = metrics.register_distribution( block_timer = metrics.register_counter(
"block_ru_utime", labels=["block_name"] "block_time_seconds",
labels=["block_name"],
alternative_names=(
metrics.name_prefix + "_block_timer:total",
),
) )
block_ru_stime = metrics.register_distribution( block_ru_utime = metrics.register_counter(
"block_ru_stime", labels=["block_name"] "block_ru_utime_seconds", labels=["block_name"],
alternative_names=(
metrics.name_prefix + "_block_ru_utime:total",
),
) )
block_db_txn_count = metrics.register_distribution( block_ru_stime = metrics.register_counter(
"block_db_txn_count", labels=["block_name"] "block_ru_stime_seconds", labels=["block_name"],
alternative_names=(
metrics.name_prefix + "_block_ru_stime:total",
),
) )
block_db_txn_duration = metrics.register_distribution( block_db_txn_count = metrics.register_counter(
"block_db_txn_duration", labels=["block_name"] "block_db_txn_count", labels=["block_name"],
alternative_names=(
metrics.name_prefix + "_block_db_txn_count:total",
),
)
# seconds spent waiting for db txns, excluding scheduling time, in this block
block_db_txn_duration = metrics.register_counter(
"block_db_txn_duration_seconds", labels=["block_name"],
alternative_names=(
metrics.name_prefix + "_block_db_txn_duration:total",
),
)
# seconds spent waiting for a db connection, in this block
block_db_sched_duration = metrics.register_counter(
"block_db_sched_duration_seconds", labels=["block_name"],
) )
@ -64,7 +101,9 @@ def measure_func(name):
class Measure(object): class Measure(object):
__slots__ = [ __slots__ = [
"clock", "name", "start_context", "start", "new_context", "ru_utime", "clock", "name", "start_context", "start", "new_context", "ru_utime",
"ru_stime", "db_txn_count", "db_txn_duration", "created_context" "ru_stime",
"db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
"created_context",
] ]
def __init__(self, clock, name): def __init__(self, clock, name):
@ -84,13 +123,16 @@ class Measure(object):
self.ru_utime, self.ru_stime = self.start_context.get_resource_usage() self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
self.db_txn_count = self.start_context.db_txn_count self.db_txn_count = self.start_context.db_txn_count
self.db_txn_duration = self.start_context.db_txn_duration self.db_txn_duration_ms = self.start_context.db_txn_duration_ms
self.db_sched_duration_ms = self.start_context.db_sched_duration_ms
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
if isinstance(exc_type, Exception) or not self.start_context: if isinstance(exc_type, Exception) or not self.start_context:
return return
duration = self.clock.time_msec() - self.start duration = self.clock.time_msec() - self.start
block_counter.inc(self.name)
block_timer.inc_by(duration, self.name) block_timer.inc_by(duration, self.name)
context = LoggingContext.current_context() context = LoggingContext.current_context()
@ -114,7 +156,12 @@ class Measure(object):
context.db_txn_count - self.db_txn_count, self.name context.db_txn_count - self.db_txn_count, self.name
) )
block_db_txn_duration.inc_by( block_db_txn_duration.inc_by(
context.db_txn_duration - self.db_txn_duration, self.name (context.db_txn_duration_ms - self.db_txn_duration_ms) / 1000.,
self.name
)
block_db_sched_duration.inc_by(
(context.db_sched_duration_ms - self.db_sched_duration_ms) / 1000.,
self.name
) )
if self.created_context: if self.created_context:

View File

@ -26,6 +26,18 @@ logger = logging.getLogger(__name__)
class NotRetryingDestination(Exception): class NotRetryingDestination(Exception):
def __init__(self, retry_last_ts, retry_interval, destination): def __init__(self, retry_last_ts, retry_interval, destination):
"""Raised by the limiter (and federation client) to indicate that we are
are deliberately not attempting to contact a given server.
Args:
retry_last_ts (int): the unix ts in milliseconds of our last attempt
to contact the server. 0 indicates that the last attempt was
successful or that we've never actually attempted to connect.
retry_interval (int): the time in milliseconds to wait until the next
attempt.
destination (str): the domain in question
"""
msg = "Not retrying server %s." % (destination,) msg = "Not retrying server %s." % (destination,)
super(NotRetryingDestination, self).__init__(msg) super(NotRetryingDestination, self).__init__(msg)

48
synapse/util/threepids.py Normal file
View File

@ -0,0 +1,48 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
logger = logging.getLogger(__name__)
def check_3pid_allowed(hs, medium, address):
"""Checks whether a given format of 3PID is allowed to be used on this HS
Args:
hs (synapse.server.HomeServer): server
medium (str): 3pid medium - e.g. email, msisdn
address (str): address within that medium (e.g. "wotan@matrix.org")
msisdns need to first have been canonicalised
Returns:
bool: whether the 3PID medium/address is allowed to be added to this HS
"""
if hs.config.allowed_local_3pids:
for constraint in hs.config.allowed_local_3pids:
logger.debug(
"Checking 3PID %s (%s) against %s (%s)",
address, medium, constraint['pattern'], constraint['medium'],
)
if (
medium == constraint['medium'] and
re.match(constraint['pattern'], address)
):
return True
else:
return True
return False

View File

@ -12,9 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os.path import os.path
import re
import shutil import shutil
import tempfile import tempfile
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from tests import unittest from tests import unittest
@ -23,7 +26,6 @@ class ConfigGenerationTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.dir = tempfile.mkdtemp() self.dir = tempfile.mkdtemp()
print self.dir
self.file = os.path.join(self.dir, "homeserver.yaml") self.file = os.path.join(self.dir, "homeserver.yaml")
def tearDown(self): def tearDown(self):
@ -48,3 +50,16 @@ class ConfigGenerationTestCase(unittest.TestCase):
]), ]),
set(os.listdir(self.dir)) set(os.listdir(self.dir))
) )
self.assert_log_filename_is(
os.path.join(self.dir, "lemurs.win.log.config"),
os.path.join(os.getcwd(), "homeserver.log"),
)
def assert_log_filename_is(self, log_config_file, expected):
with open(log_config_file) as f:
config = f.read()
# find the 'filename' line
matches = re.findall("^\s*filename:\s*(.*)$", config, re.M)
self.assertEqual(1, len(matches))
self.assertEqual(matches[0], expected)

View File

@ -68,7 +68,7 @@ class KeyringTestCase(unittest.TestCase):
def check_context(self, _, expected): def check_context(self, _, expected):
self.assertEquals( self.assertEquals(
getattr(LoggingContext.current_context(), "test_key", None), getattr(LoggingContext.current_context(), "request", None),
expected expected
) )
@ -82,7 +82,7 @@ class KeyringTestCase(unittest.TestCase):
lookup_2_deferred = defer.Deferred() lookup_2_deferred = defer.Deferred()
with LoggingContext("one") as context_one: with LoggingContext("one") as context_one:
context_one.test_key = "one" context_one.request = "one"
wait_1_deferred = kr.wait_for_previous_lookups( wait_1_deferred = kr.wait_for_previous_lookups(
["server1"], ["server1"],
@ -96,7 +96,7 @@ class KeyringTestCase(unittest.TestCase):
wait_1_deferred.addBoth(self.check_context, "one") wait_1_deferred.addBoth(self.check_context, "one")
with LoggingContext("two") as context_two: with LoggingContext("two") as context_two:
context_two.test_key = "two" context_two.request = "two"
# set off another wait. It should block because the first lookup # set off another wait. It should block because the first lookup
# hasn't yet completed. # hasn't yet completed.
@ -137,7 +137,7 @@ class KeyringTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_perspectives(**kwargs): def get_perspectives(**kwargs):
self.assertEquals( self.assertEquals(
LoggingContext.current_context().test_key, "11", LoggingContext.current_context().request, "11",
) )
with logcontext.PreserveLoggingContext(): with logcontext.PreserveLoggingContext():
yield persp_deferred yield persp_deferred
@ -145,7 +145,7 @@ class KeyringTestCase(unittest.TestCase):
self.http_client.post_json.side_effect = get_perspectives self.http_client.post_json.side_effect = get_perspectives
with LoggingContext("11") as context_11: with LoggingContext("11") as context_11:
context_11.test_key = "11" context_11.request = "11"
# start off a first set of lookups # start off a first set of lookups
res_deferreds = kr.verify_json_objects_for_server( res_deferreds = kr.verify_json_objects_for_server(
@ -167,13 +167,13 @@ class KeyringTestCase(unittest.TestCase):
# wait a tick for it to send the request to the perspectives server # wait a tick for it to send the request to the perspectives server
# (it first tries the datastore) # (it first tries the datastore)
yield async.sleep(0.005) yield async.sleep(1) # XXX find out why this takes so long!
self.http_client.post_json.assert_called_once() self.http_client.post_json.assert_called_once()
self.assertIs(LoggingContext.current_context(), context_11) self.assertIs(LoggingContext.current_context(), context_11)
context_12 = LoggingContext("12") context_12 = LoggingContext("12")
context_12.test_key = "12" context_12.request = "12"
with logcontext.PreserveLoggingContext(context_12): with logcontext.PreserveLoggingContext(context_12):
# a second request for a server with outstanding requests # a second request for a server with outstanding requests
# should block rather than start a second call # should block rather than start a second call
@ -183,7 +183,7 @@ class KeyringTestCase(unittest.TestCase):
res_deferreds_2 = kr.verify_json_objects_for_server( res_deferreds_2 = kr.verify_json_objects_for_server(
[("server10", json1)], [("server10", json1)],
) )
yield async.sleep(0.005) yield async.sleep(01)
self.http_client.post_json.assert_not_called() self.http_client.post_json.assert_not_called()
res_deferreds_2[0].addBoth(self.check_context, None) res_deferreds_2[0].addBoth(self.check_context, None)
@ -211,7 +211,7 @@ class KeyringTestCase(unittest.TestCase):
sentinel_context = LoggingContext.current_context() sentinel_context = LoggingContext.current_context()
with LoggingContext("one") as context_one: with LoggingContext("one") as context_one:
context_one.test_key = "one" context_one.request = "one"
defer = kr.verify_json_for_server("server9", {}) defer = kr.verify_json_for_server("server9", {})
try: try:

View File

@ -143,7 +143,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
except errors.SynapseError: except errors.SynapseError:
pass pass
@unittest.DEBUG
@defer.inlineCallbacks @defer.inlineCallbacks
def test_claim_one_time_key(self): def test_claim_one_time_key(self):
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname

View File

@ -15,6 +15,8 @@
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from tests import unittest from tests import unittest
import tempfile
from mock import Mock, NonCallableMock from mock import Mock, NonCallableMock
from tests.utils import setup_test_homeserver from tests.utils import setup_test_homeserver
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@ -41,7 +43,9 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
self.event_id = 0 self.event_id = 0
server_factory = ReplicationStreamProtocolFactory(self.hs) server_factory = ReplicationStreamProtocolFactory(self.hs)
listener = reactor.listenUNIX("\0xxx", server_factory) # XXX: mktemp is unsafe and should never be used. but we're just a test.
path = tempfile.mktemp(prefix="base_slaved_store_test_case_socket")
listener = reactor.listenUNIX(path, server_factory)
self.addCleanup(listener.stopListening) self.addCleanup(listener.stopListening)
self.streamer = server_factory.streamer self.streamer = server_factory.streamer
@ -49,7 +53,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
client_factory = ReplicationClientFactory( client_factory = ReplicationClientFactory(
self.hs, "client_name", self.replication_handler self.hs, "client_name", self.replication_handler
) )
client_connector = reactor.connectUNIX("\0xxx", client_factory) client_connector = reactor.connectUNIX(path, client_factory)
self.addCleanup(client_factory.stopTrying) self.addCleanup(client_factory.stopTrying)
self.addCleanup(client_connector.disconnect) self.addCleanup(client_connector.disconnect)

View File

@ -515,9 +515,6 @@ class RoomsCreateTestCase(RestTestCase):
synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
def tearDown(self):
pass
@defer.inlineCallbacks @defer.inlineCallbacks
def test_post_room_no_keys(self): def test_post_room_no_keys(self):
# POST with no config keys, expect new room id # POST with no config keys, expect new room id

View File

@ -49,6 +49,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.hs.get_auth_handler = Mock(return_value=self.auth_handler) self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
self.hs.get_device_handler = Mock(return_value=self.device_handler) self.hs.get_device_handler = Mock(return_value=self.device_handler)
self.hs.config.enable_registration = True self.hs.config.enable_registration = True
self.hs.config.registrations_require_3pid = []
self.hs.config.auto_join_rooms = [] self.hs.config.auto_join_rooms = []
# init the thing we're testing # init the thing we're testing

View File

@ -0,0 +1,88 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.storage import UserDirectoryStore
from synapse.storage.roommember import ProfileInfo
from tests import unittest
from tests.utils import setup_test_homeserver
ALICE = "@alice:a"
BOB = "@bob:b"
BOBBY = "@bobby:a"
class UserDirectoryStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hs = yield setup_test_homeserver()
self.store = UserDirectoryStore(None, self.hs)
# alice and bob are both in !room_id. bobby is not but shares
# a homeserver with alice.
yield self.store.add_profiles_to_user_dir(
"!room:id",
{
ALICE: ProfileInfo(None, "alice"),
BOB: ProfileInfo(None, "bob"),
BOBBY: ProfileInfo(None, "bobby")
},
)
yield self.store.add_users_to_public_room(
"!room:id",
[ALICE, BOB],
)
yield self.store.add_users_who_share_room(
"!room:id",
False,
(
(ALICE, BOB),
(BOB, ALICE),
),
)
@defer.inlineCallbacks
def test_search_user_dir(self):
# normally when alice searches the directory she should just find
# bob because bobby doesn't share a room with her.
r = yield self.store.search_user_dir(ALICE, "bob", 10)
self.assertFalse(r["limited"])
self.assertEqual(1, len(r["results"]))
self.assertDictEqual(r["results"][0], {
"user_id": BOB,
"display_name": "bob",
"avatar_url": None,
})
@defer.inlineCallbacks
def test_search_user_dir_all_users(self):
self.hs.config.user_directory_search_all_users = True
try:
r = yield self.store.search_user_dir(ALICE, "bob", 10)
self.assertFalse(r["limited"])
self.assertEqual(2, len(r["results"]))
self.assertDictEqual(r["results"][0], {
"user_id": BOB,
"display_name": "bob",
"avatar_url": None,
})
self.assertDictEqual(r["results"][1], {
"user_id": BOBBY,
"display_name": "bobby",
"avatar_url": None,
})
finally:
self.hs.config.user_directory_search_all_users = False

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import twisted
from twisted.trial import unittest from twisted.trial import unittest
import logging import logging
@ -65,6 +65,10 @@ class TestCase(unittest.TestCase):
@around(self) @around(self)
def setUp(orig): def setUp(orig):
# enable debugging of delayed calls - this means that we get a
# traceback when a unit test exits leaving things on the reactor.
twisted.internet.base.DelayedCall.debug = True
old_level = logging.getLogger().level old_level = logging.getLogger().level
if old_level != level: if old_level != level:

View File

@ -0,0 +1,176 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer, reactor
from mock import NonCallableMock
from synapse.util.file_consumer import BackgroundFileConsumer
from tests import unittest
from StringIO import StringIO
import threading
class FileConsumerTests(unittest.TestCase):
@defer.inlineCallbacks
def test_pull_consumer(self):
string_file = StringIO()
consumer = BackgroundFileConsumer(string_file)
try:
producer = DummyPullProducer()
yield producer.register_with_consumer(consumer)
yield producer.write_and_wait("Foo")
self.assertEqual(string_file.getvalue(), "Foo")
yield producer.write_and_wait("Bar")
self.assertEqual(string_file.getvalue(), "FooBar")
finally:
consumer.unregisterProducer()
yield consumer.wait()
self.assertTrue(string_file.closed)
@defer.inlineCallbacks
def test_push_consumer(self):
string_file = BlockingStringWrite()
consumer = BackgroundFileConsumer(string_file)
try:
producer = NonCallableMock(spec_set=[])
consumer.registerProducer(producer, True)
consumer.write("Foo")
yield string_file.wait_for_n_writes(1)
self.assertEqual(string_file.buffer, "Foo")
consumer.write("Bar")
yield string_file.wait_for_n_writes(2)
self.assertEqual(string_file.buffer, "FooBar")
finally:
consumer.unregisterProducer()
yield consumer.wait()
self.assertTrue(string_file.closed)
@defer.inlineCallbacks
def test_push_producer_feedback(self):
string_file = BlockingStringWrite()
consumer = BackgroundFileConsumer(string_file)
try:
producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])
resume_deferred = defer.Deferred()
producer.resumeProducing.side_effect = lambda: resume_deferred.callback(None)
consumer.registerProducer(producer, True)
number_writes = 0
with string_file.write_lock:
for _ in range(consumer._PAUSE_ON_QUEUE_SIZE):
consumer.write("Foo")
number_writes += 1
producer.pauseProducing.assert_called_once()
yield string_file.wait_for_n_writes(number_writes)
yield resume_deferred
producer.resumeProducing.assert_called_once()
finally:
consumer.unregisterProducer()
yield consumer.wait()
self.assertTrue(string_file.closed)
class DummyPullProducer(object):
def __init__(self):
self.consumer = None
self.deferred = defer.Deferred()
def resumeProducing(self):
d = self.deferred
self.deferred = defer.Deferred()
d.callback(None)
def write_and_wait(self, bytes):
d = self.deferred
self.consumer.write(bytes)
return d
def register_with_consumer(self, consumer):
d = self.deferred
self.consumer = consumer
self.consumer.registerProducer(self, False)
return d
class BlockingStringWrite(object):
def __init__(self):
self.buffer = ""
self.closed = False
self.write_lock = threading.Lock()
self._notify_write_deferred = None
self._number_of_writes = 0
def write(self, bytes):
with self.write_lock:
self.buffer += bytes
self._number_of_writes += 1
reactor.callFromThread(self._notify_write)
def close(self):
self.closed = True
def _notify_write(self):
"Called by write to indicate a write happened"
with self.write_lock:
if not self._notify_write_deferred:
return
d = self._notify_write_deferred
self._notify_write_deferred = None
d.callback(None)
@defer.inlineCallbacks
def wait_for_n_writes(self, n):
"Wait for n writes to have happened"
while True:
with self.write_lock:
if n <= self._number_of_writes:
return
if not self._notify_write_deferred:
self._notify_write_deferred = defer.Deferred()
d = self._notify_write_deferred
yield d

View File

@ -12,12 +12,12 @@ class LoggingContextTestCase(unittest.TestCase):
def _check_test_key(self, value): def _check_test_key(self, value):
self.assertEquals( self.assertEquals(
LoggingContext.current_context().test_key, value LoggingContext.current_context().request, value
) )
def test_with_context(self): def test_with_context(self):
with LoggingContext() as context_one: with LoggingContext() as context_one:
context_one.test_key = "test" context_one.request = "test"
self._check_test_key("test") self._check_test_key("test")
@defer.inlineCallbacks @defer.inlineCallbacks
@ -25,14 +25,14 @@ class LoggingContextTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def competing_callback(): def competing_callback():
with LoggingContext() as competing_context: with LoggingContext() as competing_context:
competing_context.test_key = "competing" competing_context.request = "competing"
yield sleep(0) yield sleep(0)
self._check_test_key("competing") self._check_test_key("competing")
reactor.callLater(0, competing_callback) reactor.callLater(0, competing_callback)
with LoggingContext() as context_one: with LoggingContext() as context_one:
context_one.test_key = "one" context_one.request = "one"
yield sleep(0) yield sleep(0)
self._check_test_key("one") self._check_test_key("one")
@ -43,14 +43,14 @@ class LoggingContextTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def cb(): def cb():
context_one.test_key = "one" context_one.request = "one"
yield function() yield function()
self._check_test_key("one") self._check_test_key("one")
callback_completed[0] = True callback_completed[0] = True
with LoggingContext() as context_one: with LoggingContext() as context_one:
context_one.test_key = "one" context_one.request = "one"
# fire off function, but don't wait on it. # fire off function, but don't wait on it.
logcontext.preserve_fn(cb)() logcontext.preserve_fn(cb)()
@ -107,7 +107,7 @@ class LoggingContextTestCase(unittest.TestCase):
sentinel_context = LoggingContext.current_context() sentinel_context = LoggingContext.current_context()
with LoggingContext() as context_one: with LoggingContext() as context_one:
context_one.test_key = "one" context_one.request = "one"
d1 = logcontext.make_deferred_yieldable(blocking_function()) d1 = logcontext.make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable # make sure that the context was reset by make_deferred_yieldable
@ -124,7 +124,7 @@ class LoggingContextTestCase(unittest.TestCase):
argument isn't actually a deferred""" argument isn't actually a deferred"""
with LoggingContext() as context_one: with LoggingContext() as context_one:
context_one.test_key = "one" context_one.request = "one"
d1 = logcontext.make_deferred_yieldable("bum") d1 = logcontext.make_deferred_yieldable("bum")
self._check_test_key("one") self._check_test_key("one")

View File

@ -13,27 +13,28 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.http.server import HttpServer
from synapse.api.errors import cs_error, CodeMessageException, StoreError
from synapse.api.constants import EventTypes
from synapse.storage.prepare_database import prepare_database
from synapse.storage.engines import create_engine
from synapse.server import HomeServer
from synapse.federation.transport import server
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.logcontext import LoggingContext
from twisted.internet import defer, reactor
from twisted.enterprise.adbapi import ConnectionPool
from collections import namedtuple
from mock import patch, Mock
import hashlib import hashlib
from inspect import getcallargs
import urllib import urllib
import urlparse import urlparse
from inspect import getcallargs from mock import Mock, patch
from twisted.internet import defer, reactor
from synapse.api.errors import CodeMessageException, cs_error
from synapse.federation.transport import server
from synapse.http.server import HttpServer
from synapse.server import HomeServer
from synapse.storage import PostgresEngine
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
from synapse.util.logcontext import LoggingContext
from synapse.util.ratelimitutils import FederationRateLimiter
# set this to True to run the tests against postgres instead of sqlite.
# It requires you to have a local postgres database called synapse_test, within
# which ALL TABLES WILL BE DROPPED
USE_POSTGRES_FOR_TESTS = False
@defer.inlineCallbacks @defer.inlineCallbacks
@ -57,32 +58,70 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config.worker_app = None config.worker_app = None
config.email_enable_notifs = False config.email_enable_notifs = False
config.block_non_admin_invites = False config.block_non_admin_invites = False
config.federation_domain_whitelist = None
config.user_directory_search_all_users = False
# disable user directory updates, because they get done in the
# background, which upsets the test runner.
config.update_user_directory = False
config.use_frozen_dicts = True config.use_frozen_dicts = True
config.database_config = {"name": "sqlite3"}
config.ldap_enabled = False config.ldap_enabled = False
if "clock" not in kargs: if "clock" not in kargs:
kargs["clock"] = MockClock() kargs["clock"] = MockClock()
if USE_POSTGRES_FOR_TESTS:
config.database_config = {
"name": "psycopg2",
"args": {
"database": "synapse_test",
"cp_min": 1,
"cp_max": 5,
},
}
else:
config.database_config = {
"name": "sqlite3",
"args": {
"database": ":memory:",
"cp_min": 1,
"cp_max": 1,
},
}
db_engine = create_engine(config.database_config)
# we need to configure the connection pool to run the on_new_connection
# function, so that we can test code that uses custom sqlite functions
# (like rank).
config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
if datastore is None: if datastore is None:
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer( hs = HomeServer(
name, db_pool=db_pool, config=config, name, config=config,
db_config=config.database_config,
version_string="Synapse/tests", version_string="Synapse/tests",
database_engine=create_engine(config.database_config), database_engine=db_engine,
get_db_conn=db_pool.get_db_conn,
room_list_handler=object(), room_list_handler=object(),
tls_server_context_factory=Mock(), tls_server_context_factory=Mock(),
**kargs **kargs
) )
db_conn = hs.get_db_conn()
# make sure that the database is empty
if isinstance(db_engine, PostgresEngine):
cur = db_conn.cursor()
cur.execute("SELECT tablename FROM pg_tables where schemaname='public'")
rows = cur.fetchall()
for r in rows:
cur.execute("DROP TABLE %s CASCADE" % r[0])
yield prepare_database(db_conn, db_engine, config)
hs.setup() hs.setup()
else: else:
hs = HomeServer( hs = HomeServer(
name, db_pool=None, datastore=datastore, config=config, name, db_pool=None, datastore=datastore, config=config,
version_string="Synapse/tests", version_string="Synapse/tests",
database_engine=create_engine(config.database_config), database_engine=db_engine,
room_list_handler=object(), room_list_handler=object(),
tls_server_context_factory=Mock(), tls_server_context_factory=Mock(),
**kargs **kargs
@ -301,168 +340,6 @@ class MockClock(object):
return d return d
class SQLiteMemoryDbPool(ConnectionPool, object):
def __init__(self):
super(SQLiteMemoryDbPool, self).__init__(
"sqlite3", ":memory:",
cp_min=1,
cp_max=1,
)
self.config = Mock()
self.config.password_providers = []
self.config.database_config = {"name": "sqlite3"}
def prepare(self):
engine = self.create_engine()
return self.runWithConnection(
lambda conn: prepare_database(conn, engine, self.config)
)
def get_db_conn(self):
conn = self.connect()
engine = self.create_engine()
prepare_database(conn, engine, self.config)
return conn
def create_engine(self):
return create_engine(self.config.database_config)
class MemoryDataStore(object):
Room = namedtuple(
"Room",
["room_id", "is_public", "creator"]
)
def __init__(self):
self.tokens_to_users = {}
self.paths_to_content = {}
self.members = {}
self.rooms = {}
self.current_state = {}
self.events = []
class Snapshot(namedtuple("Snapshot", "room_id user_id membership_state")):
def fill_out_prev_events(self, event):
pass
def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
return self.Snapshot(
room_id, user_id, self.get_room_member(user_id, room_id)
)
def register(self, user_id, token, password_hash):
if user_id in self.tokens_to_users.values():
raise StoreError(400, "User in use.")
self.tokens_to_users[token] = user_id
def get_user_by_access_token(self, token):
try:
return {
"name": self.tokens_to_users[token],
}
except Exception:
raise StoreError(400, "User does not exist.")
def get_room(self, room_id):
try:
return self.rooms[room_id]
except Exception:
return None
def store_room(self, room_id, room_creator_user_id, is_public):
if room_id in self.rooms:
raise StoreError(409, "Conflicting room!")
room = MemoryDataStore.Room(
room_id=room_id,
is_public=is_public,
creator=room_creator_user_id
)
self.rooms[room_id] = room
def get_room_member(self, user_id, room_id):
return self.members.get(room_id, {}).get(user_id)
def get_room_members(self, room_id, membership=None):
if membership:
return [
v for k, v in self.members.get(room_id, {}).items()
if v.membership == membership
]
else:
return self.members.get(room_id, {}).values()
def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
return [
m[user_id] for m in self.members.values()
if user_id in m and m[user_id].membership in membership_list
]
def get_room_events_stream(self, user_id=None, from_key=None, to_key=None,
limit=0, with_feedback=False):
return ([], from_key) # TODO
def get_joined_hosts_for_room(self, room_id):
return defer.succeed([])
def persist_event(self, event):
if event.type == EventTypes.Member:
room_id = event.room_id
user = event.state_key
self.members.setdefault(room_id, {})[user] = event
if hasattr(event, "state_key"):
key = (event.room_id, event.type, event.state_key)
self.current_state[key] = event
self.events.append(event)
def get_current_state(self, room_id, event_type=None, state_key=""):
if event_type:
key = (room_id, event_type, state_key)
if self.current_state.get(key):
return [self.current_state.get(key)]
return None
else:
return [
e for e in self.current_state
if e[0] == room_id
]
def set_presence_state(self, user_localpart, state):
return defer.succeed({"state": 0})
def get_presence_list(self, user_localpart, accepted):
return []
def get_room_events_max_id(self):
return "s0" # TODO (erikj)
def get_send_event_level(self, room_id):
return defer.succeed(0)
def get_power_level(self, room_id, user_id):
return defer.succeed(0)
def get_add_state_level(self, room_id):
return defer.succeed(0)
def get_room_join_rule(self, room_id):
# TODO (erikj): This should be configurable
return defer.succeed("invite")
def get_ops_levels(self, room_id):
return defer.succeed((5, 5, 5))
def insert_client_ip(self, user, access_token, ip, user_agent):
return defer.succeed(None)
def _format_call(args, kwargs): def _format_call(args, kwargs):
return ", ".join( return ", ".join(
["%r" % (a) for a in args] + ["%r" % (a) for a in args] +