Merge branch 'release-v0.27.0' of https://github.com/matrix-org/synapse

This commit is contained in:
Neil Johnson 2018-03-26 14:51:11 +01:00
commit aa3587fdd1
177 changed files with 8690 additions and 4366 deletions

2
.gitignore vendored
View File

@ -46,3 +46,5 @@ static/client/register/register_config.js
env/ env/
*.config *.config
.vscode/

View File

@ -1,10 +1,75 @@
Changes in synapse v0.27.0 (2018-03-26)
=======================================
No changes since v0.27.0-rc2
Changes in synapse v0.27.0-rc2 (2018-03-19)
===========================================
Pulls in v0.26.1
Bug fixes:
* Fix bug introduced in v0.27.0-rc1 that causes much increased memory usage in state cache (PR #3005)
Changes in synapse v0.26.1 (2018-03-15) Changes in synapse v0.26.1 (2018-03-15)
======================================= =======================================
Bug fixes: Bug fixes:
* Fix bug where an invalid event caused server to stop functioning correctly, * Fix bug where an invalid event caused server to stop functioning correctly,
due to parsing and serializing bugs in ujson library. due to parsing and serializing bugs in ujson library (PR #3008)
Changes in synapse v0.27.0-rc1 (2018-03-14)
===========================================
The common case for running Synapse is not to run separate workers, but for those that do, be aware that synctl no longer starts the main synapse when using ``-a`` option with workers. A new worker file should be added with ``worker_app: synapse.app.homeserver``.
This release also begins the process of renaming a number of the metrics
reported to prometheus. See `docs/metrics-howto.rst <docs/metrics-howto.rst#block-and-response-metrics-renamed-for-0-27-0>`_.
Note that the v0.28.0 release will remove the deprecated metric names.
Features:
* Add ability for ASes to override message send time (PR #2754)
* Add support for custom storage providers for media repository (PR #2867, #2777, #2783, #2789, #2791, #2804, #2812, #2814, #2857, #2868, #2767)
* Add purge API features, see `docs/admin_api/purge_history_api.rst <docs/admin_api/purge_history_api.rst>`_ for full details (PR #2858, #2867, #2882, #2946, #2962, #2943)
* Add support for whitelisting 3PIDs that users can register. (PR #2813)
* Add ``/room/{id}/event/{id}`` API (PR #2766)
* Add an admin API to get all the media in a room (PR #2818) Thanks to @turt2live!
* Add ``federation_domain_whitelist`` option (PR #2820, #2821)
Changes:
* Continue to factor out processing from main process and into worker processes. See updated `docs/workers.rst <docs/metrics-howto.rst>`_ (PR #2892 - #2904, #2913, #2920 - #2926, #2947, #2847, #2854, #2872, #2873, #2874, #2928, #2929, #2934, #2856, #2976 - #2984, #2987 - #2989, #2991 - #2993, #2995, #2784)
* Ensure state cache is used when persisting events (PR #2864, #2871, #2802, #2835, #2836, #2841, #2842, #2849)
* Change the default config to bind on both IPv4 and IPv6 on all platforms (PR #2435) Thanks to @silkeh!
* No longer require a specific version of saml2 (PR #2695) Thanks to @okurz!
* Remove ``verbosity``/``log_file`` from generated config (PR #2755)
* Add and improve metrics and logging (PR #2770, #2778, #2785, #2786, #2787, #2793, #2794, #2795, #2809, #2810, #2833, #2834, #2844, #2965, #2927, #2975, #2790, #2796, #2838)
* When using synctl with workers, don't start the main synapse automatically (PR #2774)
* Minor performance improvements (PR #2773, #2792)
* Use a connection pool for non-federation outbound connections (PR #2817)
* Make it possible to run unit tests against postgres (PR #2829)
* Update pynacl dependency to 1.2.1 or higher (PR #2888) Thanks to @bachp!
* Remove ability for AS users to call /events and /sync (PR #2948)
* Use bcrypt.checkpw (PR #2949) Thanks to @krombel!
Bug fixes:
* Fix broken ``ldap_config`` config option (PR #2683) Thanks to @seckrv!
* Fix error message when user is not allowed to unban (PR #2761) Thanks to @turt2live!
* Fix publicised groups GET API (singular) over federation (PR #2772)
* Fix user directory when using ``user_directory_search_all_users`` config option (PR #2803, #2831)
* Fix error on ``/publicRooms`` when no rooms exist (PR #2827)
* Fix bug in quarantine_media (PR #2837)
* Fix url_previews when no Content-Type is returned from URL (PR #2845)
* Fix rare race in sync API when joining room (PR #2944)
* Fix slow event search, switch back from GIST to GIN indexes (PR #2769, #2848)
Changes in synapse v0.26.0 (2018-01-05) Changes in synapse v0.26.0 (2018-01-05)

View File

@ -30,8 +30,12 @@ use github's pull request workflow to review the contribution, and either ask
you to make any refinements needed or merge it and make them ourselves. The you to make any refinements needed or merge it and make them ourselves. The
changes will then land on master when we next do a release. changes will then land on master when we next do a release.
We use Jenkins for continuous integration (http://matrix.org/jenkins), and We use `Jenkins <http://matrix.org/jenkins>`_ and
typically all pull requests get automatically tested Jenkins: if your change breaks the build, Jenkins will yell about it in #matrix-dev:matrix.org so please lurk there and keep an eye open. `Travis <https://travis-ci.org/matrix-org/synapse>`_ for continuous
integration. All pull requests to synapse get automatically tested by Travis;
the Jenkins builds require an adminstrator to start them. If your change
breaks the build, this will be shown in github, so please keep an eye on the
pull request for feedback.
Code style Code style
~~~~~~~~~~ ~~~~~~~~~~

View File

@ -0,0 +1,23 @@
# List all media in a room
This API gets a list of known media in a room.
The API is:
```
GET /_matrix/client/r0/admin/room/<room_id>/media
```
including an `access_token` of a server admin.
It returns a JSON body like the following:
```
{
"local": [
"mxc://localhost/xwvutsrqponmlkjihgfedcba",
"mxc://localhost/abcdefghijklmnopqrstuvwx"
],
"remote": [
"mxc://matrix.org/xwvutsrqponmlkjihgfedcba",
"mxc://matrix.org/abcdefghijklmnopqrstuvwx"
]
}
```

View File

@ -4,14 +4,60 @@ Purge History API
The purge history API allows server admins to purge historic events from their The purge history API allows server admins to purge historic events from their
database, reclaiming disk space. database, reclaiming disk space.
**NB!** This will not delete local events (locally sent messages content etc) from the database, but will remove lots of the metadata about them and does dramatically reduce the on disk space usage
Depending on the amount of history being purged a call to the API may take Depending on the amount of history being purged a call to the API may take
several minutes or longer. During this period users will not be able to several minutes or longer. During this period users will not be able to
paginate further back in the room from the point being purged from. paginate further back in the room from the point being purged from.
The API is simply: The API is:
``POST /_matrix/client/r0/admin/purge_history/<room_id>/<event_id>`` ``POST /_matrix/client/r0/admin/purge_history/<room_id>[/<event_id>]``
including an ``access_token`` of a server admin. including an ``access_token`` of a server admin.
By default, events sent by local users are not deleted, as they may represent
the only copies of this content in existence. (Events sent by remote users are
deleted.)
Room state data (such as joins, leaves, topic) is always preserved.
To delete local message events as well, set ``delete_local_events`` in the body:
.. code:: json
{
"delete_local_events": true
}
The caller must specify the point in the room to purge up to. This can be
specified by including an event_id in the URI, or by setting a
``purge_up_to_event_id`` or ``purge_up_to_ts`` in the request body. If an event
id is given, that event (and others at the same graph depth) will be retained.
If ``purge_up_to_ts`` is given, it should be a timestamp since the unix epoch,
in milliseconds.
The API starts the purge running, and returns immediately with a JSON body with
a purge id:
.. code:: json
{
"purge_id": "<opaque id>"
}
Purge status query
------------------
It is possible to poll for updates on recent purges with a second API;
``GET /_matrix/client/r0/admin/purge_history_status/<purge_id>``
(again, with a suitable ``access_token``). This API returns a JSON body like
the following:
.. code:: json
{
"status": "active"
}
The status will be one of ``active``, ``complete``, or ``failed``.

View File

@ -279,9 +279,9 @@ Obviously that option means that the operations done in
that might be fixed by setting a different logcontext via a ``with that might be fixed by setting a different logcontext via a ``with
LoggingContext(...)`` in ``background_operation``). LoggingContext(...)`` in ``background_operation``).
The second option is to use ``logcontext.preserve_fn``, which wraps a function The second option is to use ``logcontext.run_in_background``, which wraps a
so that it doesn't reset the logcontext even when it returns an incomplete function so that it doesn't reset the logcontext even when it returns an
deferred, and adds a callback to the returned deferred to reset the incomplete deferred, and adds a callback to the returned deferred to reset the
logcontext. In other words, it turns a function that follows the Synapse rules logcontext. In other words, it turns a function that follows the Synapse rules
about logcontexts and Deferreds into one which behaves more like an external about logcontexts and Deferreds into one which behaves more like an external
function — the opposite operation to that described in the previous section. function — the opposite operation to that described in the previous section.
@ -293,7 +293,7 @@ It can be used like this:
def do_request_handling(): def do_request_handling():
yield foreground_operation() yield foreground_operation()
logcontext.preserve_fn(background_operation)() logcontext.run_in_background(background_operation)
# this will now be logged against the request context # this will now be logged against the request context
logger.debug("Request handling complete") logger.debug("Request handling complete")

View File

@ -33,6 +33,53 @@ How to monitor Synapse metrics using Prometheus
Restart prometheus. Restart prometheus.
Block and response metrics renamed for 0.27.0
---------------------------------------------
Synapse 0.27.0 begins the process of rationalising the duplicate ``*:count``
metrics reported for the resource tracking for code blocks and HTTP requests.
At the same time, the corresponding ``*:total`` metrics are being renamed, as
the ``:total`` suffix no longer makes sense in the absence of a corresponding
``:count`` metric.
To enable a graceful migration path, this release just adds new names for the
metrics being renamed. A future release will remove the old ones.
The following table shows the new metrics, and the old metrics which they are
replacing.
==================================================== ===================================================
New name Old name
==================================================== ===================================================
synapse_util_metrics_block_count synapse_util_metrics_block_timer:count
synapse_util_metrics_block_count synapse_util_metrics_block_ru_utime:count
synapse_util_metrics_block_count synapse_util_metrics_block_ru_stime:count
synapse_util_metrics_block_count synapse_util_metrics_block_db_txn_count:count
synapse_util_metrics_block_count synapse_util_metrics_block_db_txn_duration:count
synapse_util_metrics_block_time_seconds synapse_util_metrics_block_timer:total
synapse_util_metrics_block_ru_utime_seconds synapse_util_metrics_block_ru_utime:total
synapse_util_metrics_block_ru_stime_seconds synapse_util_metrics_block_ru_stime:total
synapse_util_metrics_block_db_txn_count synapse_util_metrics_block_db_txn_count:total
synapse_util_metrics_block_db_txn_duration_seconds synapse_util_metrics_block_db_txn_duration:total
synapse_http_server_response_count synapse_http_server_requests
synapse_http_server_response_count synapse_http_server_response_time:count
synapse_http_server_response_count synapse_http_server_response_ru_utime:count
synapse_http_server_response_count synapse_http_server_response_ru_stime:count
synapse_http_server_response_count synapse_http_server_response_db_txn_count:count
synapse_http_server_response_count synapse_http_server_response_db_txn_duration:count
synapse_http_server_response_time_seconds synapse_http_server_response_time:total
synapse_http_server_response_ru_utime_seconds synapse_http_server_response_ru_utime:total
synapse_http_server_response_ru_stime_seconds synapse_http_server_response_ru_stime:total
synapse_http_server_response_db_txn_count synapse_http_server_response_db_txn_count:total
synapse_http_server_response_db_txn_duration_seconds synapse_http_server_response_db_txn_duration:total
==================================================== ===================================================
Standard Metric Names Standard Metric Names
--------------------- ---------------------
@ -42,7 +89,7 @@ have been changed to seconds, from miliseconds.
================================== ============================= ================================== =============================
New name Old name New name Old name
---------------------------------- ----------------------------- ================================== =============================
process_cpu_user_seconds_total process_resource_utime / 1000 process_cpu_user_seconds_total process_resource_utime / 1000
process_cpu_system_seconds_total process_resource_stime / 1000 process_cpu_system_seconds_total process_resource_stime / 1000
process_open_fds (no 'type' label) process_fds process_open_fds (no 'type' label) process_fds
@ -52,7 +99,7 @@ The python-specific counts of garbage collector performance have been renamed.
=========================== ====================== =========================== ======================
New name Old name New name Old name
--------------------------- ---------------------- =========================== ======================
python_gc_time reactor_gc_time python_gc_time reactor_gc_time
python_gc_unreachable_total reactor_gc_unreachable python_gc_unreachable_total reactor_gc_unreachable
python_gc_counts reactor_gc_counts python_gc_counts reactor_gc_counts
@ -62,7 +109,7 @@ The twisted-specific reactor metrics have been renamed.
==================================== ===================== ==================================== =====================
New name Old name New name Old name
------------------------------------ --------------------- ==================================== =====================
python_twisted_reactor_pending_calls reactor_pending_calls python_twisted_reactor_pending_calls reactor_pending_calls
python_twisted_reactor_tick_time reactor_tick_time python_twisted_reactor_tick_time reactor_tick_time
==================================== ===================== ==================================== =====================

View File

@ -30,17 +30,29 @@ requests made to the federation port. The caveats regarding running a
reverse-proxy on the federation port still apply (see reverse-proxy on the federation port still apply (see
https://github.com/matrix-org/synapse/blob/master/README.rst#reverse-proxying-the-federation-port). https://github.com/matrix-org/synapse/blob/master/README.rst#reverse-proxying-the-federation-port).
To enable workers, you need to add a replication listener to the master synapse, e.g.:: To enable workers, you need to add two replication listeners to the master
synapse, e.g.::
listeners: listeners:
# The TCP replication port
- port: 9092 - port: 9092
bind_address: '127.0.0.1' bind_address: '127.0.0.1'
type: replication type: replication
# The HTTP replication port
- port: 9093
bind_address: '127.0.0.1'
type: http
resources:
- names: [replication]
Under **no circumstances** should this replication API listener be exposed to the Under **no circumstances** should these replication API listeners be exposed to
public internet; it currently implements no authentication whatsoever and is the public internet; it currently implements no authentication whatsoever and is
unencrypted. unencrypted.
(Roughly, the TCP port is used for streaming data from the master to the
workers, and the HTTP port for the workers to send data to the main
synapse process.)
You then create a set of configs for the various worker processes. These You then create a set of configs for the various worker processes. These
should be worker configuration files, and should be stored in a dedicated should be worker configuration files, and should be stored in a dedicated
subdirectory, to allow synctl to manipulate them. subdirectory, to allow synctl to manipulate them.
@ -52,8 +64,13 @@ You should minimise the number of overrides though to maintain a usable config.
You must specify the type of worker application (``worker_app``). The currently You must specify the type of worker application (``worker_app``). The currently
available worker applications are listed below. You must also specify the available worker applications are listed below. You must also specify the
replication endpoint that it's talking to on the main synapse process replication endpoints that it's talking to on the main synapse process.
(``worker_replication_host`` and ``worker_replication_port``). ``worker_replication_host`` should specify the host of the main synapse,
``worker_replication_port`` should point to the TCP replication listener port and
``worker_replication_http_port`` should point to the HTTP replication port.
Currently, only the ``event_creator`` worker requires specifying
``worker_replication_http_port``.
For instance:: For instance::
@ -62,6 +79,7 @@ For instance::
# The replication listener on the synapse to talk to. # The replication listener on the synapse to talk to.
worker_replication_host: 127.0.0.1 worker_replication_host: 127.0.0.1
worker_replication_port: 9092 worker_replication_port: 9092
worker_replication_http_port: 9093
worker_listeners: worker_listeners:
- type: http - type: http
@ -207,3 +225,14 @@ the ``worker_main_http_uri`` setting in the frontend_proxy worker configuration
file. For example:: file. For example::
worker_main_http_uri: http://127.0.0.1:8008 worker_main_http_uri: http://127.0.0.1:8008
``synapse.app.event_creator``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Handles non-state event creation. It can handle REST endpoints matching::
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send
It will create events locally and then send them on to the main synapse
instance to be persisted and handled.

View File

@ -0,0 +1,133 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Moves a list of remote media from one media store to another.
The input should be a list of media files to be moved, one per line. Each line
should be formatted::
<origin server>|<file id>
This can be extracted from postgres with::
psql --tuples-only -A -c "select media_origin, filesystem_id from
matrix.remote_media_cache where ..."
To use, pipe the above into::
PYTHON_PATH=. ./scripts/move_remote_media_to_new_store.py <source repo> <dest repo>
"""
from __future__ import print_function
import argparse
import logging
import sys
import os
import shutil
from synapse.rest.media.v1.filepath import MediaFilePaths
logger = logging.getLogger()
def main(src_repo, dest_repo):
src_paths = MediaFilePaths(src_repo)
dest_paths = MediaFilePaths(dest_repo)
for line in sys.stdin:
line = line.strip()
parts = line.split('|')
if len(parts) != 2:
print("Unable to parse input line %s" % line, file=sys.stderr)
exit(1)
move_media(parts[0], parts[1], src_paths, dest_paths)
def move_media(origin_server, file_id, src_paths, dest_paths):
"""Move the given file, and any thumbnails, to the dest repo
Args:
origin_server (str):
file_id (str):
src_paths (MediaFilePaths):
dest_paths (MediaFilePaths):
"""
logger.info("%s/%s", origin_server, file_id)
# check that the original exists
original_file = src_paths.remote_media_filepath(origin_server, file_id)
if not os.path.exists(original_file):
logger.warn(
"Original for %s/%s (%s) does not exist",
origin_server, file_id, original_file,
)
else:
mkdir_and_move(
original_file,
dest_paths.remote_media_filepath(origin_server, file_id),
)
# now look for thumbnails
original_thumb_dir = src_paths.remote_media_thumbnail_dir(
origin_server, file_id,
)
if not os.path.exists(original_thumb_dir):
return
mkdir_and_move(
original_thumb_dir,
dest_paths.remote_media_thumbnail_dir(origin_server, file_id)
)
def mkdir_and_move(original_file, dest_file):
dirname = os.path.dirname(dest_file)
if not os.path.exists(dirname):
logger.debug("mkdir %s", dirname)
os.makedirs(dirname)
logger.debug("mv %s %s", original_file, dest_file)
shutil.move(original_file, dest_file)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class = argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"-v", action='store_true', help='enable debug logging')
parser.add_argument(
"src_repo",
help="Path to source content repo",
)
parser.add_argument(
"dest_repo",
help="Path to source content repo",
)
args = parser.parse_args()
logging_config = {
"level": logging.DEBUG if args.v else logging.INFO,
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s"
}
logging.basicConfig(**logging_config)
main(args.src_repo, args.dest_repo)

View File

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.26.1" __version__ = "0.27.0"

View File

@ -46,6 +46,7 @@ class Codes(object):
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED" THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
THREEPID_IN_USE = "M_THREEPID_IN_USE" THREEPID_IN_USE = "M_THREEPID_IN_USE"
THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND" THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND"
THREEPID_DENIED = "M_THREEPID_DENIED"
INVALID_USERNAME = "M_INVALID_USERNAME" INVALID_USERNAME = "M_INVALID_USERNAME"
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED" SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
@ -140,6 +141,32 @@ class RegistrationError(SynapseError):
pass pass
class FederationDeniedError(SynapseError):
"""An error raised when the server tries to federate with a server which
is not on its federation whitelist.
Attributes:
destination (str): The destination which has been denied
"""
def __init__(self, destination):
"""Raised by federation client or server to indicate that we are
are deliberately not attempting to contact a given server because it is
not on our federation whitelist.
Args:
destination (str): the domain in question
"""
self.destination = destination
super(FederationDeniedError, self).__init__(
code=403,
msg="Federation denied with %s." % (self.destination,),
errcode=Codes.FORBIDDEN,
)
class InteractiveAuthIncompleteError(Exception): class InteractiveAuthIncompleteError(Exception):
"""An error raised when UI auth is not yet complete """An error raised when UI auth is not yet complete

View File

@ -25,7 +25,9 @@ except Exception:
from daemonize import Daemonize from daemonize import Daemonize
from synapse.util import PreserveLoggingContext from synapse.util import PreserveLoggingContext
from synapse.util.rlimit import change_resource_limit from synapse.util.rlimit import change_resource_limit
from twisted.internet import reactor from twisted.internet import error, reactor
logger = logging.getLogger(__name__)
def start_worker_reactor(appname, config): def start_worker_reactor(appname, config):
@ -120,3 +122,57 @@ def quit_with_error(error_string):
sys.stderr.write(" %s\n" % (line.rstrip(),)) sys.stderr.write(" %s\n" % (line.rstrip(),))
sys.stderr.write("*" * line_length + '\n') sys.stderr.write("*" * line_length + '\n')
sys.exit(1) sys.exit(1)
def listen_tcp(bind_addresses, port, factory, backlog=50):
"""
Create a TCP socket for a port and several addresses
"""
for address in bind_addresses:
try:
reactor.listenTCP(
port,
factory,
backlog,
address
)
except error.CannotListenError as e:
check_bind_error(e, address, bind_addresses)
def listen_ssl(bind_addresses, port, factory, context_factory, backlog=50):
"""
Create an SSL socket for a port and several addresses
"""
for address in bind_addresses:
try:
reactor.listenSSL(
port,
factory,
context_factory,
backlog,
address
)
except error.CannotListenError as e:
check_bind_error(e, address, bind_addresses)
def check_bind_error(e, address, bind_addresses):
"""
This method checks an exception occurred while binding on 0.0.0.0.
If :: is specified in the bind addresses a warning is shown.
The exception is still raised otherwise.
Binding on both 0.0.0.0 and :: causes an exception on Linux and macOS
because :: binds on both IPv4 and IPv6 (as per RFC 3493).
When binding on 0.0.0.0 after :: this can safely be ignored.
Args:
e (Exception): Exception that was caught.
address (str): Address on which binding was attempted.
bind_addresses (list): Addresses on which the service listens.
"""
if address == '0.0.0.0' and '::' in bind_addresses:
logger.warn('Failed to listen on 0.0.0.0, continuing because listening on [::]')
else:
raise e

View File

@ -49,19 +49,6 @@ class AppserviceSlaveStore(
class AppserviceServer(HomeServer): class AppserviceServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self): def setup(self):
logger.info("Setting up.") logger.info("Setting up.")
self.datastore = AppserviceSlaveStore(self.get_db_conn(), self) self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
@ -79,17 +66,16 @@ class AppserviceServer(HomeServer):
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, Resource())
for address in bind_addresses: _base.listen_tcp(
reactor.listenTCP( bind_addresses,
port, port,
SynapseSite( SynapseSite(
"synapse.access.http.%s" % (site_tag,), "synapse.access.http.%s" % (site_tag,),
site_tag, site_tag,
listener_config, listener_config,
root_resource, root_resource,
),
interface=address
) )
)
logger.info("Synapse appservice now listening on port %d", port) logger.info("Synapse appservice now listening on port %d", port)
@ -98,18 +84,15 @@ class AppserviceServer(HomeServer):
if listener["type"] == "http": if listener["type"] == "http":
self._listen_http(listener) self._listen_http(listener)
elif listener["type"] == "manhole": elif listener["type"] == "manhole":
bind_addresses = listener["bind_addresses"] _base.listen_tcp(
listener["bind_addresses"],
for address in bind_addresses: listener["port"],
reactor.listenTCP( manhole(
listener["port"], username="matrix",
manhole( password="rabbithole",
username="matrix", globals={"hs": self},
password="rabbithole",
globals={"hs": self},
),
interface=address
) )
)
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warn("Unrecognized listener type: %s", listener["type"])

View File

@ -64,19 +64,6 @@ class ClientReaderSlavedStore(
class ClientReaderServer(HomeServer): class ClientReaderServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self): def setup(self):
logger.info("Setting up.") logger.info("Setting up.")
self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self) self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self)
@ -103,17 +90,16 @@ class ClientReaderServer(HomeServer):
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, Resource())
for address in bind_addresses: _base.listen_tcp(
reactor.listenTCP( bind_addresses,
port, port,
SynapseSite( SynapseSite(
"synapse.access.http.%s" % (site_tag,), "synapse.access.http.%s" % (site_tag,),
site_tag, site_tag,
listener_config, listener_config,
root_resource, root_resource,
),
interface=address
) )
)
logger.info("Synapse client reader now listening on port %d", port) logger.info("Synapse client reader now listening on port %d", port)
@ -122,18 +108,16 @@ class ClientReaderServer(HomeServer):
if listener["type"] == "http": if listener["type"] == "http":
self._listen_http(listener) self._listen_http(listener)
elif listener["type"] == "manhole": elif listener["type"] == "manhole":
bind_addresses = listener["bind_addresses"] _base.listen_tcp(
listener["bind_addresses"],
for address in bind_addresses: listener["port"],
reactor.listenTCP( manhole(
listener["port"], username="matrix",
manhole( password="rabbithole",
username="matrix", globals={"hs": self},
password="rabbithole",
globals={"hs": self},
),
interface=address
) )
)
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warn("Unrecognized listener type: %s", listener["type"])
@ -172,7 +156,6 @@ def start(config_options):
) )
ss.setup() ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners) ss.start_listening(config.worker_listeners)
def start(): def start():

View File

@ -0,0 +1,189 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import sys
import synapse
from synapse import events
from synapse.app import _base
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.crypto import context_factory
from synapse.http.server import JsonResource
from synapse.http.site import SynapseSite
from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.profile import SlavedProfileStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.pushers import SlavedPusherStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.client.v1.room import (
RoomSendEventRestServlet, RoomMembershipRestServlet, RoomStateEventRestServlet,
JoinRoomAliasServlet,
)
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
from twisted.internet import reactor
from twisted.web.resource import Resource
logger = logging.getLogger("synapse.app.event_creator")
class EventCreatorSlavedStore(
DirectoryStore,
TransactionStore,
SlavedProfileStore,
SlavedAccountDataStore,
SlavedPusherStore,
SlavedReceiptsStore,
SlavedPushRuleStore,
SlavedDeviceStore,
SlavedClientIpStore,
SlavedApplicationServiceStore,
SlavedEventStore,
SlavedRegistrationStore,
RoomStore,
BaseSlavedStore,
):
pass
class EventCreatorServer(HomeServer):
def setup(self):
logger.info("Setting up.")
self.datastore = EventCreatorSlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_addresses = listener_config["bind_addresses"]
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
elif name == "client":
resource = JsonResource(self, canonical_json=False)
RoomSendEventRestServlet(self).register(resource)
RoomMembershipRestServlet(self).register(resource)
RoomStateEventRestServlet(self).register(resource)
JoinRoomAliasServlet(self).register(resource)
resources.update({
"/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource,
"/_matrix/client/v2_alpha": resource,
"/_matrix/client/api/v1": resource,
})
root_resource = create_resource_tree(resources, Resource())
_base.listen_tcp(
bind_addresses,
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
)
)
logger.info("Synapse event creator now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
_base.listen_tcp(
listener["bind_addresses"],
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
)
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self)
def build_tcp_replication(self):
return ReplicationClientHandler(self.get_datastore())
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse event creator", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.event_creator"
assert config.worker_replication_http_port is not None
setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
ss = EventCreatorServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
ss.start_listening(config.worker_listeners)
def start():
ss.get_state_handler().start_caching()
ss.get_datastore().start_profiling()
reactor.callWhenRunning(start)
_base.start_worker_reactor("synapse-event-creator", config)
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View File

@ -58,19 +58,6 @@ class FederationReaderSlavedStore(
class FederationReaderServer(HomeServer): class FederationReaderServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self): def setup(self):
logger.info("Setting up.") logger.info("Setting up.")
self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self) self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)
@ -92,17 +79,16 @@ class FederationReaderServer(HomeServer):
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, Resource())
for address in bind_addresses: _base.listen_tcp(
reactor.listenTCP( bind_addresses,
port, port,
SynapseSite( SynapseSite(
"synapse.access.http.%s" % (site_tag,), "synapse.access.http.%s" % (site_tag,),
site_tag, site_tag,
listener_config, listener_config,
root_resource, root_resource,
),
interface=address
) )
)
logger.info("Synapse federation reader now listening on port %d", port) logger.info("Synapse federation reader now listening on port %d", port)
@ -111,18 +97,15 @@ class FederationReaderServer(HomeServer):
if listener["type"] == "http": if listener["type"] == "http":
self._listen_http(listener) self._listen_http(listener)
elif listener["type"] == "manhole": elif listener["type"] == "manhole":
bind_addresses = listener["bind_addresses"] _base.listen_tcp(
listener["bind_addresses"],
for address in bind_addresses: listener["port"],
reactor.listenTCP( manhole(
listener["port"], username="matrix",
manhole( password="rabbithole",
username="matrix", globals={"hs": self},
password="rabbithole",
globals={"hs": self},
),
interface=address
) )
)
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warn("Unrecognized listener type: %s", listener["type"])
@ -161,7 +144,6 @@ def start(config_options):
) )
ss.setup() ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners) ss.start_listening(config.worker_listeners)
def start(): def start():

View File

@ -76,19 +76,6 @@ class FederationSenderSlaveStore(
class FederationSenderServer(HomeServer): class FederationSenderServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self): def setup(self):
logger.info("Setting up.") logger.info("Setting up.")
self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self) self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self)
@ -106,17 +93,16 @@ class FederationSenderServer(HomeServer):
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, Resource())
for address in bind_addresses: _base.listen_tcp(
reactor.listenTCP( bind_addresses,
port, port,
SynapseSite( SynapseSite(
"synapse.access.http.%s" % (site_tag,), "synapse.access.http.%s" % (site_tag,),
site_tag, site_tag,
listener_config, listener_config,
root_resource, root_resource,
),
interface=address
) )
)
logger.info("Synapse federation_sender now listening on port %d", port) logger.info("Synapse federation_sender now listening on port %d", port)
@ -125,18 +111,15 @@ class FederationSenderServer(HomeServer):
if listener["type"] == "http": if listener["type"] == "http":
self._listen_http(listener) self._listen_http(listener)
elif listener["type"] == "manhole": elif listener["type"] == "manhole":
bind_addresses = listener["bind_addresses"] _base.listen_tcp(
listener["bind_addresses"],
for address in bind_addresses: listener["port"],
reactor.listenTCP( manhole(
listener["port"], username="matrix",
manhole( password="rabbithole",
username="matrix", globals={"hs": self},
password="rabbithole",
globals={"hs": self},
),
interface=address
) )
)
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warn("Unrecognized listener type: %s", listener["type"])

View File

@ -118,19 +118,6 @@ class FrontendProxySlavedStore(
class FrontendProxyServer(HomeServer): class FrontendProxyServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self): def setup(self):
logger.info("Setting up.") logger.info("Setting up.")
self.datastore = FrontendProxySlavedStore(self.get_db_conn(), self) self.datastore = FrontendProxySlavedStore(self.get_db_conn(), self)
@ -157,17 +144,16 @@ class FrontendProxyServer(HomeServer):
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, Resource())
for address in bind_addresses: _base.listen_tcp(
reactor.listenTCP( bind_addresses,
port, port,
SynapseSite( SynapseSite(
"synapse.access.http.%s" % (site_tag,), "synapse.access.http.%s" % (site_tag,),
site_tag, site_tag,
listener_config, listener_config,
root_resource, root_resource,
),
interface=address
) )
)
logger.info("Synapse client reader now listening on port %d", port) logger.info("Synapse client reader now listening on port %d", port)
@ -176,18 +162,15 @@ class FrontendProxyServer(HomeServer):
if listener["type"] == "http": if listener["type"] == "http":
self._listen_http(listener) self._listen_http(listener)
elif listener["type"] == "manhole": elif listener["type"] == "manhole":
bind_addresses = listener["bind_addresses"] _base.listen_tcp(
listener["bind_addresses"],
for address in bind_addresses: listener["port"],
reactor.listenTCP( manhole(
listener["port"], username="matrix",
manhole( password="rabbithole",
username="matrix", globals={"hs": self},
password="rabbithole",
globals={"hs": self},
),
interface=address
) )
)
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warn("Unrecognized listener type: %s", listener["type"])
@ -228,7 +211,6 @@ def start(config_options):
) )
ss.setup() ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners) ss.start_listening(config.worker_listeners)
def start(): def start():

View File

@ -25,7 +25,7 @@ from synapse.api.urls import CONTENT_REPO_PREFIX, FEDERATION_PREFIX, \
LEGACY_MEDIA_PREFIX, MEDIA_PREFIX, SERVER_KEY_PREFIX, SERVER_KEY_V2_PREFIX, \ LEGACY_MEDIA_PREFIX, MEDIA_PREFIX, SERVER_KEY_PREFIX, SERVER_KEY_V2_PREFIX, \
STATIC_PREFIX, WEB_CLIENT_PREFIX STATIC_PREFIX, WEB_CLIENT_PREFIX
from synapse.app import _base from synapse.app import _base
from synapse.app._base import quit_with_error from synapse.app._base import quit_with_error, listen_ssl, listen_tcp
from synapse.config._base import ConfigError from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory from synapse.crypto import context_factory
@ -38,6 +38,7 @@ from synapse.metrics import register_memory_metrics
from synapse.metrics.resource import METRICS_PREFIX, MetricsResource from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.python_dependencies import CONDITIONAL_REQUIREMENTS, \ from synapse.python_dependencies import CONDITIONAL_REQUIREMENTS, \
check_requirements check_requirements
from synapse.replication.http import ReplicationRestResource, REPLICATION_PREFIX
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.rest import ClientRestResource from synapse.rest import ClientRestResource
from synapse.rest.key.v1.server_key_resource import LocalKey from synapse.rest.key.v1.server_key_resource import LocalKey
@ -130,30 +131,29 @@ class SynapseHomeServer(HomeServer):
root_resource = create_resource_tree(resources, root_resource) root_resource = create_resource_tree(resources, root_resource)
if tls: if tls:
for address in bind_addresses: listen_ssl(
reactor.listenSSL( bind_addresses,
port, port,
SynapseSite( SynapseSite(
"synapse.access.https.%s" % (site_tag,), "synapse.access.https.%s" % (site_tag,),
site_tag, site_tag,
listener_config, listener_config,
root_resource, root_resource,
), ),
self.tls_server_context_factory, self.tls_server_context_factory,
interface=address )
)
else: else:
for address in bind_addresses: listen_tcp(
reactor.listenTCP( bind_addresses,
port, port,
SynapseSite( SynapseSite(
"synapse.access.http.%s" % (site_tag,), "synapse.access.http.%s" % (site_tag,),
site_tag, site_tag,
listener_config, listener_config,
root_resource, root_resource,
),
interface=address
) )
)
logger.info("Synapse now listening on port %d", port) logger.info("Synapse now listening on port %d", port)
def _configure_named_resource(self, name, compress=False): def _configure_named_resource(self, name, compress=False):
@ -220,6 +220,9 @@ class SynapseHomeServer(HomeServer):
if name == "metrics" and self.get_config().enable_metrics: if name == "metrics" and self.get_config().enable_metrics:
resources[METRICS_PREFIX] = MetricsResource(self) resources[METRICS_PREFIX] = MetricsResource(self)
if name == "replication":
resources[REPLICATION_PREFIX] = ReplicationRestResource(self)
return resources return resources
def start_listening(self): def start_listening(self):
@ -229,18 +232,15 @@ class SynapseHomeServer(HomeServer):
if listener["type"] == "http": if listener["type"] == "http":
self._listener_http(config, listener) self._listener_http(config, listener)
elif listener["type"] == "manhole": elif listener["type"] == "manhole":
bind_addresses = listener["bind_addresses"] listen_tcp(
listener["bind_addresses"],
for address in bind_addresses: listener["port"],
reactor.listenTCP( manhole(
listener["port"], username="matrix",
manhole( password="rabbithole",
username="matrix", globals={"hs": self},
password="rabbithole",
globals={"hs": self},
),
interface=address
) )
)
elif listener["type"] == "replication": elif listener["type"] == "replication":
bind_addresses = listener["bind_addresses"] bind_addresses = listener["bind_addresses"]
for address in bind_addresses: for address in bind_addresses:
@ -270,19 +270,6 @@ class SynapseHomeServer(HomeServer):
except IncorrectDatabaseSetup as e: except IncorrectDatabaseSetup as e:
quit_with_error(e.message) quit_with_error(e.message)
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(config_options): def setup(config_options):
""" """
@ -361,7 +348,7 @@ def setup(config_options):
hs.get_state_handler().start_caching() hs.get_state_handler().start_caching()
hs.get_datastore().start_profiling() hs.get_datastore().start_profiling()
hs.get_datastore().start_doing_background_updates() hs.get_datastore().start_doing_background_updates()
hs.get_replication_layer().start_get_pdu_cache() hs.get_federation_client().start_get_pdu_cache()
register_memory_metrics(hs) register_memory_metrics(hs)

View File

@ -60,19 +60,6 @@ class MediaRepositorySlavedStore(
class MediaRepositoryServer(HomeServer): class MediaRepositoryServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self): def setup(self):
logger.info("Setting up.") logger.info("Setting up.")
self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self) self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
@ -99,17 +86,16 @@ class MediaRepositoryServer(HomeServer):
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, Resource())
for address in bind_addresses: _base.listen_tcp(
reactor.listenTCP( bind_addresses,
port, port,
SynapseSite( SynapseSite(
"synapse.access.http.%s" % (site_tag,), "synapse.access.http.%s" % (site_tag,),
site_tag, site_tag,
listener_config, listener_config,
root_resource, root_resource,
),
interface=address
) )
)
logger.info("Synapse media repository now listening on port %d", port) logger.info("Synapse media repository now listening on port %d", port)
@ -118,18 +104,15 @@ class MediaRepositoryServer(HomeServer):
if listener["type"] == "http": if listener["type"] == "http":
self._listen_http(listener) self._listen_http(listener)
elif listener["type"] == "manhole": elif listener["type"] == "manhole":
bind_addresses = listener["bind_addresses"] _base.listen_tcp(
listener["bind_addresses"],
for address in bind_addresses: listener["port"],
reactor.listenTCP( manhole(
listener["port"], username="matrix",
manhole( password="rabbithole",
username="matrix", globals={"hs": self},
password="rabbithole",
globals={"hs": self},
),
interface=address
) )
)
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warn("Unrecognized listener type: %s", listener["type"])
@ -175,7 +158,6 @@ def start(config_options):
) )
ss.setup() ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners) ss.start_listening(config.worker_listeners)
def start(): def start():

View File

@ -32,7 +32,6 @@ from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage.roommember import RoomMemberStore
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext, preserve_fn from synapse.util.logcontext import LoggingContext, preserve_fn
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
@ -75,25 +74,8 @@ class PusherSlaveStore(
DataStore.get_profile_displayname.__func__ DataStore.get_profile_displayname.__func__
) )
who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"]
)
class PusherServer(HomeServer): class PusherServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self): def setup(self):
logger.info("Setting up.") logger.info("Setting up.")
self.datastore = PusherSlaveStore(self.get_db_conn(), self) self.datastore = PusherSlaveStore(self.get_db_conn(), self)
@ -114,17 +96,16 @@ class PusherServer(HomeServer):
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, Resource())
for address in bind_addresses: _base.listen_tcp(
reactor.listenTCP( bind_addresses,
port, port,
SynapseSite( SynapseSite(
"synapse.access.http.%s" % (site_tag,), "synapse.access.http.%s" % (site_tag,),
site_tag, site_tag,
listener_config, listener_config,
root_resource, root_resource,
),
interface=address
) )
)
logger.info("Synapse pusher now listening on port %d", port) logger.info("Synapse pusher now listening on port %d", port)
@ -133,18 +114,15 @@ class PusherServer(HomeServer):
if listener["type"] == "http": if listener["type"] == "http":
self._listen_http(listener) self._listen_http(listener)
elif listener["type"] == "manhole": elif listener["type"] == "manhole":
bind_addresses = listener["bind_addresses"] _base.listen_tcp(
listener["bind_addresses"],
for address in bind_addresses: listener["port"],
reactor.listenTCP( manhole(
listener["port"], username="matrix",
manhole( password="rabbithole",
username="matrix", globals={"hs": self},
password="rabbithole",
globals={"hs": self},
),
interface=address
) )
)
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warn("Unrecognized listener type: %s", listener["type"])

View File

@ -62,8 +62,6 @@ logger = logging.getLogger("synapse.app.synchrotron")
class SynchrotronSlavedStore( class SynchrotronSlavedStore(
SlavedPushRuleStore,
SlavedEventStore,
SlavedReceiptsStore, SlavedReceiptsStore,
SlavedAccountDataStore, SlavedAccountDataStore,
SlavedApplicationServiceStore, SlavedApplicationServiceStore,
@ -73,14 +71,12 @@ class SynchrotronSlavedStore(
SlavedGroupServerStore, SlavedGroupServerStore,
SlavedDeviceInboxStore, SlavedDeviceInboxStore,
SlavedDeviceStore, SlavedDeviceStore,
SlavedPushRuleStore,
SlavedEventStore,
SlavedClientIpStore, SlavedClientIpStore,
RoomStore, RoomStore,
BaseSlavedStore, BaseSlavedStore,
): ):
who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"]
)
did_forget = ( did_forget = (
RoomMemberStore.__dict__["did_forget"] RoomMemberStore.__dict__["did_forget"]
) )
@ -246,19 +242,6 @@ class SynchrotronApplicationService(object):
class SynchrotronServer(HomeServer): class SynchrotronServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self): def setup(self):
logger.info("Setting up.") logger.info("Setting up.")
self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self) self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self)
@ -288,17 +271,16 @@ class SynchrotronServer(HomeServer):
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, Resource())
for address in bind_addresses: _base.listen_tcp(
reactor.listenTCP( bind_addresses,
port, port,
SynapseSite( SynapseSite(
"synapse.access.http.%s" % (site_tag,), "synapse.access.http.%s" % (site_tag,),
site_tag, site_tag,
listener_config, listener_config,
root_resource, root_resource,
),
interface=address
) )
)
logger.info("Synapse synchrotron now listening on port %d", port) logger.info("Synapse synchrotron now listening on port %d", port)
@ -307,18 +289,15 @@ class SynchrotronServer(HomeServer):
if listener["type"] == "http": if listener["type"] == "http":
self._listen_http(listener) self._listen_http(listener)
elif listener["type"] == "manhole": elif listener["type"] == "manhole":
bind_addresses = listener["bind_addresses"] _base.listen_tcp(
listener["bind_addresses"],
for address in bind_addresses: listener["port"],
reactor.listenTCP( manhole(
listener["port"], username="matrix",
manhole( password="rabbithole",
username="matrix", globals={"hs": self},
password="rabbithole",
globals={"hs": self},
),
interface=address
) )
)
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warn("Unrecognized listener type: %s", listener["type"])

View File

@ -184,6 +184,9 @@ def main():
worker_configfiles.append(worker_configfile) worker_configfiles.append(worker_configfile)
if options.all_processes: if options.all_processes:
# To start the main synapse with -a you need to add a worker file
# with worker_app == "synapse.app.homeserver"
start_stop_synapse = False
worker_configdir = options.all_processes worker_configdir = options.all_processes
if not os.path.isdir(worker_configdir): if not os.path.isdir(worker_configdir):
write( write(
@ -200,11 +203,29 @@ def main():
with open(worker_configfile) as stream: with open(worker_configfile) as stream:
worker_config = yaml.load(stream) worker_config = yaml.load(stream)
worker_app = worker_config["worker_app"] worker_app = worker_config["worker_app"]
worker_pidfile = worker_config["worker_pid_file"] if worker_app == "synapse.app.homeserver":
worker_daemonize = worker_config["worker_daemonize"] # We need to special case all of this to pick up options that may
assert worker_daemonize, "In config %r: expected '%s' to be True" % ( # be set in the main config file or in this worker config file.
worker_configfile, "worker_daemonize") worker_pidfile = (
worker_cache_factor = worker_config.get("synctl_cache_factor") worker_config.get("pid_file")
or pidfile
)
worker_cache_factor = worker_config.get("synctl_cache_factor") or cache_factor
daemonize = worker_config.get("daemonize") or config.get("daemonize")
assert daemonize, "Main process must have daemonize set to true"
# The master process doesn't support using worker_* config.
for key in worker_config:
if key == "worker_app": # But we allow worker_app
continue
assert not key.startswith("worker_"), \
"Main process cannot use worker_* config"
else:
worker_pidfile = worker_config["worker_pid_file"]
worker_daemonize = worker_config["worker_daemonize"]
assert worker_daemonize, "In config %r: expected '%s' to be True" % (
worker_configfile, "worker_daemonize")
worker_cache_factor = worker_config.get("synctl_cache_factor")
workers.append(Worker( workers.append(Worker(
worker_app, worker_configfile, worker_pidfile, worker_cache_factor, worker_app, worker_configfile, worker_pidfile, worker_cache_factor,
)) ))

View File

@ -92,19 +92,6 @@ class UserDirectorySlaveStore(
class UserDirectoryServer(HomeServer): class UserDirectoryServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self): def setup(self):
logger.info("Setting up.") logger.info("Setting up.")
self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self) self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self)
@ -131,17 +118,16 @@ class UserDirectoryServer(HomeServer):
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, Resource())
for address in bind_addresses: _base.listen_tcp(
reactor.listenTCP( bind_addresses,
port, port,
SynapseSite( SynapseSite(
"synapse.access.http.%s" % (site_tag,), "synapse.access.http.%s" % (site_tag,),
site_tag, site_tag,
listener_config, listener_config,
root_resource, root_resource,
),
interface=address
) )
)
logger.info("Synapse user_dir now listening on port %d", port) logger.info("Synapse user_dir now listening on port %d", port)
@ -150,18 +136,15 @@ class UserDirectoryServer(HomeServer):
if listener["type"] == "http": if listener["type"] == "http":
self._listen_http(listener) self._listen_http(listener)
elif listener["type"] == "manhole": elif listener["type"] == "manhole":
bind_addresses = listener["bind_addresses"] _base.listen_tcp(
listener["bind_addresses"],
for address in bind_addresses: listener["port"],
reactor.listenTCP( manhole(
listener["port"], username="matrix",
manhole( password="rabbithole",
username="matrix", globals={"hs": self},
password="rabbithole",
globals={"hs": self},
),
interface=address
) )
)
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warn("Unrecognized listener type: %s", listener["type"])

View File

@ -28,27 +28,27 @@ DEFAULT_LOG_CONFIG = Template("""
version: 1 version: 1
formatters: formatters:
precise: precise:
format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s\ format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - \
- %(message)s' %(request)s - %(message)s'
filters: filters:
context: context:
(): synapse.util.logcontext.LoggingContextFilter (): synapse.util.logcontext.LoggingContextFilter
request: "" request: ""
handlers: handlers:
file: file:
class: logging.handlers.RotatingFileHandler class: logging.handlers.RotatingFileHandler
formatter: precise formatter: precise
filename: ${log_file} filename: ${log_file}
maxBytes: 104857600 maxBytes: 104857600
backupCount: 10 backupCount: 10
filters: [context] filters: [context]
console: console:
class: logging.StreamHandler class: logging.StreamHandler
formatter: precise formatter: precise
filters: [context] filters: [context]
loggers: loggers:
synapse: synapse:
@ -74,17 +74,10 @@ class LoggingConfig(Config):
self.log_file = self.abspath(config.get("log_file")) self.log_file = self.abspath(config.get("log_file"))
def default_config(self, config_dir_path, server_name, **kwargs): def default_config(self, config_dir_path, server_name, **kwargs):
log_file = self.abspath("homeserver.log")
log_config = self.abspath( log_config = self.abspath(
os.path.join(config_dir_path, server_name + ".log.config") os.path.join(config_dir_path, server_name + ".log.config")
) )
return """ return """
# Logging verbosity level. Ignored if log_config is specified.
verbose: 0
# File to write logging to. Ignored if log_config is specified.
log_file: "%(log_file)s"
# A yaml python logging config file # A yaml python logging config file
log_config: "%(log_config)s" log_config: "%(log_config)s"
""" % locals() """ % locals()
@ -123,9 +116,10 @@ class LoggingConfig(Config):
def generate_files(self, config): def generate_files(self, config):
log_config = config.get("log_config") log_config = config.get("log_config")
if log_config and not os.path.exists(log_config): if log_config and not os.path.exists(log_config):
log_file = self.abspath("homeserver.log")
with open(log_config, "wb") as log_config_file: with open(log_config, "wb") as log_config_file:
log_config_file.write( log_config_file.write(
DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"]) DEFAULT_LOG_CONFIG.substitute(log_file=log_file)
) )
@ -150,6 +144,9 @@ def setup_logging(config, use_worker_options=False):
) )
if log_config is None: if log_config is None:
# We don't have a logfile, so fall back to the 'verbosity' param from
# the config or cmdline. (Note that we generate a log config for new
# installs, so this will be an unusual case)
level = logging.INFO level = logging.INFO
level_for_storage = logging.INFO level_for_storage = logging.INFO
if config.verbosity: if config.verbosity:
@ -157,11 +154,10 @@ def setup_logging(config, use_worker_options=False):
if config.verbosity > 1: if config.verbosity > 1:
level_for_storage = logging.DEBUG level_for_storage = logging.DEBUG
# FIXME: we need a logging.WARN for a -q quiet option
logger = logging.getLogger('') logger = logging.getLogger('')
logger.setLevel(level) logger.setLevel(level)
logging.getLogger('synapse.storage').setLevel(level_for_storage) logging.getLogger('synapse.storage.SQL').setLevel(level_for_storage)
formatter = logging.Formatter(log_format) formatter = logging.Formatter(log_format)
if log_file: if log_file:

View File

@ -29,10 +29,10 @@ class PasswordAuthProviderConfig(Config):
# param. # param.
ldap_config = config.get("ldap_config", {}) ldap_config = config.get("ldap_config", {})
if ldap_config.get("enabled", False): if ldap_config.get("enabled", False):
providers.append[{ providers.append({
'module': LDAP_PROVIDER, 'module': LDAP_PROVIDER,
'config': ldap_config, 'config': ldap_config,
}] })
providers.extend(config.get("password_providers", [])) providers.extend(config.get("password_providers", []))
for provider in providers: for provider in providers:

View File

@ -31,6 +31,8 @@ class RegistrationConfig(Config):
strtobool(str(config["disable_registration"])) strtobool(str(config["disable_registration"]))
) )
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
self.registration_shared_secret = config.get("registration_shared_secret") self.registration_shared_secret = config.get("registration_shared_secret")
self.bcrypt_rounds = config.get("bcrypt_rounds", 12) self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
@ -52,13 +54,32 @@ class RegistrationConfig(Config):
# Enable registration for new users. # Enable registration for new users.
enable_registration: False enable_registration: False
# The user must provide all of the below types of 3PID when registering.
#
# registrations_require_3pid:
# - email
# - msisdn
# Mandate that users are only allowed to associate certain formats of
# 3PIDs with accounts on this server.
#
# allowed_local_3pids:
# - medium: email
# pattern: ".*@matrix\\.org"
# - medium: email
# pattern: ".*@vector\\.im"
# - medium: msisdn
# pattern: "\\+44"
# If set, allows registration by anyone who also has the shared # If set, allows registration by anyone who also has the shared
# secret, even if registration is otherwise disabled. # secret, even if registration is otherwise disabled.
registration_shared_secret: "%(registration_shared_secret)s" registration_shared_secret: "%(registration_shared_secret)s"
# Set the number of bcrypt rounds used to generate password hash. # Set the number of bcrypt rounds used to generate password hash.
# Larger numbers increase the work factor needed to generate the hash. # Larger numbers increase the work factor needed to generate the hash.
# The default number of rounds is 12. # The default number is 12 (which equates to 2^12 rounds).
# N.B. that increasing this will exponentially increase the time required
# to register or login - e.g. 24 => 2^24 rounds which will take >20 mins.
bcrypt_rounds: 12 bcrypt_rounds: 12
# Allows users to register as guests without a password/email/etc, and # Allows users to register as guests without a password/email/etc, and

View File

@ -16,6 +16,8 @@
from ._base import Config, ConfigError from ._base import Config, ConfigError
from collections import namedtuple from collections import namedtuple
from synapse.util.module_loader import load_module
MISSING_NETADDR = ( MISSING_NETADDR = (
"Missing netaddr library. This is required for URL preview API." "Missing netaddr library. This is required for URL preview API."
@ -36,6 +38,14 @@ ThumbnailRequirement = namedtuple(
"ThumbnailRequirement", ["width", "height", "method", "media_type"] "ThumbnailRequirement", ["width", "height", "method", "media_type"]
) )
MediaStorageProviderConfig = namedtuple(
"MediaStorageProviderConfig", (
"store_local", # Whether to store newly uploaded local files
"store_remote", # Whether to store newly downloaded remote files
"store_synchronous", # Whether to wait for successful storage for local uploads
),
)
def parse_thumbnail_requirements(thumbnail_sizes): def parse_thumbnail_requirements(thumbnail_sizes):
""" Takes a list of dictionaries with "width", "height", and "method" keys """ Takes a list of dictionaries with "width", "height", and "method" keys
@ -73,16 +83,61 @@ class ContentRepositoryConfig(Config):
self.media_store_path = self.ensure_directory(config["media_store_path"]) self.media_store_path = self.ensure_directory(config["media_store_path"])
self.backup_media_store_path = config.get("backup_media_store_path") backup_media_store_path = config.get("backup_media_store_path")
if self.backup_media_store_path:
self.backup_media_store_path = self.ensure_directory(
self.backup_media_store_path
)
self.synchronous_backup_media_store = config.get( synchronous_backup_media_store = config.get(
"synchronous_backup_media_store", False "synchronous_backup_media_store", False
) )
storage_providers = config.get("media_storage_providers", [])
if backup_media_store_path:
if storage_providers:
raise ConfigError(
"Cannot use both 'backup_media_store_path' and 'storage_providers'"
)
storage_providers = [{
"module": "file_system",
"store_local": True,
"store_synchronous": synchronous_backup_media_store,
"store_remote": True,
"config": {
"directory": backup_media_store_path,
}
}]
# This is a list of config that can be used to create the storage
# providers. The entries are tuples of (Class, class_config,
# MediaStorageProviderConfig), where Class is the class of the provider,
# the class_config the config to pass to it, and
# MediaStorageProviderConfig are options for StorageProviderWrapper.
#
# We don't create the storage providers here as not all workers need
# them to be started.
self.media_storage_providers = []
for provider_config in storage_providers:
# We special case the module "file_system" so as not to need to
# expose FileStorageProviderBackend
if provider_config["module"] == "file_system":
provider_config["module"] = (
"synapse.rest.media.v1.storage_provider"
".FileStorageProviderBackend"
)
provider_class, parsed_config = load_module(provider_config)
wrapper_config = MediaStorageProviderConfig(
provider_config.get("store_local", False),
provider_config.get("store_remote", False),
provider_config.get("store_synchronous", False),
)
self.media_storage_providers.append(
(provider_class, parsed_config, wrapper_config,)
)
self.uploads_path = self.ensure_directory(config["uploads_path"]) self.uploads_path = self.ensure_directory(config["uploads_path"])
self.dynamic_thumbnails = config["dynamic_thumbnails"] self.dynamic_thumbnails = config["dynamic_thumbnails"]
self.thumbnail_requirements = parse_thumbnail_requirements( self.thumbnail_requirements = parse_thumbnail_requirements(
@ -127,13 +182,19 @@ class ContentRepositoryConfig(Config):
# Directory where uploaded images and attachments are stored. # Directory where uploaded images and attachments are stored.
media_store_path: "%(media_store)s" media_store_path: "%(media_store)s"
# A secondary directory where uploaded images and attachments are # Media storage providers allow media to be stored in different
# stored as a backup. # locations.
# backup_media_store_path: "%(media_store)s" # media_storage_providers:
# - module: file_system
# Whether to wait for successful write to backup media store before # # Whether to write new local files.
# returning successfully. # store_local: false
# synchronous_backup_media_store: false # # Whether to write new remote media
# store_remote: false
# # Whether to block upload requests waiting for write to this
# # provider to complete
# store_synchronous: false
# config:
# directory: /mnt/some/other/directory
# Directory where in-progress uploads are stored. # Directory where in-progress uploads are stored.
uploads_path: "%(uploads_path)s" uploads_path: "%(uploads_path)s"

View File

@ -55,6 +55,17 @@ class ServerConfig(Config):
"block_non_admin_invites", False, "block_non_admin_invites", False,
) )
# FIXME: federation_domain_whitelist needs sytests
self.federation_domain_whitelist = None
federation_domain_whitelist = config.get(
"federation_domain_whitelist", None
)
# turn the whitelist into a hash for speed of lookup
if federation_domain_whitelist is not None:
self.federation_domain_whitelist = {}
for domain in federation_domain_whitelist:
self.federation_domain_whitelist[domain] = True
if self.public_baseurl is not None: if self.public_baseurl is not None:
if self.public_baseurl[-1] != '/': if self.public_baseurl[-1] != '/':
self.public_baseurl += '/' self.public_baseurl += '/'
@ -210,6 +221,17 @@ class ServerConfig(Config):
# (except those sent by local server admins). The default is False. # (except those sent by local server admins). The default is False.
# block_non_admin_invites: True # block_non_admin_invites: True
# Restrict federation to the following whitelist of domains.
# N.B. we recommend also firewalling your federation listener to limit
# inbound federation traffic as early as possible, rather than relying
# purely on this application-layer restriction. If not specified, the
# default is to whitelist everything.
#
# federation_domain_whitelist:
# - lon.example.com
# - nyc.example.com
# - syd.example.com
# List of ports that Synapse should listen on, their purpose and their # List of ports that Synapse should listen on, their purpose and their
# configuration. # configuration.
listeners: listeners:
@ -220,13 +242,12 @@ class ServerConfig(Config):
port: %(bind_port)s port: %(bind_port)s
# Local addresses to listen on. # Local addresses to listen on.
# This will listen on all IPv4 addresses by default. # On Linux and Mac OS, `::` will listen on all IPv4 and IPv6
# addresses by default. For most other OSes, this will only listen
# on IPv6.
bind_addresses: bind_addresses:
- '::'
- '0.0.0.0' - '0.0.0.0'
# Uncomment to listen on all IPv6 interfaces
# N.B: On at least Linux this will also listen on all IPv4
# addresses, so you will need to comment out the line above.
# - '::'
# This is a 'http' listener, allows us to specify 'resources'. # This is a 'http' listener, allows us to specify 'resources'.
type: http type: http
@ -264,7 +285,7 @@ class ServerConfig(Config):
# For when matrix traffic passes through loadbalancer that unwraps TLS. # For when matrix traffic passes through loadbalancer that unwraps TLS.
- port: %(unsecure_port)s - port: %(unsecure_port)s
tls: false tls: false
bind_addresses: ['0.0.0.0'] bind_addresses: ['::', '0.0.0.0']
type: http type: http
x_forwarded: false x_forwarded: false
@ -278,7 +299,7 @@ class ServerConfig(Config):
# Turn on the twisted ssh manhole service on localhost on the given # Turn on the twisted ssh manhole service on localhost on the given
# port. # port.
# - port: 9000 # - port: 9000
# bind_address: 127.0.0.1 # bind_addresses: ['::1', '127.0.0.1']
# type: manhole # type: manhole
""" % locals() """ % locals()

View File

@ -96,7 +96,7 @@ class TlsConfig(Config):
# certificates returned by this server match one of the fingerprints. # certificates returned by this server match one of the fingerprints.
# #
# Synapse automatically adds the fingerprint of its own certificate # Synapse automatically adds the fingerprint of its own certificate
# to the list. So if federation traffic is handle directly by synapse # to the list. So if federation traffic is handled directly by synapse
# then no modification to the list is required. # then no modification to the list is required.
# #
# If synapse is run behind a load balancer that handles the TLS then it # If synapse is run behind a load balancer that handles the TLS then it

View File

@ -23,13 +23,26 @@ class WorkerConfig(Config):
def read_config(self, config): def read_config(self, config):
self.worker_app = config.get("worker_app") self.worker_app = config.get("worker_app")
# Canonicalise worker_app so that master always has None
if self.worker_app == "synapse.app.homeserver":
self.worker_app = None
self.worker_listeners = config.get("worker_listeners") self.worker_listeners = config.get("worker_listeners")
self.worker_daemonize = config.get("worker_daemonize") self.worker_daemonize = config.get("worker_daemonize")
self.worker_pid_file = config.get("worker_pid_file") self.worker_pid_file = config.get("worker_pid_file")
self.worker_log_file = config.get("worker_log_file") self.worker_log_file = config.get("worker_log_file")
self.worker_log_config = config.get("worker_log_config") self.worker_log_config = config.get("worker_log_config")
# The host used to connect to the main synapse
self.worker_replication_host = config.get("worker_replication_host", None) self.worker_replication_host = config.get("worker_replication_host", None)
# The port on the main synapse for TCP replication
self.worker_replication_port = config.get("worker_replication_port", None) self.worker_replication_port = config.get("worker_replication_port", None)
# The port on the main synapse for HTTP replication endpoint
self.worker_replication_http_port = config.get("worker_replication_http_port")
self.worker_name = config.get("worker_name", self.worker_app) self.worker_name = config.get("worker_name", self.worker_app)
self.worker_main_http_uri = config.get("worker_main_http_uri", None) self.worker_main_http_uri = config.get("worker_main_http_uri", None)

View File

@ -319,7 +319,7 @@ def _is_membership_change_allowed(event, auth_events):
# TODO (erikj): Implement kicks. # TODO (erikj): Implement kicks.
if target_banned and user_level < ban_level: if target_banned and user_level < ban_level:
raise AuthError( raise AuthError(
403, "You cannot unban user &s." % (target_user_id,) 403, "You cannot unban user %s." % (target_user_id,)
) )
elif target_user_id != event.user_id: elif target_user_id != event.user_id:
kick_level = _get_named_level(auth_events, "kick", 50) kick_level = _get_named_level(auth_events, "kick", 50)

View File

@ -13,6 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from frozendict import frozendict
class EventContext(object): class EventContext(object):
""" """
@ -25,7 +29,9 @@ class EventContext(object):
The current state map excluding the current event. The current state map excluding the current event.
(type, state_key) -> event_id (type, state_key) -> event_id
state_group (int): state group id state_group (int|None): state group id, if the state has been stored
as a state group. This is usually only None if e.g. the event is
an outlier.
rejected (bool|str): A rejection reason if the event was rejected, else rejected (bool|str): A rejection reason if the event was rejected, else
False False
@ -46,7 +52,6 @@ class EventContext(object):
"prev_state_ids", "prev_state_ids",
"state_group", "state_group",
"rejected", "rejected",
"push_actions",
"prev_group", "prev_group",
"delta_ids", "delta_ids",
"prev_state_events", "prev_state_events",
@ -61,7 +66,6 @@ class EventContext(object):
self.state_group = None self.state_group = None
self.rejected = False self.rejected = False
self.push_actions = []
# A previously persisted state group and a delta between that # A previously persisted state group and a delta between that
# and this state. # and this state.
@ -71,3 +75,98 @@ class EventContext(object):
self.prev_state_events = None self.prev_state_events = None
self.app_service = None self.app_service = None
def serialize(self, event):
"""Converts self to a type that can be serialized as JSON, and then
deserialized by `deserialize`
Args:
event (FrozenEvent): The event that this context relates to
Returns:
dict
"""
# We don't serialize the full state dicts, instead they get pulled out
# of the DB on the other side. However, the other side can't figure out
# the prev_state_ids, so if we're a state event we include the event
# id that we replaced in the state.
if event.is_state():
prev_state_id = self.prev_state_ids.get((event.type, event.state_key))
else:
prev_state_id = None
return {
"prev_state_id": prev_state_id,
"event_type": event.type,
"event_state_key": event.state_key if event.is_state() else None,
"state_group": self.state_group,
"rejected": self.rejected,
"prev_group": self.prev_group,
"delta_ids": _encode_state_dict(self.delta_ids),
"prev_state_events": self.prev_state_events,
"app_service_id": self.app_service.id if self.app_service else None
}
@staticmethod
@defer.inlineCallbacks
def deserialize(store, input):
"""Converts a dict that was produced by `serialize` back into a
EventContext.
Args:
store (DataStore): Used to convert AS ID to AS object
input (dict): A dict produced by `serialize`
Returns:
EventContext
"""
context = EventContext()
context.state_group = input["state_group"]
context.rejected = input["rejected"]
context.prev_group = input["prev_group"]
context.delta_ids = _decode_state_dict(input["delta_ids"])
context.prev_state_events = input["prev_state_events"]
# We use the state_group and prev_state_id stuff to pull the
# current_state_ids out of the DB and construct prev_state_ids.
prev_state_id = input["prev_state_id"]
event_type = input["event_type"]
event_state_key = input["event_state_key"]
context.current_state_ids = yield store.get_state_ids_for_group(
context.state_group,
)
if prev_state_id and event_state_key:
context.prev_state_ids = dict(context.current_state_ids)
context.prev_state_ids[(event_type, event_state_key)] = prev_state_id
else:
context.prev_state_ids = context.current_state_ids
app_service_id = input["app_service_id"]
if app_service_id:
context.app_service = store.get_app_service_by_id(app_service_id)
defer.returnValue(context)
def _encode_state_dict(state_dict):
"""Since dicts of (type, state_key) -> event_id cannot be serialized in
JSON we need to convert them to a form that can.
"""
if state_dict is None:
return None
return [
(etype, state_key, v)
for (etype, state_key), v in state_dict.iteritems()
]
def _decode_state_dict(input):
"""Decodes a state dict encoded using `_encode_state_dict` above
"""
if input is None:
return None
return frozendict({(etype, state_key,): v for etype, state_key, v in input})

View File

@ -15,11 +15,3 @@
""" This package includes all the federation specific logic. """ This package includes all the federation specific logic.
""" """
from .replication import ReplicationLayer
def initialize_http_replication(hs):
transport = hs.get_federation_transport_client()
return ReplicationLayer(hs, transport)

View File

@ -16,7 +16,9 @@ import logging
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.crypto.event_signing import check_event_content_hash from synapse.crypto.event_signing import check_event_content_hash
from synapse.events import FrozenEvent
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.http.servlet import assert_params_in_request
from synapse.util import unwrapFirstError, logcontext from synapse.util import unwrapFirstError, logcontext
from twisted.internet import defer from twisted.internet import defer
@ -25,7 +27,13 @@ logger = logging.getLogger(__name__)
class FederationBase(object): class FederationBase(object):
def __init__(self, hs): def __init__(self, hs):
self.hs = hs
self.server_name = hs.hostname
self.keyring = hs.get_keyring()
self.spam_checker = hs.get_spam_checker() self.spam_checker = hs.get_spam_checker()
self.store = hs.get_datastore()
self._clock = hs.get_clock()
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False, def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
@ -169,3 +177,28 @@ class FederationBase(object):
) )
return deferreds return deferreds
def event_from_pdu_json(pdu_json, outlier=False):
"""Construct a FrozenEvent from an event json received over federation
Args:
pdu_json (object): pdu as received over federation
outlier (bool): True to mark this event as an outlier
Returns:
FrozenEvent
Raises:
SynapseError: if the pdu is missing required fields
"""
# we could probably enforce a bunch of other fields here (room_id, sender,
# origin, etc etc)
assert_params_in_request(pdu_json, ('event_id', 'type'))
event = FrozenEvent(
pdu_json
)
event.internal_metadata.outlier = outlier
return event

View File

@ -14,28 +14,28 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from .federation_base import FederationBase
from synapse.api.constants import Membership
from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError,
)
from synapse.util import unwrapFirstError, logcontext
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.events import FrozenEvent, builder
import synapse.metrics
from synapse.util.retryutils import NotRetryingDestination
import copy import copy
import itertools import itertools
import logging import logging
import random import random
from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError, FederationDeniedError
)
from synapse.events import builder
from synapse.federation.federation_base import (
FederationBase,
event_from_pdu_json,
)
import synapse.metrics
from synapse.util import logcontext, unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.util.logutils import log_function
from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -58,6 +58,7 @@ class FederationClient(FederationBase):
self._clear_tried_cache, 60 * 1000, self._clear_tried_cache, 60 * 1000,
) )
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client()
def _clear_tried_cache(self): def _clear_tried_cache(self):
"""Clear pdu_destination_tried cache""" """Clear pdu_destination_tried cache"""
@ -184,7 +185,7 @@ class FederationClient(FederationBase):
logger.debug("backfill transaction_data=%s", repr(transaction_data)) logger.debug("backfill transaction_data=%s", repr(transaction_data))
pdus = [ pdus = [
self.event_from_pdu_json(p, outlier=False) event_from_pdu_json(p, outlier=False)
for p in transaction_data["pdus"] for p in transaction_data["pdus"]
] ]
@ -244,7 +245,7 @@ class FederationClient(FederationBase):
logger.debug("transaction_data %r", transaction_data) logger.debug("transaction_data %r", transaction_data)
pdu_list = [ pdu_list = [
self.event_from_pdu_json(p, outlier=outlier) event_from_pdu_json(p, outlier=outlier)
for p in transaction_data["pdus"] for p in transaction_data["pdus"]
] ]
@ -266,6 +267,9 @@ class FederationClient(FederationBase):
except NotRetryingDestination as e: except NotRetryingDestination as e:
logger.info(e.message) logger.info(e.message)
continue continue
except FederationDeniedError as e:
logger.info(e.message)
continue
except Exception as e: except Exception as e:
pdu_attempts[destination] = now pdu_attempts[destination] = now
@ -336,11 +340,11 @@ class FederationClient(FederationBase):
) )
pdus = [ pdus = [
self.event_from_pdu_json(p, outlier=True) for p in result["pdus"] event_from_pdu_json(p, outlier=True) for p in result["pdus"]
] ]
auth_chain = [ auth_chain = [
self.event_from_pdu_json(p, outlier=True) event_from_pdu_json(p, outlier=True)
for p in result.get("auth_chain", []) for p in result.get("auth_chain", [])
] ]
@ -441,7 +445,7 @@ class FederationClient(FederationBase):
) )
auth_chain = [ auth_chain = [
self.event_from_pdu_json(p, outlier=True) event_from_pdu_json(p, outlier=True)
for p in res["auth_chain"] for p in res["auth_chain"]
] ]
@ -570,12 +574,12 @@ class FederationClient(FederationBase):
logger.debug("Got content: %s", content) logger.debug("Got content: %s", content)
state = [ state = [
self.event_from_pdu_json(p, outlier=True) event_from_pdu_json(p, outlier=True)
for p in content.get("state", []) for p in content.get("state", [])
] ]
auth_chain = [ auth_chain = [
self.event_from_pdu_json(p, outlier=True) event_from_pdu_json(p, outlier=True)
for p in content.get("auth_chain", []) for p in content.get("auth_chain", [])
] ]
@ -650,7 +654,7 @@ class FederationClient(FederationBase):
logger.debug("Got response to send_invite: %s", pdu_dict) logger.debug("Got response to send_invite: %s", pdu_dict)
pdu = self.event_from_pdu_json(pdu_dict) pdu = event_from_pdu_json(pdu_dict)
# Check signatures are correct. # Check signatures are correct.
pdu = yield self._check_sigs_and_hash(pdu) pdu = yield self._check_sigs_and_hash(pdu)
@ -740,7 +744,7 @@ class FederationClient(FederationBase):
) )
auth_chain = [ auth_chain = [
self.event_from_pdu_json(e) event_from_pdu_json(e)
for e in content["auth_chain"] for e in content["auth_chain"]
] ]
@ -788,7 +792,7 @@ class FederationClient(FederationBase):
) )
events = [ events = [
self.event_from_pdu_json(e) event_from_pdu_json(e)
for e in content.get("events", []) for e in content.get("events", [])
] ]
@ -805,15 +809,6 @@ class FederationClient(FederationBase):
defer.returnValue(signed_events) defer.returnValue(signed_events)
def event_from_pdu_json(self, pdu_json, outlier=False):
event = FrozenEvent(
pdu_json
)
event.internal_metadata.outlier = outlier
return event
@defer.inlineCallbacks @defer.inlineCallbacks
def forward_third_party_invite(self, destinations, room_id, event_dict): def forward_third_party_invite(self, destinations, room_id, event_dict):
for destination in destinations: for destination in destinations:

View File

@ -12,25 +12,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer import logging
from .federation_base import FederationBase
from .units import Transaction, Edu
from synapse.util import async
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.util.logutils import log_function
from synapse.util.caches.response_cache import ResponseCache
from synapse.events import FrozenEvent
from synapse.types import get_domain_from_id
import synapse.metrics
from synapse.api.errors import AuthError, FederationError, SynapseError
from synapse.crypto.event_signing import compute_event_signature
import simplejson as json import simplejson as json
import logging from twisted.internet import defer
from synapse.api.errors import AuthError, FederationError, SynapseError, NotFoundError
from synapse.crypto.event_signing import compute_event_signature
from synapse.federation.federation_base import (
FederationBase,
event_from_pdu_json,
)
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
import synapse.metrics
from synapse.types import get_domain_from_id
from synapse.util import async
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.util.logutils import log_function
# when processing incoming transactions, we try to handle multiple rooms in # when processing incoming transactions, we try to handle multiple rooms in
# parallel, up to this limit. # parallel, up to this limit.
@ -53,50 +54,19 @@ class FederationServer(FederationBase):
super(FederationServer, self).__init__(hs) super(FederationServer, self).__init__(hs)
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.handler = hs.get_handlers().federation_handler
self._server_linearizer = async.Linearizer("fed_server") self._server_linearizer = async.Linearizer("fed_server")
self._transaction_linearizer = async.Linearizer("fed_txn_handler") self._transaction_linearizer = async.Linearizer("fed_txn_handler")
self.transaction_actions = TransactionActions(self.store)
self.registry = hs.get_federation_registry()
# We cache responses to state queries, as they take a while and often # We cache responses to state queries, as they take a while and often
# come in waves. # come in waves.
self._state_resp_cache = ResponseCache(hs, timeout_ms=30000) self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
def set_handler(self, handler):
"""Sets the handler that the replication layer will use to communicate
receipt of new PDUs from other home servers. The required methods are
documented on :py:class:`.ReplicationHandler`.
"""
self.handler = handler
def register_edu_handler(self, edu_type, handler):
if edu_type in self.edu_handlers:
raise KeyError("Already have an EDU handler for %s" % (edu_type,))
self.edu_handlers[edu_type] = handler
def register_query_handler(self, query_type, handler):
"""Sets the handler callable that will be used to handle an incoming
federation Query of the given type.
Args:
query_type (str): Category name of the query, which should match
the string used by make_query.
handler (callable): Invoked to handle incoming queries of this type
handler is invoked as:
result = handler(args)
where 'args' is a dict mapping strings to strings of the query
arguments. It should return a Deferred that will eventually yield an
object to encode as JSON.
"""
if query_type in self.query_handlers:
raise KeyError(
"Already have a Query handler for %s" % (query_type,)
)
self.query_handlers[query_type] = handler
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_backfill_request(self, origin, room_id, versions, limit): def on_backfill_request(self, origin, room_id, versions, limit):
@ -172,7 +142,7 @@ class FederationServer(FederationBase):
p["age_ts"] = request_time - int(p["age"]) p["age_ts"] = request_time - int(p["age"])
del p["age"] del p["age"]
event = self.event_from_pdu_json(p) event = event_from_pdu_json(p)
room_id = event.room_id room_id = event.room_id
pdus_by_room.setdefault(room_id, []).append(event) pdus_by_room.setdefault(room_id, []).append(event)
@ -230,16 +200,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def received_edu(self, origin, edu_type, content): def received_edu(self, origin, edu_type, content):
received_edus_counter.inc() received_edus_counter.inc()
yield self.registry.on_edu(edu_type, origin, content)
if edu_type in self.edu_handlers:
try:
yield self.edu_handlers[edu_type](origin, content)
except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception as e:
logger.exception("Failed to handle edu %r", edu_type)
else:
logger.warn("Received EDU of type %s with no handler", edu_type)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -329,14 +290,8 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_query_request(self, query_type, args): def on_query_request(self, query_type, args):
received_queries_counter.inc(query_type) received_queries_counter.inc(query_type)
resp = yield self.registry.on_query(query_type, args)
if query_type in self.query_handlers: defer.returnValue((200, resp))
response = yield self.query_handlers[query_type](args)
defer.returnValue((200, response))
else:
defer.returnValue(
(404, "No handler for Query type '%s'" % (query_type,))
)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_make_join_request(self, room_id, user_id): def on_make_join_request(self, room_id, user_id):
@ -346,7 +301,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_invite_request(self, origin, content): def on_invite_request(self, origin, content):
pdu = self.event_from_pdu_json(content) pdu = event_from_pdu_json(content)
ret_pdu = yield self.handler.on_invite_request(origin, pdu) ret_pdu = yield self.handler.on_invite_request(origin, pdu)
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)})) defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
@ -354,7 +309,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_send_join_request(self, origin, content): def on_send_join_request(self, origin, content):
logger.debug("on_send_join_request: content: %s", content) logger.debug("on_send_join_request: content: %s", content)
pdu = self.event_from_pdu_json(content) pdu = event_from_pdu_json(content)
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
res_pdus = yield self.handler.on_send_join_request(origin, pdu) res_pdus = yield self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
@ -374,7 +329,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_send_leave_request(self, origin, content): def on_send_leave_request(self, origin, content):
logger.debug("on_send_leave_request: content: %s", content) logger.debug("on_send_leave_request: content: %s", content)
pdu = self.event_from_pdu_json(content) pdu = event_from_pdu_json(content)
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures) logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
yield self.handler.on_send_leave_request(origin, pdu) yield self.handler.on_send_leave_request(origin, pdu)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -411,7 +366,7 @@ class FederationServer(FederationBase):
""" """
with (yield self._server_linearizer.queue((origin, room_id))): with (yield self._server_linearizer.queue((origin, room_id))):
auth_chain = [ auth_chain = [
self.event_from_pdu_json(e) event_from_pdu_json(e)
for e in content["auth_chain"] for e in content["auth_chain"]
] ]
@ -586,15 +541,6 @@ class FederationServer(FederationBase):
def __str__(self): def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name return "<ReplicationLayer(%s)>" % self.server_name
def event_from_pdu_json(self, pdu_json, outlier=False):
event = FrozenEvent(
pdu_json
)
event.internal_metadata.outlier = outlier
return event
@defer.inlineCallbacks @defer.inlineCallbacks
def exchange_third_party_invite( def exchange_third_party_invite(
self, self,
@ -617,3 +563,66 @@ class FederationServer(FederationBase):
origin, room_id, event_dict origin, room_id, event_dict
) )
defer.returnValue(ret) defer.returnValue(ret)
class FederationHandlerRegistry(object):
"""Allows classes to register themselves as handlers for a given EDU or
query type for incoming federation traffic.
"""
def __init__(self):
self.edu_handlers = {}
self.query_handlers = {}
def register_edu_handler(self, edu_type, handler):
"""Sets the handler callable that will be used to handle an incoming
federation EDU of the given type.
Args:
edu_type (str): The type of the incoming EDU to register handler for
handler (Callable[[str, dict]]): A callable invoked on incoming EDU
of the given type. The arguments are the origin server name and
the EDU contents.
"""
if edu_type in self.edu_handlers:
raise KeyError("Already have an EDU handler for %s" % (edu_type,))
self.edu_handlers[edu_type] = handler
def register_query_handler(self, query_type, handler):
"""Sets the handler callable that will be used to handle an incoming
federation query of the given type.
Args:
query_type (str): Category name of the query, which should match
the string used by make_query.
handler (Callable[[dict], Deferred[dict]]): Invoked to handle
incoming queries of this type. The return will be yielded
on and the result used as the response to the query request.
"""
if query_type in self.query_handlers:
raise KeyError(
"Already have a Query handler for %s" % (query_type,)
)
self.query_handlers[query_type] = handler
@defer.inlineCallbacks
def on_edu(self, edu_type, origin, content):
handler = self.edu_handlers.get(edu_type)
if not handler:
logger.warn("No handler registered for EDU type %s", edu_type)
try:
yield handler(origin, content)
except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception as e:
logger.exception("Failed to handle edu %r", edu_type)
def on_query(self, query_type, args):
handler = self.query_handlers.get(query_type)
if not handler:
logger.warn("No handler registered for query type %s", query_type)
raise NotFoundError("No handler for Query type '%s'" % (query_type,))
return handler(args)

View File

@ -1,73 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This layer is responsible for replicating with remote home servers using
a given transport.
"""
from .federation_client import FederationClient
from .federation_server import FederationServer
from .persistence import TransactionActions
import logging
logger = logging.getLogger(__name__)
class ReplicationLayer(FederationClient, FederationServer):
"""This layer is responsible for replicating with remote home servers over
the given transport. I.e., does the sending and receiving of PDUs to
remote home servers.
The layer communicates with the rest of the server via a registered
ReplicationHandler.
In more detail, the layer:
* Receives incoming data and processes it into transactions and pdus.
* Fetches any PDUs it thinks it might have missed.
* Keeps the current state for contexts up to date by applying the
suitable conflict resolution.
* Sends outgoing pdus wrapped in transactions.
* Fills out the references to previous pdus/transactions appropriately
for outgoing data.
"""
def __init__(self, hs, transport_layer):
self.server_name = hs.hostname
self.keyring = hs.get_keyring()
self.transport_layer = transport_layer
self.federation_client = self
self.store = hs.get_datastore()
self.handler = None
self.edu_handlers = {}
self.query_handlers = {}
self._clock = hs.get_clock()
self.transaction_actions = TransactionActions(self.store)
self.hs = hs
super(ReplicationLayer, self).__init__(hs)
def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name

View File

@ -19,7 +19,7 @@ from twisted.internet import defer
from .persistence import TransactionActions from .persistence import TransactionActions
from .units import Transaction, Edu from .units import Transaction, Edu
from synapse.api.errors import HttpResponseException from synapse.api.errors import HttpResponseException, FederationDeniedError
from synapse.util import logcontext, PreserveLoggingContext from synapse.util import logcontext, PreserveLoggingContext
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
@ -42,6 +42,8 @@ sent_edus_counter = client_metrics.register_counter("sent_edus")
sent_transactions_counter = client_metrics.register_counter("sent_transactions") sent_transactions_counter = client_metrics.register_counter("sent_transactions")
events_processed_counter = client_metrics.register_counter("events_processed")
class TransactionQueue(object): class TransactionQueue(object):
"""This class makes sure we only have one transaction in flight at """This class makes sure we only have one transaction in flight at
@ -205,6 +207,8 @@ class TransactionQueue(object):
self._send_pdu(event, destinations) self._send_pdu(event, destinations)
events_processed_counter.inc_by(len(events))
yield self.store.update_federation_out_pos( yield self.store.update_federation_out_pos(
"events", next_token "events", next_token
) )
@ -486,6 +490,8 @@ class TransactionQueue(object):
(e.retry_last_ts + e.retry_interval) / 1000.0 (e.retry_last_ts + e.retry_interval) / 1000.0
), ),
) )
except FederationDeniedError as e:
logger.info(e)
except Exception as e: except Exception as e:
logger.warn( logger.warn(
"TX [%s] Failed to send transaction: %s", "TX [%s] Failed to send transaction: %s",

View File

@ -212,6 +212,9 @@ class TransportLayerClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server. to retry this server.
Fails with ``FederationDeniedError`` if the remote destination
is not in our federation whitelist
""" """
valid_memberships = {Membership.JOIN, Membership.LEAVE} valid_memberships = {Membership.JOIN, Membership.LEAVE}
if membership not in valid_memberships: if membership not in valid_memberships:

View File

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError, FederationDeniedError
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.http.servlet import ( from synapse.http.servlet import (
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args, parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
@ -81,6 +81,7 @@ class Authenticator(object):
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
self.server_name = hs.hostname self.server_name = hs.hostname
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
# A method just so we can pass 'self' as the authenticator to the Servlets # A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks @defer.inlineCallbacks
@ -92,6 +93,12 @@ class Authenticator(object):
"signatures": {}, "signatures": {},
} }
if (
self.federation_domain_whitelist is not None and
self.server_name not in self.federation_domain_whitelist
):
raise FederationDeniedError(self.server_name)
if content is not None: if content is not None:
json_request["content"] = content json_request["content"] = content
@ -1183,7 +1190,7 @@ GROUP_ATTESTATION_SERVLET_CLASSES = (
def register_servlets(hs, resource, authenticator, ratelimiter): def register_servlets(hs, resource, authenticator, ratelimiter):
for servletclass in FEDERATION_SERVLET_CLASSES: for servletclass in FEDERATION_SERVLET_CLASSES:
servletclass( servletclass(
handler=hs.get_replication_layer(), handler=hs.get_federation_server(),
authenticator=authenticator, authenticator=authenticator,
ratelimiter=ratelimiter, ratelimiter=ratelimiter,
server_name=hs.hostname, server_name=hs.hostname,

View File

@ -17,7 +17,6 @@ from .register import RegistrationHandler
from .room import ( from .room import (
RoomCreationHandler, RoomContextHandler, RoomCreationHandler, RoomContextHandler,
) )
from .room_member import RoomMemberHandler
from .message import MessageHandler from .message import MessageHandler
from .federation import FederationHandler from .federation import FederationHandler
from .directory import DirectoryHandler from .directory import DirectoryHandler
@ -49,7 +48,6 @@ class Handlers(object):
self.registration_handler = RegistrationHandler(hs) self.registration_handler = RegistrationHandler(hs)
self.message_handler = MessageHandler(hs) self.message_handler = MessageHandler(hs)
self.room_creation_handler = RoomCreationHandler(hs) self.room_creation_handler = RoomCreationHandler(hs)
self.room_member_handler = RoomMemberHandler(hs)
self.federation_handler = FederationHandler(hs) self.federation_handler = FederationHandler(hs)
self.directory_handler = DirectoryHandler(hs) self.directory_handler = DirectoryHandler(hs)
self.admin_handler = AdminHandler(hs) self.admin_handler = AdminHandler(hs)

View File

@ -158,7 +158,7 @@ class BaseHandler(object):
# homeserver. # homeserver.
requester = synapse.types.create_requester( requester = synapse.types.create_requester(
target_user, is_guest=True) target_user, is_guest=True)
handler = self.hs.get_handlers().room_member_handler handler = self.hs.get_room_member_handler()
yield handler.update_membership( yield handler.update_membership(
requester, requester,
target_user, target_user,

View File

@ -15,6 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
import synapse
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
@ -23,6 +24,10 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__)
events_processed_counter = metrics.register_counter("events_processed")
def log_failure(failure): def log_failure(failure):
logger.error( logger.error(
@ -103,6 +108,8 @@ class ApplicationServicesHandler(object):
service, event service, event
) )
events_processed_counter.inc_by(len(events))
yield self.store.set_appservice_last_pos(upper_bound) yield self.store.set_appservice_last_pos(upper_bound)
finally: finally:
self.is_processing = False self.is_processing = False

View File

@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer, threads
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
@ -25,6 +25,7 @@ from synapse.module_api import ModuleApi
from synapse.types import UserID from synapse.types import UserID
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logcontext import make_deferred_yieldable
from twisted.web.client import PartialDownloadError from twisted.web.client import PartialDownloadError
@ -714,7 +715,7 @@ class AuthHandler(BaseHandler):
if not lookupres: if not lookupres:
defer.returnValue(None) defer.returnValue(None)
(user_id, password_hash) = lookupres (user_id, password_hash) = lookupres
result = self.validate_hash(password, password_hash) result = yield self.validate_hash(password, password_hash)
if not result: if not result:
logger.warn("Failed password login for user %s", user_id) logger.warn("Failed password login for user %s", user_id)
defer.returnValue(None) defer.returnValue(None)
@ -842,10 +843,13 @@ class AuthHandler(BaseHandler):
password (str): Password to hash. password (str): Password to hash.
Returns: Returns:
Hashed password (str). Deferred(str): Hashed password.
""" """
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper, def _do_hash():
bcrypt.gensalt(self.bcrypt_rounds)) return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
bcrypt.gensalt(self.bcrypt_rounds))
return make_deferred_yieldable(threads.deferToThread(_do_hash))
def validate_hash(self, password, stored_hash): def validate_hash(self, password, stored_hash):
"""Validates that self.hash(password) == stored_hash. """Validates that self.hash(password) == stored_hash.
@ -855,13 +859,19 @@ class AuthHandler(BaseHandler):
stored_hash (str): Expected hash value. stored_hash (str): Expected hash value.
Returns: Returns:
Whether self.hash(password) == stored_hash (bool). Deferred(bool): Whether self.hash(password) == stored_hash.
""" """
def _do_validate_hash():
return bcrypt.checkpw(
password.encode('utf8') + self.hs.config.password_pepper,
stored_hash.encode('utf8')
)
if stored_hash: if stored_hash:
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper, return make_deferred_yieldable(threads.deferToThread(_do_validate_hash))
stored_hash.encode('utf8')) == stored_hash
else: else:
return False return defer.succeed(False)
class MacaroonGeneartor(object): class MacaroonGeneartor(object):

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from synapse.api import errors from synapse.api import errors
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import FederationDeniedError
from synapse.util import stringutils from synapse.util import stringutils
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
@ -36,14 +37,15 @@ class DeviceHandler(BaseHandler):
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self.federation_sender = hs.get_federation_sender() self.federation_sender = hs.get_federation_sender()
self.federation = hs.get_replication_layer()
self._edu_updater = DeviceListEduUpdater(hs, self) self._edu_updater = DeviceListEduUpdater(hs, self)
self.federation.register_edu_handler( federation_registry = hs.get_federation_registry()
federation_registry.register_edu_handler(
"m.device_list_update", self._edu_updater.incoming_device_list_update, "m.device_list_update", self._edu_updater.incoming_device_list_update,
) )
self.federation.register_query_handler( federation_registry.register_query_handler(
"user_devices", self.on_federation_query_user_devices, "user_devices", self.on_federation_query_user_devices,
) )
@ -429,7 +431,7 @@ class DeviceListEduUpdater(object):
def __init__(self, hs, device_handler): def __init__(self, hs, device_handler):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.federation = hs.get_replication_layer() self.federation = hs.get_federation_client()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.device_handler = device_handler self.device_handler = device_handler
@ -513,6 +515,9 @@ class DeviceListEduUpdater(object):
# This makes it more likely that the device lists will # This makes it more likely that the device lists will
# eventually become consistent. # eventually become consistent.
return return
except FederationDeniedError as e:
logger.info(e)
return
except Exception: except Exception:
# TODO: Remember that we are now out of sync and try again # TODO: Remember that we are now out of sync and try again
# later # later

View File

@ -17,7 +17,8 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.types import get_domain_from_id from synapse.api.errors import SynapseError
from synapse.types import get_domain_from_id, UserID
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
@ -33,10 +34,10 @@ class DeviceMessageHandler(object):
""" """
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id self.is_mine = hs.is_mine
self.federation = hs.get_federation_sender() self.federation = hs.get_federation_sender()
hs.get_replication_layer().register_edu_handler( hs.get_federation_registry().register_edu_handler(
"m.direct_to_device", self.on_direct_to_device_edu "m.direct_to_device", self.on_direct_to_device_edu
) )
@ -52,6 +53,12 @@ class DeviceMessageHandler(object):
message_type = content["type"] message_type = content["type"]
message_id = content["message_id"] message_id = content["message_id"]
for user_id, by_device in content["messages"].items(): for user_id, by_device in content["messages"].items():
# we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
logger.warning("Request for keys for non-local user %s",
user_id)
raise SynapseError(400, "Not a user here")
messages_by_device = { messages_by_device = {
device_id: { device_id: {
"content": message_content, "content": message_content,
@ -77,7 +84,8 @@ class DeviceMessageHandler(object):
local_messages = {} local_messages = {}
remote_messages = {} remote_messages = {}
for user_id, by_device in messages.items(): for user_id, by_device in messages.items():
if self.is_mine_id(user_id): # we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
messages_by_device = { messages_by_device = {
device_id: { device_id: {
"content": message_content, "content": message_content,

View File

@ -34,9 +34,10 @@ class DirectoryHandler(BaseHandler):
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.appservice_handler = hs.get_application_service_handler() self.appservice_handler = hs.get_application_service_handler()
self.event_creation_handler = hs.get_event_creation_handler()
self.federation = hs.get_replication_layer() self.federation = hs.get_federation_client()
self.federation.register_query_handler( hs.get_federation_registry().register_query_handler(
"directory", self.on_directory_query "directory", self.on_directory_query
) )
@ -249,8 +250,7 @@ class DirectoryHandler(BaseHandler):
def send_room_alias_update_event(self, requester, user_id, room_id): def send_room_alias_update_event(self, requester, user_id, room_id):
aliases = yield self.store.get_aliases_for_room(room_id) aliases = yield self.store.get_aliases_for_room(room_id)
msg_handler = self.hs.get_handlers().message_handler yield self.event_creation_handler.create_and_send_nonmember_event(
yield msg_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.Aliases, "type": EventTypes.Aliases,
@ -272,8 +272,7 @@ class DirectoryHandler(BaseHandler):
if not alias_event or alias_event.content.get("alias", "") != alias_str: if not alias_event or alias_event.content.get("alias", "") != alias_str:
return return
msg_handler = self.hs.get_handlers().message_handler yield self.event_creation_handler.create_and_send_nonmember_event(
yield msg_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.CanonicalAlias, "type": EventTypes.CanonicalAlias,

View File

@ -19,8 +19,10 @@ import logging
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError, CodeMessageException from synapse.api.errors import (
from synapse.types import get_domain_from_id SynapseError, CodeMessageException, FederationDeniedError,
)
from synapse.types import get_domain_from_id, UserID
from synapse.util.logcontext import preserve_fn, make_deferred_yieldable from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
@ -30,15 +32,15 @@ logger = logging.getLogger(__name__)
class E2eKeysHandler(object): class E2eKeysHandler(object):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.federation = hs.get_replication_layer() self.federation = hs.get_federation_client()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.is_mine_id = hs.is_mine_id self.is_mine = hs.is_mine
self.clock = hs.get_clock() self.clock = hs.get_clock()
# doesn't really work as part of the generic query API, because the # doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the # query request requires an object POST, but we abuse the
# "query handler" interface. # "query handler" interface.
self.federation.register_query_handler( hs.get_federation_registry().register_query_handler(
"client_keys", self.on_federation_query_client_keys "client_keys", self.on_federation_query_client_keys
) )
@ -70,7 +72,8 @@ class E2eKeysHandler(object):
remote_queries = {} remote_queries = {}
for user_id, device_ids in device_keys_query.items(): for user_id, device_ids in device_keys_query.items():
if self.is_mine_id(user_id): # we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
local_query[user_id] = device_ids local_query[user_id] = device_ids
else: else:
remote_queries[user_id] = device_ids remote_queries[user_id] = device_ids
@ -139,6 +142,10 @@ class E2eKeysHandler(object):
failures[destination] = { failures[destination] = {
"status": 503, "message": "Not ready for retry", "status": 503, "message": "Not ready for retry",
} }
except FederationDeniedError as e:
failures[destination] = {
"status": 403, "message": "Federation Denied",
}
except Exception as e: except Exception as e:
# include ConnectionRefused and other errors # include ConnectionRefused and other errors
failures[destination] = { failures[destination] = {
@ -170,7 +177,8 @@ class E2eKeysHandler(object):
result_dict = {} result_dict = {}
for user_id, device_ids in query.items(): for user_id, device_ids in query.items():
if not self.is_mine_id(user_id): # we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
logger.warning("Request for keys for non-local user %s", logger.warning("Request for keys for non-local user %s",
user_id) user_id)
raise SynapseError(400, "Not a user here") raise SynapseError(400, "Not a user here")
@ -213,7 +221,8 @@ class E2eKeysHandler(object):
remote_queries = {} remote_queries = {}
for user_id, device_keys in query.get("one_time_keys", {}).items(): for user_id, device_keys in query.get("one_time_keys", {}).items():
if self.is_mine_id(user_id): # we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
for device_id, algorithm in device_keys.items(): for device_id, algorithm in device_keys.items():
local_query.append((user_id, device_id, algorithm)) local_query.append((user_id, device_id, algorithm))
else: else:

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -22,6 +23,7 @@ from ._base import BaseHandler
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, FederationError, StoreError, CodeMessageException, SynapseError, AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
FederationDeniedError,
) )
from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
@ -66,7 +68,7 @@ class FederationHandler(BaseHandler):
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.replication_layer = hs.get_replication_layer() self.replication_layer = hs.get_federation_client()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
self.server_name = hs.hostname self.server_name = hs.hostname
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
@ -74,8 +76,7 @@ class FederationHandler(BaseHandler):
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.pusher_pool = hs.get_pusherpool() self.pusher_pool = hs.get_pusherpool()
self.spam_checker = hs.get_spam_checker() self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
self.replication_layer.set_handler(self)
# When joining a room we need to queue any events for that room up # When joining a room we need to queue any events for that room up
self.room_queues = {} self.room_queues = {}
@ -782,6 +783,9 @@ class FederationHandler(BaseHandler):
except NotRetryingDestination as e: except NotRetryingDestination as e:
logger.info(e.message) logger.info(e.message)
continue continue
except FederationDeniedError as e:
logger.info(e)
continue
except Exception as e: except Exception as e:
logger.exception( logger.exception(
"Failed to backfill from %s because %s", "Failed to backfill from %s because %s",
@ -804,13 +808,12 @@ class FederationHandler(BaseHandler):
event_ids = list(extremities.keys()) event_ids = list(extremities.keys())
logger.debug("calling resolve_state_groups in _maybe_backfill") logger.debug("calling resolve_state_groups in _maybe_backfill")
resolve = logcontext.preserve_fn(
self.state_handler.resolve_state_groups_for_events
)
states = yield logcontext.make_deferred_yieldable(defer.gatherResults( states = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[ [resolve(room_id, [e]) for e in event_ids],
logcontext.preserve_fn(self.state_handler.resolve_state_groups)( consumeErrors=True,
room_id, [e]
)
for e in event_ids
], consumeErrors=True,
)) ))
states = dict(zip(event_ids, [s.state for s in states])) states = dict(zip(event_ids, [s.state for s in states]))
@ -1004,8 +1007,7 @@ class FederationHandler(BaseHandler):
}) })
try: try:
message_handler = self.hs.get_handlers().message_handler event, context = yield self.event_creation_handler.create_new_client_event(
event, context = yield message_handler._create_new_client_event(
builder=builder, builder=builder,
) )
except AuthError as e: except AuthError as e:
@ -1245,8 +1247,7 @@ class FederationHandler(BaseHandler):
"state_key": user_id, "state_key": user_id,
}) })
message_handler = self.hs.get_handlers().message_handler event, context = yield self.event_creation_handler.create_new_client_event(
event, context = yield message_handler._create_new_client_event(
builder=builder, builder=builder,
) )
@ -1444,16 +1445,24 @@ class FederationHandler(BaseHandler):
auth_events=auth_events, auth_events=auth_events,
) )
if not event.internal_metadata.is_outlier() and not backfilled: try:
yield self.action_generator.handle_push_actions_for_event( if not event.internal_metadata.is_outlier() and not backfilled:
event, context yield self.action_generator.handle_push_actions_for_event(
) event, context
)
event_stream_id, max_stream_id = yield self.store.persist_event( event_stream_id, max_stream_id = yield self.store.persist_event(
event, event,
context=context, context=context,
backfilled=backfilled, backfilled=backfilled,
) )
except: # noqa: E722, as we reraise the exception this is fine.
# Ensure that we actually remove the entries in the push actions
# staging area
logcontext.preserve_fn(
self.store.remove_push_actions_from_staging
)(event.event_id)
raise
if not backfilled: if not backfilled:
# this intentionally does not yield: we don't care about the result # this intentionally does not yield: we don't care about the result
@ -1828,8 +1837,8 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values()) current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state different_auth = event_auth_events - current_state
self._update_context_for_auth_events( yield self._update_context_for_auth_events(
context, auth_events, event_key, event, context, auth_events, event_key,
) )
if different_auth and not event.internal_metadata.is_outlier(): if different_auth and not event.internal_metadata.is_outlier():
@ -1910,8 +1919,8 @@ class FederationHandler(BaseHandler):
# 4. Look at rejects and their proofs. # 4. Look at rejects and their proofs.
# TODO. # TODO.
self._update_context_for_auth_events( yield self._update_context_for_auth_events(
context, auth_events, event_key, event, context, auth_events, event_key,
) )
try: try:
@ -1920,11 +1929,15 @@ class FederationHandler(BaseHandler):
logger.warn("Failed auth resolution for %r because %s", event, e) logger.warn("Failed auth resolution for %r because %s", event, e)
raise e raise e
def _update_context_for_auth_events(self, context, auth_events, @defer.inlineCallbacks
def _update_context_for_auth_events(self, event, context, auth_events,
event_key): event_key):
"""Update the state_ids in an event context after auth event resolution """Update the state_ids in an event context after auth event resolution,
storing the changes as a new state group.
Args: Args:
event (Event): The event we're handling the context for
context (synapse.events.snapshot.EventContext): event context context (synapse.events.snapshot.EventContext): event context
to be updated to be updated
@ -1947,7 +1960,13 @@ class FederationHandler(BaseHandler):
context.prev_state_ids.update({ context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.iteritems() k: a.event_id for k, a in auth_events.iteritems()
}) })
context.state_group = self.store.get_next_state_group() context.state_group = yield self.store.store_state_group(
event.event_id,
event.room_id,
prev_group=context.prev_group,
delta_ids=context.delta_ids,
current_state_ids=context.current_state_ids,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def construct_auth_difference(self, local_auth, remote_auth): def construct_auth_difference(self, local_auth, remote_auth):
@ -2117,8 +2136,7 @@ class FederationHandler(BaseHandler):
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)): if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
builder = self.event_builder_factory.new(event_dict) builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder) EventValidator().validate_new(builder)
message_handler = self.hs.get_handlers().message_handler event, context = yield self.event_creation_handler.create_new_client_event(
event, context = yield message_handler._create_new_client_event(
builder=builder builder=builder
) )
@ -2133,7 +2151,7 @@ class FederationHandler(BaseHandler):
raise e raise e
yield self._check_signature(event, context) yield self._check_signature(event, context)
member_handler = self.hs.get_handlers().room_member_handler member_handler = self.hs.get_room_member_handler()
yield member_handler.send_membership_event(None, event, context) yield member_handler.send_membership_event(None, event, context)
else: else:
destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id)) destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
@ -2156,8 +2174,7 @@ class FederationHandler(BaseHandler):
""" """
builder = self.event_builder_factory.new(event_dict) builder = self.event_builder_factory.new(event_dict)
message_handler = self.hs.get_handlers().message_handler event, context = yield self.event_creation_handler.create_new_client_event(
event, context = yield message_handler._create_new_client_event(
builder=builder, builder=builder,
) )
@ -2178,7 +2195,7 @@ class FederationHandler(BaseHandler):
# TODO: Make sure the signatures actually are correct. # TODO: Make sure the signatures actually are correct.
event.signatures.update(returned_invite.signatures) event.signatures.update(returned_invite.signatures)
member_handler = self.hs.get_handlers().room_member_handler member_handler = self.hs.get_room_member_handler()
yield member_handler.send_membership_event(None, event, context) yield member_handler.send_membership_event(None, event, context)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -2207,8 +2224,9 @@ class FederationHandler(BaseHandler):
builder = self.event_builder_factory.new(event_dict) builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder) EventValidator().validate_new(builder)
message_handler = self.hs.get_handlers().message_handler event, context = yield self.event_creation_handler.create_new_client_event(
event, context = yield message_handler._create_new_client_event(builder=builder) builder=builder,
)
defer.returnValue((event, context)) defer.returnValue((event, context))
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -383,11 +383,12 @@ class GroupsLocalHandler(object):
defer.returnValue({"groups": result}) defer.returnValue({"groups": result})
else: else:
result = yield self.transport_client.get_publicised_groups_for_user( bulk_result = yield self.transport_client.bulk_get_publicised_groups(
get_domain_from_id(user_id), user_id get_domain_from_id(user_id), [user_id],
) )
result = bulk_result.get("users", {}).get(user_id)
# TODO: Verify attestations # TODO: Verify attestations
defer.returnValue(result) defer.returnValue({"groups": result})
@defer.inlineCallbacks @defer.inlineCallbacks
def bulk_get_publicised_groups(self, user_ids, proxy=True): def bulk_get_publicised_groups(self, user_ids, proxy=True):

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd # Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2017 New Vector Ltd # Copyright 2017 - 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,7 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer, reactor
from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError from synapse.api.errors import AuthError, Codes, SynapseError
@ -24,10 +25,12 @@ from synapse.types import (
UserID, RoomAlias, RoomStreamToken, UserID, RoomAlias, RoomStreamToken,
) )
from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import preserve_fn, run_in_background
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.util.frozenutils import unfreeze from synapse.util.frozenutils import unfreeze
from synapse.util.stringutils import random_string
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from synapse.replication.http.send_event import send_event_to_master
from ._base import BaseHandler from ._base import BaseHandler
@ -40,6 +43,36 @@ import simplejson
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PurgeStatus(object):
"""Object tracking the status of a purge request
This class contains information on the progress of a purge request, for
return by get_purge_status.
Attributes:
status (int): Tracks whether this request has completed. One of
STATUS_{ACTIVE,COMPLETE,FAILED}
"""
STATUS_ACTIVE = 0
STATUS_COMPLETE = 1
STATUS_FAILED = 2
STATUS_TEXT = {
STATUS_ACTIVE: "active",
STATUS_COMPLETE: "complete",
STATUS_FAILED: "failed",
}
def __init__(self):
self.status = PurgeStatus.STATUS_ACTIVE
def asdict(self):
return {
"status": PurgeStatus.STATUS_TEXT[self.status]
}
class MessageHandler(BaseHandler): class MessageHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
@ -47,32 +80,89 @@ class MessageHandler(BaseHandler):
self.hs = hs self.hs = hs
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.validator = EventValidator()
self.profile_handler = hs.get_profile_handler()
self.pagination_lock = ReadWriteLock() self.pagination_lock = ReadWriteLock()
self._purges_in_progress_by_room = set()
# map from purge id to PurgeStatus
self._purges_by_id = {}
self.pusher_pool = hs.get_pusherpool() def start_purge_history(self, room_id, topological_ordering,
delete_local_events=False):
"""Start off a history purge on a room.
# We arbitrarily limit concurrent event creation for a room to 5. Args:
# This is to stop us from diverging history *too* much. room_id (str): The room to purge from
self.limiter = Limiter(max_count=5)
self.action_generator = hs.get_action_generator() topological_ordering (int): minimum topo ordering to preserve
delete_local_events (bool): True to delete local events as well as
remote ones
self.spam_checker = hs.get_spam_checker() Returns:
str: unique ID for this purge transaction.
"""
if room_id in self._purges_in_progress_by_room:
raise SynapseError(
400,
"History purge already in progress for %s" % (room_id, ),
)
purge_id = random_string(16)
# we log the purge_id here so that it can be tied back to the
# request id in the log lines.
logger.info("[purge] starting purge_id %s", purge_id)
self._purges_by_id[purge_id] = PurgeStatus()
run_in_background(
self._purge_history,
purge_id, room_id, topological_ordering, delete_local_events,
)
return purge_id
@defer.inlineCallbacks @defer.inlineCallbacks
def purge_history(self, room_id, event_id): def _purge_history(self, purge_id, room_id, topological_ordering,
event = yield self.store.get_event(event_id) delete_local_events):
"""Carry out a history purge on a room.
if event.room_id != room_id: Args:
raise SynapseError(400, "Event is for wrong room.") purge_id (str): The id for this purge
room_id (str): The room to purge from
topological_ordering (int): minimum topo ordering to preserve
delete_local_events (bool): True to delete local events as well as
remote ones
depth = event.depth Returns:
Deferred
"""
self._purges_in_progress_by_room.add(room_id)
try:
with (yield self.pagination_lock.write(room_id)):
yield self.store.purge_history(
room_id, topological_ordering, delete_local_events,
)
logger.info("[purge] complete")
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_COMPLETE
except Exception:
logger.error("[purge] failed: %s", Failure().getTraceback().rstrip())
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED
finally:
self._purges_in_progress_by_room.discard(room_id)
with (yield self.pagination_lock.write(room_id)): # remove the purge from the list 24 hours after it completes
yield self.store.delete_old_state(room_id, depth) def clear_purge():
del self._purges_by_id[purge_id]
reactor.callLater(24 * 3600, clear_purge)
def get_purge_status(self, purge_id):
"""Get the current status of an active purge
Args:
purge_id (str): purge_id returned by start_purge_history
Returns:
PurgeStatus|None
"""
return self._purges_by_id.get(purge_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_messages(self, requester, room_id=None, pagin_config=None, def get_messages(self, requester, room_id=None, pagin_config=None,
@ -182,166 +272,6 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk) defer.returnValue(chunk)
@defer.inlineCallbacks
def create_event(self, requester, event_dict, token_id=None, txn_id=None,
prev_event_ids=None):
"""
Given a dict from a client, create a new event.
Creates an FrozenEvent object, filling out auth_events, prev_events,
etc.
Adds display names to Join membership events.
Args:
requester
event_dict (dict): An entire event
token_id (str)
txn_id (str)
prev_event_ids (list): The prev event ids to use when creating the event
Returns:
Tuple of created event (FrozenEvent), Context
"""
builder = self.event_builder_factory.new(event_dict)
with (yield self.limiter.queue(builder.room_id)):
self.validator.validate_new(builder)
if builder.type == EventTypes.Member:
membership = builder.content.get("membership", None)
target = UserID.from_string(builder.state_key)
if membership in {Membership.JOIN, Membership.INVITE}:
# If event doesn't include a display name, add one.
profile = self.profile_handler
content = builder.content
try:
if "displayname" not in content:
content["displayname"] = yield profile.get_displayname(target)
if "avatar_url" not in content:
content["avatar_url"] = yield profile.get_avatar_url(target)
except Exception as e:
logger.info(
"Failed to get profile information for %r: %s",
target, e
)
if token_id is not None:
builder.internal_metadata.token_id = token_id
if txn_id is not None:
builder.internal_metadata.txn_id = txn_id
event, context = yield self._create_new_client_event(
builder=builder,
requester=requester,
prev_event_ids=prev_event_ids,
)
defer.returnValue((event, context))
@defer.inlineCallbacks
def send_nonmember_event(self, requester, event, context, ratelimit=True):
"""
Persists and notifies local clients and federation of an event.
Args:
event (FrozenEvent) the event to send.
context (Context) the context of the event.
ratelimit (bool): Whether to rate limit this send.
is_guest (bool): Whether the sender is a guest.
"""
if event.type == EventTypes.Member:
raise SynapseError(
500,
"Tried to send member event through non-member codepath"
)
# We check here if we are currently being rate limited, so that we
# don't do unnecessary work. We check again just before we actually
# send the event.
yield self.ratelimit(requester, update=False)
user = UserID.from_string(event.sender)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if event.is_state():
prev_state = yield self.deduplicate_state_event(event, context)
if prev_state is not None:
defer.returnValue(prev_state)
yield self.handle_new_client_event(
requester=requester,
event=event,
context=context,
ratelimit=ratelimit,
)
if event.type == EventTypes.Message:
presence = self.hs.get_presence_handler()
# We don't want to block sending messages on any presence code. This
# matters as sometimes presence code can take a while.
preserve_fn(presence.bump_presence_active_time)(user)
@defer.inlineCallbacks
def deduplicate_state_event(self, event, context):
"""
Checks whether event is in the latest resolved state in context.
If so, returns the version of the event in context.
Otherwise, returns None.
"""
prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if not prev_event:
return
if prev_event and event.user_id == prev_event.user_id:
prev_content = encode_canonical_json(prev_event.content)
next_content = encode_canonical_json(event.content)
if prev_content == next_content:
defer.returnValue(prev_event)
return
@defer.inlineCallbacks
def create_and_send_nonmember_event(
self,
requester,
event_dict,
ratelimit=True,
txn_id=None
):
"""
Creates an event, then sends it.
See self.create_event and self.send_nonmember_event.
"""
event, context = yield self.create_event(
requester,
event_dict,
token_id=requester.access_token_id,
txn_id=txn_id
)
spam_error = self.spam_checker.check_event_for_spam(event)
if spam_error:
if not isinstance(spam_error, basestring):
spam_error = "Spam is not permitted here"
raise SynapseError(
403, spam_error, Codes.FORBIDDEN
)
yield self.send_nonmember_event(
requester,
event,
context,
ratelimit=ratelimit,
)
defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_data(self, user_id=None, room_id=None, def get_room_data(self, user_id=None, room_id=None,
event_type=None, state_key="", is_guest=False): event_type=None, state_key="", is_guest=False):
@ -470,9 +400,189 @@ class MessageHandler(BaseHandler):
for user_id, profile in users_with_profile.iteritems() for user_id, profile in users_with_profile.iteritems()
}) })
@measure_func("_create_new_client_event")
class EventCreationHandler(object):
def __init__(self, hs):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
self.profile_handler = hs.get_profile_handler()
self.event_builder_factory = hs.get_event_builder_factory()
self.server_name = hs.hostname
self.ratelimiter = hs.get_ratelimiter()
self.notifier = hs.get_notifier()
self.config = hs.config
self.http_client = hs.get_simple_http_client()
# This is only used to get at ratelimit function, and maybe_kick_guest_users
self.base_handler = BaseHandler(hs)
self.pusher_pool = hs.get_pusherpool()
# We arbitrarily limit concurrent event creation for a room to 5.
# This is to stop us from diverging history *too* much.
self.limiter = Limiter(max_count=5)
self.action_generator = hs.get_action_generator()
self.spam_checker = hs.get_spam_checker()
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_new_client_event(self, builder, requester=None, prev_event_ids=None): def create_event(self, requester, event_dict, token_id=None, txn_id=None,
prev_event_ids=None):
"""
Given a dict from a client, create a new event.
Creates an FrozenEvent object, filling out auth_events, prev_events,
etc.
Adds display names to Join membership events.
Args:
requester
event_dict (dict): An entire event
token_id (str)
txn_id (str)
prev_event_ids (list): The prev event ids to use when creating the event
Returns:
Tuple of created event (FrozenEvent), Context
"""
builder = self.event_builder_factory.new(event_dict)
with (yield self.limiter.queue(builder.room_id)):
self.validator.validate_new(builder)
if builder.type == EventTypes.Member:
membership = builder.content.get("membership", None)
target = UserID.from_string(builder.state_key)
if membership in {Membership.JOIN, Membership.INVITE}:
# If event doesn't include a display name, add one.
profile = self.profile_handler
content = builder.content
try:
if "displayname" not in content:
content["displayname"] = yield profile.get_displayname(target)
if "avatar_url" not in content:
content["avatar_url"] = yield profile.get_avatar_url(target)
except Exception as e:
logger.info(
"Failed to get profile information for %r: %s",
target, e
)
if token_id is not None:
builder.internal_metadata.token_id = token_id
if txn_id is not None:
builder.internal_metadata.txn_id = txn_id
event, context = yield self.create_new_client_event(
builder=builder,
requester=requester,
prev_event_ids=prev_event_ids,
)
defer.returnValue((event, context))
@defer.inlineCallbacks
def send_nonmember_event(self, requester, event, context, ratelimit=True):
"""
Persists and notifies local clients and federation of an event.
Args:
event (FrozenEvent) the event to send.
context (Context) the context of the event.
ratelimit (bool): Whether to rate limit this send.
is_guest (bool): Whether the sender is a guest.
"""
if event.type == EventTypes.Member:
raise SynapseError(
500,
"Tried to send member event through non-member codepath"
)
user = UserID.from_string(event.sender)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if event.is_state():
prev_state = yield self.deduplicate_state_event(event, context)
if prev_state is not None:
defer.returnValue(prev_state)
yield self.handle_new_client_event(
requester=requester,
event=event,
context=context,
ratelimit=ratelimit,
)
@defer.inlineCallbacks
def deduplicate_state_event(self, event, context):
"""
Checks whether event is in the latest resolved state in context.
If so, returns the version of the event in context.
Otherwise, returns None.
"""
prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if not prev_event:
return
if prev_event and event.user_id == prev_event.user_id:
prev_content = encode_canonical_json(prev_event.content)
next_content = encode_canonical_json(event.content)
if prev_content == next_content:
defer.returnValue(prev_event)
return
@defer.inlineCallbacks
def create_and_send_nonmember_event(
self,
requester,
event_dict,
ratelimit=True,
txn_id=None
):
"""
Creates an event, then sends it.
See self.create_event and self.send_nonmember_event.
"""
event, context = yield self.create_event(
requester,
event_dict,
token_id=requester.access_token_id,
txn_id=txn_id
)
spam_error = self.spam_checker.check_event_for_spam(event)
if spam_error:
if not isinstance(spam_error, basestring):
spam_error = "Spam is not permitted here"
raise SynapseError(
403, spam_error, Codes.FORBIDDEN
)
yield self.send_nonmember_event(
requester,
event,
context,
ratelimit=ratelimit,
)
defer.returnValue(event)
@measure_func("create_new_client_event")
@defer.inlineCallbacks
def create_new_client_event(self, builder, requester=None, prev_event_ids=None):
if prev_event_ids: if prev_event_ids:
prev_events = yield self.store.add_event_hashes(prev_event_ids) prev_events = yield self.store.add_event_hashes(prev_event_ids)
prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids) prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids)
@ -509,9 +619,7 @@ class MessageHandler(BaseHandler):
builder.prev_events = prev_events builder.prev_events = prev_events
builder.depth = depth builder.depth = depth
state_handler = self.state_handler context = yield self.state.compute_event_context(builder)
context = yield state_handler.compute_event_context(builder)
if requester: if requester:
context.app_service = requester.app_service context.app_service = requester.app_service
@ -546,12 +654,21 @@ class MessageHandler(BaseHandler):
event, event,
context, context,
ratelimit=True, ratelimit=True,
extra_users=[] extra_users=[],
): ):
# We now need to go and hit out to wherever we need to hit out to. """Processes a new event. This includes checking auth, persisting it,
notifying users, sending to remote servers, etc.
if ratelimit: If called from a worker will hit out to the master process for final
yield self.ratelimit(requester) processing.
Args:
requester (Requester)
event (FrozenEvent)
context (EventContext)
ratelimit (bool)
extra_users (list(UserID)): Any extra users to notify about event
"""
try: try:
yield self.auth.check_from_context(event, context) yield self.auth.check_from_context(event, context)
@ -567,7 +684,58 @@ class MessageHandler(BaseHandler):
logger.exception("Failed to encode content: %r", event.content) logger.exception("Failed to encode content: %r", event.content)
raise raise
yield self.maybe_kick_guest_users(event, context) yield self.action_generator.handle_push_actions_for_event(
event, context
)
try:
# If we're a worker we need to hit out to the master.
if self.config.worker_app:
yield send_event_to_master(
self.http_client,
host=self.config.worker_replication_host,
port=self.config.worker_replication_http_port,
requester=requester,
event=event,
context=context,
ratelimit=ratelimit,
extra_users=extra_users,
)
return
yield self.persist_and_notify_client_event(
requester,
event,
context,
ratelimit=ratelimit,
extra_users=extra_users,
)
except: # noqa: E722, as we reraise the exception this is fine.
# Ensure that we actually remove the entries in the push actions
# staging area, if we calculated them.
preserve_fn(self.store.remove_push_actions_from_staging)(event.event_id)
raise
@defer.inlineCallbacks
def persist_and_notify_client_event(
self,
requester,
event,
context,
ratelimit=True,
extra_users=[],
):
"""Called when we have fully built the event, have already
calculated the push actions for the event, and checked auth.
This should only be run on master.
"""
assert not self.config.worker_app
if ratelimit:
yield self.base_handler.ratelimit(requester)
yield self.base_handler.maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias: if event.type == EventTypes.CanonicalAlias:
# Check the alias is acually valid (at this time at least) # Check the alias is acually valid (at this time at least)
@ -660,10 +828,6 @@ class MessageHandler(BaseHandler):
"Changing the room create event is forbidden", "Changing the room create event is forbidden",
) )
yield self.action_generator.handle_push_actions_for_event(
event, context
)
(event_stream_id, max_stream_id) = yield self.store.persist_event( (event_stream_id, max_stream_id) = yield self.store.persist_event(
event, context=context event, context=context
) )
@ -683,3 +847,9 @@ class MessageHandler(BaseHandler):
) )
preserve_fn(_notify)() preserve_fn(_notify)()
if event.type == EventTypes.Message:
presence = self.hs.get_presence_handler()
# We don't want to block sending messages on any presence code. This
# matters as sometimes presence code can take a while.
preserve_fn(presence.bump_presence_active_time)(requester.user)

View File

@ -93,29 +93,30 @@ class PresenceHandler(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.wheel_timer = WheelTimer() self.wheel_timer = WheelTimer()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.replication = hs.get_replication_layer()
self.federation = hs.get_federation_sender() self.federation = hs.get_federation_sender()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.replication.register_edu_handler( federation_registry = hs.get_federation_registry()
federation_registry.register_edu_handler(
"m.presence", self.incoming_presence "m.presence", self.incoming_presence
) )
self.replication.register_edu_handler( federation_registry.register_edu_handler(
"m.presence_invite", "m.presence_invite",
lambda origin, content: self.invite_presence( lambda origin, content: self.invite_presence(
observed_user=UserID.from_string(content["observed_user"]), observed_user=UserID.from_string(content["observed_user"]),
observer_user=UserID.from_string(content["observer_user"]), observer_user=UserID.from_string(content["observer_user"]),
) )
) )
self.replication.register_edu_handler( federation_registry.register_edu_handler(
"m.presence_accept", "m.presence_accept",
lambda origin, content: self.accept_presence( lambda origin, content: self.accept_presence(
observed_user=UserID.from_string(content["observed_user"]), observed_user=UserID.from_string(content["observed_user"]),
observer_user=UserID.from_string(content["observer_user"]), observer_user=UserID.from_string(content["observer_user"]),
) )
) )
self.replication.register_edu_handler( federation_registry.register_edu_handler(
"m.presence_deny", "m.presence_deny",
lambda origin, content: self.deny_presence( lambda origin, content: self.deny_presence(
observed_user=UserID.from_string(content["observed_user"]), observed_user=UserID.from_string(content["observed_user"]),

View File

@ -31,14 +31,17 @@ class ProfileHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(ProfileHandler, self).__init__(hs) super(ProfileHandler, self).__init__(hs)
self.federation = hs.get_replication_layer() self.federation = hs.get_federation_client()
self.federation.register_query_handler( hs.get_federation_registry().register_query_handler(
"profile", self.on_profile_query "profile", self.on_profile_query
) )
self.user_directory_handler = hs.get_user_directory_handler() self.user_directory_handler = hs.get_user_directory_handler()
self.clock.looping_call(self._update_remote_profile_cache, self.PROFILE_UPDATE_MS) if hs.config.worker_app is None:
self.clock.looping_call(
self._update_remote_profile_cache, self.PROFILE_UPDATE_MS,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_profile(self, user_id): def get_profile(self, user_id):
@ -233,7 +236,7 @@ class ProfileHandler(BaseHandler):
) )
for room_id in room_ids: for room_id in room_ids:
handler = self.hs.get_handlers().room_member_handler handler = self.hs.get_room_member_handler()
try: try:
# Assume the target_user isn't a guest, # Assume the target_user isn't a guest,
# because we don't let guests set profile or avatar data. # because we don't let guests set profile or avatar data.

View File

@ -41,9 +41,9 @@ class ReadMarkerHandler(BaseHandler):
""" """
with (yield self.read_marker_linearizer.queue((room_id, user_id))): with (yield self.read_marker_linearizer.queue((room_id, user_id))):
account_data = yield self.store.get_account_data_for_room(user_id, room_id) existing_read_marker = yield self.store.get_account_data_for_room_and_type(
user_id, room_id, "m.fully_read",
existing_read_marker = account_data.get("m.fully_read", None) )
should_update = True should_update = True

View File

@ -35,7 +35,7 @@ class ReceiptsHandler(BaseHandler):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.hs = hs self.hs = hs
self.federation = hs.get_federation_sender() self.federation = hs.get_federation_sender()
hs.get_replication_layer().register_edu_handler( hs.get_federation_registry().register_edu_handler(
"m.receipt", self._received_remote_receipt "m.receipt", self._received_remote_receipt
) )
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()

View File

@ -25,6 +25,7 @@ from synapse.http.client import CaptchaServerHttpClient
from synapse import types from synapse import types
from synapse.types import UserID from synapse.types import UserID
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.threepids import check_3pid_allowed
from ._base import BaseHandler from ._base import BaseHandler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -131,7 +132,7 @@ class RegistrationHandler(BaseHandler):
yield run_on_reactor() yield run_on_reactor()
password_hash = None password_hash = None
if password: if password:
password_hash = self.auth_handler().hash(password) password_hash = yield self.auth_handler().hash(password)
if localpart: if localpart:
yield self.check_username(localpart, guest_access_token=guest_access_token) yield self.check_username(localpart, guest_access_token=guest_access_token)
@ -293,7 +294,7 @@ class RegistrationHandler(BaseHandler):
""" """
for c in threepidCreds: for c in threepidCreds:
logger.info("validating theeepidcred sid %s on id server %s", logger.info("validating threepidcred sid %s on id server %s",
c['sid'], c['idServer']) c['sid'], c['idServer'])
try: try:
identity_handler = self.hs.get_handlers().identity_handler identity_handler = self.hs.get_handlers().identity_handler
@ -307,6 +308,11 @@ class RegistrationHandler(BaseHandler):
logger.info("got threepid with medium '%s' and address '%s'", logger.info("got threepid with medium '%s' and address '%s'",
threepid['medium'], threepid['address']) threepid['medium'], threepid['address'])
if not check_3pid_allowed(self.hs, threepid['medium'], threepid['address']):
raise RegistrationError(
403, "Third party identifier is not allowed"
)
@defer.inlineCallbacks @defer.inlineCallbacks
def bind_emails(self, user_id, threepidCreds): def bind_emails(self, user_id, threepidCreds):
"""Links emails with a user ID and informs an identity server. """Links emails with a user ID and informs an identity server.
@ -440,16 +446,34 @@ class RegistrationHandler(BaseHandler):
return self.hs.get_auth_handler() return self.hs.get_auth_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def guest_access_token_for(self, medium, address, inviter_user_id): def get_or_register_3pid_guest(self, medium, address, inviter_user_id):
"""Get a guest access token for a 3PID, creating a guest account if
one doesn't already exist.
Args:
medium (str)
address (str)
inviter_user_id (str): The user ID who is trying to invite the
3PID
Returns:
Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
3PID guest account.
"""
access_token = yield self.store.get_3pid_guest_access_token(medium, address) access_token = yield self.store.get_3pid_guest_access_token(medium, address)
if access_token: if access_token:
defer.returnValue(access_token) user_info = yield self.auth.get_user_by_access_token(
access_token
)
_, access_token = yield self.register( defer.returnValue((user_info["user"].to_string(), access_token))
user_id, access_token = yield self.register(
generate_token=True, generate_token=True,
make_guest=True make_guest=True
) )
access_token = yield self.store.save_or_get_3pid_guest_access_token( access_token = yield self.store.save_or_get_3pid_guest_access_token(
medium, address, access_token, inviter_user_id medium, address, access_token, inviter_user_id
) )
defer.returnValue(access_token)
defer.returnValue((user_id, access_token))

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd # Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -64,6 +65,7 @@ class RoomCreationHandler(BaseHandler):
super(RoomCreationHandler, self).__init__(hs) super(RoomCreationHandler, self).__init__(hs)
self.spam_checker = hs.get_spam_checker() self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def create_room(self, requester, config, ratelimit=True): def create_room(self, requester, config, ratelimit=True):
@ -163,13 +165,11 @@ class RoomCreationHandler(BaseHandler):
creation_content = config.get("creation_content", {}) creation_content = config.get("creation_content", {})
msg_handler = self.hs.get_handlers().message_handler room_member_handler = self.hs.get_room_member_handler()
room_member_handler = self.hs.get_handlers().room_member_handler
yield self._send_events_for_new_room( yield self._send_events_for_new_room(
requester, requester,
room_id, room_id,
msg_handler,
room_member_handler, room_member_handler,
preset_config=preset_config, preset_config=preset_config,
invite_list=invite_list, invite_list=invite_list,
@ -181,7 +181,7 @@ class RoomCreationHandler(BaseHandler):
if "name" in config: if "name" in config:
name = config["name"] name = config["name"]
yield msg_handler.create_and_send_nonmember_event( yield self.event_creation_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.Name, "type": EventTypes.Name,
@ -194,7 +194,7 @@ class RoomCreationHandler(BaseHandler):
if "topic" in config: if "topic" in config:
topic = config["topic"] topic = config["topic"]
yield msg_handler.create_and_send_nonmember_event( yield self.event_creation_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.Topic, "type": EventTypes.Topic,
@ -224,7 +224,7 @@ class RoomCreationHandler(BaseHandler):
id_server = invite_3pid["id_server"] id_server = invite_3pid["id_server"]
address = invite_3pid["address"] address = invite_3pid["address"]
medium = invite_3pid["medium"] medium = invite_3pid["medium"]
yield self.hs.get_handlers().room_member_handler.do_3pid_invite( yield self.hs.get_room_member_handler().do_3pid_invite(
room_id, room_id,
requester.user, requester.user,
medium, medium,
@ -249,7 +249,6 @@ class RoomCreationHandler(BaseHandler):
self, self,
creator, # A Requester object. creator, # A Requester object.
room_id, room_id,
msg_handler,
room_member_handler, room_member_handler,
preset_config, preset_config,
invite_list, invite_list,
@ -272,7 +271,7 @@ class RoomCreationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def send(etype, content, **kwargs): def send(etype, content, **kwargs):
event = create(etype, content, **kwargs) event = create(etype, content, **kwargs)
yield msg_handler.create_and_send_nonmember_event( yield self.event_creation_handler.create_and_send_nonmember_event(
creator, creator,
event, event,
ratelimit=False ratelimit=False
@ -476,12 +475,9 @@ class RoomEventSource(object):
user.to_string() user.to_string()
) )
if app_service: if app_service:
events, end_key = yield self.store.get_appservice_room_stream( # We no longer support AS users using /sync directly.
service=app_service, # See https://github.com/matrix-org/matrix-doc/issues/1144
from_key=from_key, raise NotImplementedError()
to_key=to_key,
limit=limit,
)
else: else:
room_events = yield self.store.get_membership_changes_for_user( room_events = yield self.store.get_membership_changes_for_user(
user.to_string(), from_key, to_key user.to_string(), from_key, to_key

View File

@ -203,7 +203,8 @@ class RoomListHandler(BaseHandler):
if limit: if limit:
step = limit + 1 step = limit + 1
else: else:
step = len(rooms_to_scan) # step cannot be zero
step = len(rooms_to_scan) if len(rooms_to_scan) != 0 else 1
chunk = [] chunk = []
for i in xrange(0, len(rooms_to_scan), step): for i in xrange(0, len(rooms_to_scan), step):
@ -408,7 +409,7 @@ class RoomListHandler(BaseHandler):
def _get_remote_list_cached(self, server_name, limit=None, since_token=None, def _get_remote_list_cached(self, server_name, limit=None, since_token=None,
search_filter=None, include_all_networks=False, search_filter=None, include_all_networks=False,
third_party_instance_id=None,): third_party_instance_id=None,):
repl_layer = self.hs.get_replication_layer() repl_layer = self.hs.get_federation_client()
if search_filter: if search_filter:
# We can't cache when asking for search # We can't cache when asking for search
return repl_layer.get_public_rooms( return repl_layer.get_public_rooms(

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd # Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc
import logging import logging
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
@ -29,32 +30,121 @@ from synapse.api.errors import AuthError, SynapseError, Codes
from synapse.types import UserID, RoomID from synapse.types import UserID, RoomID
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from synapse.util.distributor import user_left_room, user_joined_room from synapse.util.distributor import user_left_room, user_joined_room
from ._base import BaseHandler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
id_server_scheme = "https://" id_server_scheme = "https://"
class RoomMemberHandler(BaseHandler): class RoomMemberHandler(object):
# TODO(paul): This handler currently contains a messy conflation of # TODO(paul): This handler currently contains a messy conflation of
# low-level API that works on UserID objects and so on, and REST-level # low-level API that works on UserID objects and so on, and REST-level
# API that takes ID strings and returns pagination chunks. These concerns # API that takes ID strings and returns pagination chunks. These concerns
# ought to be separated out a lot better. # ought to be separated out a lot better.
def __init__(self, hs): __metaclass__ = abc.ABCMeta
super(RoomMemberHandler, self).__init__(hs)
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.state_handler = hs.get_state_handler()
self.config = hs.config
self.simple_http_client = hs.get_simple_http_client()
self.federation_handler = hs.get_handlers().federation_handler
self.directory_handler = hs.get_handlers().directory_handler
self.registration_handler = hs.get_handlers().registration_handler
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.event_creation_hander = hs.get_event_creation_handler()
self.member_linearizer = Linearizer(name="member") self.member_linearizer = Linearizer(name="member")
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.spam_checker = hs.get_spam_checker() self.spam_checker = hs.get_spam_checker()
self.distributor = hs.get_distributor() @abc.abstractmethod
self.distributor.declare("user_joined_room") def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
self.distributor.declare("user_left_room") """Try and join a room that this server is not in
Args:
requester (Requester)
remote_room_hosts (list[str]): List of servers that can be used
to join via.
room_id (str): Room that we are trying to join
user (UserID): User who is trying to join
content (dict): A dict that should be used as the content of the
join event.
Returns:
Deferred
"""
raise NotImplementedError()
@abc.abstractmethod
def _remote_reject_invite(self, remote_room_hosts, room_id, target):
"""Attempt to reject an invite for a room this server is not in. If we
fail to do so we locally mark the invite as rejected.
Args:
requester (Requester)
remote_room_hosts (list[str]): List of servers to use to try and
reject invite
room_id (str)
target (UserID): The user rejecting the invite
Returns:
Deferred[dict]: A dictionary to be returned to the client, may
include event_id etc, or nothing if we locally rejected
"""
raise NotImplementedError()
@abc.abstractmethod
def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
"""Get a guest access token for a 3PID, creating a guest account if
one doesn't already exist.
Args:
requester (Requester)
medium (str)
address (str)
inviter_user_id (str): The user ID who is trying to invite the
3PID
Returns:
Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
3PID guest account.
"""
raise NotImplementedError()
@abc.abstractmethod
def _user_joined_room(self, target, room_id):
"""Notifies distributor on master process that the user has joined the
room.
Args:
target (UserID)
room_id (str)
Returns:
Deferred|None
"""
raise NotImplementedError()
@abc.abstractmethod
def _user_left_room(self, target, room_id):
"""Notifies distributor on master process that the user has left the
room.
Args:
target (UserID)
room_id (str)
Returns:
Deferred|None
"""
raise NotImplementedError()
@defer.inlineCallbacks @defer.inlineCallbacks
def _local_membership_update( def _local_membership_update(
@ -66,13 +156,12 @@ class RoomMemberHandler(BaseHandler):
): ):
if content is None: if content is None:
content = {} content = {}
msg_handler = self.hs.get_handlers().message_handler
content["membership"] = membership content["membership"] = membership
if requester.is_guest: if requester.is_guest:
content["kind"] = "guest" content["kind"] = "guest"
event, context = yield msg_handler.create_event( event, context = yield self.event_creation_hander.create_event(
requester, requester,
{ {
"type": EventTypes.Member, "type": EventTypes.Member,
@ -90,12 +179,14 @@ class RoomMemberHandler(BaseHandler):
) )
# Check if this event matches the previous membership event for the user. # Check if this event matches the previous membership event for the user.
duplicate = yield msg_handler.deduplicate_state_event(event, context) duplicate = yield self.event_creation_hander.deduplicate_state_event(
event, context,
)
if duplicate is not None: if duplicate is not None:
# Discard the new event since this membership change is a no-op. # Discard the new event since this membership change is a no-op.
defer.returnValue(duplicate) defer.returnValue(duplicate)
yield msg_handler.handle_new_client_event( yield self.event_creation_hander.handle_new_client_event(
requester, requester,
event, event,
context, context,
@ -117,32 +208,15 @@ class RoomMemberHandler(BaseHandler):
prev_member_event = yield self.store.get_event(prev_member_event_id) prev_member_event = yield self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined: if newly_joined:
yield user_joined_room(self.distributor, target, room_id) yield self._user_joined_room(target, room_id)
elif event.membership == Membership.LEAVE: elif event.membership == Membership.LEAVE:
if prev_member_event_id: if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id) prev_member_event = yield self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN: if prev_member_event.membership == Membership.JOIN:
user_left_room(self.distributor, target, room_id) yield self._user_left_room(target, room_id)
defer.returnValue(event) defer.returnValue(event)
@defer.inlineCallbacks
def remote_join(self, remote_room_hosts, room_id, user, content):
if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers")
# We don't do an auth check if we are doing an invite
# join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we
# need to do the invite/join dance.
yield self.hs.get_handlers().federation_handler.do_invite_join(
remote_room_hosts,
room_id,
user.to_string(),
content,
)
yield user_joined_room(self.distributor, user, room_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def update_membership( def update_membership(
self, self,
@ -201,8 +275,7 @@ class RoomMemberHandler(BaseHandler):
# if this is a join with a 3pid signature, we may need to turn a 3pid # if this is a join with a 3pid signature, we may need to turn a 3pid
# invite into a normal invite before we can handle the join. # invite into a normal invite before we can handle the join.
if third_party_signed is not None: if third_party_signed is not None:
replication = self.hs.get_replication_layer() yield self.federation_handler.exchange_third_party_invite(
yield replication.exchange_third_party_invite(
third_party_signed["sender"], third_party_signed["sender"],
target.to_string(), target.to_string(),
room_id, room_id,
@ -223,7 +296,7 @@ class RoomMemberHandler(BaseHandler):
requester.user, requester.user,
) )
if not is_requester_admin: if not is_requester_admin:
if self.hs.config.block_non_admin_invites: if self.config.block_non_admin_invites:
logger.info( logger.info(
"Blocking invite: user is not admin and non-admin " "Blocking invite: user is not admin and non-admin "
"invites disabled" "invites disabled"
@ -282,7 +355,7 @@ class RoomMemberHandler(BaseHandler):
raise AuthError(403, "Guest access not allowed") raise AuthError(403, "Guest access not allowed")
if not is_host_in_room: if not is_host_in_room:
inviter = yield self.get_inviter(target.to_string(), room_id) inviter = yield self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter): if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain) remote_room_hosts.append(inviter.domain)
@ -296,15 +369,15 @@ class RoomMemberHandler(BaseHandler):
if requester.is_guest: if requester.is_guest:
content["kind"] = "guest" content["kind"] = "guest"
ret = yield self.remote_join( ret = yield self._remote_join(
remote_room_hosts, room_id, target, content requester, remote_room_hosts, room_id, target, content
) )
defer.returnValue(ret) defer.returnValue(ret)
elif effective_membership_state == Membership.LEAVE: elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room: if not is_host_in_room:
# perhaps we've been invited # perhaps we've been invited
inviter = yield self.get_inviter(target.to_string(), room_id) inviter = yield self._get_inviter(target.to_string(), room_id)
if not inviter: if not inviter:
raise SynapseError(404, "Not a known room") raise SynapseError(404, "Not a known room")
@ -318,28 +391,10 @@ class RoomMemberHandler(BaseHandler):
else: else:
# send the rejection to the inviter's HS. # send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain] remote_room_hosts = remote_room_hosts + [inviter.domain]
fed_handler = self.hs.get_handlers().federation_handler res = yield self._remote_reject_invite(
try: requester, remote_room_hosts, room_id, target,
ret = yield fed_handler.do_remotely_reject_invite( )
remote_room_hosts, defer.returnValue(res)
room_id,
target.to_string(),
)
defer.returnValue(ret)
except Exception as e:
# if we were unable to reject the exception, just mark
# it as rejected on our end and plough ahead.
#
# The 'except' clause is very broad, but we need to
# capture everything from DNS failures upwards
#
logger.warn("Failed to reject invite: %s", e)
yield self.store.locally_reject_invite(
target.to_string(), room_id
)
defer.returnValue({})
res = yield self._local_membership_update( res = yield self._local_membership_update(
requester=requester, requester=requester,
@ -394,8 +449,9 @@ class RoomMemberHandler(BaseHandler):
else: else:
requester = synapse.types.create_requester(target_user) requester = synapse.types.create_requester(target_user)
message_handler = self.hs.get_handlers().message_handler prev_event = yield self.event_creation_hander.deduplicate_state_event(
prev_event = yield message_handler.deduplicate_state_event(event, context) event, context,
)
if prev_event is not None: if prev_event is not None:
return return
@ -412,7 +468,7 @@ class RoomMemberHandler(BaseHandler):
if is_blocked: if is_blocked:
raise SynapseError(403, "This room has been blocked on this server") raise SynapseError(403, "This room has been blocked on this server")
yield message_handler.handle_new_client_event( yield self.event_creation_hander.handle_new_client_event(
requester, requester,
event, event,
context, context,
@ -434,12 +490,12 @@ class RoomMemberHandler(BaseHandler):
prev_member_event = yield self.store.get_event(prev_member_event_id) prev_member_event = yield self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined: if newly_joined:
yield user_joined_room(self.distributor, target_user, room_id) yield self._user_joined_room(target_user, room_id)
elif event.membership == Membership.LEAVE: elif event.membership == Membership.LEAVE:
if prev_member_event_id: if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id) prev_member_event = yield self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN: if prev_member_event.membership == Membership.JOIN:
user_left_room(self.distributor, target_user, room_id) yield self._user_left_room(target_user, room_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def _can_guest_join(self, current_state_ids): def _can_guest_join(self, current_state_ids):
@ -473,7 +529,7 @@ class RoomMemberHandler(BaseHandler):
Raises: Raises:
SynapseError if room alias could not be found. SynapseError if room alias could not be found.
""" """
directory_handler = self.hs.get_handlers().directory_handler directory_handler = self.directory_handler
mapping = yield directory_handler.get_association(room_alias) mapping = yield directory_handler.get_association(room_alias)
if not mapping: if not mapping:
@ -485,7 +541,7 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue((RoomID.from_string(room_id), servers)) defer.returnValue((RoomID.from_string(room_id), servers))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_inviter(self, user_id, room_id): def _get_inviter(self, user_id, room_id):
invite = yield self.store.get_invite_for_user_in_room( invite = yield self.store.get_invite_for_user_in_room(
user_id=user_id, user_id=user_id,
room_id=room_id, room_id=room_id,
@ -504,7 +560,7 @@ class RoomMemberHandler(BaseHandler):
requester, requester,
txn_id txn_id
): ):
if self.hs.config.block_non_admin_invites: if self.config.block_non_admin_invites:
is_requester_admin = yield self.auth.is_server_admin( is_requester_admin = yield self.auth.is_server_admin(
requester.user, requester.user,
) )
@ -551,7 +607,7 @@ class RoomMemberHandler(BaseHandler):
str: the matrix ID of the 3pid, or None if it is not recognized. str: the matrix ID of the 3pid, or None if it is not recognized.
""" """
try: try:
data = yield self.hs.get_simple_http_client().get_json( data = yield self.simple_http_client.get_json(
"%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,), "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,),
{ {
"medium": medium, "medium": medium,
@ -562,7 +618,7 @@ class RoomMemberHandler(BaseHandler):
if "mxid" in data: if "mxid" in data:
if "signatures" not in data: if "signatures" not in data:
raise AuthError(401, "No signatures on 3pid binding") raise AuthError(401, "No signatures on 3pid binding")
self.verify_any_signature(data, id_server) yield self._verify_any_signature(data, id_server)
defer.returnValue(data["mxid"]) defer.returnValue(data["mxid"])
except IOError as e: except IOError as e:
@ -570,11 +626,11 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def verify_any_signature(self, data, server_hostname): def _verify_any_signature(self, data, server_hostname):
if server_hostname not in data["signatures"]: if server_hostname not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (server_hostname,)) raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items(): for key_name, signature in data["signatures"][server_hostname].items():
key_data = yield self.hs.get_simple_http_client().get_json( key_data = yield self.simple_http_client.get_json(
"%s%s/_matrix/identity/api/v1/pubkey/%s" % "%s%s/_matrix/identity/api/v1/pubkey/%s" %
(id_server_scheme, server_hostname, key_name,), (id_server_scheme, server_hostname, key_name,),
) )
@ -599,7 +655,7 @@ class RoomMemberHandler(BaseHandler):
user, user,
txn_id txn_id
): ):
room_state = yield self.hs.get_state_handler().get_current_state(room_id) room_state = yield self.state_handler.get_current_state(room_id)
inviter_display_name = "" inviter_display_name = ""
inviter_avatar_url = "" inviter_avatar_url = ""
@ -630,6 +686,7 @@ class RoomMemberHandler(BaseHandler):
token, public_keys, fallback_public_key, display_name = ( token, public_keys, fallback_public_key, display_name = (
yield self._ask_id_server_for_third_party_invite( yield self._ask_id_server_for_third_party_invite(
requester=requester,
id_server=id_server, id_server=id_server,
medium=medium, medium=medium,
address=address, address=address,
@ -644,8 +701,7 @@ class RoomMemberHandler(BaseHandler):
) )
) )
msg_handler = self.hs.get_handlers().message_handler yield self.event_creation_hander.create_and_send_nonmember_event(
yield msg_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.ThirdPartyInvite, "type": EventTypes.ThirdPartyInvite,
@ -667,6 +723,7 @@ class RoomMemberHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _ask_id_server_for_third_party_invite( def _ask_id_server_for_third_party_invite(
self, self,
requester,
id_server, id_server,
medium, medium,
address, address,
@ -683,6 +740,7 @@ class RoomMemberHandler(BaseHandler):
Asks an identity server for a third party invite. Asks an identity server for a third party invite.
Args: Args:
requester (Requester)
id_server (str): hostname + optional port for the identity server. id_server (str): hostname + optional port for the identity server.
medium (str): The literal string "email". medium (str): The literal string "email".
address (str): The third party address being invited. address (str): The third party address being invited.
@ -724,24 +782,20 @@ class RoomMemberHandler(BaseHandler):
"sender_avatar_url": inviter_avatar_url, "sender_avatar_url": inviter_avatar_url,
} }
if self.hs.config.invite_3pid_guest: if self.config.invite_3pid_guest:
registration_handler = self.hs.get_handlers().registration_handler guest_access_token, guest_user_id = yield self.get_or_register_3pid_guest(
guest_access_token = yield registration_handler.guest_access_token_for( requester=requester,
medium=medium, medium=medium,
address=address, address=address,
inviter_user_id=inviter_user_id, inviter_user_id=inviter_user_id,
) )
guest_user_info = yield self.hs.get_auth().get_user_by_access_token(
guest_access_token
)
invite_config.update({ invite_config.update({
"guest_access_token": guest_access_token, "guest_access_token": guest_access_token,
"guest_user_id": guest_user_info["user"].to_string(), "guest_user_id": guest_user_id,
}) })
data = yield self.hs.get_simple_http_client().post_urlencoded_get_json( data = yield self.simple_http_client.post_urlencoded_get_json(
is_url, is_url,
invite_config invite_config
) )
@ -763,27 +817,6 @@ class RoomMemberHandler(BaseHandler):
display_name = data["display_name"] display_name = data["display_name"]
defer.returnValue((token, public_keys, fallback_public_key, display_name)) defer.returnValue((token, public_keys, fallback_public_key, display_name))
@defer.inlineCallbacks
def forget(self, user, room_id):
user_id = user.to_string()
member = yield self.state_handler.get_current_state(
room_id=room_id,
event_type=EventTypes.Member,
state_key=user_id
)
membership = member.membership if member else None
if membership is not None and membership not in [
Membership.LEAVE, Membership.BAN
]:
raise SynapseError(400, "User %s in room %s" % (
user_id, room_id
))
if membership:
yield self.store.forget(user_id, room_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def _is_host_in_room(self, current_state_ids): def _is_host_in_room(self, current_state_ids):
# Have we just created the room, and is this about to be the very # Have we just created the room, and is this about to be the very
@ -805,3 +838,94 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue(True) defer.returnValue(True)
defer.returnValue(False) defer.returnValue(False)
class RoomMemberMasterHandler(RoomMemberHandler):
def __init__(self, hs):
super(RoomMemberMasterHandler, self).__init__(hs)
self.distributor = hs.get_distributor()
self.distributor.declare("user_joined_room")
self.distributor.declare("user_left_room")
@defer.inlineCallbacks
def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
"""Implements RoomMemberHandler._remote_join
"""
if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers")
# We don't do an auth check if we are doing an invite
# join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we
# need to do the invite/join dance.
yield self.federation_handler.do_invite_join(
remote_room_hosts,
room_id,
user.to_string(),
content,
)
yield self._user_joined_room(user, room_id)
@defer.inlineCallbacks
def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target):
"""Implements RoomMemberHandler._remote_reject_invite
"""
fed_handler = self.federation_handler
try:
ret = yield fed_handler.do_remotely_reject_invite(
remote_room_hosts,
room_id,
target.to_string(),
)
defer.returnValue(ret)
except Exception as e:
# if we were unable to reject the exception, just mark
# it as rejected on our end and plough ahead.
#
# The 'except' clause is very broad, but we need to
# capture everything from DNS failures upwards
#
logger.warn("Failed to reject invite: %s", e)
yield self.store.locally_reject_invite(
target.to_string(), room_id
)
defer.returnValue({})
def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
"""Implements RoomMemberHandler.get_or_register_3pid_guest
"""
rg = self.registration_handler
return rg.get_or_register_3pid_guest(medium, address, inviter_user_id)
def _user_joined_room(self, target, room_id):
"""Implements RoomMemberHandler._user_joined_room
"""
return user_joined_room(self.distributor, target, room_id)
def _user_left_room(self, target, room_id):
"""Implements RoomMemberHandler._user_left_room
"""
return user_left_room(self.distributor, target, room_id)
@defer.inlineCallbacks
def forget(self, user, room_id):
user_id = user.to_string()
member = yield self.state_handler.get_current_state(
room_id=room_id,
event_type=EventTypes.Member,
state_key=user_id
)
membership = member.membership if member else None
if membership is not None and membership not in [
Membership.LEAVE, Membership.BAN
]:
raise SynapseError(400, "User %s in room %s" % (
user_id, room_id
))
if membership:
yield self.store.forget(user_id, room_id)

View File

@ -0,0 +1,102 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.handlers.room_member import RoomMemberHandler
from synapse.replication.http.membership import (
remote_join, remote_reject_invite, get_or_register_3pid_guest,
notify_user_membership_change,
)
logger = logging.getLogger(__name__)
class RoomMemberWorkerHandler(RoomMemberHandler):
@defer.inlineCallbacks
def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
"""Implements RoomMemberHandler._remote_join
"""
if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers")
ret = yield remote_join(
self.simple_http_client,
host=self.config.worker_replication_host,
port=self.config.worker_replication_http_port,
requester=requester,
remote_room_hosts=remote_room_hosts,
room_id=room_id,
user_id=user.to_string(),
content=content,
)
yield self._user_joined_room(user, room_id)
defer.returnValue(ret)
def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target):
"""Implements RoomMemberHandler._remote_reject_invite
"""
return remote_reject_invite(
self.simple_http_client,
host=self.config.worker_replication_host,
port=self.config.worker_replication_http_port,
requester=requester,
remote_room_hosts=remote_room_hosts,
room_id=room_id,
user_id=target.to_string(),
)
def _user_joined_room(self, target, room_id):
"""Implements RoomMemberHandler._user_joined_room
"""
return notify_user_membership_change(
self.simple_http_client,
host=self.config.worker_replication_host,
port=self.config.worker_replication_http_port,
user_id=target.to_string(),
room_id=room_id,
change="joined",
)
def _user_left_room(self, target, room_id):
"""Implements RoomMemberHandler._user_left_room
"""
return notify_user_membership_change(
self.simple_http_client,
host=self.config.worker_replication_host,
port=self.config.worker_replication_http_port,
user_id=target.to_string(),
room_id=room_id,
change="left",
)
def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
"""Implements RoomMemberHandler.get_or_register_3pid_guest
"""
return get_or_register_3pid_guest(
self.simple_http_client,
host=self.config.worker_replication_host,
port=self.config.worker_replication_http_port,
requester=requester,
medium=medium,
address=address,
inviter_user_id=inviter_user_id,
)

View File

@ -31,7 +31,7 @@ class SetPasswordHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def set_password(self, user_id, newpassword, requester=None): def set_password(self, user_id, newpassword, requester=None):
password_hash = self._auth_handler.hash(newpassword) password_hash = yield self._auth_handler.hash(newpassword)
except_device_id = requester.device_id if requester else None except_device_id = requester.device_id if requester else None
except_access_token_id = requester.access_token_id if requester else None except_access_token_id = requester.access_token_id if requester else None

View File

@ -235,10 +235,10 @@ class SyncHandler(object):
defer.returnValue(rules) defer.returnValue(rules)
@defer.inlineCallbacks @defer.inlineCallbacks
def ephemeral_by_room(self, sync_config, now_token, since_token=None): def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None):
"""Get the ephemeral events for each room the user is in """Get the ephemeral events for each room the user is in
Args: Args:
sync_config (SyncConfig): The flags, filters and user for the sync. sync_result_builder(SyncResultBuilder)
now_token (StreamToken): Where the server is currently up to. now_token (StreamToken): Where the server is currently up to.
since_token (StreamToken): Where the server was when the client since_token (StreamToken): Where the server was when the client
last synced. last synced.
@ -248,10 +248,12 @@ class SyncHandler(object):
typing events for that room. typing events for that room.
""" """
sync_config = sync_result_builder.sync_config
with Measure(self.clock, "ephemeral_by_room"): with Measure(self.clock, "ephemeral_by_room"):
typing_key = since_token.typing_key if since_token else "0" typing_key = since_token.typing_key if since_token else "0"
room_ids = yield self.store.get_rooms_for_user(sync_config.user.to_string()) room_ids = sync_result_builder.joined_room_ids
typing_source = self.event_sources.sources["typing"] typing_source = self.event_sources.sources["typing"]
typing, typing_key = yield typing_source.get_new_events( typing, typing_key = yield typing_source.get_new_events(
@ -565,10 +567,22 @@ class SyncHandler(object):
# Always use the `now_token` in `SyncResultBuilder` # Always use the `now_token` in `SyncResultBuilder`
now_token = yield self.event_sources.get_current_token() now_token = yield self.event_sources.get_current_token()
user_id = sync_config.user.to_string()
app_service = self.store.get_app_service_by_user_id(user_id)
if app_service:
# We no longer support AS users using /sync directly.
# See https://github.com/matrix-org/matrix-doc/issues/1144
raise NotImplementedError()
else:
joined_room_ids = yield self.get_rooms_for_user_at(
user_id, now_token.room_stream_id,
)
sync_result_builder = SyncResultBuilder( sync_result_builder = SyncResultBuilder(
sync_config, full_state, sync_config, full_state,
since_token=since_token, since_token=since_token,
now_token=now_token, now_token=now_token,
joined_room_ids=joined_room_ids,
) )
account_data_by_room = yield self._generate_sync_entry_for_account_data( account_data_by_room = yield self._generate_sync_entry_for_account_data(
@ -603,7 +617,6 @@ class SyncHandler(object):
device_id = sync_config.device_id device_id = sync_config.device_id
one_time_key_counts = {} one_time_key_counts = {}
if device_id: if device_id:
user_id = sync_config.user.to_string()
one_time_key_counts = yield self.store.count_e2e_one_time_keys( one_time_key_counts = yield self.store.count_e2e_one_time_keys(
user_id, device_id user_id, device_id
) )
@ -891,7 +904,7 @@ class SyncHandler(object):
ephemeral_by_room = {} ephemeral_by_room = {}
else: else:
now_token, ephemeral_by_room = yield self.ephemeral_by_room( now_token, ephemeral_by_room = yield self.ephemeral_by_room(
sync_result_builder.sync_config, sync_result_builder,
now_token=sync_result_builder.now_token, now_token=sync_result_builder.now_token,
since_token=sync_result_builder.since_token, since_token=sync_result_builder.since_token,
) )
@ -996,15 +1009,8 @@ class SyncHandler(object):
if rooms_changed: if rooms_changed:
defer.returnValue(True) defer.returnValue(True)
app_service = self.store.get_app_service_by_user_id(user_id)
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
joined_room_ids = set(r.room_id for r in rooms)
else:
joined_room_ids = yield self.store.get_rooms_for_user(user_id)
stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream
for room_id in joined_room_ids: for room_id in sync_result_builder.joined_room_ids:
if self.store.has_room_changed_since(room_id, stream_id): if self.store.has_room_changed_since(room_id, stream_id):
defer.returnValue(True) defer.returnValue(True)
defer.returnValue(False) defer.returnValue(False)
@ -1028,13 +1034,6 @@ class SyncHandler(object):
assert since_token assert since_token
app_service = self.store.get_app_service_by_user_id(user_id)
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
joined_room_ids = set(r.room_id for r in rooms)
else:
joined_room_ids = yield self.store.get_rooms_for_user(user_id)
# Get a list of membership change events that have happened. # Get a list of membership change events that have happened.
rooms_changed = yield self.store.get_membership_changes_for_user( rooms_changed = yield self.store.get_membership_changes_for_user(
user_id, since_token.room_key, now_token.room_key user_id, since_token.room_key, now_token.room_key
@ -1057,7 +1056,7 @@ class SyncHandler(object):
# we do send down the room, and with full state, where necessary # we do send down the room, and with full state, where necessary
old_state_ids = None old_state_ids = None
if room_id in joined_room_ids and non_joins: if room_id in sync_result_builder.joined_room_ids and non_joins:
# Always include if the user (re)joined the room, especially # Always include if the user (re)joined the room, especially
# important so that device list changes are calculated correctly. # important so that device list changes are calculated correctly.
# If there are non join member events, but we are still in the room, # If there are non join member events, but we are still in the room,
@ -1067,7 +1066,7 @@ class SyncHandler(object):
# User is in the room so we don't need to do the invite/leave checks # User is in the room so we don't need to do the invite/leave checks
continue continue
if room_id in joined_room_ids or has_join: if room_id in sync_result_builder.joined_room_ids or has_join:
old_state_ids = yield self.get_state_at(room_id, since_token) old_state_ids = yield self.get_state_at(room_id, since_token)
old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None) old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
old_mem_ev = None old_mem_ev = None
@ -1079,7 +1078,7 @@ class SyncHandler(object):
newly_joined_rooms.append(room_id) newly_joined_rooms.append(room_id)
# If user is in the room then we don't need to do the invite/leave checks # If user is in the room then we don't need to do the invite/leave checks
if room_id in joined_room_ids: if room_id in sync_result_builder.joined_room_ids:
continue continue
if not non_joins: if not non_joins:
@ -1146,7 +1145,7 @@ class SyncHandler(object):
# Get all events for rooms we're currently joined to. # Get all events for rooms we're currently joined to.
room_to_events = yield self.store.get_room_events_stream_for_rooms( room_to_events = yield self.store.get_room_events_stream_for_rooms(
room_ids=joined_room_ids, room_ids=sync_result_builder.joined_room_ids,
from_key=since_token.room_key, from_key=since_token.room_key,
to_key=now_token.room_key, to_key=now_token.room_key,
limit=timeline_limit + 1, limit=timeline_limit + 1,
@ -1154,7 +1153,7 @@ class SyncHandler(object):
# We loop through all room ids, even if there are no new events, in case # We loop through all room ids, even if there are no new events, in case
# there are non room events taht we need to notify about. # there are non room events taht we need to notify about.
for room_id in joined_room_ids: for room_id in sync_result_builder.joined_room_ids:
room_entry = room_to_events.get(room_id, None) room_entry = room_to_events.get(room_id, None)
if room_entry: if room_entry:
@ -1362,6 +1361,54 @@ class SyncHandler(object):
else: else:
raise Exception("Unrecognized rtype: %r", room_builder.rtype) raise Exception("Unrecognized rtype: %r", room_builder.rtype)
@defer.inlineCallbacks
def get_rooms_for_user_at(self, user_id, stream_ordering):
"""Get set of joined rooms for a user at the given stream ordering.
The stream ordering *must* be recent, otherwise this may throw an
exception if older than a month. (This function is called with the
current token, which should be perfectly fine).
Args:
user_id (str)
stream_ordering (int)
ReturnValue:
Deferred[frozenset[str]]: Set of room_ids the user is in at given
stream_ordering.
"""
joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering(
user_id,
)
joined_room_ids = set()
# We need to check that the stream ordering of the join for each room
# is before the stream_ordering asked for. This might not be the case
# if the user joins a room between us getting the current token and
# calling `get_rooms_for_user_with_stream_ordering`.
# If the membership's stream ordering is after the given stream
# ordering, we need to go and work out if the user was in the room
# before.
for room_id, membership_stream_ordering in joined_rooms:
if membership_stream_ordering <= stream_ordering:
joined_room_ids.add(room_id)
continue
logger.info("User joined room after current token: %s", room_id)
extrems = yield self.store.get_forward_extremeties_for_room(
room_id, stream_ordering,
)
users_in_room = yield self.state.get_current_user_in_room(
room_id, extrems,
)
if user_id in users_in_room:
joined_room_ids.add(room_id)
joined_room_ids = frozenset(joined_room_ids)
defer.returnValue(joined_room_ids)
def _action_has_highlight(actions): def _action_has_highlight(actions):
for action in actions: for action in actions:
@ -1411,7 +1458,8 @@ def _calculate_state(timeline_contains, timeline_start, previous, current):
class SyncResultBuilder(object): class SyncResultBuilder(object):
"Used to help build up a new SyncResult for a user" "Used to help build up a new SyncResult for a user"
def __init__(self, sync_config, full_state, since_token, now_token): def __init__(self, sync_config, full_state, since_token, now_token,
joined_room_ids):
""" """
Args: Args:
sync_config(SyncConfig) sync_config(SyncConfig)
@ -1423,6 +1471,7 @@ class SyncResultBuilder(object):
self.full_state = full_state self.full_state = full_state
self.since_token = since_token self.since_token = since_token
self.now_token = now_token self.now_token = now_token
self.joined_room_ids = joined_room_ids
self.presence = [] self.presence = []
self.account_data = [] self.account_data = []

View File

@ -56,7 +56,7 @@ class TypingHandler(object):
self.federation = hs.get_federation_sender() self.federation = hs.get_federation_sender()
hs.get_replication_layer().register_edu_handler("m.typing", self._recv_edu) hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu)
hs.get_distributor().observe("user_left_room", self.user_left_room) hs.get_distributor().observe("user_left_room", self.user_left_room)

View File

@ -18,6 +18,7 @@ from OpenSSL.SSL import VERIFY_NONE
from synapse.api.errors import ( from synapse.api.errors import (
CodeMessageException, MatrixCodeMessageException, SynapseError, Codes, CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
) )
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.logcontext import make_deferred_yieldable from synapse.util.logcontext import make_deferred_yieldable
from synapse.util import logcontext from synapse.util import logcontext
import synapse.metrics import synapse.metrics
@ -30,6 +31,7 @@ from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.web.client import ( from twisted.web.client import (
BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent, BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent,
readBody, PartialDownloadError, readBody, PartialDownloadError,
HTTPConnectionPool,
) )
from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer
from twisted.web.http import PotentialDataLoss from twisted.web.http import PotentialDataLoss
@ -64,13 +66,23 @@ class SimpleHttpClient(object):
""" """
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
pool = HTTPConnectionPool(reactor)
# the pusher makes lots of concurrent SSL connections to sygnal, and
# tends to do so in batches, so we need to allow the pool to keep lots
# of idle connections around.
pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5))
pool.cachedConnectionTimeout = 2 * 60
# The default context factory in Twisted 14.0.0 (which we require) is # The default context factory in Twisted 14.0.0 (which we require) is
# BrowserLikePolicyForHTTPS which will do regular cert validation # BrowserLikePolicyForHTTPS which will do regular cert validation
# 'like a browser' # 'like a browser'
self.agent = Agent( self.agent = Agent(
reactor, reactor,
connectTimeout=15, connectTimeout=15,
contextFactory=hs.get_http_client_context_factory() contextFactory=hs.get_http_client_context_factory(),
pool=pool,
) )
self.user_agent = hs.version_string self.user_agent = hs.version_string
self.clock = hs.get_clock() self.clock = hs.get_clock()

View File

@ -357,8 +357,7 @@ def _get_hosts_for_srv_record(dns_client, host):
def eb(res, record_type): def eb(res, record_type):
if res.check(DNSNameError): if res.check(DNSNameError):
return [] return []
logger.warn("Error looking up %s for %s: %s", logger.warn("Error looking up %s for %s: %s", record_type, host, res)
record_type, host, res, res.value)
return res return res
# no logcontexts here, so we can safely fire these off and gatherResults # no logcontexts here, so we can safely fire these off and gatherResults

View File

@ -27,7 +27,7 @@ import synapse.metrics
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from synapse.api.errors import ( from synapse.api.errors import (
SynapseError, Codes, HttpResponseException, SynapseError, Codes, HttpResponseException, FederationDeniedError,
) )
from signedjson.sign import sign_json from signedjson.sign import sign_json
@ -123,11 +123,22 @@ class MatrixFederationHttpClient(object):
Fails with ``HTTPRequestException``: if we get an HTTP response Fails with ``HTTPRequestException``: if we get an HTTP response
code >= 300. code >= 300.
Fails with ``NotRetryingDestination`` if we are not yet ready Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server. to retry this server.
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
(May also fail with plenty of other Exceptions for things like DNS (May also fail with plenty of other Exceptions for things like DNS
failures, connection failures, SSL failures.) failures, connection failures, SSL failures.)
""" """
if (
self.hs.config.federation_domain_whitelist and
destination not in self.hs.config.federation_domain_whitelist
):
raise FederationDeniedError(destination)
limiter = yield synapse.util.retryutils.get_retry_limiter( limiter = yield synapse.util.retryutils.get_retry_limiter(
destination, destination,
self.clock, self.clock,
@ -308,6 +319,9 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server. to retry this server.
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
""" """
if not json_data_callback: if not json_data_callback:
@ -368,6 +382,9 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server. to retry this server.
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
""" """
def body_callback(method, url_bytes, headers_dict): def body_callback(method, url_bytes, headers_dict):
@ -422,6 +439,9 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server. to retry this server.
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
""" """
logger.debug("get_json args: %s", args) logger.debug("get_json args: %s", args)
@ -475,6 +495,9 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server. to retry this server.
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
""" """
response = yield self._request( response = yield self._request(
@ -518,6 +541,9 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server. to retry this server.
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
""" """
encoded_args = {} encoded_args = {}

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -42,36 +43,75 @@ logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__) metrics = synapse.metrics.get_metrics_for(__name__)
incoming_requests_counter = metrics.register_counter( # total number of responses served, split by method/servlet/tag
"requests", response_count = metrics.register_counter(
"response_count",
labels=["method", "servlet", "tag"], labels=["method", "servlet", "tag"],
alternative_names=(
# the following are all deprecated aliases for the same metric
metrics.name_prefix + x for x in (
"_requests",
"_response_time:count",
"_response_ru_utime:count",
"_response_ru_stime:count",
"_response_db_txn_count:count",
"_response_db_txn_duration:count",
)
)
) )
requests_counter = metrics.register_counter(
"requests_received",
labels=["method", "servlet", ],
)
outgoing_responses_counter = metrics.register_counter( outgoing_responses_counter = metrics.register_counter(
"responses", "responses",
labels=["method", "code"], labels=["method", "code"],
) )
response_timer = metrics.register_distribution( response_timer = metrics.register_counter(
"response_time", "response_time_seconds",
labels=["method", "servlet", "tag"] labels=["method", "servlet", "tag"],
alternative_names=(
metrics.name_prefix + "_response_time:total",
),
) )
response_ru_utime = metrics.register_distribution( response_ru_utime = metrics.register_counter(
"response_ru_utime", labels=["method", "servlet", "tag"] "response_ru_utime_seconds", labels=["method", "servlet", "tag"],
alternative_names=(
metrics.name_prefix + "_response_ru_utime:total",
),
) )
response_ru_stime = metrics.register_distribution( response_ru_stime = metrics.register_counter(
"response_ru_stime", labels=["method", "servlet", "tag"] "response_ru_stime_seconds", labels=["method", "servlet", "tag"],
alternative_names=(
metrics.name_prefix + "_response_ru_stime:total",
),
) )
response_db_txn_count = metrics.register_distribution( response_db_txn_count = metrics.register_counter(
"response_db_txn_count", labels=["method", "servlet", "tag"] "response_db_txn_count", labels=["method", "servlet", "tag"],
alternative_names=(
metrics.name_prefix + "_response_db_txn_count:total",
),
) )
response_db_txn_duration = metrics.register_distribution( # seconds spent waiting for db txns, excluding scheduling time, when processing
"response_db_txn_duration", labels=["method", "servlet", "tag"] # this request
response_db_txn_duration = metrics.register_counter(
"response_db_txn_duration_seconds", labels=["method", "servlet", "tag"],
alternative_names=(
metrics.name_prefix + "_response_db_txn_duration:total",
),
) )
# seconds spent waiting for a db connection, when processing this request
response_db_sched_duration = metrics.register_counter(
"response_db_sched_duration_seconds", labels=["method", "servlet", "tag"]
)
_next_request_id = 0 _next_request_id = 0
@ -107,7 +147,12 @@ def wrap_request_handler(request_handler, include_metrics=False):
with LoggingContext(request_id) as request_context: with LoggingContext(request_id) as request_context:
with Measure(self.clock, "wrapped_request_handler"): with Measure(self.clock, "wrapped_request_handler"):
request_metrics = RequestMetrics() request_metrics = RequestMetrics()
request_metrics.start(self.clock, name=self.__class__.__name__) # we start the request metrics timer here with an initial stab
# at the servlet name. For most requests that name will be
# JsonResource (or a subclass), and JsonResource._async_render
# will update it once it picks a servlet.
servlet_name = self.__class__.__name__
request_metrics.start(self.clock, name=servlet_name)
request_context.request = request_id request_context.request = request_id
with request.processing(): with request.processing():
@ -116,6 +161,7 @@ def wrap_request_handler(request_handler, include_metrics=False):
if include_metrics: if include_metrics:
yield request_handler(self, request, request_metrics) yield request_handler(self, request, request_metrics)
else: else:
requests_counter.inc(request.method, servlet_name)
yield request_handler(self, request) yield request_handler(self, request)
except CodeMessageException as e: except CodeMessageException as e:
code = e.code code = e.code
@ -191,7 +237,7 @@ class JsonResource(HttpServer, resource.Resource):
""" This implements the HttpServer interface and provides JSON support for """ This implements the HttpServer interface and provides JSON support for
Resources. Resources.
Register callbacks via register_path() Register callbacks via register_paths()
Callbacks can return a tuple of status code and a dict in which case the Callbacks can return a tuple of status code and a dict in which case the
the dict will automatically be sent to the client as a JSON object. the dict will automatically be sent to the client as a JSON object.
@ -238,57 +284,62 @@ class JsonResource(HttpServer, resource.Resource):
This checks if anyone has registered a callback for that method and This checks if anyone has registered a callback for that method and
path. path.
""" """
callback, group_dict = self._get_handler_for_request(request)
servlet_instance = getattr(callback, "__self__", None)
if servlet_instance is not None:
servlet_classname = servlet_instance.__class__.__name__
else:
servlet_classname = "%r" % callback
request_metrics.name = servlet_classname
requests_counter.inc(request.method, servlet_classname)
# Now trigger the callback. If it returns a response, we send it
# here. If it throws an exception, that is handled by the wrapper
# installed by @request_handler.
kwargs = intern_dict({
name: urllib.unquote(value).decode("UTF-8") if value else value
for name, value in group_dict.items()
})
callback_return = yield callback(request, **kwargs)
if callback_return is not None:
code, response = callback_return
self._send_response(request, code, response)
def _get_handler_for_request(self, request):
"""Finds a callback method to handle the given request
Args:
request (twisted.web.http.Request):
Returns:
Tuple[Callable, dict[str, str]]: callback method, and the dict
mapping keys to path components as specified in the handler's
path match regexp.
The callback will normally be a method registered via
register_paths, so will return (possibly via Deferred) either
None, or a tuple of (http code, response body).
"""
if request.method == "OPTIONS": if request.method == "OPTIONS":
self._send_response(request, 200, {}) return _options_handler, {}
return
# Loop through all the registered callbacks to check if the method # Loop through all the registered callbacks to check if the method
# and path regex match # and path regex match
for path_entry in self.path_regexs.get(request.method, []): for path_entry in self.path_regexs.get(request.method, []):
m = path_entry.pattern.match(request.path) m = path_entry.pattern.match(request.path)
if not m: if m:
continue # We found a match!
return path_entry.callback, m.groupdict()
# We found a match! Trigger callback and then return the
# returned response. We pass both the request and any
# matched groups from the regex to the callback.
callback = path_entry.callback
kwargs = intern_dict({
name: urllib.unquote(value).decode("UTF-8") if value else value
for name, value in m.groupdict().items()
})
callback_return = yield callback(request, **kwargs)
if callback_return is not None:
code, response = callback_return
self._send_response(request, code, response)
servlet_instance = getattr(callback, "__self__", None)
if servlet_instance is not None:
servlet_classname = servlet_instance.__class__.__name__
else:
servlet_classname = "%r" % callback
request_metrics.name = servlet_classname
return
# Huh. No one wanted to handle that? Fiiiiiine. Send 400. # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
raise UnrecognizedRequestError() return _unrecognised_request_handler, {}
def _send_response(self, request, code, response_json_object, def _send_response(self, request, code, response_json_object,
response_code_message=None): response_code_message=None):
# could alternatively use request.notifyFinish() and flip a flag when
# the Deferred fires, but since the flag is RIGHT THERE it seems like
# a waste.
if request._disconnected:
logger.warn(
"Not sending response to request %s, already disconnected.",
request)
return
outgoing_responses_counter.inc(request.method, str(code)) outgoing_responses_counter.inc(request.method, str(code))
# TODO: Only enable CORS for the requests that need it. # TODO: Only enable CORS for the requests that need it.
@ -302,6 +353,34 @@ class JsonResource(HttpServer, resource.Resource):
) )
def _options_handler(request):
"""Request handler for OPTIONS requests
This is a request handler suitable for return from
_get_handler_for_request. It returns a 200 and an empty body.
Args:
request (twisted.web.http.Request):
Returns:
Tuple[int, dict]: http code, response body.
"""
return 200, {}
def _unrecognised_request_handler(request):
"""Request handler for unrecognised requests
This is a request handler suitable for return from
_get_handler_for_request. It actually just raises an
UnrecognizedRequestError.
Args:
request (twisted.web.http.Request):
"""
raise UnrecognizedRequestError()
class RequestMetrics(object): class RequestMetrics(object):
def start(self, clock, name): def start(self, clock, name):
self.start = clock.time_msec() self.start = clock.time_msec()
@ -322,7 +401,7 @@ class RequestMetrics(object):
) )
return return
incoming_requests_counter.inc(request.method, self.name, tag) response_count.inc(request.method, self.name, tag)
response_timer.inc_by( response_timer.inc_by(
clock.time_msec() - self.start, request.method, clock.time_msec() - self.start, request.method,
@ -341,7 +420,10 @@ class RequestMetrics(object):
context.db_txn_count, request.method, self.name, tag context.db_txn_count, request.method, self.name, tag
) )
response_db_txn_duration.inc_by( response_db_txn_duration.inc_by(
context.db_txn_duration, request.method, self.name, tag context.db_txn_duration_ms / 1000., request.method, self.name, tag
)
response_db_sched_duration.inc_by(
context.db_sched_duration_ms / 1000., request.method, self.name, tag
) )
@ -364,6 +446,15 @@ class RootRedirect(resource.Resource):
def respond_with_json(request, code, json_object, send_cors=False, def respond_with_json(request, code, json_object, send_cors=False,
response_code_message=None, pretty_print=False, response_code_message=None, pretty_print=False,
version_string="", canonical_json=True): version_string="", canonical_json=True):
# could alternatively use request.notifyFinish() and flip a flag when
# the Deferred fires, but since the flag is RIGHT THERE it seems like
# a waste.
if request._disconnected:
logger.warn(
"Not sending response to request %s, already disconnected.",
request)
return
if pretty_print: if pretty_print:
json_bytes = encode_pretty_printed_json(json_object) + "\n" json_bytes = encode_pretty_printed_json(json_object) + "\n"
else: else:

View File

@ -148,11 +148,13 @@ def parse_string_from_args(args, name, default=None, required=False,
return default return default
def parse_json_value_from_request(request): def parse_json_value_from_request(request, allow_empty_body=False):
"""Parse a JSON value from the body of a twisted HTTP request. """Parse a JSON value from the body of a twisted HTTP request.
Args: Args:
request: the twisted HTTP request. request: the twisted HTTP request.
allow_empty_body (bool): if True, an empty body will be accepted and
turned into None
Returns: Returns:
The JSON value. The JSON value.
@ -165,6 +167,9 @@ def parse_json_value_from_request(request):
except Exception: except Exception:
raise SynapseError(400, "Error reading JSON content.") raise SynapseError(400, "Error reading JSON content.")
if not content_bytes and allow_empty_body:
return None
try: try:
content = simplejson.loads(content_bytes) content = simplejson.loads(content_bytes)
except Exception as e: except Exception as e:
@ -174,17 +179,24 @@ def parse_json_value_from_request(request):
return content return content
def parse_json_object_from_request(request): def parse_json_object_from_request(request, allow_empty_body=False):
"""Parse a JSON object from the body of a twisted HTTP request. """Parse a JSON object from the body of a twisted HTTP request.
Args: Args:
request: the twisted HTTP request. request: the twisted HTTP request.
allow_empty_body (bool): if True, an empty body will be accepted and
turned into an empty dict.
Raises: Raises:
SynapseError if the request body couldn't be decoded as JSON or SynapseError if the request body couldn't be decoded as JSON or
if it wasn't a JSON object. if it wasn't a JSON object.
""" """
content = parse_json_value_from_request(request) content = parse_json_value_from_request(
request, allow_empty_body=allow_empty_body,
)
if allow_empty_body and content is None:
return {}
if type(content) != dict: if type(content) != dict:
message = "Content must be a JSON object." message = "Content must be a JSON object."

View File

@ -66,14 +66,15 @@ class SynapseRequest(Request):
context = LoggingContext.current_context() context = LoggingContext.current_context()
ru_utime, ru_stime = context.get_resource_usage() ru_utime, ru_stime = context.get_resource_usage()
db_txn_count = context.db_txn_count db_txn_count = context.db_txn_count
db_txn_duration = context.db_txn_duration db_txn_duration_ms = context.db_txn_duration_ms
db_sched_duration_ms = context.db_sched_duration_ms
except Exception: except Exception:
ru_utime, ru_stime = (0, 0) ru_utime, ru_stime = (0, 0)
db_txn_count, db_txn_duration = (0, 0) db_txn_count, db_txn_duration_ms = (0, 0)
self.site.access_logger.info( self.site.access_logger.info(
"%s - %s - {%s}" "%s - %s - {%s}"
" Processed request: %dms (%dms, %dms) (%dms/%d)" " Processed request: %dms (%dms, %dms) (%dms/%dms/%d)"
" %sB %s \"%s %s %s\" \"%s\"", " %sB %s \"%s %s %s\" \"%s\"",
self.getClientIP(), self.getClientIP(),
self.site.site_tag, self.site.site_tag,
@ -81,7 +82,8 @@ class SynapseRequest(Request):
int(time.time() * 1000) - self.start_time, int(time.time() * 1000) - self.start_time,
int(ru_utime * 1000), int(ru_utime * 1000),
int(ru_stime * 1000), int(ru_stime * 1000),
int(db_txn_duration * 1000), db_sched_duration_ms,
db_txn_duration_ms,
int(db_txn_count), int(db_txn_count),
self.sentLength, self.sentLength,
self.code, self.code,

View File

@ -57,15 +57,31 @@ class Metrics(object):
return metric return metric
def register_counter(self, *args, **kwargs): def register_counter(self, *args, **kwargs):
"""
Returns:
CounterMetric
"""
return self._register(CounterMetric, *args, **kwargs) return self._register(CounterMetric, *args, **kwargs)
def register_callback(self, *args, **kwargs): def register_callback(self, *args, **kwargs):
"""
Returns:
CallbackMetric
"""
return self._register(CallbackMetric, *args, **kwargs) return self._register(CallbackMetric, *args, **kwargs)
def register_distribution(self, *args, **kwargs): def register_distribution(self, *args, **kwargs):
"""
Returns:
DistributionMetric
"""
return self._register(DistributionMetric, *args, **kwargs) return self._register(DistributionMetric, *args, **kwargs)
def register_cache(self, *args, **kwargs): def register_cache(self, *args, **kwargs):
"""
Returns:
CacheMetric
"""
return self._register(CacheMetric, *args, **kwargs) return self._register(CacheMetric, *args, **kwargs)
@ -146,10 +162,15 @@ def runUntilCurrentTimer(func):
num_pending += 1 num_pending += 1
num_pending += len(reactor.threadCallQueue) num_pending += len(reactor.threadCallQueue)
start = time.time() * 1000 start = time.time() * 1000
ret = func(*args, **kwargs) ret = func(*args, **kwargs)
end = time.time() * 1000 end = time.time() * 1000
# record the amount of wallclock time spent running pending calls.
# This is a proxy for the actual amount of time between reactor polls,
# since about 25% of time is actually spent running things triggered by
# I/O events, but that is harder to capture without rewriting half the
# reactor.
tick_time.inc_by(end - start) tick_time.inc_by(end - start)
pending_calls_metric.inc_by(num_pending) pending_calls_metric.inc_by(num_pending)

View File

@ -15,18 +15,38 @@
from itertools import chain from itertools import chain
import logging
logger = logging.getLogger(__name__)
# TODO(paul): I can't believe Python doesn't have one of these def flatten(items):
def map_concat(func, items): """Flatten a list of lists
# flatten a list-of-lists
return list(chain.from_iterable(map(func, items))) Args:
items: iterable[iterable[X]]
Returns:
list[X]: flattened list
"""
return list(chain.from_iterable(items))
class BaseMetric(object): class BaseMetric(object):
"""Base class for metrics which report a single value per label set
"""
def __init__(self, name, labels=[]): def __init__(self, name, labels=[], alternative_names=[]):
self.name = name """
Args:
name (str): principal name for this metric
labels (list(str)): names of the labels which will be reported
for this metric
alternative_names (iterable(str)): list of alternative names for
this metric. This can be useful to provide a migration path
when renaming metrics.
"""
self._names = [name] + list(alternative_names)
self.labels = labels # OK not to clone as we never write it self.labels = labels # OK not to clone as we never write it
def dimension(self): def dimension(self):
@ -36,7 +56,7 @@ class BaseMetric(object):
return not len(self.labels) return not len(self.labels)
def _render_labelvalue(self, value): def _render_labelvalue(self, value):
# TODO: some kind of value escape # TODO: escape backslashes, quotes and newlines
return '"%s"' % (value) return '"%s"' % (value)
def _render_key(self, values): def _render_key(self, values):
@ -47,19 +67,60 @@ class BaseMetric(object):
for k, v in zip(self.labels, values)]) for k, v in zip(self.labels, values)])
) )
def _render_for_labels(self, label_values, value):
"""Render this metric for a single set of labels
Args:
label_values (list[str]): values for each of the labels
value: value of the metric at with these labels
Returns:
iterable[str]: rendered metric
"""
rendered_labels = self._render_key(label_values)
return (
"%s%s %.12g" % (name, rendered_labels, value)
for name in self._names
)
def render(self):
"""Render this metric
Each metric is rendered as:
name{label1="val1",label2="val2"} value
https://prometheus.io/docs/instrumenting/exposition_formats/#text-format-details
Returns:
iterable[str]: rendered metrics
"""
raise NotImplementedError()
class CounterMetric(BaseMetric): class CounterMetric(BaseMetric):
"""The simplest kind of metric; one that stores a monotonically-increasing """The simplest kind of metric; one that stores a monotonically-increasing
integer that counts events.""" value that counts events or running totals.
Example use cases for Counters:
- Number of requests processed
- Number of items that were inserted into a queue
- Total amount of data that a system has processed
Counters can only go up (and be reset when the process restarts).
"""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(CounterMetric, self).__init__(*args, **kwargs) super(CounterMetric, self).__init__(*args, **kwargs)
# dict[list[str]]: value for each set of label values. the keys are the
# label values, in the same order as the labels in self.labels.
#
# (if the metric is a scalar, the (single) key is the empty list).
self.counts = {} self.counts = {}
# Scalar metrics are never empty # Scalar metrics are never empty
if self.is_scalar(): if self.is_scalar():
self.counts[()] = 0 self.counts[()] = 0.
def inc_by(self, incr, *values): def inc_by(self, incr, *values):
if len(values) != self.dimension(): if len(values) != self.dimension():
@ -77,11 +138,11 @@ class CounterMetric(BaseMetric):
def inc(self, *values): def inc(self, *values):
self.inc_by(1, *values) self.inc_by(1, *values)
def render_item(self, k):
return ["%s%s %d" % (self.name, self._render_key(k), self.counts[k])]
def render(self): def render(self):
return map_concat(self.render_item, sorted(self.counts.keys())) return flatten(
self._render_for_labels(k, self.counts[k])
for k in sorted(self.counts.keys())
)
class CallbackMetric(BaseMetric): class CallbackMetric(BaseMetric):
@ -95,13 +156,19 @@ class CallbackMetric(BaseMetric):
self.callback = callback self.callback = callback
def render(self): def render(self):
value = self.callback() try:
value = self.callback()
except Exception:
logger.exception("Failed to render %s", self.name)
return ["# FAILED to render " + self.name]
if self.is_scalar(): if self.is_scalar():
return ["%s %.12g" % (self.name, value)] return list(self._render_for_labels([], value))
return ["%s%s %.12g" % (self.name, self._render_key(k), value[k]) return flatten(
for k in sorted(value.keys())] self._render_for_labels(k, value[k])
for k in sorted(value.keys())
)
class DistributionMetric(object): class DistributionMetric(object):
@ -126,7 +193,9 @@ class DistributionMetric(object):
class CacheMetric(object): class CacheMetric(object):
__slots__ = ("name", "cache_name", "hits", "misses", "size_callback") __slots__ = (
"name", "cache_name", "hits", "misses", "evicted_size", "size_callback",
)
def __init__(self, name, size_callback, cache_name): def __init__(self, name, size_callback, cache_name):
self.name = name self.name = name
@ -134,6 +203,7 @@ class CacheMetric(object):
self.hits = 0 self.hits = 0
self.misses = 0 self.misses = 0
self.evicted_size = 0
self.size_callback = size_callback self.size_callback = size_callback
@ -143,6 +213,9 @@ class CacheMetric(object):
def inc_misses(self): def inc_misses(self):
self.misses += 1 self.misses += 1
def inc_evictions(self, size=1):
self.evicted_size += size
def render(self): def render(self):
size = self.size_callback() size = self.size_callback()
hits = self.hits hits = self.hits
@ -152,6 +225,9 @@ class CacheMetric(object):
"""%s:hits{name="%s"} %d""" % (self.name, self.cache_name, hits), """%s:hits{name="%s"} %d""" % (self.name, self.cache_name, hits),
"""%s:total{name="%s"} %d""" % (self.name, self.cache_name, total), """%s:total{name="%s"} %d""" % (self.name, self.cache_name, total),
"""%s:size{name="%s"} %d""" % (self.name, self.cache_name, size), """%s:size{name="%s"} %d""" % (self.name, self.cache_name, size),
"""%s:evicted_size{name="%s"} %d""" % (
self.name, self.cache_name, self.evicted_size
),
] ]

View File

@ -40,10 +40,6 @@ class ActionGenerator(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_push_actions_for_event(self, event, context): def handle_push_actions_for_event(self, event, context):
with Measure(self.clock, "action_for_event_by_user"): with Measure(self.clock, "action_for_event_by_user"):
actions_by_user = yield self.bulk_evaluator.action_for_event_by_user( yield self.bulk_evaluator.action_for_event_by_user(
event, context event, context
) )
context.push_actions = [
(uid, actions) for uid, actions in actions_by_user.iteritems()
]

View File

@ -137,11 +137,11 @@ class BulkPushRuleEvaluator(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def action_for_event_by_user(self, event, context): def action_for_event_by_user(self, event, context):
"""Given an event and context, evaluate the push rules and return """Given an event and context, evaluate the push rules and insert the
the results results into the event_push_actions_staging table.
Returns: Returns:
dict of user_id -> action Deferred
""" """
rules_by_user = yield self._get_rules_for_event(event, context) rules_by_user = yield self._get_rules_for_event(event, context)
actions_by_user = {} actions_by_user = {}
@ -190,9 +190,16 @@ class BulkPushRuleEvaluator(object):
if matches: if matches:
actions = [x for x in rule['actions'] if x != 'dont_notify'] actions = [x for x in rule['actions'] if x != 'dont_notify']
if actions and 'notify' in actions: if actions and 'notify' in actions:
# Push rules say we should notify the user of this event
actions_by_user[uid] = actions actions_by_user[uid] = actions
break break
defer.returnValue(actions_by_user)
# Mark in the DB staging area the push actions for users who should be
# notified for this event. (This will then get handled when we persist
# the event)
yield self.store.add_push_actions_to_staging(
event.event_id, actions_by_user,
)
def _condition_checker(evaluator, conditions, uid, display_name, cache): def _condition_checker(evaluator, conditions, uid, display_name, cache):

View File

@ -13,21 +13,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from synapse.push import PusherConfigException
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.error import AlreadyCalled, AlreadyCancelled from twisted.internet.error import AlreadyCalled, AlreadyCancelled
import logging
import push_rule_evaluator import push_rule_evaluator
import push_tools import push_tools
import synapse
from synapse.push import PusherConfigException
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__)
http_push_processed_counter = metrics.register_counter(
"http_pushes_processed",
)
http_push_failed_counter = metrics.register_counter(
"http_pushes_failed",
)
class HttpPusher(object): class HttpPusher(object):
INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes
@ -152,9 +161,16 @@ class HttpPusher(object):
self.user_id, self.last_stream_ordering, self.max_stream_ordering self.user_id, self.last_stream_ordering, self.max_stream_ordering
) )
logger.info(
"Processing %i unprocessed push actions for %s starting at "
"stream_ordering %s",
len(unprocessed), self.name, self.last_stream_ordering,
)
for push_action in unprocessed: for push_action in unprocessed:
processed = yield self._process_one(push_action) processed = yield self._process_one(push_action)
if processed: if processed:
http_push_processed_counter.inc()
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action['stream_ordering'] self.last_stream_ordering = push_action['stream_ordering']
yield self.store.update_pusher_last_stream_ordering_and_success( yield self.store.update_pusher_last_stream_ordering_and_success(
@ -169,6 +185,7 @@ class HttpPusher(object):
self.failing_since self.failing_since
) )
else: else:
http_push_failed_counter.inc()
if not self.failing_since: if not self.failing_since:
self.failing_since = self.clock.time_msec() self.failing_since = self.clock.time_msec()
yield self.store.update_pusher_failing_since( yield self.store.update_pusher_failing_since(
@ -316,7 +333,10 @@ class HttpPusher(object):
try: try:
resp = yield self.http_client.post_json_get_json(self.url, notification_dict) resp = yield self.http_client.post_json_get_json(self.url, notification_dict)
except Exception: except Exception:
logger.warn("Failed to push %s ", self.url) logger.warn(
"Failed to push event %s to %s",
event.event_id, self.name, exc_info=True,
)
defer.returnValue(False) defer.returnValue(False)
rejected = [] rejected = []
if 'rejected' in resp: if 'rejected' in resp:
@ -325,7 +345,7 @@ class HttpPusher(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _send_badge(self, badge): def _send_badge(self, badge):
logger.info("Sending updated badge count %d to %r", badge, self.user_id) logger.info("Sending updated badge count %d to %s", badge, self.name)
d = { d = {
'notification': { 'notification': {
'id': '', 'id': '',
@ -347,7 +367,10 @@ class HttpPusher(object):
try: try:
resp = yield self.http_client.post_json_get_json(self.url, d) resp = yield self.http_client.post_json_get_json(self.url, d)
except Exception: except Exception:
logger.exception("Failed to push %s ", self.url) logger.warn(
"Failed to send badge count to %s",
self.name, exc_info=True,
)
defer.returnValue(False) defer.returnValue(False)
rejected = [] rejected = []
if 'rejected' in resp: if 'rejected' in resp:

View File

@ -24,19 +24,19 @@ REQUIREMENTS = {
"unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"], "unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"],
"canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"], "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],
"signedjson>=1.0.0": ["signedjson>=1.0.0"], "signedjson>=1.0.0": ["signedjson>=1.0.0"],
"pynacl==0.3.0": ["nacl==0.3.0", "nacl.bindings"], "pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"], "service_identity>=1.0.0": ["service_identity>=1.0.0"],
"Twisted>=16.0.0": ["twisted>=16.0.0"], "Twisted>=16.0.0": ["twisted>=16.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"], "pyopenssl>=0.14": ["OpenSSL>=0.14"],
"pyyaml": ["yaml"], "pyyaml": ["yaml"],
"pyasn1": ["pyasn1"], "pyasn1": ["pyasn1"],
"daemonize": ["daemonize"], "daemonize": ["daemonize"],
"bcrypt": ["bcrypt"], "bcrypt": ["bcrypt>=3.1.0"],
"pillow": ["PIL"], "pillow": ["PIL"],
"pydenticon": ["pydenticon"], "pydenticon": ["pydenticon"],
"ujson": ["ujson"], "ujson": ["ujson"],
"blist": ["blist"], "blist": ["blist"],
"pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"], "pysaml2>=3.0.0": ["saml2>=3.0.0"],
"pymacaroons-pynacl": ["pymacaroons"], "pymacaroons-pynacl": ["pymacaroons"],
"msgpack-python>=0.3.0": ["msgpack"], "msgpack-python>=0.3.0": ["msgpack"],
"phonenumbers>=8.2.0": ["phonenumbers"], "phonenumbers>=8.2.0": ["phonenumbers"],

View File

@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.http.server import JsonResource
from synapse.replication.http import membership, send_event
REPLICATION_PREFIX = "/_synapse/replication"
class ReplicationRestResource(JsonResource):
def __init__(self, hs):
JsonResource.__init__(self, hs, canonical_json=False)
self.register_servlets(hs)
def register_servlets(self, hs):
send_event.register_servlets(hs, self)
membership.register_servlets(hs, self)

View File

@ -0,0 +1,334 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
from twisted.internet import defer
from synapse.api.errors import SynapseError, MatrixCodeMessageException
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import Requester, UserID
from synapse.util.distributor import user_left_room, user_joined_room
logger = logging.getLogger(__name__)
@defer.inlineCallbacks
def remote_join(client, host, port, requester, remote_room_hosts,
room_id, user_id, content):
"""Ask the master to do a remote join for the given user to the given room
Args:
client (SimpleHttpClient)
host (str): host of master
port (int): port on master listening for HTTP replication
requester (Requester)
remote_room_hosts (list[str]): Servers to try and join via
room_id (str)
user_id (str)
content (dict): The event content to use for the join event
Returns:
Deferred
"""
uri = "http://%s:%s/_synapse/replication/remote_join" % (host, port)
payload = {
"requester": requester.serialize(),
"remote_room_hosts": remote_room_hosts,
"room_id": room_id,
"user_id": user_id,
"content": content,
}
try:
result = yield client.post_json_get_json(uri, payload)
except MatrixCodeMessageException as e:
# We convert to SynapseError as we know that it was a SynapseError
# on the master process that we should send to the client. (And
# importantly, not stack traces everywhere)
raise SynapseError(e.code, e.msg, e.errcode)
defer.returnValue(result)
@defer.inlineCallbacks
def remote_reject_invite(client, host, port, requester, remote_room_hosts,
room_id, user_id):
"""Ask master to reject the invite for the user and room.
Args:
client (SimpleHttpClient)
host (str): host of master
port (int): port on master listening for HTTP replication
requester (Requester)
remote_room_hosts (list[str]): Servers to try and reject via
room_id (str)
user_id (str)
Returns:
Deferred
"""
uri = "http://%s:%s/_synapse/replication/remote_reject_invite" % (host, port)
payload = {
"requester": requester.serialize(),
"remote_room_hosts": remote_room_hosts,
"room_id": room_id,
"user_id": user_id,
}
try:
result = yield client.post_json_get_json(uri, payload)
except MatrixCodeMessageException as e:
# We convert to SynapseError as we know that it was a SynapseError
# on the master process that we should send to the client. (And
# importantly, not stack traces everywhere)
raise SynapseError(e.code, e.msg, e.errcode)
defer.returnValue(result)
@defer.inlineCallbacks
def get_or_register_3pid_guest(client, host, port, requester,
medium, address, inviter_user_id):
"""Ask the master to get/create a guest account for given 3PID.
Args:
client (SimpleHttpClient)
host (str): host of master
port (int): port on master listening for HTTP replication
requester (Requester)
medium (str)
address (str)
inviter_user_id (str): The user ID who is trying to invite the
3PID
Returns:
Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
3PID guest account.
"""
uri = "http://%s:%s/_synapse/replication/get_or_register_3pid_guest" % (host, port)
payload = {
"requester": requester.serialize(),
"medium": medium,
"address": address,
"inviter_user_id": inviter_user_id,
}
try:
result = yield client.post_json_get_json(uri, payload)
except MatrixCodeMessageException as e:
# We convert to SynapseError as we know that it was a SynapseError
# on the master process that we should send to the client. (And
# importantly, not stack traces everywhere)
raise SynapseError(e.code, e.msg, e.errcode)
defer.returnValue(result)
@defer.inlineCallbacks
def notify_user_membership_change(client, host, port, user_id, room_id, change):
"""Notify master that a user has joined or left the room
Args:
client (SimpleHttpClient)
host (str): host of master
port (int): port on master listening for HTTP replication.
user_id (str)
room_id (str)
change (str): Either "join" or "left"
Returns:
Deferred
"""
assert change in ("joined", "left")
uri = "http://%s:%s/_synapse/replication/user_%s_room" % (host, port, change)
payload = {
"user_id": user_id,
"room_id": room_id,
}
try:
result = yield client.post_json_get_json(uri, payload)
except MatrixCodeMessageException as e:
# We convert to SynapseError as we know that it was a SynapseError
# on the master process that we should send to the client. (And
# importantly, not stack traces everywhere)
raise SynapseError(e.code, e.msg, e.errcode)
defer.returnValue(result)
class ReplicationRemoteJoinRestServlet(RestServlet):
PATTERNS = [re.compile("^/_synapse/replication/remote_join$")]
def __init__(self, hs):
super(ReplicationRemoteJoinRestServlet, self).__init__()
self.federation_handler = hs.get_handlers().federation_handler
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@defer.inlineCallbacks
def on_POST(self, request):
content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"]
room_id = content["room_id"]
user_id = content["user_id"]
event_content = content["content"]
requester = Requester.deserialize(self.store, content["requester"])
if requester.user:
request.authenticated_entity = requester.user.to_string()
logger.info(
"remote_join: %s into room: %s",
user_id, room_id,
)
yield self.federation_handler.do_invite_join(
remote_room_hosts,
room_id,
user_id,
event_content,
)
defer.returnValue((200, {}))
class ReplicationRemoteRejectInviteRestServlet(RestServlet):
PATTERNS = [re.compile("^/_synapse/replication/remote_reject_invite$")]
def __init__(self, hs):
super(ReplicationRemoteRejectInviteRestServlet, self).__init__()
self.federation_handler = hs.get_handlers().federation_handler
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@defer.inlineCallbacks
def on_POST(self, request):
content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"]
room_id = content["room_id"]
user_id = content["user_id"]
requester = Requester.deserialize(self.store, content["requester"])
if requester.user:
request.authenticated_entity = requester.user.to_string()
logger.info(
"remote_reject_invite: %s out of room: %s",
user_id, room_id,
)
try:
event = yield self.federation_handler.do_remotely_reject_invite(
remote_room_hosts,
room_id,
user_id,
)
ret = event.get_pdu_json()
except Exception as e:
# if we were unable to reject the exception, just mark
# it as rejected on our end and plough ahead.
#
# The 'except' clause is very broad, but we need to
# capture everything from DNS failures upwards
#
logger.warn("Failed to reject invite: %s", e)
yield self.store.locally_reject_invite(
user_id, room_id
)
ret = {}
defer.returnValue((200, ret))
class ReplicationRegister3PIDGuestRestServlet(RestServlet):
PATTERNS = [re.compile("^/_synapse/replication/get_or_register_3pid_guest$")]
def __init__(self, hs):
super(ReplicationRegister3PIDGuestRestServlet, self).__init__()
self.registeration_handler = hs.get_handlers().registration_handler
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@defer.inlineCallbacks
def on_POST(self, request):
content = parse_json_object_from_request(request)
medium = content["medium"]
address = content["address"]
inviter_user_id = content["inviter_user_id"]
requester = Requester.deserialize(self.store, content["requester"])
if requester.user:
request.authenticated_entity = requester.user.to_string()
logger.info("get_or_register_3pid_guest: %r", content)
ret = yield self.registeration_handler.get_or_register_3pid_guest(
medium, address, inviter_user_id,
)
defer.returnValue((200, ret))
class ReplicationUserJoinedLeftRoomRestServlet(RestServlet):
PATTERNS = [re.compile("^/_synapse/replication/user_(?P<change>joined|left)_room$")]
def __init__(self, hs):
super(ReplicationUserJoinedLeftRoomRestServlet, self).__init__()
self.registeration_handler = hs.get_handlers().registration_handler
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.distributor = hs.get_distributor()
def on_POST(self, request, change):
content = parse_json_object_from_request(request)
user_id = content["user_id"]
room_id = content["room_id"]
logger.info("user membership change: %s in %s", user_id, room_id)
user = UserID.from_string(user_id)
if change == "joined":
user_joined_room(self.distributor, user, room_id)
elif change == "left":
user_left_room(self.distributor, user, room_id)
else:
raise Exception("Unrecognized change: %r", change)
return (200, {})
def register_servlets(hs, http_server):
ReplicationRemoteJoinRestServlet(hs).register(http_server)
ReplicationRemoteRejectInviteRestServlet(hs).register(http_server)
ReplicationRegister3PIDGuestRestServlet(hs).register(http_server)
ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server)

View File

@ -0,0 +1,166 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import (
SynapseError, MatrixCodeMessageException, CodeMessageException,
)
from synapse.events import FrozenEvent
from synapse.events.snapshot import EventContext
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.util.async import sleep
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.util.metrics import Measure
from synapse.types import Requester, UserID
import logging
import re
logger = logging.getLogger(__name__)
@defer.inlineCallbacks
def send_event_to_master(client, host, port, requester, event, context,
ratelimit, extra_users):
"""Send event to be handled on the master
Args:
client (SimpleHttpClient)
host (str): host of master
port (int): port on master listening for HTTP replication
requester (Requester)
event (FrozenEvent)
context (EventContext)
ratelimit (bool)
extra_users (list(UserID)): Any extra users to notify about event
"""
uri = "http://%s:%s/_synapse/replication/send_event/%s" % (
host, port, event.event_id,
)
payload = {
"event": event.get_pdu_json(),
"internal_metadata": event.internal_metadata.get_dict(),
"rejected_reason": event.rejected_reason,
"context": context.serialize(event),
"requester": requester.serialize(),
"ratelimit": ratelimit,
"extra_users": [u.to_string() for u in extra_users],
}
try:
# We keep retrying the same request for timeouts. This is so that we
# have a good idea that the request has either succeeded or failed on
# the master, and so whether we should clean up or not.
while True:
try:
result = yield client.put_json(uri, payload)
break
except CodeMessageException as e:
if e.code != 504:
raise
logger.warn("send_event request timed out")
# If we timed out we probably don't need to worry about backing
# off too much, but lets just wait a little anyway.
yield sleep(1)
except MatrixCodeMessageException as e:
# We convert to SynapseError as we know that it was a SynapseError
# on the master process that we should send to the client. (And
# importantly, not stack traces everywhere)
raise SynapseError(e.code, e.msg, e.errcode)
defer.returnValue(result)
class ReplicationSendEventRestServlet(RestServlet):
"""Handles events newly created on workers, including persisting and
notifying.
The API looks like:
POST /_synapse/replication/send_event/:event_id
{
"event": { .. serialized event .. },
"internal_metadata": { .. serialized internal_metadata .. },
"rejected_reason": .., // The event.rejected_reason field
"context": { .. serialized event context .. },
"requester": { .. serialized requester .. },
"ratelimit": true,
"extra_users": [],
}
"""
PATTERNS = [re.compile("^/_synapse/replication/send_event/(?P<event_id>[^/]+)$")]
def __init__(self, hs):
super(ReplicationSendEventRestServlet, self).__init__()
self.event_creation_handler = hs.get_event_creation_handler()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
# The responses are tiny, so we may as well cache them for a while
self.response_cache = ResponseCache(hs, timeout_ms=30 * 60 * 1000)
def on_PUT(self, request, event_id):
result = self.response_cache.get(event_id)
if not result:
result = self.response_cache.set(
event_id,
self._handle_request(request)
)
else:
logger.warn("Returning cached response")
return make_deferred_yieldable(result)
@preserve_fn
@defer.inlineCallbacks
def _handle_request(self, request):
with Measure(self.clock, "repl_send_event_parse"):
content = parse_json_object_from_request(request)
event_dict = content["event"]
internal_metadata = content["internal_metadata"]
rejected_reason = content["rejected_reason"]
event = FrozenEvent(event_dict, internal_metadata, rejected_reason)
requester = Requester.deserialize(self.store, content["requester"])
context = yield EventContext.deserialize(self.store, content["context"])
ratelimit = content["ratelimit"]
extra_users = [UserID.from_string(u) for u in content["extra_users"]]
if requester.user:
request.authenticated_entity = requester.user.to_string()
logger.info(
"Got event to send with ID: %s into room: %s",
event.event_id, event.room_id,
)
yield self.event_creation_handler.persist_and_notify_client_event(
requester, event, context,
ratelimit=ratelimit,
extra_users=extra_users,
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
ReplicationSendEventRestServlet(hs).register(http_server)

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd # Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,50 +14,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore from synapse.storage.account_data import AccountDataWorkerStore
from synapse.storage.account_data import AccountDataStore from synapse.storage.tags import TagsWorkerStore
from synapse.storage.tags import TagsStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedAccountDataStore(BaseSlavedStore): class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(SlavedAccountDataStore, self).__init__(db_conn, hs)
self._account_data_id_gen = SlavedIdTracker( self._account_data_id_gen = SlavedIdTracker(
db_conn, "account_data_max_stream_id", "stream_id", db_conn, "account_data_max_stream_id", "stream_id",
) )
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache",
self._account_data_id_gen.get_current_token(),
)
get_account_data_for_user = ( super(SlavedAccountDataStore, self).__init__(db_conn, hs)
AccountDataStore.__dict__["get_account_data_for_user"]
)
get_global_account_data_by_type_for_users = (
AccountDataStore.__dict__["get_global_account_data_by_type_for_users"]
)
get_global_account_data_by_type_for_user = (
AccountDataStore.__dict__["get_global_account_data_by_type_for_user"]
)
get_tags_for_user = TagsStore.__dict__["get_tags_for_user"]
get_tags_for_room = (
DataStore.get_tags_for_room.__func__
)
get_account_data_for_room = (
DataStore.get_account_data_for_room.__func__
)
get_updated_tags = DataStore.get_updated_tags.__func__
get_updated_account_data_for_user = (
DataStore.get_updated_account_data_for_user.__func__
)
def get_max_account_data_stream_id(self): def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token() return self._account_data_id_gen.get_current_token()
@ -85,6 +56,10 @@ class SlavedAccountDataStore(BaseSlavedStore):
(row.data_type, row.user_id,) (row.data_type, row.user_id,)
) )
self.get_account_data_for_user.invalidate((row.user_id,)) self.get_account_data_for_user.invalidate((row.user_id,))
self.get_account_data_for_room.invalidate((row.user_id, row.room_id,))
self.get_account_data_for_room_and_type.invalidate(
(row.user_id, row.room_id, row.data_type,),
)
self._account_data_stream_cache.entity_has_changed( self._account_data_stream_cache.entity_has_changed(
row.user_id, token row.user_id, token
) )

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,33 +14,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import BaseSlavedStore from synapse.storage.appservice import (
from synapse.storage import DataStore ApplicationServiceWorkerStore, ApplicationServiceTransactionWorkerStore,
from synapse.config.appservice import load_appservices )
from synapse.storage.appservice import _make_exclusive_regex
class SlavedApplicationServiceStore(BaseSlavedStore): class SlavedApplicationServiceStore(ApplicationServiceTransactionWorkerStore,
def __init__(self, db_conn, hs): ApplicationServiceWorkerStore):
super(SlavedApplicationServiceStore, self).__init__(db_conn, hs) pass
self.services_cache = load_appservices(
hs.config.server_name,
hs.config.app_service_config_files
)
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
get_app_service_by_token = DataStore.get_app_service_by_token.__func__
get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__
get_app_services = DataStore.get_app_services.__func__
get_new_events_for_appservice = DataStore.get_new_events_for_appservice.__func__
create_appservice_txn = DataStore.create_appservice_txn.__func__
get_appservices_by_state = DataStore.get_appservices_by_state.__func__
get_oldest_unsent_txn = DataStore.get_oldest_unsent_txn.__func__
_get_last_txn = DataStore._get_last_txn.__func__
complete_appservice_txn = DataStore.complete_appservice_txn.__func__
get_appservice_state = DataStore.get_appservice_state.__func__
set_appservice_last_pos = DataStore.set_appservice_last_pos.__func__
set_appservice_state = DataStore.set_appservice_state.__func__
get_if_app_services_interested_in_user = (
DataStore.get_if_app_services_interested_in_user.__func__
)

View File

@ -14,10 +14,8 @@
# limitations under the License. # limitations under the License.
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from synapse.storage.directory import DirectoryStore from synapse.storage.directory import DirectoryWorkerStore
class DirectoryStore(BaseSlavedStore): class DirectoryStore(DirectoryWorkerStore, BaseSlavedStore):
get_aliases_for_room = DirectoryStore.__dict__[ pass
"get_aliases_for_room"
]

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd # Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -15,13 +16,13 @@
import logging import logging
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.storage import DataStore from synapse.storage.event_federation import EventFederationWorkerStore
from synapse.storage.event_federation import EventFederationStore from synapse.storage.event_push_actions import EventPushActionsWorkerStore
from synapse.storage.event_push_actions import EventPushActionsStore from synapse.storage.events_worker import EventsWorkerStore
from synapse.storage.roommember import RoomMemberStore from synapse.storage.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateGroupReadStore from synapse.storage.state import StateGroupWorkerStore
from synapse.storage.stream import StreamStore from synapse.storage.stream import StreamWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.storage.signatures import SignatureWorkerStore
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
@ -37,138 +38,33 @@ logger = logging.getLogger(__name__)
# the method descriptor on the DataStore and chuck them into our class. # the method descriptor on the DataStore and chuck them into our class.
class SlavedEventStore(StateGroupReadStore, BaseSlavedStore): class SlavedEventStore(EventFederationWorkerStore,
RoomMemberWorkerStore,
EventPushActionsWorkerStore,
StreamWorkerStore,
EventsWorkerStore,
StateGroupWorkerStore,
SignatureWorkerStore,
BaseSlavedStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(SlavedEventStore, self).__init__(db_conn, hs)
self._stream_id_gen = SlavedIdTracker( self._stream_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering", db_conn, "events", "stream_ordering",
) )
self._backfill_id_gen = SlavedIdTracker( self._backfill_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering", step=-1 db_conn, "events", "stream_ordering", step=-1
) )
events_max = self._stream_id_gen.get_current_token()
event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events",
entity_column="room_id",
stream_column="stream_ordering",
max_value=events_max,
)
self._events_stream_cache = StreamChangeCache(
"EventsRoomStreamChangeCache", min_event_val,
prefilled_cache=event_cache_prefill,
)
self._membership_stream_cache = StreamChangeCache(
"MembershipStreamChangeCache", events_max,
)
self.stream_ordering_month_ago = 0 super(SlavedEventStore, self).__init__(db_conn, hs)
self._stream_order_on_start = self.get_room_max_stream_ordering()
# Cached functions can't be accessed through a class instance so we need # Cached functions can't be accessed through a class instance so we need
# to reach inside the __dict__ to extract them. # to reach inside the __dict__ to extract them.
get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"]
get_hosts_in_room = RoomMemberStore.__dict__["get_hosts_in_room"]
get_users_who_share_room_with_user = (
RoomMemberStore.__dict__["get_users_who_share_room_with_user"]
)
get_latest_event_ids_in_room = EventFederationStore.__dict__[
"get_latest_event_ids_in_room"
]
get_invited_rooms_for_user = RoomMemberStore.__dict__[
"get_invited_rooms_for_user"
]
get_unread_event_push_actions_by_room_for_user = (
EventPushActionsStore.__dict__["get_unread_event_push_actions_by_room_for_user"]
)
_get_unread_counts_by_receipt_txn = (
DataStore._get_unread_counts_by_receipt_txn.__func__
)
_get_unread_counts_by_pos_txn = (
DataStore._get_unread_counts_by_pos_txn.__func__
)
get_recent_event_ids_for_room = (
StreamStore.__dict__["get_recent_event_ids_for_room"]
)
_get_joined_hosts_cache = RoomMemberStore.__dict__["_get_joined_hosts_cache"]
has_room_changed_since = DataStore.has_room_changed_since.__func__
get_unread_push_actions_for_user_in_range_for_http = ( def get_room_max_stream_ordering(self):
DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__ return self._stream_id_gen.get_current_token()
)
get_unread_push_actions_for_user_in_range_for_email = (
DataStore.get_unread_push_actions_for_user_in_range_for_email.__func__
)
get_push_action_users_in_range = (
DataStore.get_push_action_users_in_range.__func__
)
get_event = DataStore.get_event.__func__
get_events = DataStore.get_events.__func__
get_rooms_for_user_where_membership_is = (
DataStore.get_rooms_for_user_where_membership_is.__func__
)
get_membership_changes_for_user = (
DataStore.get_membership_changes_for_user.__func__
)
get_room_events_max_id = DataStore.get_room_events_max_id.__func__
get_room_events_stream_for_room = (
DataStore.get_room_events_stream_for_room.__func__
)
get_events_around = DataStore.get_events_around.__func__
get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__
get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__
_get_joined_users_from_context = (
RoomMemberStore.__dict__["_get_joined_users_from_context"]
)
get_joined_hosts = DataStore.get_joined_hosts.__func__ def get_room_min_stream_ordering(self):
_get_joined_hosts = RoomMemberStore.__dict__["_get_joined_hosts"] return self._backfill_id_gen.get_current_token()
get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__
get_room_events_stream_for_rooms = (
DataStore.get_room_events_stream_for_rooms.__func__
)
is_host_joined = RoomMemberStore.__dict__["is_host_joined"]
get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__
_set_before_and_after = staticmethod(DataStore._set_before_and_after)
_get_events = DataStore._get_events.__func__
_get_events_from_cache = DataStore._get_events_from_cache.__func__
_invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__
_enqueue_events = DataStore._enqueue_events.__func__
_do_fetch = DataStore._do_fetch.__func__
_fetch_event_rows = DataStore._fetch_event_rows.__func__
_get_event_from_row = DataStore._get_event_from_row.__func__
_get_rooms_for_user_where_membership_is_txn = (
DataStore._get_rooms_for_user_where_membership_is_txn.__func__
)
_get_events_around_txn = DataStore._get_events_around_txn.__func__
get_backfill_events = DataStore.get_backfill_events.__func__
_get_backfill_events = DataStore._get_backfill_events.__func__
get_missing_events = DataStore.get_missing_events.__func__
_get_missing_events = DataStore._get_missing_events.__func__
get_auth_chain = DataStore.get_auth_chain.__func__
get_auth_chain_ids = DataStore.get_auth_chain_ids.__func__
_get_auth_chain_ids_txn = DataStore._get_auth_chain_ids_txn.__func__
get_room_max_stream_ordering = DataStore.get_room_max_stream_ordering.__func__
get_forward_extremeties_for_room = (
DataStore.get_forward_extremeties_for_room.__func__
)
_get_forward_extremeties_for_room = (
EventFederationStore.__dict__["_get_forward_extremeties_for_room"]
)
get_all_new_events_stream = DataStore.get_all_new_events_stream.__func__
get_federation_out_pos = DataStore.get_federation_out_pos.__func__
update_federation_out_pos = DataStore.update_federation_out_pos.__func__
def stream_positions(self): def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions() result = super(SlavedEventStore, self).stream_positions()

View File

@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.storage.profile import ProfileWorkerStore
class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore):
pass

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -15,29 +16,15 @@
from .events import SlavedEventStore from .events import SlavedEventStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore from synapse.storage.push_rule import PushRulesWorkerStore
from synapse.storage.push_rule import PushRuleStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedPushRuleStore(SlavedEventStore): class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(SlavedPushRuleStore, self).__init__(db_conn, hs)
self._push_rules_stream_id_gen = SlavedIdTracker( self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id", db_conn, "push_rules_stream", "stream_id",
) )
self.push_rules_stream_cache = StreamChangeCache( super(SlavedPushRuleStore, self).__init__(db_conn, hs)
"PushRulesStreamChangeCache",
self._push_rules_stream_id_gen.get_current_token(),
)
get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"]
get_push_rules_enabled_for_user = (
PushRuleStore.__dict__["get_push_rules_enabled_for_user"]
)
have_push_rules_changed_for_user = (
DataStore.have_push_rules_changed_for_user.__func__
)
def get_push_rules_stream_token(self): def get_push_rules_stream_token(self):
return ( return (
@ -45,6 +32,9 @@ class SlavedPushRuleStore(SlavedEventStore):
self._stream_id_gen.get_current_token(), self._stream_id_gen.get_current_token(),
) )
def get_max_push_rules_stream_id(self):
return self._push_rules_stream_id_gen.get_current_token()
def stream_positions(self): def stream_positions(self):
result = super(SlavedPushRuleStore, self).stream_positions() result = super(SlavedPushRuleStore, self).stream_positions()
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token() result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd # Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -16,10 +17,10 @@
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore from synapse.storage.pusher import PusherWorkerStore
class SlavedPusherStore(BaseSlavedStore): class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(SlavedPusherStore, self).__init__(db_conn, hs) super(SlavedPusherStore, self).__init__(db_conn, hs)
@ -28,13 +29,6 @@ class SlavedPusherStore(BaseSlavedStore):
extra_tables=[("deleted_pushers", "stream_id")], extra_tables=[("deleted_pushers", "stream_id")],
) )
get_all_pushers = DataStore.get_all_pushers.__func__
get_pushers_by = DataStore.get_pushers_by.__func__
get_pushers_by_app_id_and_pushkey = (
DataStore.get_pushers_by_app_id_and_pushkey.__func__
)
_decode_pushers_rows = DataStore._decode_pushers_rows.__func__
def stream_positions(self): def stream_positions(self):
result = super(SlavedPusherStore, self).stream_positions() result = super(SlavedPusherStore, self).stream_positions()
result["pushers"] = self._pushers_id_gen.get_current_token() result["pushers"] = self._pushers_id_gen.get_current_token()

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd # Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -16,9 +17,7 @@
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore from synapse.storage.receipts import ReceiptsWorkerStore
from synapse.storage.receipts import ReceiptsStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
# So, um, we want to borrow a load of functions intended for reading from # So, um, we want to borrow a load of functions intended for reading from
# a DataStore, but we don't want to take functions that either write to the # a DataStore, but we don't want to take functions that either write to the
@ -29,36 +28,19 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
# the method descriptor on the DataStore and chuck them into our class. # the method descriptor on the DataStore and chuck them into our class.
class SlavedReceiptsStore(BaseSlavedStore): class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(SlavedReceiptsStore, self).__init__(db_conn, hs) # We instantiate this first as the ReceiptsWorkerStore constructor
# needs to be able to call get_max_receipt_stream_id
self._receipts_id_gen = SlavedIdTracker( self._receipts_id_gen = SlavedIdTracker(
db_conn, "receipts_linearized", "stream_id" db_conn, "receipts_linearized", "stream_id"
) )
self._receipts_stream_cache = StreamChangeCache( super(SlavedReceiptsStore, self).__init__(db_conn, hs)
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
)
get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"] def get_max_receipt_stream_id(self):
get_linearized_receipts_for_room = ( return self._receipts_id_gen.get_current_token()
ReceiptsStore.__dict__["get_linearized_receipts_for_room"]
)
_get_linearized_receipts_for_rooms = (
ReceiptsStore.__dict__["_get_linearized_receipts_for_rooms"]
)
get_last_receipt_event_id_for_user = (
ReceiptsStore.__dict__["get_last_receipt_event_id_for_user"]
)
get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__
get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__
get_linearized_receipts_for_rooms = (
DataStore.get_linearized_receipts_for_rooms.__func__
)
def stream_positions(self): def stream_positions(self):
result = super(SlavedReceiptsStore, self).stream_positions() result = super(SlavedReceiptsStore, self).stream_positions()
@ -71,6 +53,8 @@ class SlavedReceiptsStore(BaseSlavedStore):
self.get_last_receipt_event_id_for_user.invalidate( self.get_last_receipt_event_id_for_user.invalidate(
(user_id, room_id, receipt_type) (user_id, room_id, receipt_type)
) )
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
self.get_receipts_for_room.invalidate((room_id, receipt_type))
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, token, rows):
if stream_name == "receipts": if stream_name == "receipts":

View File

@ -14,20 +14,8 @@
# limitations under the License. # limitations under the License.
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from synapse.storage import DataStore from synapse.storage.registration import RegistrationWorkerStore
from synapse.storage.registration import RegistrationStore
class SlavedRegistrationStore(BaseSlavedStore): class SlavedRegistrationStore(RegistrationWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs): pass
super(SlavedRegistrationStore, self).__init__(db_conn, hs)
# TODO: use the cached version and invalidate deleted tokens
get_user_by_access_token = RegistrationStore.__dict__[
"get_user_by_access_token"
]
_query_for_auth = DataStore._query_for_auth.__func__
get_user_by_id = RegistrationStore.__dict__[
"get_user_by_id"
]

View File

@ -14,32 +14,19 @@
# limitations under the License. # limitations under the License.
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from synapse.storage import DataStore from synapse.storage.room import RoomWorkerStore
from synapse.storage.room import RoomStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
class RoomStore(BaseSlavedStore): class RoomStore(RoomWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(RoomStore, self).__init__(db_conn, hs) super(RoomStore, self).__init__(db_conn, hs)
self._public_room_id_gen = SlavedIdTracker( self._public_room_id_gen = SlavedIdTracker(
db_conn, "public_room_list_stream", "stream_id" db_conn, "public_room_list_stream", "stream_id"
) )
get_public_room_ids = DataStore.get_public_room_ids.__func__ def get_current_public_room_stream_id(self):
get_current_public_room_stream_id = ( return self._public_room_id_gen.get_current_token()
DataStore.get_current_public_room_stream_id.__func__
)
get_public_room_ids_at_stream_id = (
RoomStore.__dict__["get_public_room_ids_at_stream_id"]
)
get_public_room_ids_at_stream_id_txn = (
DataStore.get_public_room_ids_at_stream_id_txn.__func__
)
get_published_at_stream_id_txn = (
DataStore.get_published_at_stream_id_txn.__func__
)
get_public_room_changes = DataStore.get_public_room_changes.__func__
def stream_positions(self): def stream_positions(self):
result = super(RoomStore, self).stream_positions() result = super(RoomStore, self).stream_positions()

View File

@ -517,25 +517,28 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.send_error("Wrong remote") self.send_error("Wrong remote")
def on_RDATA(self, cmd): def on_RDATA(self, cmd):
stream_name = cmd.stream_name
inbound_rdata_count.inc(stream_name)
try: try:
row = STREAMS_MAP[cmd.stream_name].ROW_TYPE(*cmd.row) row = STREAMS_MAP[stream_name].ROW_TYPE(*cmd.row)
except Exception: except Exception:
logger.exception( logger.exception(
"[%s] Failed to parse RDATA: %r %r", "[%s] Failed to parse RDATA: %r %r",
self.id(), cmd.stream_name, cmd.row self.id(), stream_name, cmd.row
) )
raise raise
if cmd.token is None: if cmd.token is None:
# I.e. this is part of a batch of updates for this stream. Batch # I.e. this is part of a batch of updates for this stream. Batch
# until we get an update for the stream with a non None token # until we get an update for the stream with a non None token
self.pending_batches.setdefault(cmd.stream_name, []).append(row) self.pending_batches.setdefault(stream_name, []).append(row)
else: else:
# Check if this is the last of a batch of updates # Check if this is the last of a batch of updates
rows = self.pending_batches.pop(cmd.stream_name, []) rows = self.pending_batches.pop(stream_name, [])
rows.append(row) rows.append(row)
self.handler.on_rdata(cmd.stream_name, cmd.token, rows) self.handler.on_rdata(stream_name, cmd.token, rows)
def on_POSITION(self, cmd): def on_POSITION(self, cmd):
self.handler.on_position(cmd.stream_name, cmd.token) self.handler.on_position(cmd.stream_name, cmd.token)
@ -644,3 +647,9 @@ metrics.register_callback(
}, },
labels=["command", "name", "conn_id"], labels=["command", "name", "conn_id"],
) )
# number of updates received for each RDATA stream
inbound_rdata_count = metrics.register_counter(
"inbound_rdata_count",
labels=["stream_name"],
)

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -16,7 +17,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError, Codes, NotFoundError
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
@ -113,12 +114,18 @@ class PurgeMediaCacheRestServlet(ClientV1RestServlet):
class PurgeHistoryRestServlet(ClientV1RestServlet): class PurgeHistoryRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns( PATTERNS = client_path_patterns(
"/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" "/admin/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?"
) )
def __init__(self, hs): def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer)
"""
super(PurgeHistoryRestServlet, self).__init__(hs) super(PurgeHistoryRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.store = hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, event_id): def on_POST(self, request, room_id, event_id):
@ -128,9 +135,93 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
if not is_admin: if not is_admin:
raise AuthError(403, "You are not a server admin") raise AuthError(403, "You are not a server admin")
yield self.handlers.message_handler.purge_history(room_id, event_id) body = parse_json_object_from_request(request, allow_empty_body=True)
defer.returnValue((200, {})) delete_local_events = bool(body.get("delete_local_events", False))
# establish the topological ordering we should keep events from. The
# user can provide an event_id in the URL or the request body, or can
# provide a timestamp in the request body.
if event_id is None:
event_id = body.get('purge_up_to_event_id')
if event_id is not None:
event = yield self.store.get_event(event_id)
if event.room_id != room_id:
raise SynapseError(400, "Event is for wrong room.")
depth = event.depth
logger.info(
"[purge] purging up to depth %i (event_id %s)",
depth, event_id,
)
elif 'purge_up_to_ts' in body:
ts = body['purge_up_to_ts']
if not isinstance(ts, int):
raise SynapseError(
400, "purge_up_to_ts must be an int",
errcode=Codes.BAD_JSON,
)
stream_ordering = (
yield self.store.find_first_stream_ordering_after_ts(ts)
)
(_, depth, _) = (
yield self.store.get_room_event_after_stream_ordering(
room_id, stream_ordering,
)
)
logger.info(
"[purge] purging up to depth %i (received_ts %i => "
"stream_ordering %i)",
depth, ts, stream_ordering,
)
else:
raise SynapseError(
400,
"must specify purge_up_to_event_id or purge_up_to_ts",
errcode=Codes.BAD_JSON,
)
purge_id = yield self.handlers.message_handler.start_purge_history(
room_id, depth,
delete_local_events=delete_local_events,
)
defer.returnValue((200, {
"purge_id": purge_id,
}))
class PurgeHistoryStatusRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns(
"/admin/purge_history_status/(?P<purge_id>[^/]+)"
)
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer)
"""
super(PurgeHistoryStatusRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, purge_id):
requester = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(requester.user)
if not is_admin:
raise AuthError(403, "You are not a server admin")
purge_status = self.handlers.message_handler.get_purge_status(purge_id)
if purge_status is None:
raise NotFoundError("purge id '%s' not found" % purge_id)
defer.returnValue((200, purge_status.asdict()))
class DeactivateAccountRestServlet(ClientV1RestServlet): class DeactivateAccountRestServlet(ClientV1RestServlet):
@ -171,6 +262,8 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id): def on_POST(self, request, room_id):
@ -203,8 +296,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
) )
new_room_id = info["room_id"] new_room_id = info["room_id"]
msg_handler = self.handlers.message_handler yield self.event_creation_handler.create_and_send_nonmember_event(
yield msg_handler.create_and_send_nonmember_event(
room_creator_requester, room_creator_requester,
{ {
"type": "m.room.message", "type": "m.room.message",
@ -230,7 +322,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
logger.info("Kicking %r from %r...", user_id, room_id) logger.info("Kicking %r from %r...", user_id, room_id)
target_requester = create_requester(user_id) target_requester = create_requester(user_id)
yield self.handlers.room_member_handler.update_membership( yield self.room_member_handler.update_membership(
requester=target_requester, requester=target_requester,
target=target_requester.user, target=target_requester.user,
room_id=room_id, room_id=room_id,
@ -239,9 +331,9 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
ratelimit=False ratelimit=False
) )
yield self.handlers.room_member_handler.forget(target_requester.user, room_id) yield self.room_member_handler.forget(target_requester.user, room_id)
yield self.handlers.room_member_handler.update_membership( yield self.room_member_handler.update_membership(
requester=target_requester, requester=target_requester,
target=target_requester.user, target=target_requester.user,
room_id=new_room_id, room_id=new_room_id,
@ -289,6 +381,27 @@ class QuarantineMediaInRoom(ClientV1RestServlet):
defer.returnValue((200, {"num_quarantined": num_quarantined})) defer.returnValue((200, {"num_quarantined": num_quarantined}))
class ListMediaInRoom(ClientV1RestServlet):
"""Lists all of the media in a given room.
"""
PATTERNS = client_path_patterns("/admin/room/(?P<room_id>[^/]+)/media")
def __init__(self, hs):
super(ListMediaInRoom, self).__init__(hs)
self.store = hs.get_datastore()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(requester.user)
if not is_admin:
raise AuthError(403, "You are not a server admin")
local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id)
defer.returnValue((200, {"local": local_mxcs, "remote": remote_mxcs}))
class ResetPasswordRestServlet(ClientV1RestServlet): class ResetPasswordRestServlet(ClientV1RestServlet):
"""Post request to allow an administrator reset password for a user. """Post request to allow an administrator reset password for a user.
This needs user to have administrator access in Synapse. This needs user to have administrator access in Synapse.
@ -479,6 +592,7 @@ class SearchUsersRestServlet(ClientV1RestServlet):
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
WhoisRestServlet(hs).register(http_server) WhoisRestServlet(hs).register(http_server)
PurgeMediaCacheRestServlet(hs).register(http_server) PurgeMediaCacheRestServlet(hs).register(http_server)
PurgeHistoryStatusRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server)
PurgeHistoryRestServlet(hs).register(http_server) PurgeHistoryRestServlet(hs).register(http_server)
UsersRestServlet(hs).register(http_server) UsersRestServlet(hs).register(http_server)
@ -487,3 +601,4 @@ def register_servlets(hs, http_server):
SearchUsersRestServlet(hs).register(http_server) SearchUsersRestServlet(hs).register(http_server)
ShutdownRoomRestServlet(hs).register(http_server) ShutdownRoomRestServlet(hs).register(http_server)
QuarantineMediaInRoom(hs).register(http_server) QuarantineMediaInRoom(hs).register(http_server)
ListMediaInRoom(hs).register(http_server)

View File

@ -191,19 +191,25 @@ class LoginRestServlet(ClientV1RestServlet):
# convert threepid identifiers to user IDs # convert threepid identifiers to user IDs
if identifier["type"] == "m.id.thirdparty": if identifier["type"] == "m.id.thirdparty":
if 'medium' not in identifier or 'address' not in identifier: address = identifier.get('address')
medium = identifier.get('medium')
if medium is None or address is None:
raise SynapseError(400, "Invalid thirdparty identifier") raise SynapseError(400, "Invalid thirdparty identifier")
address = identifier['address'] if medium == 'email':
if identifier['medium'] == 'email':
# For emails, transform the address to lowercase. # For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB. # We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py) # (See add_threepid in synapse/handlers/auth.py)
address = address.lower() address = address.lower()
user_id = yield self.hs.get_datastore().get_user_id_by_threepid( user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
identifier['medium'], address medium, address,
) )
if not user_id: if not user_id:
logger.warn(
"unknown 3pid identifier medium %s, address %r",
medium, address,
)
raise LoginError(403, "", errcode=Codes.FORBIDDEN) raise LoginError(403, "", errcode=Codes.FORBIDDEN)
identifier = { identifier = {

View File

@ -70,10 +70,15 @@ class RegisterRestServlet(ClientV1RestServlet):
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
def on_GET(self, request): def on_GET(self, request):
require_email = 'email' in self.hs.config.registrations_require_3pid
require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
flows = []
if self.hs.config.enable_registration_captcha: if self.hs.config.enable_registration_captcha:
return ( # only support the email-only flow if we don't require MSISDN 3PIDs
200, if not require_msisdn:
{"flows": [ flows.extend([
{ {
"type": LoginType.RECAPTCHA, "type": LoginType.RECAPTCHA,
"stages": [ "stages": [
@ -82,27 +87,34 @@ class RegisterRestServlet(ClientV1RestServlet):
LoginType.PASSWORD LoginType.PASSWORD
] ]
}, },
])
# only support 3PIDless registration if no 3PIDs are required
if not require_email and not require_msisdn:
flows.extend([
{ {
"type": LoginType.RECAPTCHA, "type": LoginType.RECAPTCHA,
"stages": [LoginType.RECAPTCHA, LoginType.PASSWORD] "stages": [LoginType.RECAPTCHA, LoginType.PASSWORD]
} }
]} ])
)
else: else:
return ( # only support the email-only flow if we don't require MSISDN 3PIDs
200, if require_email or not require_msisdn:
{"flows": [ flows.extend([
{ {
"type": LoginType.EMAIL_IDENTITY, "type": LoginType.EMAIL_IDENTITY,
"stages": [ "stages": [
LoginType.EMAIL_IDENTITY, LoginType.PASSWORD LoginType.EMAIL_IDENTITY, LoginType.PASSWORD
] ]
}, }
])
# only support 3PIDless registration if no 3PIDs are required
if not require_email and not require_msisdn:
flows.extend([
{ {
"type": LoginType.PASSWORD "type": LoginType.PASSWORD
} }
]} ])
) return (200, {"flows": flows})
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -82,6 +83,8 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomStateEventRestServlet, self).__init__(hs) super(RoomStateEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.event_creation_hander = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
def register(self, http_server): def register(self, http_server):
# /room/$roomid/state/$eventtype # /room/$roomid/state/$eventtype
@ -154,7 +157,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
if event_type == EventTypes.Member: if event_type == EventTypes.Member:
membership = content.get("membership", None) membership = content.get("membership", None)
event = yield self.handlers.room_member_handler.update_membership( event = yield self.room_member_handler.update_membership(
requester, requester,
target=UserID.from_string(state_key), target=UserID.from_string(state_key),
room_id=room_id, room_id=room_id,
@ -162,15 +165,16 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
content=content, content=content,
) )
else: else:
msg_handler = self.handlers.message_handler event, context = yield self.event_creation_hander.create_event(
event, context = yield msg_handler.create_event(
requester, requester,
event_dict, event_dict,
token_id=requester.access_token_id, token_id=requester.access_token_id,
txn_id=txn_id, txn_id=txn_id,
) )
yield msg_handler.send_nonmember_event(requester, event, context) yield self.event_creation_hander.send_nonmember_event(
requester, event, context,
)
ret = {} ret = {}
if event: if event:
@ -183,7 +187,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomSendEventRestServlet, self).__init__(hs) super(RoomSendEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.event_creation_hander = hs.get_event_creation_handler()
def register(self, http_server): def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id] # /rooms/$roomid/send/$event_type[/$txn_id]
@ -195,15 +199,19 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
msg_handler = self.handlers.message_handler event_dict = {
event = yield msg_handler.create_and_send_nonmember_event( "type": event_type,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
}
if 'ts' in request.args and requester.app_service:
event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
event = yield self.event_creation_hander.create_and_send_nonmember_event(
requester, requester,
{ event_dict,
"type": event_type,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
},
txn_id=txn_id, txn_id=txn_id,
) )
@ -222,7 +230,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
class JoinRoomAliasServlet(ClientV1RestServlet): class JoinRoomAliasServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(JoinRoomAliasServlet, self).__init__(hs) super(JoinRoomAliasServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.room_member_handler = hs.get_room_member_handler()
def register(self, http_server): def register(self, http_server):
# /join/$room_identifier[/$txn_id] # /join/$room_identifier[/$txn_id]
@ -250,7 +258,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
except Exception: except Exception:
remote_room_hosts = None remote_room_hosts = None
elif RoomAlias.is_valid(room_identifier): elif RoomAlias.is_valid(room_identifier):
handler = self.handlers.room_member_handler handler = self.room_member_handler
room_alias = RoomAlias.from_string(room_identifier) room_alias = RoomAlias.from_string(room_identifier)
room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias) room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias)
room_id = room_id.to_string() room_id = room_id.to_string()
@ -259,7 +267,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
room_identifier, room_identifier,
)) ))
yield self.handlers.room_member_handler.update_membership( yield self.room_member_handler.update_membership(
requester=requester, requester=requester,
target=requester.user, target=requester.user,
room_id=room_id, room_id=room_id,
@ -487,13 +495,35 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
defer.returnValue((200, content)) defer.returnValue((200, content))
class RoomEventContext(ClientV1RestServlet): class RoomEventServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns(
"/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$"
)
def __init__(self, hs):
super(RoomEventServlet, self).__init__(hs)
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_id):
requester = yield self.auth.get_user_by_req(request)
event = yield self.event_handler.get_event(requester.user, event_id)
time_now = self.clock.time_msec()
if event:
defer.returnValue((200, serialize_event(event, time_now)))
else:
defer.returnValue((404, "Event not found."))
class RoomEventContextServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns( PATTERNS = client_path_patterns(
"/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$" "/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$"
) )
def __init__(self, hs): def __init__(self, hs):
super(RoomEventContext, self).__init__(hs) super(RoomEventContextServlet, self).__init__(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
@ -533,7 +563,7 @@ class RoomEventContext(ClientV1RestServlet):
class RoomForgetRestServlet(ClientV1RestServlet): class RoomForgetRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomForgetRestServlet, self).__init__(hs) super(RoomForgetRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.room_member_handler = hs.get_room_member_handler()
def register(self, http_server): def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget") PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
@ -546,7 +576,7 @@ class RoomForgetRestServlet(ClientV1RestServlet):
allow_guest=False, allow_guest=False,
) )
yield self.handlers.room_member_handler.forget( yield self.room_member_handler.forget(
user=requester.user, user=requester.user,
room_id=room_id, room_id=room_id,
) )
@ -564,12 +594,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomMembershipRestServlet, self).__init__(hs) super(RoomMembershipRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.room_member_handler = hs.get_room_member_handler()
def register(self, http_server): def register(self, http_server):
# /rooms/$roomid/[invite|join|leave] # /rooms/$roomid/[invite|join|leave]
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/" PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
"(?P<membership_action>join|invite|leave|ban|unban|kick|forget)") "(?P<membership_action>join|invite|leave|ban|unban|kick)")
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -593,7 +623,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
content = {} content = {}
if membership_action == "invite" and self._has_3pid_invite_keys(content): if membership_action == "invite" and self._has_3pid_invite_keys(content):
yield self.handlers.room_member_handler.do_3pid_invite( yield self.room_member_handler.do_3pid_invite(
room_id, room_id,
requester.user, requester.user,
content["medium"], content["medium"],
@ -615,7 +645,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
if 'reason' in content and membership_action in ['kick', 'ban']: if 'reason' in content and membership_action in ['kick', 'ban']:
event_content = {'reason': content['reason']} event_content = {'reason': content['reason']}
yield self.handlers.room_member_handler.update_membership( yield self.room_member_handler.update_membership(
requester=requester, requester=requester,
target=target, target=target,
room_id=room_id, room_id=room_id,
@ -643,6 +673,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomRedactEventRestServlet, self).__init__(hs) super(RoomRedactEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.event_creation_handler = hs.get_event_creation_handler()
def register(self, http_server): def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)") PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
@ -653,8 +684,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
msg_handler = self.handlers.message_handler event = yield self.event_creation_handler.create_and_send_nonmember_event(
event = yield msg_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.Redaction, "type": EventTypes.Redaction,
@ -803,4 +833,5 @@ def register_servlets(hs, http_server):
RoomTypingRestServlet(hs).register(http_server) RoomTypingRestServlet(hs).register(http_server)
SearchRestServlet(hs).register(http_server) SearchRestServlet(hs).register(http_server)
JoinedRoomsRestServlet(hs).register(http_server) JoinedRoomsRestServlet(hs).register(http_server)
RoomEventContext(hs).register(http_server) RoomEventServlet(hs).register(http_server)
RoomEventContextServlet(hs).register(http_server)

View File

@ -26,6 +26,7 @@ from synapse.http.servlet import (
) )
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import check_3pid_allowed
from ._base import client_v2_patterns, interactive_auth_handler from ._base import client_v2_patterns, interactive_auth_handler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -47,6 +48,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
'id_server', 'client_secret', 'email', 'send_attempt' 'id_server', 'client_secret', 'email', 'send_attempt'
]) ])
if not check_3pid_allowed(self.hs, "email", body['email']):
raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email'] 'email', body['email']
) )
@ -78,6 +84,11 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
)
existingUid = yield self.datastore.get_user_id_by_threepid( existingUid = yield self.datastore.get_user_id_by_threepid(
'msisdn', msisdn 'msisdn', msisdn
) )
@ -217,6 +228,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
if absent: if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
if not check_3pid_allowed(self.hs, "email", body['email']):
raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
)
existingUid = yield self.datastore.get_user_id_by_threepid( existingUid = yield self.datastore.get_user_id_by_threepid(
'email', body['email'] 'email', body['email']
) )
@ -255,6 +271,11 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
)
existingUid = yield self.datastore.get_user_id_by_threepid( existingUid = yield self.datastore.get_user_id_by_threepid(
'msisdn', msisdn 'msisdn', msisdn
) )

View File

@ -26,6 +26,7 @@ from synapse.http.servlet import (
RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string
) )
from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import check_3pid_allowed
from ._base import client_v2_patterns, interactive_auth_handler from ._base import client_v2_patterns, interactive_auth_handler
@ -70,6 +71,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
'id_server', 'client_secret', 'email', 'send_attempt' 'id_server', 'client_secret', 'email', 'send_attempt'
]) ])
if not check_3pid_allowed(self.hs, "email", body['email']):
raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email'] 'email', body['email']
) )
@ -105,6 +111,11 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'msisdn', msisdn 'msisdn', msisdn
) )
@ -172,7 +183,7 @@ class RegisterRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_handlers().registration_handler self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
self.room_member_handler = hs.get_handlers().room_member_handler self.room_member_handler = hs.get_room_member_handler()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator() self.macaroon_gen = hs.get_macaroon_generator()
@ -305,31 +316,67 @@ class RegisterRestServlet(RestServlet):
if 'x_show_msisdn' in body and body['x_show_msisdn']: if 'x_show_msisdn' in body and body['x_show_msisdn']:
show_msisdn = True show_msisdn = True
# FIXME: need a better error than "no auth flow found" for scenarios
# where we required 3PID for registration but the user didn't give one
require_email = 'email' in self.hs.config.registrations_require_3pid
require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
flows = []
if self.hs.config.enable_registration_captcha: if self.hs.config.enable_registration_captcha:
flows = [ # only support 3PIDless registration if no 3PIDs are required
[LoginType.RECAPTCHA], if not require_email and not require_msisdn:
[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA], flows.extend([[LoginType.RECAPTCHA]])
] # only support the email-only flow if we don't require MSISDN 3PIDs
if not require_msisdn:
flows.extend([[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]])
if show_msisdn: if show_msisdn:
# only support the MSISDN-only flow if we don't require email 3PIDs
if not require_email:
flows.extend([[LoginType.MSISDN, LoginType.RECAPTCHA]])
# always let users provide both MSISDN & email
flows.extend([ flows.extend([
[LoginType.MSISDN, LoginType.RECAPTCHA],
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA], [LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
]) ])
else: else:
flows = [ # only support 3PIDless registration if no 3PIDs are required
[LoginType.DUMMY], if not require_email and not require_msisdn:
[LoginType.EMAIL_IDENTITY], flows.extend([[LoginType.DUMMY]])
] # only support the email-only flow if we don't require MSISDN 3PIDs
if not require_msisdn:
flows.extend([[LoginType.EMAIL_IDENTITY]])
if show_msisdn: if show_msisdn:
# only support the MSISDN-only flow if we don't require email 3PIDs
if not require_email or require_msisdn:
flows.extend([[LoginType.MSISDN]])
# always let users provide both MSISDN & email
flows.extend([ flows.extend([
[LoginType.MSISDN], [LoginType.MSISDN, LoginType.EMAIL_IDENTITY]
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
]) ])
auth_result, params, session_id = yield self.auth_handler.check_auth( auth_result, params, session_id = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request) flows, body, self.hs.get_ip_from_request(request)
) )
# Check that we're not trying to register a denied 3pid.
#
# the user-facing checks will probably already have happened in
# /register/email/requestToken when we requested a 3pid, but that's not
# guaranteed.
if auth_result:
for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
if login_type in auth_result:
medium = auth_result[login_type]['medium']
address = auth_result[login_type]['address']
if not check_3pid_allowed(self.hs, medium, address):
raise SynapseError(
403, "Third party identifier is not allowed",
Codes.THREEPID_DENIED,
)
if registered_user_id is not None: if registered_user_id is not None:
logger.info( logger.info(
"Already registered user ID %r for this session", "Already registered user ID %r for this session",

View File

@ -93,6 +93,7 @@ class RemoteKey(Resource):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.version_string = hs.version_string self.version_string = hs.version_string
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
def render_GET(self, request): def render_GET(self, request):
self.async_render_GET(request) self.async_render_GET(request)
@ -137,6 +138,13 @@ class RemoteKey(Resource):
logger.info("Handling query for keys %r", query) logger.info("Handling query for keys %r", query)
store_queries = [] store_queries = []
for server_name, key_ids in query.items(): for server_name, key_ids in query.items():
if (
self.federation_domain_whitelist is not None and
server_name not in self.federation_domain_whitelist
):
logger.debug("Federation denied with %s", server_name)
continue
if not key_ids: if not key_ids:
key_ids = (None,) key_ids = (None,)
for key_id in key_ids: for key_id in key_ids:

View File

@ -70,38 +70,11 @@ def respond_with_file(request, media_type, file_path,
logger.debug("Responding with %r", file_path) logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path): if os.path.isfile(file_path):
request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
if upload_name:
if is_ascii(upload_name):
request.setHeader(
b"Content-Disposition",
b"inline; filename=%s" % (
urllib.quote(upload_name.encode("utf-8")),
),
)
else:
request.setHeader(
b"Content-Disposition",
b"inline; filename*=utf-8''%s" % (
urllib.quote(upload_name.encode("utf-8")),
),
)
# cache for at least a day.
# XXX: we might want to turn this off for data we don't want to
# recommend caching as it's sensitive or private - or at least
# select private. don't bother setting Expires as all our
# clients are smart enough to be happy with Cache-Control
request.setHeader(
b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
)
if file_size is None: if file_size is None:
stat = os.stat(file_path) stat = os.stat(file_path)
file_size = stat.st_size file_size = stat.st_size
request.setHeader( add_file_headers(request, media_type, file_size, upload_name)
b"Content-Length", b"%d" % (file_size,)
)
with open(file_path, "rb") as f: with open(file_path, "rb") as f:
yield logcontext.make_deferred_yieldable( yield logcontext.make_deferred_yieldable(
@ -111,3 +84,118 @@ def respond_with_file(request, media_type, file_path,
finish_request(request) finish_request(request)
else: else:
respond_404(request) respond_404(request)
def add_file_headers(request, media_type, file_size, upload_name):
"""Adds the correct response headers in preparation for responding with the
media.
Args:
request (twisted.web.http.Request)
media_type (str): The media/content type.
file_size (int): Size in bytes of the media, if known.
upload_name (str): The name of the requested file, if any.
"""
request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
if upload_name:
if is_ascii(upload_name):
request.setHeader(
b"Content-Disposition",
b"inline; filename=%s" % (
urllib.quote(upload_name.encode("utf-8")),
),
)
else:
request.setHeader(
b"Content-Disposition",
b"inline; filename*=utf-8''%s" % (
urllib.quote(upload_name.encode("utf-8")),
),
)
# cache for at least a day.
# XXX: we might want to turn this off for data we don't want to
# recommend caching as it's sensitive or private - or at least
# select private. don't bother setting Expires as all our
# clients are smart enough to be happy with Cache-Control
request.setHeader(
b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
)
request.setHeader(
b"Content-Length", b"%d" % (file_size,)
)
@defer.inlineCallbacks
def respond_with_responder(request, responder, media_type, file_size, upload_name=None):
"""Responds to the request with given responder. If responder is None then
returns 404.
Args:
request (twisted.web.http.Request)
responder (Responder|None)
media_type (str): The media/content type.
file_size (int|None): Size in bytes of the media. If not known it should be None
upload_name (str|None): The name of the requested file, if any.
"""
if not responder:
respond_404(request)
return
add_file_headers(request, media_type, file_size, upload_name)
with responder:
yield responder.write_to_consumer(request)
finish_request(request)
class Responder(object):
"""Represents a response that can be streamed to the requester.
Responder is a context manager which *must* be used, so that any resources
held can be cleaned up.
"""
def write_to_consumer(self, consumer):
"""Stream response into consumer
Args:
consumer (IConsumer)
Returns:
Deferred: Resolves once the response has finished being written
"""
pass
def __enter__(self):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
pass
class FileInfo(object):
"""Details about a requested/uploaded file.
Attributes:
server_name (str): The server name where the media originated from,
or None if local.
file_id (str): The local ID of the file. For local files this is the
same as the media_id
url_cache (bool): If the file is for the url preview cache
thumbnail (bool): Whether the file is a thumbnail or not.
thumbnail_width (int)
thumbnail_height (int)
thumbnail_method (str)
thumbnail_type (str): Content type of thumbnail, e.g. image/png
"""
def __init__(self, server_name, file_id, url_cache=False,
thumbnail=False, thumbnail_width=None, thumbnail_height=None,
thumbnail_method=None, thumbnail_type=None):
self.server_name = server_name
self.file_id = file_id
self.url_cache = url_cache
self.thumbnail = thumbnail
self.thumbnail_width = thumbnail_width
self.thumbnail_height = thumbnail_height
self.thumbnail_method = thumbnail_method
self.thumbnail_type = thumbnail_type

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import synapse.http.servlet import synapse.http.servlet
from ._base import parse_media_id, respond_with_file, respond_404 from ._base import parse_media_id, respond_404
from twisted.web.resource import Resource from twisted.web.resource import Resource
from synapse.http.server import request_handler, set_cors_headers from synapse.http.server import request_handler, set_cors_headers
@ -32,12 +32,12 @@ class DownloadResource(Resource):
def __init__(self, hs, media_repo): def __init__(self, hs, media_repo):
Resource.__init__(self) Resource.__init__(self)
self.filepaths = media_repo.filepaths
self.media_repo = media_repo self.media_repo = media_repo
self.server_name = hs.hostname self.server_name = hs.hostname
self.store = hs.get_datastore()
self.version_string = hs.version_string # Both of these are expected by @request_handler()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.version_string = hs.version_string
def render_GET(self, request): def render_GET(self, request):
self._async_render_GET(request) self._async_render_GET(request)
@ -57,59 +57,16 @@ class DownloadResource(Resource):
) )
server_name, media_id, name = parse_media_id(request) server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name: if server_name == self.server_name:
yield self._respond_local_file(request, media_id, name) yield self.media_repo.get_local_media(request, media_id, name)
else: else:
yield self._respond_remote_file( allow_remote = synapse.http.servlet.parse_boolean(
request, server_name, media_id, name request, "allow_remote", default=True)
) if not allow_remote:
logger.info(
"Rejecting request for remote media %s/%s due to allow_remote",
server_name, media_id,
)
respond_404(request)
return
@defer.inlineCallbacks yield self.media_repo.get_remote_media(request, server_name, media_id, name)
def _respond_local_file(self, request, media_id, name):
media_info = yield self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]:
respond_404(request)
return
media_type = media_info["media_type"]
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
if media_info["url_cache"]:
# TODO: Check the file still exists, if it doesn't we can redownload
# it from the url `media_info["url_cache"]`
file_path = self.filepaths.url_cache_filepath(media_id)
else:
file_path = self.filepaths.local_media_filepath(media_id)
yield respond_with_file(
request, media_type, file_path, media_length,
upload_name=upload_name,
)
@defer.inlineCallbacks
def _respond_remote_file(self, request, server_name, media_id, name):
# don't forward requests for remote media if allow_remote is false
allow_remote = synapse.http.servlet.parse_boolean(
request, "allow_remote", default=True)
if not allow_remote:
logger.info(
"Rejecting request for remote media %s/%s due to allow_remote",
server_name, media_id,
)
respond_404(request)
return
media_info = yield self.media_repo.get_remote_media(server_name, media_id)
media_type = media_info["media_type"]
media_length = media_info["media_length"]
filesystem_id = media_info["filesystem_id"]
upload_name = name if name else media_info["upload_name"]
file_path = self.filepaths.remote_media_filepath(
server_name, filesystem_id
)
yield respond_with_file(
request, media_type, file_path, media_length,
upload_name=upload_name,
)

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -18,6 +19,7 @@ import twisted.internet.error
import twisted.web.http import twisted.web.http
from twisted.web.resource import Resource from twisted.web.resource import Resource
from ._base import respond_404, FileInfo, respond_with_responder
from .upload_resource import UploadResource from .upload_resource import UploadResource
from .download_resource import DownloadResource from .download_resource import DownloadResource
from .thumbnail_resource import ThumbnailResource from .thumbnail_resource import ThumbnailResource
@ -25,15 +27,18 @@ from .identicon_resource import IdenticonResource
from .preview_url_resource import PreviewUrlResource from .preview_url_resource import PreviewUrlResource
from .filepath import MediaFilePaths from .filepath import MediaFilePaths
from .thumbnailer import Thumbnailer from .thumbnailer import Thumbnailer
from .storage_provider import StorageProviderWrapper
from .media_storage import MediaStorage
from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from synapse.api.errors import SynapseError, HttpResponseException, \ from synapse.api.errors import (
NotFoundError SynapseError, HttpResponseException, NotFoundError, FederationDeniedError,
)
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from synapse.util.stringutils import is_ascii from synapse.util.stringutils import is_ascii
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn from synapse.util.logcontext import make_deferred_yieldable
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
import os import os
@ -47,7 +52,7 @@ import urlparse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000 UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
class MediaRepository(object): class MediaRepository(object):
@ -63,96 +68,62 @@ class MediaRepository(object):
self.primary_base_path = hs.config.media_store_path self.primary_base_path = hs.config.media_store_path
self.filepaths = MediaFilePaths(self.primary_base_path) self.filepaths = MediaFilePaths(self.primary_base_path)
self.backup_base_path = hs.config.backup_media_store_path
self.synchronous_backup_media_store = hs.config.synchronous_backup_media_store
self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements self.thumbnail_requirements = hs.config.thumbnail_requirements
self.remote_media_linearizer = Linearizer(name="media_remote") self.remote_media_linearizer = Linearizer(name="media_remote")
self.recently_accessed_remotes = set() self.recently_accessed_remotes = set()
self.recently_accessed_locals = set()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
# List of StorageProviders where we should search for media and
# potentially upload to.
storage_providers = []
for clz, provider_config, wrapper_config in hs.config.media_storage_providers:
backend = clz(hs, provider_config)
provider = StorageProviderWrapper(
backend,
store_local=wrapper_config.store_local,
store_remote=wrapper_config.store_remote,
store_synchronous=wrapper_config.store_synchronous,
)
storage_providers.append(provider)
self.media_storage = MediaStorage(
self.primary_base_path, self.filepaths, storage_providers,
)
self.clock.looping_call( self.clock.looping_call(
self._update_recently_accessed_remotes, self._update_recently_accessed,
UPDATE_RECENTLY_ACCESSED_REMOTES_TS UPDATE_RECENTLY_ACCESSED_TS,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _update_recently_accessed_remotes(self): def _update_recently_accessed(self):
media = self.recently_accessed_remotes remote_media = self.recently_accessed_remotes
self.recently_accessed_remotes = set() self.recently_accessed_remotes = set()
local_media = self.recently_accessed_locals
self.recently_accessed_locals = set()
yield self.store.update_cached_last_access_time( yield self.store.update_cached_last_access_time(
media, self.clock.time_msec() local_media, remote_media, self.clock.time_msec()
) )
@staticmethod def mark_recently_accessed(self, server_name, media_id):
def _makedirs(filepath): """Mark the given media as recently accessed.
dirname = os.path.dirname(filepath)
if not os.path.exists(dirname):
os.makedirs(dirname)
@staticmethod
def _write_file_synchronously(source, fname):
"""Write `source` to the path `fname` synchronously. Should be called
from a thread.
Args: Args:
source: A file like object to be written server_name (str|None): Origin server of media, or None if local
fname (str): Path to write to media_id (str): The media ID of the content
""" """
MediaRepository._makedirs(fname) if server_name:
source.seek(0) # Ensure we read from the start of the file self.recently_accessed_remotes.add((server_name, media_id))
with open(fname, "wb") as f: else:
shutil.copyfileobj(source, f) self.recently_accessed_locals.add(media_id)
@defer.inlineCallbacks
def write_to_file_and_backup(self, source, path):
"""Write `source` to the on disk media store, and also the backup store
if configured.
Args:
source: A file like object that should be written
path (str): Relative path to write file to
Returns:
Deferred[str]: the file path written to in the primary media store
"""
fname = os.path.join(self.primary_base_path, path)
# Write to the main repository
yield make_deferred_yieldable(threads.deferToThread(
self._write_file_synchronously, source, fname,
))
# Write to backup repository
yield self.copy_to_backup(path)
defer.returnValue(fname)
@defer.inlineCallbacks
def copy_to_backup(self, path):
"""Copy a file from the primary to backup media store, if configured.
Args:
path(str): Relative path to write file to
"""
if self.backup_base_path:
primary_fname = os.path.join(self.primary_base_path, path)
backup_fname = os.path.join(self.backup_base_path, path)
# We can either wait for successful writing to the backup repository
# or write in the background and immediately return
if self.synchronous_backup_media_store:
yield make_deferred_yieldable(threads.deferToThread(
shutil.copyfile, primary_fname, backup_fname,
))
else:
preserve_fn(threads.deferToThread)(
shutil.copyfile, primary_fname, backup_fname,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def create_content(self, media_type, upload_name, content, content_length, def create_content(self, media_type, upload_name, content, content_length,
@ -171,10 +142,13 @@ class MediaRepository(object):
""" """
media_id = random_string(24) media_id = random_string(24)
fname = yield self.write_to_file_and_backup( file_info = FileInfo(
content, self.filepaths.local_media_filepath_rel(media_id) server_name=None,
file_id=media_id,
) )
fname = yield self.media_storage.store_file(content, file_info)
logger.info("Stored local media in file %r", fname) logger.info("Stored local media in file %r", fname)
yield self.store.store_local_media( yield self.store.store_local_media(
@ -185,134 +159,275 @@ class MediaRepository(object):
media_length=content_length, media_length=content_length,
user_id=auth_user, user_id=auth_user,
) )
media_info = {
"media_type": media_type,
"media_length": content_length,
}
yield self._generate_thumbnails(None, media_id, media_info) yield self._generate_thumbnails(
None, media_id, media_id, media_type,
)
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_remote_media(self, server_name, media_id): def get_local_media(self, request, media_id, name):
"""Responds to reqests for local media, if exists, or returns 404.
Args:
request(twisted.web.http.Request)
media_id (str): The media ID of the content. (This is the same as
the file_id for local content.)
name (str|None): Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response.
Returns:
Deferred: Resolves once a response has successfully been written
to request
"""
media_info = yield self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]:
respond_404(request)
return
self.mark_recently_accessed(None, media_id)
media_type = media_info["media_type"]
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
url_cache = media_info["url_cache"]
file_info = FileInfo(
None, media_id,
url_cache=url_cache,
)
responder = yield self.media_storage.fetch_media(file_info)
yield respond_with_responder(
request, responder, media_type, media_length, upload_name,
)
@defer.inlineCallbacks
def get_remote_media(self, request, server_name, media_id, name):
"""Respond to requests for remote media.
Args:
request(twisted.web.http.Request)
server_name (str): Remote server_name where the media originated.
media_id (str): The media ID of the content (as defined by the
remote server).
name (str|None): Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response.
Returns:
Deferred: Resolves once a response has successfully been written
to request
"""
if (
self.federation_domain_whitelist is not None and
server_name not in self.federation_domain_whitelist
):
raise FederationDeniedError(server_name)
self.mark_recently_accessed(server_name, media_id)
# We linearize here to ensure that we don't try and download remote
# media multiple times concurrently
key = (server_name, media_id) key = (server_name, media_id)
with (yield self.remote_media_linearizer.queue(key)): with (yield self.remote_media_linearizer.queue(key)):
media_info = yield self._get_remote_media_impl(server_name, media_id) responder, media_info = yield self._get_remote_media_impl(
server_name, media_id,
)
# We deliberately stream the file outside the lock
if responder:
media_type = media_info["media_type"]
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
yield respond_with_responder(
request, responder, media_type, media_length, upload_name,
)
else:
respond_404(request)
@defer.inlineCallbacks
def get_remote_media_info(self, server_name, media_id):
"""Gets the media info associated with the remote file, downloading
if necessary.
Args:
server_name (str): Remote server_name where the media originated.
media_id (str): The media ID of the content (as defined by the
remote server).
Returns:
Deferred[dict]: The media_info of the file
"""
if (
self.federation_domain_whitelist is not None and
server_name not in self.federation_domain_whitelist
):
raise FederationDeniedError(server_name)
# We linearize here to ensure that we don't try and download remote
# media multiple times concurrently
key = (server_name, media_id)
with (yield self.remote_media_linearizer.queue(key)):
responder, media_info = yield self._get_remote_media_impl(
server_name, media_id,
)
# Ensure we actually use the responder so that it releases resources
if responder:
with responder:
pass
defer.returnValue(media_info) defer.returnValue(media_info)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id): def _get_remote_media_impl(self, server_name, media_id):
"""Looks for media in local cache, if not there then attempt to
download from remote server.
Args:
server_name (str): Remote server_name where the media originated.
media_id (str): The media ID of the content (as defined by the
remote server).
Returns:
Deferred[(Responder, media_info)]
"""
media_info = yield self.store.get_cached_remote_media( media_info = yield self.store.get_cached_remote_media(
server_name, media_id server_name, media_id
) )
if not media_info:
media_info = yield self._download_remote_file( # file_id is the ID we use to track the file locally. If we've already
server_name, media_id # seen the file then reuse the existing ID, otherwise genereate a new
) # one.
elif media_info["quarantined_by"]: if media_info:
raise NotFoundError() file_id = media_info["filesystem_id"]
else: else:
self.recently_accessed_remotes.add((server_name, media_id)) file_id = random_string(24)
yield self.store.update_cached_last_access_time(
[(server_name, media_id)], self.clock.time_msec() file_info = FileInfo(server_name, file_id)
)
defer.returnValue(media_info) # If we have an entry in the DB, try and look for it
if media_info:
if media_info["quarantined_by"]:
logger.info("Media is quarantined")
raise NotFoundError()
responder = yield self.media_storage.fetch_media(file_info)
if responder:
defer.returnValue((responder, media_info))
# Failed to find the file anywhere, lets download it.
media_info = yield self._download_remote_file(
server_name, media_id, file_id
)
responder = yield self.media_storage.fetch_media(file_info)
defer.returnValue((responder, media_info))
@defer.inlineCallbacks @defer.inlineCallbacks
def _download_remote_file(self, server_name, media_id): def _download_remote_file(self, server_name, media_id, file_id):
file_id = random_string(24) """Attempt to download the remote file from the given server name,
using the given file_id as the local id.
fpath = self.filepaths.remote_media_filepath_rel( Args:
server_name, file_id server_name (str): Originating server
media_id (str): The media ID of the content (as defined by the
remote server). This is different than the file_id, which is
locally generated.
file_id (str): Local file ID
Returns:
Deferred[MediaInfo]
"""
file_info = FileInfo(
server_name=server_name,
file_id=file_id,
) )
fname = os.path.join(self.primary_base_path, fpath)
self._makedirs(fname)
try: with self.media_storage.store_into_file(file_info) as (f, fname, finish):
with open(fname, "wb") as f: request_path = "/".join((
request_path = "/".join(( "/_matrix/media/v1/download", server_name, media_id,
"/_matrix/media/v1/download", server_name, media_id, ))
)) try:
length, headers = yield self.client.get_file(
server_name, request_path, output_stream=f,
max_size=self.max_upload_size, args={
# tell the remote server to 404 if it doesn't
# recognise the server_name, to make sure we don't
# end up with a routing loop.
"allow_remote": "false",
}
)
except twisted.internet.error.DNSLookupError as e:
logger.warn("HTTP error fetching remote media %s/%s: %r",
server_name, media_id, e)
raise NotFoundError()
except HttpResponseException as e:
logger.warn("HTTP error fetching remote media %s/%s: %s",
server_name, media_id, e.response)
if e.code == twisted.web.http.NOT_FOUND:
raise SynapseError.from_http_response_exception(e)
raise SynapseError(502, "Failed to fetch remote media")
except SynapseError:
logger.exception("Failed to fetch remote media %s/%s",
server_name, media_id)
raise
except NotRetryingDestination:
logger.warn("Not retrying destination %r", server_name)
raise SynapseError(502, "Failed to fetch remote media")
except Exception:
logger.exception("Failed to fetch remote media %s/%s",
server_name, media_id)
raise SynapseError(502, "Failed to fetch remote media")
yield finish()
media_type = headers["Content-Type"][0]
time_now_ms = self.clock.time_msec()
content_disposition = headers.get("Content-Disposition", None)
if content_disposition:
_, params = cgi.parse_header(content_disposition[0],)
upload_name = None
# First check if there is a valid UTF-8 filename
upload_name_utf8 = params.get("filename*", None)
if upload_name_utf8:
if upload_name_utf8.lower().startswith("utf-8''"):
upload_name = upload_name_utf8[7:]
# If there isn't check for an ascii name.
if not upload_name:
upload_name_ascii = params.get("filename", None)
if upload_name_ascii and is_ascii(upload_name_ascii):
upload_name = upload_name_ascii
if upload_name:
upload_name = urlparse.unquote(upload_name)
try: try:
length, headers = yield self.client.get_file( upload_name = upload_name.decode("utf-8")
server_name, request_path, output_stream=f, except UnicodeDecodeError:
max_size=self.max_upload_size, args={ upload_name = None
# tell the remote server to 404 if it doesn't else:
# recognise the server_name, to make sure we don't upload_name = None
# end up with a routing loop.
"allow_remote": "false",
}
)
except twisted.internet.error.DNSLookupError as e:
logger.warn("HTTP error fetching remote media %s/%s: %r",
server_name, media_id, e)
raise NotFoundError()
except HttpResponseException as e: logger.info("Stored remote media in file %r", fname)
logger.warn("HTTP error fetching remote media %s/%s: %s",
server_name, media_id, e.response)
if e.code == twisted.web.http.NOT_FOUND:
raise SynapseError.from_http_response_exception(e)
raise SynapseError(502, "Failed to fetch remote media")
except SynapseError: yield self.store.store_cached_remote_media(
logger.exception("Failed to fetch remote media %s/%s", origin=server_name,
server_name, media_id) media_id=media_id,
raise media_type=media_type,
except NotRetryingDestination: time_now_ms=self.clock.time_msec(),
logger.warn("Not retrying destination %r", server_name) upload_name=upload_name,
raise SynapseError(502, "Failed to fetch remote media") media_length=length,
except Exception: filesystem_id=file_id,
logger.exception("Failed to fetch remote media %s/%s", )
server_name, media_id)
raise SynapseError(502, "Failed to fetch remote media")
yield self.copy_to_backup(fpath)
media_type = headers["Content-Type"][0]
time_now_ms = self.clock.time_msec()
content_disposition = headers.get("Content-Disposition", None)
if content_disposition:
_, params = cgi.parse_header(content_disposition[0],)
upload_name = None
# First check if there is a valid UTF-8 filename
upload_name_utf8 = params.get("filename*", None)
if upload_name_utf8:
if upload_name_utf8.lower().startswith("utf-8''"):
upload_name = upload_name_utf8[7:]
# If there isn't check for an ascii name.
if not upload_name:
upload_name_ascii = params.get("filename", None)
if upload_name_ascii and is_ascii(upload_name_ascii):
upload_name = upload_name_ascii
if upload_name:
upload_name = urlparse.unquote(upload_name)
try:
upload_name = upload_name.decode("utf-8")
except UnicodeDecodeError:
upload_name = None
else:
upload_name = None
logger.info("Stored remote media in file %r", fname)
yield self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
)
except Exception:
os.remove(fname)
raise
media_info = { media_info = {
"media_type": media_type, "media_type": media_type,
@ -323,7 +438,7 @@ class MediaRepository(object):
} }
yield self._generate_thumbnails( yield self._generate_thumbnails(
server_name, media_id, media_info server_name, media_id, file_id, media_type,
) )
defer.returnValue(media_info) defer.returnValue(media_info)
@ -357,8 +472,10 @@ class MediaRepository(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def generate_local_exact_thumbnail(self, media_id, t_width, t_height, def generate_local_exact_thumbnail(self, media_id, t_width, t_height,
t_method, t_type): t_method, t_type, url_cache):
input_path = self.filepaths.local_media_filepath(media_id) input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
None, media_id, url_cache=url_cache,
))
thumbnailer = Thumbnailer(input_path) thumbnailer = Thumbnailer(input_path)
t_byte_source = yield make_deferred_yieldable(threads.deferToThread( t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
@ -368,11 +485,19 @@ class MediaRepository(object):
if t_byte_source: if t_byte_source:
try: try:
output_path = yield self.write_to_file_and_backup( file_info = FileInfo(
t_byte_source, server_name=None,
self.filepaths.local_media_thumbnail_rel( file_id=media_id,
media_id, t_width, t_height, t_type, t_method url_cache=url_cache,
) thumbnail=True,
thumbnail_width=t_width,
thumbnail_height=t_height,
thumbnail_method=t_method,
thumbnail_type=t_type,
)
output_path = yield self.media_storage.store_file(
t_byte_source, file_info,
) )
finally: finally:
t_byte_source.close() t_byte_source.close()
@ -390,7 +515,9 @@ class MediaRepository(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def generate_remote_exact_thumbnail(self, server_name, file_id, media_id, def generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
t_width, t_height, t_method, t_type): t_width, t_height, t_method, t_type):
input_path = self.filepaths.remote_media_filepath(server_name, file_id) input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
server_name, file_id, url_cache=False,
))
thumbnailer = Thumbnailer(input_path) thumbnailer = Thumbnailer(input_path)
t_byte_source = yield make_deferred_yieldable(threads.deferToThread( t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
@ -400,11 +527,18 @@ class MediaRepository(object):
if t_byte_source: if t_byte_source:
try: try:
output_path = yield self.write_to_file_and_backup( file_info = FileInfo(
t_byte_source, server_name=server_name,
self.filepaths.remote_media_thumbnail_rel( file_id=media_id,
server_name, file_id, t_width, t_height, t_type, t_method thumbnail=True,
) thumbnail_width=t_width,
thumbnail_height=t_height,
thumbnail_method=t_method,
thumbnail_type=t_type,
)
output_path = yield self.media_storage.store_file(
t_byte_source, file_info,
) )
finally: finally:
t_byte_source.close() t_byte_source.close()
@ -421,31 +555,29 @@ class MediaRepository(object):
defer.returnValue(output_path) defer.returnValue(output_path)
@defer.inlineCallbacks @defer.inlineCallbacks
def _generate_thumbnails(self, server_name, media_id, media_info, url_cache=False): def _generate_thumbnails(self, server_name, media_id, file_id, media_type,
url_cache=False):
"""Generate and store thumbnails for an image. """Generate and store thumbnails for an image.
Args: Args:
server_name(str|None): The server name if remote media, else None if local server_name (str|None): The server name if remote media, else None if local
media_id(str) media_id (str): The media ID of the content. (This is the same as
media_info(dict) the file_id for local content)
url_cache(bool): If we are thumbnailing images downloaded for the URL cache, file_id (str): Local file ID
media_type (str): The content type of the file
url_cache (bool): If we are thumbnailing images downloaded for the URL cache,
used exclusively by the url previewer used exclusively by the url previewer
Returns: Returns:
Deferred[dict]: Dict with "width" and "height" keys of original image Deferred[dict]: Dict with "width" and "height" keys of original image
""" """
media_type = media_info["media_type"]
file_id = media_info.get("filesystem_id")
requirements = self._get_thumbnail_requirements(media_type) requirements = self._get_thumbnail_requirements(media_type)
if not requirements: if not requirements:
return return
if server_name: input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
input_path = self.filepaths.remote_media_filepath(server_name, file_id) server_name, file_id, url_cache=url_cache,
elif url_cache: ))
input_path = self.filepaths.url_cache_filepath(media_id)
else:
input_path = self.filepaths.local_media_filepath(media_id)
thumbnailer = Thumbnailer(input_path) thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width m_width = thumbnailer.width
@ -472,20 +604,6 @@ class MediaRepository(object):
# Now we generate the thumbnails for each dimension, store it # Now we generate the thumbnails for each dimension, store it
for (t_width, t_height, t_type), t_method in thumbnails.iteritems(): for (t_width, t_height, t_type), t_method in thumbnails.iteritems():
# Work out the correct file name for thumbnail
if server_name:
file_path = self.filepaths.remote_media_thumbnail_rel(
server_name, file_id, t_width, t_height, t_type, t_method
)
elif url_cache:
file_path = self.filepaths.url_cache_thumbnail_rel(
media_id, t_width, t_height, t_type, t_method
)
else:
file_path = self.filepaths.local_media_thumbnail_rel(
media_id, t_width, t_height, t_type, t_method
)
# Generate the thumbnail # Generate the thumbnail
if t_method == "crop": if t_method == "crop":
t_byte_source = yield make_deferred_yieldable(threads.deferToThread( t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
@ -505,9 +623,19 @@ class MediaRepository(object):
continue continue
try: try:
# Write to disk file_info = FileInfo(
output_path = yield self.write_to_file_and_backup( server_name=server_name,
t_byte_source, file_path, file_id=file_id,
thumbnail=True,
thumbnail_width=t_width,
thumbnail_height=t_height,
thumbnail_method=t_method,
thumbnail_type=t_type,
url_cache=url_cache,
)
output_path = yield self.media_storage.store_file(
t_byte_source, file_info,
) )
finally: finally:
t_byte_source.close() t_byte_source.close()
@ -620,7 +748,11 @@ class MediaRepositoryResource(Resource):
self.putChild("upload", UploadResource(hs, media_repo)) self.putChild("upload", UploadResource(hs, media_repo))
self.putChild("download", DownloadResource(hs, media_repo)) self.putChild("download", DownloadResource(hs, media_repo))
self.putChild("thumbnail", ThumbnailResource(hs, media_repo)) self.putChild("thumbnail", ThumbnailResource(
hs, media_repo, media_repo.media_storage,
))
self.putChild("identicon", IdenticonResource()) self.putChild("identicon", IdenticonResource())
if hs.config.url_preview_enabled: if hs.config.url_preview_enabled:
self.putChild("preview_url", PreviewUrlResource(hs, media_repo)) self.putChild("preview_url", PreviewUrlResource(
hs, media_repo, media_repo.media_storage,
))

Some files were not shown because too many files have changed in this diff Show More