diff --git a/CHANGES.rst b/CHANGES.rst index 7ebb42b0f..49673ccce 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,3 +1,50 @@ +Changes in synapse v0.17.1 (2016-08-24) +======================================= + +Changes: + +* Delete old received_transactions rows (PR #1038) +* Pass through user-supplied content in /join/$room_id (PR #1039) + + +Bug fixes: + +* Fix bug with backfill (PR #1040) + + +Changes in synapse v0.17.1-rc1 (2016-08-22) +=========================================== + +Features: + +* Add notification API (PR #1028) + + +Changes: + +* Don't print stack traces when failing to get remote keys (PR #996) +* Various federation /event/ perf improvements (PR #998) +* Only process one local membership event per room at a time (PR #1005) +* Move default display name push rule (PR #1011, #1023) +* Fix up preview URL API. Add tests. (PR #1015) +* Set ``Content-Security-Policy`` on media repo (PR #1021) +* Make notify_interested_services faster (PR #1022) +* Add usage stats to prometheus monitoring (PR #1037) + + +Bug fixes: + +* Fix token login (PR #993) +* Fix CAS login (PR #994, #995) +* Fix /sync to not clobber status_msg (PR #997) +* Fix redacted state events to include prev_content (PR #1003) +* Fix some bugs in the auth/ldap handler (PR #1007) +* Fix backfill request to limit URI length, so that remotes don't reject the + requests due to path length limits (PR #1012) +* Fix AS push code to not send duplicate events (PR #1025) + + + Changes in synapse v0.17.0 (2016-08-08) ======================================= diff --git a/README.rst b/README.rst index d65867083..172dd4dfa 100644 --- a/README.rst +++ b/README.rst @@ -95,7 +95,7 @@ Synapse is the reference python/twisted Matrix homeserver implementation. System requirements: - POSIX-compliant system (tested on Linux & OS X) - Python 2.7 -- At least 512 MB RAM. +- At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org Synapse is written in python but some of the libraries is uses are written in C. So before we can install synapse itself we need a working C compiler and the diff --git a/docs/workers.rst b/docs/workers.rst new file mode 100644 index 000000000..4eb05b0e5 --- /dev/null +++ b/docs/workers.rst @@ -0,0 +1,97 @@ +Scaling synapse via workers +--------------------------- + +Synapse has experimental support for splitting out functionality into +multiple separate python processes, helping greatly with scalability. These +processes are called 'workers', and are (eventually) intended to scale +horizontally independently. + +All processes continue to share the same database instance, and as such, workers +only work with postgres based synapse deployments (sharing a single sqlite +across multiple processes is a recipe for disaster, plus you should be using +postgres anyway if you care about scalability). + +The workers communicate with the master synapse process via a synapse-specific +HTTP protocol called 'replication' - analogous to MySQL or Postgres style +database replication; feeding a stream of relevant data to the workers so they +can be kept in sync with the main synapse process and database state. + +To enable workers, you need to add a replication listener to the master synapse, e.g.:: + + listeners: + - port: 9092 + bind_address: '127.0.0.1' + type: http + tls: false + x_forwarded: false + resources: + - names: [replication] + compress: false + +Under **no circumstances** should this replication API listener be exposed to the +public internet; it currently implements no authentication whatsoever and is +unencrypted HTTP. + +You then create a set of configs for the various worker processes. These should be +worker configuration files should be stored in a dedicated subdirectory, to allow +synctl to manipulate them. + +The current available worker applications are: + * synapse.app.pusher - handles sending push notifications to sygnal and email + * synapse.app.synchrotron - handles /sync endpoints. can scales horizontally through multiple instances. + * synapse.app.appservice - handles output traffic to Application Services + * synapse.app.federation_reader - handles receiving federation traffic (including public_rooms API) + * synapse.app.media_repository - handles the media repository. + +Each worker configuration file inherits the configuration of the main homeserver +configuration file. You can then override configuration specific to that worker, +e.g. the HTTP listener that it provides (if any); logging configuration; etc. +You should minimise the number of overrides though to maintain a usable config. + +You must specify the type of worker application (worker_app) and the replication +endpoint that it's talking to on the main synapse process (worker_replication_url). + +For instance:: + + worker_app: synapse.app.synchrotron + + # The replication listener on the synapse to talk to. + worker_replication_url: http://127.0.0.1:9092/_synapse/replication + + worker_listeners: + - type: http + port: 8083 + resources: + - names: + - client + + worker_daemonize: True + worker_pid_file: /home/matrix/synapse/synchrotron.pid + worker_log_config: /home/matrix/synapse/config/synchrotron_log_config.yaml + +...is a full configuration for a synchrotron worker instance, which will expose a +plain HTTP /sync endpoint on port 8083 separately from the /sync endpoint provided +by the main synapse. + +Obviously you should configure your loadbalancer to route the /sync endpoint to +the synchrotron instance(s) in this instance. + +Finally, to actually run your worker-based synapse, you must pass synctl the -a +commandline option to tell it to operate on all the worker configurations found +in the given directory, e.g.:: + + synctl -a $CONFIG/workers start + +Currently one should always restart all workers when restarting or upgrading +synapse, unless you explicitly know it's safe not to. For instance, restarting +synapse without restarting all the synchrotrons may result in broken typing +notifications. + +To manipulate a specific worker, you pass the -w option to synctl:: + + synctl -w $CONFIG/workers/synchrotron.yaml restart + +All of the above is highly experimental and subject to change as Synapse evolves, +but documenting it here to help folks needing highly scalable Synapses similar +to the one running matrix.org! + diff --git a/jenkins-unittests.sh b/jenkins-unittests.sh index 6b0c296cf..4c2f103e8 100755 --- a/jenkins-unittests.sh +++ b/jenkins-unittests.sh @@ -25,5 +25,6 @@ rm .coverage* || echo "No coverage files to remove" tox --notest -e py27 TOX_BIN=$WORKSPACE/.tox/py27/bin python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install +$TOX_BIN/pip install lxml tox -e py27 diff --git a/jenkins/prepare_synapse.sh b/jenkins/prepare_synapse.sh index 237223c81..6c26c5842 100755 --- a/jenkins/prepare_synapse.sh +++ b/jenkins/prepare_synapse.sh @@ -14,6 +14,7 @@ fi tox -e py27 --notest -v TOX_BIN=$TOX_DIR/py27/bin +$TOX_BIN/pip install setuptools python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install $TOX_BIN/pip install lxml $TOX_BIN/pip install psycopg2 diff --git a/synapse/__init__.py b/synapse/__init__.py index a63ee565c..43bf78f88 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.17.0" +__version__ = "0.17.1" diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 59db76deb..0db26fcfd 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -675,27 +675,18 @@ class Auth(object): try: macaroon = pymacaroons.Macaroon.deserialize(macaroon_str) - user_prefix = "user_id = " - user = None - user_id = None - guest = False - for caveat in macaroon.caveats: - if caveat.caveat_id.startswith(user_prefix): - user_id = caveat.caveat_id[len(user_prefix):] - user = UserID.from_string(user_id) - elif caveat.caveat_id == "guest = true": - guest = True + user_id = self.get_user_id_from_macaroon(macaroon) + user = UserID.from_string(user_id) self.validate_macaroon( macaroon, rights, self.hs.config.expire_access_token, user_id=user_id, ) - if user is None: - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon", - errcode=Codes.UNKNOWN_TOKEN - ) + guest = False + for caveat in macaroon.caveats: + if caveat.caveat_id == "guest = true": + guest = True if guest: ret = { @@ -743,6 +734,29 @@ class Auth(object): errcode=Codes.UNKNOWN_TOKEN ) + def get_user_id_from_macaroon(self, macaroon): + """Retrieve the user_id given by the caveats on the macaroon. + + Does *not* validate the macaroon. + + Args: + macaroon (pymacaroons.Macaroon): The macaroon to validate + + Returns: + (str) user id + + Raises: + AuthError if there is no user_id caveat in the macaroon + """ + user_prefix = "user_id = " + for caveat in macaroon.caveats: + if caveat.caveat_id.startswith(user_prefix): + return caveat.caveat_id[len(user_prefix):] + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon", + errcode=Codes.UNKNOWN_TOKEN + ) + def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id): """ validate that a Macaroon is understood by and was signed by this server. @@ -754,6 +768,7 @@ class Auth(object): verify_expiry(bool): Whether to verify whether the macaroon has expired. This should really always be True, but no clients currently implement token refresh, so we can't enforce expiry yet. + user_id (str): The user_id required """ v = pymacaroons.Verifier() v.satisfy_exact("gen = 1") diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py new file mode 100644 index 000000000..57587aed2 --- /dev/null +++ b/synapse/app/appservice.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import synapse + +from synapse.server import HomeServer +from synapse.config._base import ConfigError +from synapse.config.logger import setup_logging +from synapse.config.homeserver import HomeServerConfig +from synapse.http.site import SynapseSite +from synapse.metrics.resource import MetricsResource, METRICS_PREFIX +from synapse.replication.slave.storage.directory import DirectoryStore +from synapse.replication.slave.storage.events import SlavedEventStore +from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore +from synapse.replication.slave.storage.registration import SlavedRegistrationStore +from synapse.storage.engines import create_engine +from synapse.util.async import sleep +from synapse.util.httpresourcetree import create_resource_tree +from synapse.util.logcontext import LoggingContext +from synapse.util.manhole import manhole +from synapse.util.rlimit import change_resource_limit +from synapse.util.versionstring import get_version_string + +from twisted.internet import reactor, defer +from twisted.web.resource import Resource + +from daemonize import Daemonize + +import sys +import logging +import gc + +logger = logging.getLogger("synapse.app.appservice") + + +class AppserviceSlaveStore( + DirectoryStore, SlavedEventStore, SlavedApplicationServiceStore, + SlavedRegistrationStore, +): + pass + + +class AppserviceServer(HomeServer): + def get_db_conn(self, run_new_connection=True): + # Any param beginning with cp_ is a parameter for adbapi, and should + # not be passed to the database engine. + db_params = { + k: v for k, v in self.db_config.get("args", {}).items() + if not k.startswith("cp_") + } + db_conn = self.database_engine.module.connect(**db_params) + + if run_new_connection: + self.database_engine.on_new_connection(db_conn) + return db_conn + + def setup(self): + logger.info("Setting up.") + self.datastore = AppserviceSlaveStore(self.get_db_conn(), self) + logger.info("Finished setting up.") + + def _listen_http(self, listener_config): + port = listener_config["port"] + bind_address = listener_config.get("bind_address", "") + site_tag = listener_config.get("tag", port) + resources = {} + for res in listener_config["resources"]: + for name in res["names"]: + if name == "metrics": + resources[METRICS_PREFIX] = MetricsResource(self) + + root_resource = create_resource_tree(resources, Resource()) + reactor.listenTCP( + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + interface=bind_address + ) + logger.info("Synapse appservice now listening on port %d", port) + + def start_listening(self, listeners): + for listener in listeners: + if listener["type"] == "http": + self._listen_http(listener) + elif listener["type"] == "manhole": + reactor.listenTCP( + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ), + interface=listener.get("bind_address", '127.0.0.1') + ) + else: + logger.warn("Unrecognized listener type: %s", listener["type"]) + + @defer.inlineCallbacks + def replicate(self): + http_client = self.get_simple_http_client() + store = self.get_datastore() + replication_url = self.config.worker_replication_url + appservice_handler = self.get_application_service_handler() + + @defer.inlineCallbacks + def replicate(results): + stream = results.get("events") + if stream: + max_stream_id = stream["position"] + yield appservice_handler.notify_interested_services(max_stream_id) + + while True: + try: + args = store.stream_positions() + args["timeout"] = 30000 + result = yield http_client.get_json(replication_url, args=args) + yield store.process_replication(result) + replicate(result) + except: + logger.exception("Error replicating from %r", replication_url) + yield sleep(30) + + +def start(config_options): + try: + config = HomeServerConfig.load_config( + "Synapse appservice", config_options + ) + except ConfigError as e: + sys.stderr.write("\n" + e.message + "\n") + sys.exit(1) + + assert config.worker_app == "synapse.app.appservice" + + setup_logging(config.worker_log_config, config.worker_log_file) + + database_engine = create_engine(config.database_config) + + if config.notify_appservices: + sys.stderr.write( + "\nThe appservices must be disabled in the main synapse process" + "\nbefore they can be run in a separate worker." + "\nPlease add ``notify_appservices: false`` to the main config" + "\n" + ) + sys.exit(1) + + # Force the pushers to start since they will be disabled in the main config + config.notify_appservices = True + + ps = AppserviceServer( + config.server_name, + db_config=config.database_config, + config=config, + version_string="Synapse/" + get_version_string(synapse), + database_engine=database_engine, + ) + + ps.setup() + ps.start_listening(config.worker_listeners) + + def run(): + with LoggingContext("run"): + logger.info("Running") + change_resource_limit(config.soft_file_limit) + if config.gc_thresholds: + gc.set_threshold(*config.gc_thresholds) + reactor.run() + + def start(): + ps.replicate() + ps.get_datastore().start_profiling() + + reactor.callWhenRunning(start) + + if config.worker_daemonize: + daemon = Daemonize( + app="synapse-appservice", + pid=config.worker_pid_file, + action=run, + auto_close_fds=False, + verbose=True, + logger=logger, + ) + daemon.start() + else: + run() + + +if __name__ == '__main__': + with LoggingContext("main"): + start(sys.argv[1:]) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 40e6f6523..54f35900f 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -51,7 +51,7 @@ from synapse.api.urls import ( from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory from synapse.util.logcontext import LoggingContext -from synapse.metrics import register_memory_metrics +from synapse.metrics import register_memory_metrics, get_metrics_for from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX from synapse.federation.transport.server import TransportLayerServer @@ -385,6 +385,8 @@ def run(hs): start_time = hs.get_clock().time() + stats = {} + @defer.inlineCallbacks def phone_stats_home(): logger.info("Gathering stats for reporting") @@ -393,7 +395,10 @@ def run(hs): if uptime < 0: uptime = 0 - stats = {} + # If the stats directory is empty then this is the first time we've + # reported stats. + first_time = not stats + stats["homeserver"] = hs.config.server_name stats["timestamp"] = now stats["uptime_seconds"] = uptime @@ -406,6 +411,25 @@ def run(hs): daily_messages = yield hs.get_datastore().count_daily_messages() if daily_messages is not None: stats["daily_messages"] = daily_messages + else: + stats.pop("daily_messages", None) + + if first_time: + # Add callbacks to report the synapse stats as metrics whenever + # prometheus requests them, typically every 30s. + # As some of the stats are expensive to calculate we only update + # them when synapse phones home to matrix.org every 24 hours. + metrics = get_metrics_for("synapse.usage") + metrics.add_callback("timestamp", lambda: stats["timestamp"]) + metrics.add_callback("uptime_seconds", lambda: stats["uptime_seconds"]) + metrics.add_callback("total_users", lambda: stats["total_users"]) + metrics.add_callback("total_room_count", lambda: stats["total_room_count"]) + metrics.add_callback( + "daily_active_users", lambda: stats["daily_active_users"] + ) + metrics.add_callback( + "daily_messages", lambda: stats.get("daily_messages", 0) + ) logger.info("Reporting stats to matrix.org: %s" % (stats,)) try: diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py new file mode 100644 index 000000000..9d4c4a075 --- /dev/null +++ b/synapse/app/media_repository.py @@ -0,0 +1,212 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import synapse + +from synapse.config._base import ConfigError +from synapse.config.homeserver import HomeServerConfig +from synapse.config.logger import setup_logging +from synapse.http.site import SynapseSite +from synapse.metrics.resource import MetricsResource, METRICS_PREFIX +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore +from synapse.replication.slave.storage.registration import SlavedRegistrationStore +from synapse.rest.media.v0.content_repository import ContentRepoResource +from synapse.rest.media.v1.media_repository import MediaRepositoryResource +from synapse.server import HomeServer +from synapse.storage.client_ips import ClientIpStore +from synapse.storage.engines import create_engine +from synapse.storage.media_repository import MediaRepositoryStore +from synapse.util.async import sleep +from synapse.util.httpresourcetree import create_resource_tree +from synapse.util.logcontext import LoggingContext +from synapse.util.manhole import manhole +from synapse.util.rlimit import change_resource_limit +from synapse.util.versionstring import get_version_string +from synapse.api.urls import ( + CONTENT_REPO_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX +) +from synapse.crypto import context_factory + + +from twisted.internet import reactor, defer +from twisted.web.resource import Resource + +from daemonize import Daemonize + +import sys +import logging +import gc + +logger = logging.getLogger("synapse.app.media_repository") + + +class MediaRepositorySlavedStore( + SlavedApplicationServiceStore, + SlavedRegistrationStore, + BaseSlavedStore, + MediaRepositoryStore, + ClientIpStore, +): + pass + + +class MediaRepositoryServer(HomeServer): + def get_db_conn(self, run_new_connection=True): + # Any param beginning with cp_ is a parameter for adbapi, and should + # not be passed to the database engine. + db_params = { + k: v for k, v in self.db_config.get("args", {}).items() + if not k.startswith("cp_") + } + db_conn = self.database_engine.module.connect(**db_params) + + if run_new_connection: + self.database_engine.on_new_connection(db_conn) + return db_conn + + def setup(self): + logger.info("Setting up.") + self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self) + logger.info("Finished setting up.") + + def _listen_http(self, listener_config): + port = listener_config["port"] + bind_address = listener_config.get("bind_address", "") + site_tag = listener_config.get("tag", port) + resources = {} + for res in listener_config["resources"]: + for name in res["names"]: + if name == "metrics": + resources[METRICS_PREFIX] = MetricsResource(self) + elif name == "media": + media_repo = MediaRepositoryResource(self) + resources.update({ + MEDIA_PREFIX: media_repo, + LEGACY_MEDIA_PREFIX: media_repo, + CONTENT_REPO_PREFIX: ContentRepoResource( + self, self.config.uploads_path + ), + }) + + root_resource = create_resource_tree(resources, Resource()) + reactor.listenTCP( + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + interface=bind_address + ) + logger.info("Synapse media repository now listening on port %d", port) + + def start_listening(self, listeners): + for listener in listeners: + if listener["type"] == "http": + self._listen_http(listener) + elif listener["type"] == "manhole": + reactor.listenTCP( + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ), + interface=listener.get("bind_address", '127.0.0.1') + ) + else: + logger.warn("Unrecognized listener type: %s", listener["type"]) + + @defer.inlineCallbacks + def replicate(self): + http_client = self.get_simple_http_client() + store = self.get_datastore() + replication_url = self.config.worker_replication_url + + while True: + try: + args = store.stream_positions() + args["timeout"] = 30000 + result = yield http_client.get_json(replication_url, args=args) + yield store.process_replication(result) + except: + logger.exception("Error replicating from %r", replication_url) + yield sleep(5) + + +def start(config_options): + try: + config = HomeServerConfig.load_config( + "Synapse media repository", config_options + ) + except ConfigError as e: + sys.stderr.write("\n" + e.message + "\n") + sys.exit(1) + + assert config.worker_app == "synapse.app.media_repository" + + setup_logging(config.worker_log_config, config.worker_log_file) + + database_engine = create_engine(config.database_config) + + tls_server_context_factory = context_factory.ServerContextFactory(config) + + ss = MediaRepositoryServer( + config.server_name, + db_config=config.database_config, + tls_server_context_factory=tls_server_context_factory, + config=config, + version_string="Synapse/" + get_version_string(synapse), + database_engine=database_engine, + ) + + ss.setup() + ss.get_handlers() + ss.start_listening(config.worker_listeners) + + def run(): + with LoggingContext("run"): + logger.info("Running") + change_resource_limit(config.soft_file_limit) + if config.gc_thresholds: + gc.set_threshold(*config.gc_thresholds) + reactor.run() + + def start(): + ss.get_datastore().start_profiling() + ss.replicate() + + reactor.callWhenRunning(start) + + if config.worker_daemonize: + daemon = Daemonize( + app="synapse-media-repository", + pid=config.worker_pid_file, + action=run, + auto_close_fds=False, + verbose=True, + logger=logger, + ) + daemon.start() + else: + run() + + +if __name__ == '__main__': + with LoggingContext("main"): + start(sys.argv[1:]) diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index c8dde0fcb..8d755a4b3 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -80,11 +80,6 @@ class PusherSlaveStore( DataStore.get_profile_displayname.__func__ ) - # XXX: This is a bit broken because we don't persist forgotten rooms - # in a way that they can be streamed. This means that we don't have a - # way to invalidate the forgotten rooms cache correctly. - # For now we expire the cache every 10 minutes. - BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000 who_forgot_in_room = ( RoomMemberStore.__dict__["who_forgot_in_room"] ) @@ -168,7 +163,6 @@ class PusherServer(HomeServer): store = self.get_datastore() replication_url = self.config.worker_replication_url pusher_pool = self.get_pusherpool() - clock = self.get_clock() def stop_pusher(user_id, app_id, pushkey): key = "%s:%s" % (app_id, pushkey) @@ -220,21 +214,11 @@ class PusherServer(HomeServer): min_stream_id, max_stream_id, affected_room_ids ) - def expire_broken_caches(): - store.who_forgot_in_room.invalidate_all() - - next_expire_broken_caches_ms = 0 while True: try: args = store.stream_positions() args["timeout"] = 30000 result = yield http_client.get_json(replication_url, args=args) - now_ms = clock.time_msec() - if now_ms > next_expire_broken_caches_ms: - expire_broken_caches() - next_expire_broken_caches_ms = ( - now_ms + store.BROKEN_CACHE_EXPIRY_MS - ) yield store.process_replication(result) poke_pushers(result) except: diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 215ccfd52..e3173533e 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -26,6 +26,7 @@ from synapse.http.site import SynapseSite from synapse.http.server import JsonResource from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.rest.client.v2_alpha import sync +from synapse.rest.client.v1 import events from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore @@ -74,11 +75,6 @@ class SynchrotronSlavedStore( BaseSlavedStore, ClientIpStore, # After BaseSlavedStore because the constructor is different ): - # XXX: This is a bit broken because we don't persist forgotten rooms - # in a way that they can be streamed. This means that we don't have a - # way to invalidate the forgotten rooms cache correctly. - # For now we expire the cache every 10 minutes. - BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000 who_forgot_in_room = ( RoomMemberStore.__dict__["who_forgot_in_room"] ) @@ -89,17 +85,23 @@ class SynchrotronSlavedStore( get_presence_list_accepted = PresenceStore.__dict__[ "get_presence_list_accepted" ] + get_presence_list_observers_accepted = PresenceStore.__dict__[ + "get_presence_list_observers_accepted" + ] + UPDATE_SYNCING_USERS_MS = 10 * 1000 class SynchrotronPresence(object): def __init__(self, hs): + self.is_mine_id = hs.is_mine_id self.http_client = hs.get_simple_http_client() self.store = hs.get_datastore() self.user_to_num_current_syncs = {} self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users" self.clock = hs.get_clock() + self.notifier = hs.get_notifier() active_presence = self.store.take_presence_startup_info() self.user_to_current_state = { @@ -119,11 +121,13 @@ class SynchrotronPresence(object): reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown) - def set_state(self, user, state): + def set_state(self, user, state, ignore_status_msg=False): # TODO Hows this supposed to work? pass get_states = PresenceHandler.get_states.__func__ + get_state = PresenceHandler.get_state.__func__ + _get_interested_parties = PresenceHandler._get_interested_parties.__func__ current_state_for_users = PresenceHandler.current_state_for_users.__func__ @defer.inlineCallbacks @@ -194,19 +198,39 @@ class SynchrotronPresence(object): self._need_to_send_sync = False yield self._send_syncing_users_now() + @defer.inlineCallbacks + def notify_from_replication(self, states, stream_id): + parties = yield self._get_interested_parties( + states, calculate_remote_hosts=False + ) + room_ids_to_states, users_to_states, _ = parties + + self.notifier.on_new_event( + "presence_key", stream_id, rooms=room_ids_to_states.keys(), + users=users_to_states.keys() + ) + + @defer.inlineCallbacks def process_replication(self, result): stream = result.get("presence", {"rows": []}) + states = [] for row in stream["rows"]: ( position, user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, status_msg, currently_active ) = row - self.user_to_current_state[user_id] = UserPresenceState( + state = UserPresenceState( user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, status_msg, currently_active ) + self.user_to_current_state[user_id] = state + states.append(state) + + if states and "position" in stream: + stream_id = int(stream["position"]) + yield self.notify_from_replication(states, stream_id) class SynchrotronTyping(object): @@ -266,10 +290,12 @@ class SynchrotronServer(HomeServer): elif name == "client": resource = JsonResource(self, canonical_json=False) sync.register_servlets(self, resource) + events.register_servlets(self, resource) resources.update({ "/_matrix/client/r0": resource, "/_matrix/client/unstable": resource, "/_matrix/client/v2_alpha": resource, + "/_matrix/client/api/v1": resource, }) root_resource = create_resource_tree(resources, Resource()) @@ -307,15 +333,10 @@ class SynchrotronServer(HomeServer): http_client = self.get_simple_http_client() store = self.get_datastore() replication_url = self.config.worker_replication_url - clock = self.get_clock() notifier = self.get_notifier() presence_handler = self.get_presence_handler() typing_handler = self.get_typing_handler() - def expire_broken_caches(): - store.who_forgot_in_room.invalidate_all() - store.get_presence_list_accepted.invalidate_all() - def notify_from_stream( result, stream_name, stream_key, room=None, user=None ): @@ -377,22 +398,15 @@ class SynchrotronServer(HomeServer): result, "typing", "typing_key", room="room_id" ) - next_expire_broken_caches_ms = 0 while True: try: args = store.stream_positions() args.update(typing_handler.stream_positions()) args["timeout"] = 30000 result = yield http_client.get_json(replication_url, args=args) - now_ms = clock.time_msec() - if now_ms > next_expire_broken_caches_ms: - expire_broken_caches() - next_expire_broken_caches_ms = ( - now_ms + store.BROKEN_CACHE_EXPIRY_MS - ) yield store.process_replication(result) typing_handler.process_replication(result) - presence_handler.process_replication(result) + yield presence_handler.process_replication(result) notify(result) except: logger.exception("Error replicating from %r", replication_url) diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index f7178ea0d..bde9b51b2 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -14,6 +14,8 @@ # limitations under the License. from synapse.api.constants import EventTypes +from twisted.internet import defer + import logging import re @@ -79,13 +81,17 @@ class ApplicationService(object): NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS] def __init__(self, token, url=None, namespaces=None, hs_token=None, - sender=None, id=None): + sender=None, id=None, protocols=None): self.token = token self.url = url self.hs_token = hs_token self.sender = sender self.namespaces = self._check_namespaces(namespaces) self.id = id + if protocols: + self.protocols = set(protocols) + else: + self.protocols = set() def _check_namespaces(self, namespaces): # Sanity check that it is of the form: @@ -138,65 +144,66 @@ class ApplicationService(object): return regex_obj["exclusive"] return False - def _matches_user(self, event, member_list): - if (hasattr(event, "sender") and - self.is_interested_in_user(event.sender)): - return True + @defer.inlineCallbacks + def _matches_user(self, event, store): + if not event: + defer.returnValue(False) + + if self.is_interested_in_user(event.sender): + defer.returnValue(True) # also check m.room.member state key - if (hasattr(event, "type") and event.type == EventTypes.Member - and hasattr(event, "state_key") - and self.is_interested_in_user(event.state_key)): - return True + if (event.type == EventTypes.Member and + self.is_interested_in_user(event.state_key)): + defer.returnValue(True) + + if not store: + defer.returnValue(False) + + member_list = yield store.get_users_in_room(event.room_id) + # check joined member events for user_id in member_list: if self.is_interested_in_user(user_id): - return True - return False + defer.returnValue(True) + defer.returnValue(False) def _matches_room_id(self, event): if hasattr(event, "room_id"): return self.is_interested_in_room(event.room_id) return False - def _matches_aliases(self, event, alias_list): + @defer.inlineCallbacks + def _matches_aliases(self, event, store): + if not store or not event: + defer.returnValue(False) + + alias_list = yield store.get_aliases_for_room(event.room_id) for alias in alias_list: if self.is_interested_in_alias(alias): - return True - return False + defer.returnValue(True) + defer.returnValue(False) - def is_interested(self, event, restrict_to=None, aliases_for_event=None, - member_list=None): + @defer.inlineCallbacks + def is_interested(self, event, store=None): """Check if this service is interested in this event. Args: event(Event): The event to check. - restrict_to(str): The namespace to restrict regex tests to. - aliases_for_event(list): A list of all the known room aliases for - this event. - member_list(list): A list of all joined user_ids in this room. + store(DataStore) Returns: bool: True if this service would like to know about this event. """ - if aliases_for_event is None: - aliases_for_event = [] - if member_list is None: - member_list = [] + # Do cheap checks first + if self._matches_room_id(event): + defer.returnValue(True) - if restrict_to and restrict_to not in ApplicationService.NS_LIST: - # this is a programming error, so fail early and raise a general - # exception - raise Exception("Unexpected restrict_to value: %s". restrict_to) + if (yield self._matches_aliases(event, store)): + defer.returnValue(True) - if not restrict_to: - return (self._matches_user(event, member_list) - or self._matches_aliases(event, aliases_for_event) - or self._matches_room_id(event)) - elif restrict_to == ApplicationService.NS_ALIASES: - return self._matches_aliases(event, aliases_for_event) - elif restrict_to == ApplicationService.NS_ROOMS: - return self._matches_room_id(event) - elif restrict_to == ApplicationService.NS_USERS: - return self._matches_user(event, member_list) + if (yield self._matches_user(event, store)): + defer.returnValue(True) + + defer.returnValue(False) def is_interested_in_user(self, user_id): return ( @@ -216,6 +223,9 @@ class ApplicationService(object): or user_id == self.sender ) + def is_interested_in_protocol(self, protocol): + return protocol in self.protocols + def is_exclusive_alias(self, alias): return self._is_exclusive(ApplicationService.NS_ALIASES, alias) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 6da6a1b62..066127b66 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -17,6 +17,7 @@ from twisted.internet import defer from synapse.api.errors import CodeMessageException from synapse.http.client import SimpleHttpClient from synapse.events.utils import serialize_event +from synapse.types import ThirdPartyEntityKind import logging import urllib @@ -24,6 +25,28 @@ import urllib logger = logging.getLogger(__name__) +def _is_valid_3pe_result(r, field): + if not isinstance(r, dict): + return False + + for k in (field, "protocol"): + if k not in r: + return False + if not isinstance(r[k], str): + return False + + if "fields" not in r: + return False + fields = r["fields"] + if not isinstance(fields, dict): + return False + for k in fields.keys(): + if not isinstance(fields[k], str): + return False + + return True + + class ApplicationServiceApi(SimpleHttpClient): """This class manages HS -> AS communications, including querying and pushing. @@ -71,6 +94,43 @@ class ApplicationServiceApi(SimpleHttpClient): logger.warning("query_alias to %s threw exception %s", uri, ex) defer.returnValue(False) + @defer.inlineCallbacks + def query_3pe(self, service, kind, protocol, fields): + if kind == ThirdPartyEntityKind.USER: + uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol)) + required_field = "userid" + elif kind == ThirdPartyEntityKind.LOCATION: + uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol)) + required_field = "alias" + else: + raise ValueError( + "Unrecognised 'kind' argument %r to query_3pe()", kind + ) + + try: + response = yield self.get_json(uri, fields) + if not isinstance(response, list): + logger.warning( + "query_3pe to %s returned an invalid response %r", + uri, response + ) + defer.returnValue([]) + + ret = [] + for r in response: + if _is_valid_3pe_result(r, field=required_field): + ret.append(r) + else: + logger.warning( + "query_3pe to %s returned an invalid result %r", + uri, r + ) + + defer.returnValue(ret) + except Exception as ex: + logger.warning("query_3pe to %s threw exception %s", uri, ex) + defer.returnValue([]) + @defer.inlineCallbacks def push_bulk(self, service, events, txn_id=None): events = self._serialize(events) diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 9afc8fd75..68a9de17b 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -48,9 +48,12 @@ UP & quit +---------- YES SUCCESS This is all tied together by the AppServiceScheduler which DIs the required components. """ +from twisted.internet import defer from synapse.appservice import ApplicationServiceState -from twisted.internet import defer +from synapse.util.logcontext import preserve_fn +from synapse.util.metrics import Measure + import logging logger = logging.getLogger(__name__) @@ -73,7 +76,7 @@ class ApplicationServiceScheduler(object): self.txn_ctrl = _TransactionController( self.clock, self.store, self.as_api, create_recoverer ) - self.queuer = _ServiceQueuer(self.txn_ctrl) + self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock) @defer.inlineCallbacks def start(self): @@ -94,38 +97,36 @@ class _ServiceQueuer(object): this schedules any other events in the queue to run. """ - def __init__(self, txn_ctrl): + def __init__(self, txn_ctrl, clock): self.queued_events = {} # dict of {service_id: [events]} - self.pending_requests = {} # dict of {service_id: Deferred} + self.requests_in_flight = set() self.txn_ctrl = txn_ctrl + self.clock = clock def enqueue(self, service, event): # if this service isn't being sent something - if not self.pending_requests.get(service.id): - self._send_request(service, [event]) - else: - # add to queue for this service - if service.id not in self.queued_events: - self.queued_events[service.id] = [] - self.queued_events[service.id].append(event) + self.queued_events.setdefault(service.id, []).append(event) + preserve_fn(self._send_request)(service) - def _send_request(self, service, events): - # send request and add callbacks - d = self.txn_ctrl.send(service, events) - d.addBoth(self._on_request_finish) - d.addErrback(self._on_request_fail) - self.pending_requests[service.id] = d + @defer.inlineCallbacks + def _send_request(self, service): + if service.id in self.requests_in_flight: + return - def _on_request_finish(self, service): - self.pending_requests[service.id] = None - # if there are queued events, then send them. - if (service.id in self.queued_events - and len(self.queued_events[service.id]) > 0): - self._send_request(service, self.queued_events[service.id]) - self.queued_events[service.id] = [] + self.requests_in_flight.add(service.id) + try: + while True: + events = self.queued_events.pop(service.id, []) + if not events: + return - def _on_request_fail(self, err): - logger.error("AS request failed: %s", err) + with Measure(self.clock, "servicequeuer.send"): + try: + yield self.txn_ctrl.send(service, events) + except: + logger.exception("AS request failed") + finally: + self.requests_in_flight.discard(service.id) class _TransactionController(object): @@ -149,14 +150,12 @@ class _TransactionController(object): if service_is_up: sent = yield txn.send(self.as_api) if sent: - txn.complete(self.store) + yield txn.complete(self.store) else: - self._start_recoverer(service) + preserve_fn(self._start_recoverer)(service) except Exception as e: logger.exception(e) - self._start_recoverer(service) - # request has finished - defer.returnValue(service) + preserve_fn(self._start_recoverer)(service) @defer.inlineCallbacks def on_recovered(self, recoverer): diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index eade80390..dfe43b0b4 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -28,6 +28,7 @@ class AppServiceConfig(Config): def read_config(self, config): self.app_service_config_files = config.get("app_service_config_files", []) + self.notify_appservices = config.get("notify_appservices", True) def default_config(cls, **kwargs): return """\ @@ -122,6 +123,15 @@ def _load_appservice(hostname, as_info, config_filename): raise ValueError( "Missing/bad type 'exclusive' key in %s", regex_obj ) + # protocols check + protocols = as_info.get("protocols") + if protocols: + # Because strings are lists in python + if isinstance(protocols, str) or not isinstance(protocols, list): + raise KeyError("Optional 'protocols' must be a list if present.") + for p in protocols: + if not isinstance(p, str): + raise KeyError("Bad value for 'protocols' item") return ApplicationService( token=as_info["as_token"], url=as_info["url"], @@ -129,4 +139,5 @@ def _load_appservice(hostname, as_info, config_filename): hs_token=as_info["hs_token"], sender=user_id, id=as_info["id"], + protocols=protocols, ) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 5012c10ee..d7211ee9b 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -22,6 +22,7 @@ from synapse.util.logcontext import ( preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext, preserve_fn ) +from synapse.util.metrics import Measure from twisted.internet import defer @@ -61,6 +62,10 @@ Attributes: """ +class KeyLookupError(ValueError): + pass + + class Keyring(object): def __init__(self, hs): self.store = hs.get_datastore() @@ -239,59 +244,60 @@ class Keyring(object): @defer.inlineCallbacks def do_iterations(): - merged_results = {} + with Measure(self.clock, "get_server_verify_keys"): + merged_results = {} - missing_keys = {} - for verify_request in verify_requests: - missing_keys.setdefault(verify_request.server_name, set()).update( - verify_request.key_ids - ) - - for fn in key_fetch_fns: - results = yield fn(missing_keys.items()) - merged_results.update(results) - - # We now need to figure out which verify requests we have keys - # for and which we don't missing_keys = {} - requests_missing_keys = [] for verify_request in verify_requests: - server_name = verify_request.server_name - result_keys = merged_results[server_name] + missing_keys.setdefault(verify_request.server_name, set()).update( + verify_request.key_ids + ) - if verify_request.deferred.called: - # We've already called this deferred, which probably - # means that we've already found a key for it. - continue + for fn in key_fetch_fns: + results = yield fn(missing_keys.items()) + merged_results.update(results) - for key_id in verify_request.key_ids: - if key_id in result_keys: - with PreserveLoggingContext(): - verify_request.deferred.callback(( - server_name, - key_id, - result_keys[key_id], - )) - break - else: - # The else block is only reached if the loop above - # doesn't break. - missing_keys.setdefault(server_name, set()).update( - verify_request.key_ids - ) - requests_missing_keys.append(verify_request) + # We now need to figure out which verify requests we have keys + # for and which we don't + missing_keys = {} + requests_missing_keys = [] + for verify_request in verify_requests: + server_name = verify_request.server_name + result_keys = merged_results[server_name] - if not missing_keys: - break + if verify_request.deferred.called: + # We've already called this deferred, which probably + # means that we've already found a key for it. + continue - for verify_request in requests_missing_keys.values(): - verify_request.deferred.errback(SynapseError( - 401, - "No key for %s with id %s" % ( - verify_request.server_name, verify_request.key_ids, - ), - Codes.UNAUTHORIZED, - )) + for key_id in verify_request.key_ids: + if key_id in result_keys: + with PreserveLoggingContext(): + verify_request.deferred.callback(( + server_name, + key_id, + result_keys[key_id], + )) + break + else: + # The else block is only reached if the loop above + # doesn't break. + missing_keys.setdefault(server_name, set()).update( + verify_request.key_ids + ) + requests_missing_keys.append(verify_request) + + if not missing_keys: + break + + for verify_request in requests_missing_keys.values(): + verify_request.deferred.errback(SynapseError( + 401, + "No key for %s with id %s" % ( + verify_request.server_name, verify_request.key_ids, + ), + Codes.UNAUTHORIZED, + )) def on_err(err): for verify_request in verify_requests: @@ -302,15 +308,15 @@ class Keyring(object): @defer.inlineCallbacks def get_keys_from_store(self, server_name_and_key_ids): - res = yield defer.gatherResults( + res = yield preserve_context_over_deferred(defer.gatherResults( [ - self.store.get_server_verify_keys( + preserve_fn(self.store.get_server_verify_keys)( server_name, key_ids ).addCallback(lambda ks, server: (server, ks), server_name) for server_name, key_ids in server_name_and_key_ids ], consumeErrors=True, - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) defer.returnValue(dict(res)) @@ -331,13 +337,13 @@ class Keyring(object): ) defer.returnValue({}) - results = yield defer.gatherResults( + results = yield preserve_context_over_deferred(defer.gatherResults( [ - get_key(p_name, p_keys) + preserve_fn(get_key)(p_name, p_keys) for p_name, p_keys in self.perspective_servers.items() ], consumeErrors=True, - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) union_of_keys = {} for result in results: @@ -363,7 +369,7 @@ class Keyring(object): ) except Exception as e: logger.info( - "Unable to getting key %r for %r directly: %s %s", + "Unable to get key %r for %r directly: %s %s", key_ids, server_name, type(e).__name__, str(e.message), ) @@ -377,13 +383,13 @@ class Keyring(object): defer.returnValue(keys) - results = yield defer.gatherResults( + results = yield preserve_context_over_deferred(defer.gatherResults( [ - get_key(server_name, key_ids) + preserve_fn(get_key)(server_name, key_ids) for server_name, key_ids in server_name_and_key_ids ], consumeErrors=True, - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) merged = {} for result in results: @@ -425,7 +431,7 @@ class Keyring(object): for response in responses: if (u"signatures" not in response or perspective_name not in response[u"signatures"]): - raise ValueError( + raise KeyLookupError( "Key response not signed by perspective server" " %r" % (perspective_name,) ) @@ -448,7 +454,7 @@ class Keyring(object): list(response[u"signatures"][perspective_name]), list(perspective_keys) ) - raise ValueError( + raise KeyLookupError( "Response not signed with a known key for perspective" " server %r" % (perspective_name,) ) @@ -460,9 +466,9 @@ class Keyring(object): for server_name, response_keys in processed_response.items(): keys.setdefault(server_name, {}).update(response_keys) - yield defer.gatherResults( + yield preserve_context_over_deferred(defer.gatherResults( [ - self.store_keys( + preserve_fn(self.store_keys)( server_name=server_name, from_server=perspective_name, verify_keys=response_keys, @@ -470,7 +476,7 @@ class Keyring(object): for server_name, response_keys in keys.items() ], consumeErrors=True - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) defer.returnValue(keys) @@ -491,10 +497,10 @@ class Keyring(object): if (u"signatures" not in response or server_name not in response[u"signatures"]): - raise ValueError("Key response not signed by remote server") + raise KeyLookupError("Key response not signed by remote server") if "tls_fingerprints" not in response: - raise ValueError("Key response missing TLS fingerprints") + raise KeyLookupError("Key response missing TLS fingerprints") certificate_bytes = crypto.dump_certificate( crypto.FILETYPE_ASN1, tls_certificate @@ -508,7 +514,7 @@ class Keyring(object): response_sha256_fingerprints.add(fingerprint[u"sha256"]) if sha256_fingerprint_b64 not in response_sha256_fingerprints: - raise ValueError("TLS certificate not allowed by fingerprints") + raise KeyLookupError("TLS certificate not allowed by fingerprints") response_keys = yield self.process_v2_response( from_server=server_name, @@ -518,7 +524,7 @@ class Keyring(object): keys.update(response_keys) - yield defer.gatherResults( + yield preserve_context_over_deferred(defer.gatherResults( [ preserve_fn(self.store_keys)( server_name=key_server_name, @@ -528,7 +534,7 @@ class Keyring(object): for key_server_name, verify_keys in keys.items() ], consumeErrors=True - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) defer.returnValue(keys) @@ -560,14 +566,14 @@ class Keyring(object): server_name = response_json["server_name"] if only_from_server: if server_name != from_server: - raise ValueError( + raise KeyLookupError( "Expected a response for server %r not %r" % ( from_server, server_name ) ) for key_id in response_json["signatures"].get(server_name, {}): if key_id not in response_json["verify_keys"]: - raise ValueError( + raise KeyLookupError( "Key response must include verification keys for all" " signatures" ) @@ -594,7 +600,7 @@ class Keyring(object): response_keys.update(verify_keys) response_keys.update(old_verify_keys) - yield defer.gatherResults( + yield preserve_context_over_deferred(defer.gatherResults( [ preserve_fn(self.store.store_server_keys_json)( server_name=server_name, @@ -607,7 +613,7 @@ class Keyring(object): for key_id in updated_key_ids ], consumeErrors=True, - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) results[server_name] = response_keys @@ -635,15 +641,15 @@ class Keyring(object): if ("signatures" not in response or server_name not in response["signatures"]): - raise ValueError("Key response not signed by remote server") + raise KeyLookupError("Key response not signed by remote server") if "tls_certificate" not in response: - raise ValueError("Key response missing TLS certificate") + raise KeyLookupError("Key response missing TLS certificate") tls_certificate_b64 = response["tls_certificate"] if encode_base64(x509_certificate_bytes) != tls_certificate_b64: - raise ValueError("TLS certificate doesn't match") + raise KeyLookupError("TLS certificate doesn't match") # Cache the result in the datastore. @@ -659,7 +665,7 @@ class Keyring(object): for key_id in response["signatures"][server_name]: if key_id not in response["verify_keys"]: - raise ValueError( + raise KeyLookupError( "Key response must include verification keys for all" " signatures" ) @@ -696,7 +702,7 @@ class Keyring(object): A deferred that completes when the keys are stored. """ # TODO(markjh): Store whether the keys have expired. - yield defer.gatherResults( + yield preserve_context_over_deferred(defer.gatherResults( [ preserve_fn(self.store.store_server_verify_key)( server_name, server_name, key.time_added, key @@ -704,4 +710,4 @@ class Keyring(object): for key_id, key in verify_keys.items() ], consumeErrors=True, - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index aab18d7f7..0e9fd902a 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -88,6 +88,8 @@ def prune_event(event): if "age_ts" in event.unsigned: allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"] + if "replaces_state" in event.unsigned: + allowed_fields["unsigned"]["replaces_state"] = event.unsigned["replaces_state"] return type(event)( allowed_fields, diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index da2f5e8cf..2339cc903 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -23,6 +23,7 @@ from synapse.crypto.event_signing import check_event_content_hash from synapse.api.errors import SynapseError from synapse.util import unwrapFirstError +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred import logging @@ -102,10 +103,10 @@ class FederationBase(object): warn, pdu ) - valid_pdus = yield defer.gatherResults( + valid_pdus = yield preserve_context_over_deferred(defer.gatherResults( deferreds, consumeErrors=True - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) if include_none: defer.returnValue(valid_pdus) @@ -129,7 +130,7 @@ class FederationBase(object): for pdu in pdus ] - deferreds = self.keyring.verify_json_objects_for_server([ + deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([ (p.origin, p.get_pdu_json()) for p in redacted_pdus ]) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index da95c2ad6..f2b3aceb4 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -27,6 +27,7 @@ from synapse.util import unwrapFirstError from synapse.util.async import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logutils import log_function +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.events import FrozenEvent import synapse.metrics @@ -51,10 +52,34 @@ sent_edus_counter = metrics.register_counter("sent_edus") sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"]) +PDU_RETRY_TIME_MS = 1 * 60 * 1000 + + class FederationClient(FederationBase): def __init__(self, hs): super(FederationClient, self).__init__(hs) + self.pdu_destination_tried = {} + self._clock.looping_call( + self._clear_tried_cache, 60 * 1000, + ) + + def _clear_tried_cache(self): + """Clear pdu_destination_tried cache""" + now = self._clock.time_msec() + + old_dict = self.pdu_destination_tried + self.pdu_destination_tried = {} + + for event_id, destination_dict in old_dict.items(): + destination_dict = { + dest: time + for dest, time in destination_dict.items() + if time + PDU_RETRY_TIME_MS > now + } + if destination_dict: + self.pdu_destination_tried[event_id] = destination_dict + def start_get_pdu_cache(self): self._get_pdu_cache = ExpiringCache( cache_name="get_pdu_cache", @@ -201,10 +226,10 @@ class FederationClient(FederationBase): ] # FIXME: We should handle signature failures more gracefully. - pdus[:] = yield defer.gatherResults( + pdus[:] = yield preserve_context_over_deferred(defer.gatherResults( self._check_sigs_and_hashes(pdus), consumeErrors=True, - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) defer.returnValue(pdus) @@ -240,8 +265,15 @@ class FederationClient(FederationBase): if ev: defer.returnValue(ev) + pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {}) + pdu = None for destination in destinations: + now = self._clock.time_msec() + last_attempt = pdu_attempts.get(destination, 0) + if last_attempt + PDU_RETRY_TIME_MS > now: + continue + try: limiter = yield get_retry_limiter( destination, @@ -269,25 +301,19 @@ class FederationClient(FederationBase): break + pdu_attempts[destination] = now + except SynapseError as e: logger.info( "Failed to get PDU %s from %s because %s", event_id, destination, e, ) - continue - except CodeMessageException as e: - if 400 <= e.code < 500: - raise - - logger.info( - "Failed to get PDU %s from %s because %s", - event_id, destination, e, - ) - continue except NotRetryingDestination as e: logger.info(e.message) continue except Exception as e: + pdu_attempts[destination] = now + logger.info( "Failed to get PDU %s from %s because %s", event_id, destination, e, @@ -406,7 +432,7 @@ class FederationClient(FederationBase): events and the second is a list of event ids that we failed to fetch. """ if return_local: - seen_events = yield self.store.get_events(event_ids) + seen_events = yield self.store.get_events(event_ids, allow_rejected=True) signed_events = seen_events.values() else: seen_events = yield self.store.have_events(event_ids) @@ -432,14 +458,16 @@ class FederationClient(FederationBase): batch = set(missing_events[i:i + batch_size]) deferreds = [ - self.get_pdu( + preserve_fn(self.get_pdu)( destinations=random_server_list(), event_id=e_id, ) for e_id in batch ] - res = yield defer.DeferredList(deferreds, consumeErrors=True) + res = yield preserve_context_over_deferred( + defer.DeferredList(deferreds, consumeErrors=True) + ) for success, result in res: if success: signed_events.append(result) @@ -828,14 +856,16 @@ class FederationClient(FederationBase): return srvs deferreds = [ - self.get_pdu( + preserve_fn(self.get_pdu)( destinations=random_server_list(), event_id=e_id, ) for e_id, depth in ordered_missing[:limit - len(signed_events)] ] - res = yield defer.DeferredList(deferreds, consumeErrors=True) + res = yield preserve_context_over_deferred( + defer.DeferredList(deferreds, consumeErrors=True) + ) for (result, val), (e_id, _) in zip(res, ordered_missing): if result and val: signed_events.append(val) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 5787f854d..cb2ef0210 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -21,11 +21,11 @@ from .units import Transaction from synapse.api.errors import HttpResponseException from synapse.util.async import run_on_reactor -from synapse.util.logutils import log_function -from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.logcontext import preserve_context_over_fn from synapse.util.retryutils import ( get_retry_limiter, NotRetryingDestination, ) +from synapse.util.metrics import measure_func import synapse.metrics import logging @@ -51,7 +51,7 @@ class TransactionQueue(object): self.transport_layer = transport_layer - self._clock = hs.get_clock() + self.clock = hs.get_clock() # Is a mapping from destinations -> deferreds. Used to keep track # of which destinations have transactions in flight and when they are @@ -82,7 +82,7 @@ class TransactionQueue(object): self.pending_failures_by_dest = {} # HACK to get unique tx id - self._next_txn_id = int(self._clock.time_msec()) + self._next_txn_id = int(self.clock.time_msec()) def can_send_to(self, destination): """Can we send messages to the given server? @@ -119,266 +119,215 @@ class TransactionQueue(object): if not destinations: return - deferreds = [] - for destination in destinations: - deferred = defer.Deferred() self.pending_pdus_by_dest.setdefault(destination, []).append( - (pdu, deferred, order) + (pdu, order) ) - def chain(failure): - if not deferred.called: - deferred.errback(failure) + preserve_context_over_fn( + self._attempt_new_transaction, destination + ) - def log_failure(f): - logger.warn("Failed to send pdu to %s: %s", destination, f.value) - - deferred.addErrback(log_failure) - - with PreserveLoggingContext(): - self._attempt_new_transaction(destination).addErrback(chain) - - deferreds.append(deferred) - - # NO inlineCallbacks def enqueue_edu(self, edu): destination = edu.destination if not self.can_send_to(destination): return - deferred = defer.Deferred() - self.pending_edus_by_dest.setdefault(destination, []).append( - (edu, deferred) + self.pending_edus_by_dest.setdefault(destination, []).append(edu) + + preserve_context_over_fn( + self._attempt_new_transaction, destination ) - def chain(failure): - if not deferred.called: - deferred.errback(failure) - - def log_failure(f): - logger.warn("Failed to send edu to %s: %s", destination, f.value) - - deferred.addErrback(log_failure) - - with PreserveLoggingContext(): - self._attempt_new_transaction(destination).addErrback(chain) - - return deferred - - @defer.inlineCallbacks def enqueue_failure(self, failure, destination): if destination == self.server_name or destination == "localhost": return - deferred = defer.Deferred() - if not self.can_send_to(destination): return self.pending_failures_by_dest.setdefault( destination, [] - ).append( - (failure, deferred) + ).append(failure) + + preserve_context_over_fn( + self._attempt_new_transaction, destination ) - def chain(f): - if not deferred.called: - deferred.errback(f) - - def log_failure(f): - logger.warn("Failed to send failure to %s: %s", destination, f.value) - - deferred.addErrback(log_failure) - - with PreserveLoggingContext(): - self._attempt_new_transaction(destination).addErrback(chain) - - yield deferred - @defer.inlineCallbacks - @log_function def _attempt_new_transaction(self, destination): yield run_on_reactor() + while True: + # list of (pending_pdu, deferred, order) + if destination in self.pending_transactions: + # XXX: pending_transactions can get stuck on by a never-ending + # request at which point pending_pdus_by_dest just keeps growing. + # we need application-layer timeouts of some flavour of these + # requests + logger.debug( + "TX [%s] Transaction already in progress", + destination + ) + return - # list of (pending_pdu, deferred, order) - if destination in self.pending_transactions: - # XXX: pending_transactions can get stuck on by a never-ending - # request at which point pending_pdus_by_dest just keeps growing. - # we need application-layer timeouts of some flavour of these - # requests - logger.debug( - "TX [%s] Transaction already in progress", - destination + pending_pdus = self.pending_pdus_by_dest.pop(destination, []) + pending_edus = self.pending_edus_by_dest.pop(destination, []) + pending_failures = self.pending_failures_by_dest.pop(destination, []) + + if pending_pdus: + logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", + destination, len(pending_pdus)) + + if not pending_pdus and not pending_edus and not pending_failures: + logger.debug("TX [%s] Nothing to send", destination) + return + + yield self._send_new_transaction( + destination, pending_pdus, pending_edus, pending_failures ) - return - pending_pdus = self.pending_pdus_by_dest.pop(destination, []) - pending_edus = self.pending_edus_by_dest.pop(destination, []) - pending_failures = self.pending_failures_by_dest.pop(destination, []) - - if pending_pdus: - logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", - destination, len(pending_pdus)) - - if not pending_pdus and not pending_edus and not pending_failures: - logger.debug("TX [%s] Nothing to send", destination) - return - - try: - self.pending_transactions[destination] = 1 - - logger.debug("TX [%s] _attempt_new_transaction", destination) + @measure_func("_send_new_transaction") + @defer.inlineCallbacks + def _send_new_transaction(self, destination, pending_pdus, pending_edus, + pending_failures): # Sort based on the order field - pending_pdus.sort(key=lambda t: t[2]) - + pending_pdus.sort(key=lambda t: t[1]) pdus = [x[0] for x in pending_pdus] - edus = [x[0] for x in pending_edus] - failures = [x[0].get_dict() for x in pending_failures] - deferreds = [ - x[1] - for x in pending_pdus + pending_edus + pending_failures - ] + edus = pending_edus + failures = [x.get_dict() for x in pending_failures] - txn_id = str(self._next_txn_id) + try: + self.pending_transactions[destination] = 1 - limiter = yield get_retry_limiter( - destination, - self._clock, - self.store, - ) + logger.debug("TX [%s] _attempt_new_transaction", destination) - logger.debug( - "TX [%s] {%s} Attempting new transaction" - " (pdus: %d, edus: %d, failures: %d)", - destination, txn_id, - len(pending_pdus), - len(pending_edus), - len(pending_failures) - ) + txn_id = str(self._next_txn_id) - logger.debug("TX [%s] Persisting transaction...", destination) - - transaction = Transaction.create_new( - origin_server_ts=int(self._clock.time_msec()), - transaction_id=txn_id, - origin=self.server_name, - destination=destination, - pdus=pdus, - edus=edus, - pdu_failures=failures, - ) - - self._next_txn_id += 1 - - yield self.transaction_actions.prepare_to_send(transaction) - - logger.debug("TX [%s] Persisted transaction", destination) - logger.info( - "TX [%s] {%s} Sending transaction [%s]," - " (PDUs: %d, EDUs: %d, failures: %d)", - destination, txn_id, - transaction.transaction_id, - len(pending_pdus), - len(pending_edus), - len(pending_failures), - ) - - with limiter: - # Actually send the transaction - - # FIXME (erikj): This is a bit of a hack to make the Pdu age - # keys work - def json_data_cb(): - data = transaction.get_dict() - now = int(self._clock.time_msec()) - if "pdus" in data: - for p in data["pdus"]: - if "age_ts" in p: - unsigned = p.setdefault("unsigned", {}) - unsigned["age"] = now - int(p["age_ts"]) - del p["age_ts"] - return data - - try: - response = yield self.transport_layer.send_transaction( - transaction, json_data_cb - ) - code = 200 - - if response: - for e_id, r in response.get("pdus", {}).items(): - if "error" in r: - logger.warn( - "Transaction returned error for %s: %s", - e_id, r, - ) - except HttpResponseException as e: - code = e.code - response = e.response - - logger.info( - "TX [%s] {%s} got %d response", - destination, txn_id, code + limiter = yield get_retry_limiter( + destination, + self.clock, + self.store, ) - logger.debug("TX [%s] Sent transaction", destination) - logger.debug("TX [%s] Marking as delivered...", destination) + logger.debug( + "TX [%s] {%s} Attempting new transaction" + " (pdus: %d, edus: %d, failures: %d)", + destination, txn_id, + len(pending_pdus), + len(pending_edus), + len(pending_failures) + ) - yield self.transaction_actions.delivered( - transaction, code, response - ) + logger.debug("TX [%s] Persisting transaction...", destination) - logger.debug("TX [%s] Marked as delivered", destination) + transaction = Transaction.create_new( + origin_server_ts=int(self.clock.time_msec()), + transaction_id=txn_id, + origin=self.server_name, + destination=destination, + pdus=pdus, + edus=edus, + pdu_failures=failures, + ) - logger.debug("TX [%s] Yielding to callbacks...", destination) + self._next_txn_id += 1 - for deferred in deferreds: - if code == 200: - deferred.callback(None) - else: - deferred.errback(RuntimeError("Got status %d" % code)) + yield self.transaction_actions.prepare_to_send(transaction) - # Ensures we don't continue until all callbacks on that - # deferred have fired - try: - yield deferred - except: - pass + logger.debug("TX [%s] Persisted transaction", destination) + logger.info( + "TX [%s] {%s} Sending transaction [%s]," + " (PDUs: %d, EDUs: %d, failures: %d)", + destination, txn_id, + transaction.transaction_id, + len(pending_pdus), + len(pending_edus), + len(pending_failures), + ) - logger.debug("TX [%s] Yielded to callbacks", destination) - except NotRetryingDestination: - logger.info( - "TX [%s] not ready for retry yet - " - "dropping transaction for now", - destination, - ) - except RuntimeError as e: - # We capture this here as there as nothing actually listens - # for this finishing functions deferred. - logger.warn( - "TX [%s] Problem in _attempt_transaction: %s", - destination, - e, - ) - except Exception as e: - # We capture this here as there as nothing actually listens - # for this finishing functions deferred. - logger.warn( - "TX [%s] Problem in _attempt_transaction: %s", - destination, - e, - ) + with limiter: + # Actually send the transaction - for deferred in deferreds: - if not deferred.called: - deferred.errback(e) + # FIXME (erikj): This is a bit of a hack to make the Pdu age + # keys work + def json_data_cb(): + data = transaction.get_dict() + now = int(self.clock.time_msec()) + if "pdus" in data: + for p in data["pdus"]: + if "age_ts" in p: + unsigned = p.setdefault("unsigned", {}) + unsigned["age"] = now - int(p["age_ts"]) + del p["age_ts"] + return data - finally: - # We want to be *very* sure we delete this after we stop processing - self.pending_transactions.pop(destination, None) + try: + response = yield self.transport_layer.send_transaction( + transaction, json_data_cb + ) + code = 200 - # Check to see if there is anything else to send. - self._attempt_new_transaction(destination) + if response: + for e_id, r in response.get("pdus", {}).items(): + if "error" in r: + logger.warn( + "Transaction returned error for %s: %s", + e_id, r, + ) + except HttpResponseException as e: + code = e.code + response = e.response + + logger.info( + "TX [%s] {%s} got %d response", + destination, txn_id, code + ) + + logger.debug("TX [%s] Sent transaction", destination) + logger.debug("TX [%s] Marking as delivered...", destination) + + yield self.transaction_actions.delivered( + transaction, code, response + ) + + logger.debug("TX [%s] Marked as delivered", destination) + + if code != 200: + for p in pdus: + logger.info( + "Failed to send event %s to %s", p.event_id, destination + ) + except NotRetryingDestination: + logger.info( + "TX [%s] not ready for retry yet - " + "dropping transaction for now", + destination, + ) + except RuntimeError as e: + # We capture this here as there as nothing actually listens + # for this finishing functions deferred. + logger.warn( + "TX [%s] Problem in _attempt_transaction: %s", + destination, + e, + ) + + for p in pdus: + logger.info("Failed to send event %s to %s", p.event_id, destination) + except Exception as e: + # We capture this here as there as nothing actually listens + # for this finishing functions deferred. + logger.warn( + "TX [%s] Problem in _attempt_transaction: %s", + destination, + e, + ) + + for p in pdus: + logger.info("Failed to send event %s to %s", p.event_id, destination) + + finally: + # We want to be *very* sure we delete this after we stop processing + self.pending_transactions.pop(destination, None) diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index 1a50a2ec9..63d05f253 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -19,7 +19,6 @@ from .room import ( ) from .room_member import RoomMemberHandler from .message import MessageHandler -from .events import EventStreamHandler, EventHandler from .federation import FederationHandler from .profile import ProfileHandler from .directory import DirectoryHandler @@ -53,8 +52,6 @@ class Handlers(object): self.message_handler = MessageHandler(hs) self.room_creation_handler = RoomCreationHandler(hs) self.room_member_handler = RoomMemberHandler(hs) - self.event_stream_handler = EventStreamHandler(hs) - self.event_handler = EventHandler(hs) self.federation_handler = FederationHandler(hs) self.profile_handler = ProfileHandler(hs) self.directory_handler = DirectoryHandler(hs) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 051ccdb38..306686a38 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -16,7 +16,8 @@ from twisted.internet import defer from synapse.api.constants import EventTypes -from synapse.appservice import ApplicationService +from synapse.util.metrics import Measure +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred import logging @@ -42,36 +43,73 @@ class ApplicationServicesHandler(object): self.appservice_api = hs.get_application_service_api() self.scheduler = hs.get_application_service_scheduler() self.started_scheduler = False + self.clock = hs.get_clock() + self.notify_appservices = hs.config.notify_appservices + + self.current_max = 0 + self.is_processing = False @defer.inlineCallbacks - def notify_interested_services(self, event): + def notify_interested_services(self, current_id): """Notifies (pushes) all application services interested in this event. Pushing is done asynchronously, so this method won't block for any prolonged length of time. Args: - event(Event): The event to push out to interested services. + current_id(int): The current maximum ID. """ - # Gather interested services - services = yield self._get_services_for_event(event) - if len(services) == 0: - return # no services need notifying + services = yield self.store.get_app_services() + if not services or not self.notify_appservices: + return - # Do we know this user exists? If not, poke the user query API for - # all services which match that user regex. This needs to block as these - # user queries need to be made BEFORE pushing the event. - yield self._check_user_exists(event.sender) - if event.type == EventTypes.Member: - yield self._check_user_exists(event.state_key) + self.current_max = max(self.current_max, current_id) + if self.is_processing: + return - if not self.started_scheduler: - self.scheduler.start().addErrback(log_failure) - self.started_scheduler = True + with Measure(self.clock, "notify_interested_services"): + self.is_processing = True + try: + upper_bound = self.current_max + limit = 100 + while True: + upper_bound, events = yield self.store.get_new_events_for_appservice( + upper_bound, limit + ) - # Fork off pushes to these services - for service in services: - self.scheduler.submit_event_for_as(service, event) + if not events: + break + + for event in events: + # Gather interested services + services = yield self._get_services_for_event(event) + if len(services) == 0: + continue # no services need notifying + + # Do we know this user exists? If not, poke the user + # query API for all services which match that user regex. + # This needs to block as these user queries need to be + # made BEFORE pushing the event. + yield self._check_user_exists(event.sender) + if event.type == EventTypes.Member: + yield self._check_user_exists(event.state_key) + + if not self.started_scheduler: + self.scheduler.start().addErrback(log_failure) + self.started_scheduler = True + + # Fork off pushes to these services + for service in services: + preserve_fn(self.scheduler.submit_event_for_as)( + service, event + ) + + yield self.store.set_appservice_last_pos(upper_bound) + + if len(events) < limit: + break + finally: + self.is_processing = False @defer.inlineCallbacks def query_user_exists(self, user_id): @@ -104,11 +142,12 @@ class ApplicationServicesHandler(object): association can be found. """ room_alias_str = room_alias.to_string() - alias_query_services = yield self._get_services_for_event( - event=None, - restrict_to=ApplicationService.NS_ALIASES, - alias_list=[room_alias_str] - ) + services = yield self.store.get_app_services() + alias_query_services = [ + s for s in services if ( + s.is_interested_in_alias(room_alias_str) + ) + ] for alias_service in alias_query_services: is_known_alias = yield self.appservice_api.query_alias( alias_service, room_alias_str @@ -121,34 +160,35 @@ class ApplicationServicesHandler(object): defer.returnValue(result) @defer.inlineCallbacks - def _get_services_for_event(self, event, restrict_to="", alias_list=None): + def query_3pe(self, kind, protocol, fields): + services = yield self._get_services_for_3pn(protocol) + + results = yield preserve_context_over_deferred(defer.DeferredList([ + preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields) + for service in services + ], consumeErrors=True)) + + ret = [] + for (success, result) in results: + if success: + ret.extend(result) + + defer.returnValue(ret) + + @defer.inlineCallbacks + def _get_services_for_event(self, event): """Retrieve a list of application services interested in this event. Args: event(Event): The event to check. Can be None if alias_list is not. - restrict_to(str): The namespace to restrict regex tests to. - alias_list: A list of aliases to get services for. If None, this - list is obtained from the database. Returns: list: A list of services interested in this event based on the service regex. """ - member_list = None - if hasattr(event, "room_id"): - # We need to know the aliases associated with this event.room_id, - # if any. - if not alias_list: - alias_list = yield self.store.get_aliases_for_room( - event.room_id - ) - # We need to know the members associated with this event.room_id, - # if any. - member_list = yield self.store.get_users_in_room(event.room_id) - services = yield self.store.get_app_services() interested_list = [ s for s in services if ( - s.is_interested(event, restrict_to, alias_list, member_list) + yield s.is_interested(event, self.store) ) ] defer.returnValue(interested_list) @@ -163,6 +203,14 @@ class ApplicationServicesHandler(object): ] defer.returnValue(interested_list) + @defer.inlineCallbacks + def _get_services_for_3pn(self, protocol): + services = yield self.store.get_app_services() + interested_list = [ + s for s in services if s.is_interested_in_protocol(protocol) + ] + defer.returnValue(interested_list) + @defer.inlineCallbacks def _is_unknown_user(self, user_id): if not self.is_mine_id(user_id): diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 2e138f328..6986930c0 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -70,11 +70,11 @@ class AuthHandler(BaseHandler): self.ldap_uri = hs.config.ldap_uri self.ldap_start_tls = hs.config.ldap_start_tls self.ldap_base = hs.config.ldap_base - self.ldap_filter = hs.config.ldap_filter self.ldap_attributes = hs.config.ldap_attributes if self.ldap_mode == LDAPMode.SEARCH: self.ldap_bind_dn = hs.config.ldap_bind_dn self.ldap_bind_password = hs.config.ldap_bind_password + self.ldap_filter = hs.config.ldap_filter self.hs = hs # FIXME better possibility to access registrationHandler later? self.device_handler = hs.get_device_handler() @@ -660,7 +660,7 @@ class AuthHandler(BaseHandler): else: logger.warn( "ldap registration failed: unexpected (%d!=1) amount of results", - len(result) + len(conn.response) ) defer.returnValue(False) @@ -719,13 +719,14 @@ class AuthHandler(BaseHandler): return macaroon.serialize() def validate_short_term_login_token_and_get_user_id(self, login_token): + auth_api = self.hs.get_auth() try: macaroon = pymacaroons.Macaroon.deserialize(login_token) - auth_api = self.hs.get_auth() - auth_api.validate_macaroon(macaroon, "login", True) - return self.get_user_from_macaroon(macaroon) - except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): - raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN) + user_id = auth_api.get_user_id_from_macaroon(macaroon) + auth_api.validate_macaroon(macaroon, "login", True, user_id) + return user_id + except Exception: + raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) def _generate_base_macaroon(self, user_id): macaroon = pymacaroons.Macaroon( @@ -736,21 +737,11 @@ class AuthHandler(BaseHandler): macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) return macaroon - def get_user_from_macaroon(self, macaroon): - user_prefix = "user_id = " - for caveat in macaroon.caveats: - if caveat.caveat_id.startswith(user_prefix): - return caveat.caveat_id[len(user_prefix):] - raise AuthError( - self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token", - errcode=Codes.UNKNOWN_TOKEN - ) - @defer.inlineCallbacks def set_password(self, user_id, newpassword, requester=None): password_hash = self.hash(newpassword) - except_access_token_ids = [requester.access_token_id] if requester else [] + except_access_token_id = requester.access_token_id if requester else None try: yield self.store.user_set_password_hash(user_id, password_hash) @@ -759,10 +750,10 @@ class AuthHandler(BaseHandler): raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) raise e yield self.store.user_delete_access_tokens( - user_id, except_access_token_ids + user_id, except_access_token_id ) yield self.hs.get_pusherpool().remove_pushers_by_user( - user_id, except_access_token_ids + user_id, except_access_token_id ) @defer.inlineCallbacks diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 618cb5362..01a761715 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -26,7 +26,9 @@ from synapse.api.errors import ( from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.events.validator import EventValidator from synapse.util import unwrapFirstError -from synapse.util.logcontext import PreserveLoggingContext, preserve_fn +from synapse.util.logcontext import ( + PreserveLoggingContext, preserve_fn, preserve_context_over_deferred +) from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor from synapse.util.frozenutils import unfreeze @@ -249,7 +251,7 @@ class FederationHandler(BaseHandler): if ev.type != EventTypes.Member: continue try: - domain = UserID.from_string(ev.state_key).domain + domain = get_domain_from_id(ev.state_key) except: continue @@ -274,7 +276,7 @@ class FederationHandler(BaseHandler): @log_function @defer.inlineCallbacks - def backfill(self, dest, room_id, limit, extremities=[]): + def backfill(self, dest, room_id, limit, extremities): """ Trigger a backfill request to `dest` for the given `room_id` This will attempt to get more events from the remote. This may return @@ -284,9 +286,6 @@ class FederationHandler(BaseHandler): if dest == self.server_name: raise SynapseError(400, "Can't backfill from self.") - if not extremities: - extremities = yield self.store.get_oldest_events_in_room(room_id) - events = yield self.replication_layer.backfill( dest, room_id, @@ -364,9 +363,9 @@ class FederationHandler(BaseHandler): missing_auth - failed_to_fetch ) - results = yield defer.gatherResults( + results = yield preserve_context_over_deferred(defer.gatherResults( [ - self.replication_layer.get_pdu( + preserve_fn(self.replication_layer.get_pdu)( [dest], event_id, outlier=True, @@ -375,10 +374,10 @@ class FederationHandler(BaseHandler): for event_id in missing_auth - failed_to_fetch ], consumeErrors=True - ).addErrback(unwrapFirstError) - auth_events.update({a.event_id: a for a in results}) + )).addErrback(unwrapFirstError) + auth_events.update({a.event_id: a for a in results if a}) required_auth.update( - a_id for event in results for a_id, _ in event.auth_events + a_id for event in results for a_id, _ in event.auth_events if event ) missing_auth = required_auth - set(auth_events) @@ -455,6 +454,10 @@ class FederationHandler(BaseHandler): ) max_depth = sorted_extremeties_tuple[0][1] + # We don't want to specify too many extremities as it causes the backfill + # request URI to be too long. + extremities = dict(sorted_extremeties_tuple[:5]) + if current_depth > max_depth: logger.debug( "Not backfilling as we don't need to. %d < %d", @@ -551,10 +554,10 @@ class FederationHandler(BaseHandler): event_ids = list(extremities.keys()) - states = yield defer.gatherResults([ - self.state_handler.resolve_state_groups(room_id, [e]) + states = yield preserve_context_over_deferred(defer.gatherResults([ + preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e]) for e in event_ids - ]) + ])) states = dict(zip(event_ids, [s[1] for s in states])) for e_id, _ in sorted_extremeties_tuple: @@ -1093,16 +1096,17 @@ class FederationHandler(BaseHandler): ) if event: - # FIXME: This is a temporary work around where we occasionally - # return events slightly differently than when they were - # originally signed - event.signatures.update( - compute_event_signature( - event, - self.hs.hostname, - self.hs.config.signing_key[0] + if self.hs.is_mine_id(event.event_id): + # FIXME: This is a temporary work around where we occasionally + # return events slightly differently than when they were + # originally signed + event.signatures.update( + compute_event_signature( + event, + self.hs.hostname, + self.hs.config.signing_key[0] + ) ) - ) if do_auth: in_room = yield self.auth.check_host_in_room( @@ -1112,6 +1116,12 @@ class FederationHandler(BaseHandler): if not in_room: raise AuthError(403, "Host not in room.") + events = yield self._filter_events_for_server( + origin, event.room_id, [event] + ) + + event = events[0] + defer.returnValue(event) else: defer.returnValue(None) @@ -1158,9 +1168,9 @@ class FederationHandler(BaseHandler): a bunch of outliers, but not a chunk of individual events that depend on each other for state calculations. """ - contexts = yield defer.gatherResults( + contexts = yield preserve_context_over_deferred(defer.gatherResults( [ - self._prep_event( + preserve_fn(self._prep_event)( origin, ev_info["event"], state=ev_info.get("state"), @@ -1168,7 +1178,7 @@ class FederationHandler(BaseHandler): ) for ev_info in event_infos ] - ) + )) yield self.store.persist_events( [ @@ -1452,9 +1462,9 @@ class FederationHandler(BaseHandler): # Do auth conflict res. logger.info("Different auth: %s", different_auth) - different_events = yield defer.gatherResults( + different_events = yield preserve_context_over_deferred(defer.gatherResults( [ - self.store.get_event( + preserve_fn(self.store.get_event)( d, allow_none=True, allow_rejected=False, @@ -1463,7 +1473,7 @@ class FederationHandler(BaseHandler): if d in have_events and not have_events[d] ], consumeErrors=True - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) if different_events: local_view = dict(auth_events) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index dc76d34a5..4c3cd9d12 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -28,7 +28,8 @@ from synapse.types import ( from synapse.util import unwrapFirstError from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock from synapse.util.caches.snapshot_cache import SnapshotCache -from synapse.util.logcontext import preserve_fn +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred +from synapse.util.metrics import measure_func from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -502,15 +503,17 @@ class MessageHandler(BaseHandler): lambda states: states[event.event_id] ) - (messages, token), current_state = yield defer.gatherResults( - [ - self.store.get_recent_events_for_room( - event.room_id, - limit=limit, - end_token=room_end_token, - ), - deferred_room_state, - ] + (messages, token), current_state = yield preserve_context_over_deferred( + defer.gatherResults( + [ + preserve_fn(self.store.get_recent_events_for_room)( + event.room_id, + limit=limit, + end_token=room_end_token, + ), + deferred_room_state, + ] + ) ).addErrback(unwrapFirstError) messages = yield filter_events_for_client( @@ -719,9 +722,9 @@ class MessageHandler(BaseHandler): presence, receipts, (messages, token) = yield defer.gatherResults( [ - get_presence(), - get_receipts(), - self.store.get_recent_events_for_room( + preserve_fn(get_presence)(), + preserve_fn(get_receipts)(), + preserve_fn(self.store.get_recent_events_for_room)( room_id, limit=limit, end_token=now_token.room_key, @@ -755,6 +758,7 @@ class MessageHandler(BaseHandler): defer.returnValue(ret) + @measure_func("_create_new_client_event") @defer.inlineCallbacks def _create_new_client_event(self, builder, prev_event_ids=None): if prev_event_ids: @@ -806,6 +810,7 @@ class MessageHandler(BaseHandler): (event, context,) ) + @measure_func("handle_new_client_event") @defer.inlineCallbacks def handle_new_client_event( self, @@ -934,7 +939,7 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def _notify(): yield run_on_reactor() - self.notifier.on_new_room_event( + yield self.notifier.on_new_room_event( event, event_stream_id, max_stream_id, extra_users=extra_users ) @@ -944,6 +949,6 @@ class MessageHandler(BaseHandler): # If invite, remove room_state from unsigned before sending. event.unsigned.pop("invite_room_state", None) - federation_handler.handle_new_event( + preserve_fn(federation_handler.handle_new_event)( event, destinations=destinations, ) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 6b70fa381..6a1fe76c8 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -503,7 +503,7 @@ class PresenceHandler(object): defer.returnValue(states) @defer.inlineCallbacks - def _get_interested_parties(self, states): + def _get_interested_parties(self, states, calculate_remote_hosts=True): """Given a list of states return which entities (rooms, users, servers) are interested in the given states. @@ -526,14 +526,15 @@ class PresenceHandler(object): users_to_states.setdefault(state.user_id, []).append(state) hosts_to_states = {} - for room_id, states in room_ids_to_states.items(): - local_states = filter(lambda s: self.is_mine_id(s.user_id), states) - if not local_states: - continue + if calculate_remote_hosts: + for room_id, states in room_ids_to_states.items(): + local_states = filter(lambda s: self.is_mine_id(s.user_id), states) + if not local_states: + continue - hosts = yield self.store.get_joined_hosts_for_room(room_id) - for host in hosts: - hosts_to_states.setdefault(host, []).extend(local_states) + hosts = yield self.store.get_joined_hosts_for_room(room_id) + for host in hosts: + hosts_to_states.setdefault(host, []).extend(local_states) for user_id, states in users_to_states.items(): local_states = filter(lambda s: self.is_mine_id(s.user_id), states) @@ -565,6 +566,16 @@ class PresenceHandler(object): self._push_to_remotes(hosts_to_states) + @defer.inlineCallbacks + def notify_for_states(self, state, stream_id): + parties = yield self._get_interested_parties([state]) + room_ids_to_states, users_to_states, hosts_to_states = parties + + self.notifier.on_new_event( + "presence_key", stream_id, rooms=room_ids_to_states.keys(), + users=[UserID.from_string(u) for u in users_to_states.keys()] + ) + def _push_to_remotes(self, hosts_to_states): """Sends state updates to remote servers. @@ -672,7 +683,7 @@ class PresenceHandler(object): ]) @defer.inlineCallbacks - def set_state(self, target_user, state): + def set_state(self, target_user, state, ignore_status_msg=False): """Set the presence state of the user. """ status_msg = state.get("status_msg", None) @@ -689,10 +700,13 @@ class PresenceHandler(object): prev_state = yield self.current_state_for_user(user_id) new_fields = { - "state": presence, - "status_msg": status_msg if presence != PresenceState.OFFLINE else None + "state": presence } + if not ignore_status_msg: + msg = status_msg if presence != PresenceState.OFFLINE else None + new_fields["status_msg"] = msg + if presence == PresenceState.ONLINE: new_fields["last_active_ts"] = self.clock.time_msec() diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 8cec8fc4e..8b17632fd 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -59,10 +59,13 @@ class RoomMemberHandler(BaseHandler): prev_event_ids, txn_id=None, ratelimit=True, + content=None, ): + if content is None: + content = {} msg_handler = self.hs.get_handlers().message_handler - content = {"membership": membership} + content["membership"] = membership if requester.is_guest: content["kind"] = "guest" @@ -140,8 +143,9 @@ class RoomMemberHandler(BaseHandler): remote_room_hosts=None, third_party_signed=None, ratelimit=True, + content=None, ): - key = (target, room_id,) + key = (room_id,) with (yield self.member_linearizer.queue(key)): result = yield self._update_membership( @@ -153,6 +157,7 @@ class RoomMemberHandler(BaseHandler): remote_room_hosts=remote_room_hosts, third_party_signed=third_party_signed, ratelimit=ratelimit, + content=content, ) defer.returnValue(result) @@ -168,7 +173,11 @@ class RoomMemberHandler(BaseHandler): remote_room_hosts=None, third_party_signed=None, ratelimit=True, + content=None, ): + if content is None: + content = {} + effective_membership_state = action if action in ["kick", "unban"]: effective_membership_state = "leave" @@ -218,7 +227,7 @@ class RoomMemberHandler(BaseHandler): if inviter and not self.hs.is_mine(inviter): remote_room_hosts.append(inviter.domain) - content = {"membership": Membership.JOIN} + content["membership"] = Membership.JOIN profile = self.hs.get_handlers().profile_handler content["displayname"] = yield profile.get_displayname(target) @@ -272,6 +281,7 @@ class RoomMemberHandler(BaseHandler): txn_id=txn_id, ratelimit=ratelimit, prev_event_ids=latest_event_ids, + content=content, ) @defer.inlineCallbacks diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 0ee4ebe50..c8dfd02e7 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -464,10 +464,10 @@ class SyncHandler(object): else: state = {} - defer.returnValue({ - (e.type, e.state_key): e - for e in sync_config.filter_collection.filter_room_state(state.values()) - }) + defer.returnValue({ + (e.type, e.state_key): e + for e in sync_config.filter_collection.filter_room_state(state.values()) + }) @defer.inlineCallbacks def unread_notifs_for_room_id(self, room_id, sync_config): @@ -485,9 +485,9 @@ class SyncHandler(object): ) defer.returnValue(notifs) - # There is no new information in this period, so your notification - # count is whatever it was last time. - defer.returnValue(None) + # There is no new information in this period, so your notification + # count is whatever it was last time. + defer.returnValue(None) @defer.inlineCallbacks def generate_sync_result(self, sync_config, since_token=None, full_state=False): diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 5589296c0..46181984c 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -16,7 +16,9 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, AuthError -from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.logcontext import ( + PreserveLoggingContext, preserve_fn, preserve_context_over_deferred, +) from synapse.util.metrics import Measure from synapse.types import UserID @@ -169,13 +171,13 @@ class TypingHandler(object): deferreds = [] for domain in domains: if domain == self.server_name: - self._push_update_local( + preserve_fn(self._push_update_local)( room_id=room_id, user_id=user_id, typing=typing ) else: - deferreds.append(self.federation.send_edu( + deferreds.append(preserve_fn(self.federation.send_edu)( destination=domain, edu_type="m.typing", content={ @@ -185,7 +187,9 @@ class TypingHandler(object): }, )) - yield defer.DeferredList(deferreds, consumeErrors=True) + yield preserve_context_over_deferred( + defer.DeferredList(deferreds, consumeErrors=True) + ) @defer.inlineCallbacks def _recv_edu(self, origin, content): diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index c3589534f..f93093dd8 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -155,9 +155,7 @@ class MatrixFederationHttpClient(object): time_out=timeout / 1000. if timeout else 60, ) - response = yield preserve_context_over_fn( - send_request, - ) + response = yield preserve_context_over_fn(send_request) log_result = "%d %s" % (response.code, response.phrase,) break diff --git a/synapse/http/server.py b/synapse/http/server.py index 2b3c05a74..168e53ce0 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -19,6 +19,7 @@ from synapse.api.errors import ( ) from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.caches import intern_dict +from synapse.util.metrics import Measure import synapse.metrics import synapse.events @@ -74,12 +75,12 @@ response_db_txn_duration = metrics.register_distribution( _next_request_id = 0 -def request_handler(report_metrics=True): +def request_handler(include_metrics=False): """Decorator for ``wrap_request_handler``""" - return lambda request_handler: wrap_request_handler(request_handler, report_metrics) + return lambda request_handler: wrap_request_handler(request_handler, include_metrics) -def wrap_request_handler(request_handler, report_metrics): +def wrap_request_handler(request_handler, include_metrics=False): """Wraps a method that acts as a request handler with the necessary logging and exception handling. @@ -103,54 +104,56 @@ def wrap_request_handler(request_handler, report_metrics): _next_request_id += 1 with LoggingContext(request_id) as request_context: - if report_metrics: + with Measure(self.clock, "wrapped_request_handler"): request_metrics = RequestMetrics() - request_metrics.start(self.clock) + request_metrics.start(self.clock, name=self.__class__.__name__) - request_context.request = request_id - with request.processing(): - try: - with PreserveLoggingContext(request_context): - yield request_handler(self, request) - except CodeMessageException as e: - code = e.code - if isinstance(e, SynapseError): - logger.info( - "%s SynapseError: %s - %s", request, code, e.msg - ) - else: - logger.exception(e) - outgoing_responses_counter.inc(request.method, str(code)) - respond_with_json( - request, code, cs_exception(e), send_cors=True, - pretty_print=_request_user_agent_is_curl(request), - version_string=self.version_string, - ) - except: - logger.exception( - "Failed handle request %s.%s on %r: %r", - request_handler.__module__, - request_handler.__name__, - self, - request - ) - respond_with_json( - request, - 500, - { - "error": "Internal server error", - "errcode": Codes.UNKNOWN, - }, - send_cors=True - ) - finally: + request_context.request = request_id + with request.processing(): try: - if report_metrics: - request_metrics.stop( - self.clock, request, self.__class__.__name__ + with PreserveLoggingContext(request_context): + if include_metrics: + yield request_handler(self, request, request_metrics) + else: + yield request_handler(self, request) + except CodeMessageException as e: + code = e.code + if isinstance(e, SynapseError): + logger.info( + "%s SynapseError: %s - %s", request, code, e.msg ) + else: + logger.exception(e) + outgoing_responses_counter.inc(request.method, str(code)) + respond_with_json( + request, code, cs_exception(e), send_cors=True, + pretty_print=_request_user_agent_is_curl(request), + version_string=self.version_string, + ) except: - pass + logger.exception( + "Failed handle request %s.%s on %r: %r", + request_handler.__module__, + request_handler.__name__, + self, + request + ) + respond_with_json( + request, + 500, + { + "error": "Internal server error", + "errcode": Codes.UNKNOWN, + }, + send_cors=True + ) + finally: + try: + request_metrics.stop( + self.clock, request + ) + except Exception as e: + logger.warn("Failed to stop metrics: %r", e) return wrapped_request_handler @@ -220,9 +223,9 @@ class JsonResource(HttpServer, resource.Resource): # It does its own metric reporting because _async_render dispatches to # a callback and it's the class name of that callback we want to report # against rather than the JsonResource itself. - @request_handler(report_metrics=False) + @request_handler(include_metrics=True) @defer.inlineCallbacks - def _async_render(self, request): + def _async_render(self, request, request_metrics): """ This gets called from render() every time someone sends us a request. This checks if anyone has registered a callback for that method and path. @@ -231,9 +234,6 @@ class JsonResource(HttpServer, resource.Resource): self._send_response(request, 200, {}) return - request_metrics = RequestMetrics() - request_metrics.start(self.clock) - # Loop through all the registered callbacks to check if the method # and path regex match for path_entry in self.path_regexs.get(request.method, []): @@ -247,12 +247,6 @@ class JsonResource(HttpServer, resource.Resource): callback = path_entry.callback - servlet_instance = getattr(callback, "__self__", None) - if servlet_instance is not None: - servlet_classname = servlet_instance.__class__.__name__ - else: - servlet_classname = "%r" % callback - kwargs = intern_dict({ name: urllib.unquote(value).decode("UTF-8") if value else value for name, value in m.groupdict().items() @@ -263,10 +257,13 @@ class JsonResource(HttpServer, resource.Resource): code, response = callback_return self._send_response(request, code, response) - try: - request_metrics.stop(self.clock, request, servlet_classname) - except: - pass + servlet_instance = getattr(callback, "__self__", None) + if servlet_instance is not None: + servlet_classname = servlet_instance.__class__.__name__ + else: + servlet_classname = "%r" % callback + + request_metrics.name = servlet_classname return @@ -298,11 +295,12 @@ class JsonResource(HttpServer, resource.Resource): class RequestMetrics(object): - def start(self, clock): + def start(self, clock, name): self.start = clock.time_msec() self.start_context = LoggingContext.current_context() + self.name = name - def stop(self, clock, request, servlet_classname): + def stop(self, clock, request): context = LoggingContext.current_context() tag = "" @@ -316,26 +314,26 @@ class RequestMetrics(object): ) return - incoming_requests_counter.inc(request.method, servlet_classname, tag) + incoming_requests_counter.inc(request.method, self.name, tag) response_timer.inc_by( clock.time_msec() - self.start, request.method, - servlet_classname, tag + self.name, tag ) ru_utime, ru_stime = context.get_resource_usage() response_ru_utime.inc_by( - ru_utime, request.method, servlet_classname, tag + ru_utime, request.method, self.name, tag ) response_ru_stime.inc_by( - ru_stime, request.method, servlet_classname, tag + ru_stime, request.method, self.name, tag ) response_db_txn_count.inc_by( - context.db_txn_count, request.method, servlet_classname, tag + context.db_txn_count, request.method, self.name, tag ) response_db_txn_duration.inc_by( - context.db_txn_duration, request.method, servlet_classname, tag + context.db_txn_duration, request.method, self.name, tag ) diff --git a/synapse/notifier.py b/synapse/notifier.py index 30883a069..b86648f5e 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -19,7 +19,8 @@ from synapse.api.errors import AuthError from synapse.util.logutils import log_function from synapse.util.async import ObservableDeferred -from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.logcontext import PreserveLoggingContext, preserve_fn +from synapse.util.metrics import Measure from synapse.types import StreamToken from synapse.visibility import filter_events_for_client import synapse.metrics @@ -67,10 +68,8 @@ class _NotifierUserStream(object): so that it can remove itself from the indexes in the Notifier class. """ - def __init__(self, user_id, rooms, current_token, time_now_ms, - appservice=None): + def __init__(self, user_id, rooms, current_token, time_now_ms): self.user_id = user_id - self.appservice = appservice self.rooms = set(rooms) self.current_token = current_token self.last_notified_ms = time_now_ms @@ -107,11 +106,6 @@ class _NotifierUserStream(object): notifier.user_to_user_stream.pop(self.user_id) - if self.appservice: - notifier.appservice_to_user_streams.get( - self.appservice, set() - ).discard(self) - def count_listeners(self): return len(self.notify_deferred.observers()) @@ -142,7 +136,6 @@ class Notifier(object): def __init__(self, hs): self.user_to_user_stream = {} self.room_to_user_streams = {} - self.appservice_to_user_streams = {} self.event_sources = hs.get_event_sources() self.store = hs.get_datastore() @@ -168,8 +161,6 @@ class Notifier(object): all_user_streams |= x for x in self.user_to_user_stream.values(): all_user_streams.add(x) - for x in self.appservice_to_user_streams.values(): - all_user_streams |= x return sum(stream.count_listeners() for stream in all_user_streams) metrics.register_callback("listeners", count_listeners) @@ -182,11 +173,8 @@ class Notifier(object): "users", lambda: len(self.user_to_user_stream), ) - metrics.register_callback( - "appservices", - lambda: count(bool, self.appservice_to_user_streams.values()), - ) + @preserve_fn def on_new_room_event(self, event, room_stream_id, max_room_stream_id, extra_users=[]): """ Used by handlers to inform the notifier something has happened @@ -208,6 +196,7 @@ class Notifier(object): self.notify_replication() + @preserve_fn def _notify_pending_new_room_events(self, max_room_stream_id): """Notify for the room events that were queued waiting for a previous event to be persisted. @@ -225,24 +214,11 @@ class Notifier(object): else: self._on_new_room_event(event, room_stream_id, extra_users) + @preserve_fn def _on_new_room_event(self, event, room_stream_id, extra_users=[]): """Notify any user streams that are interested in this room event""" # poke any interested application service. - self.appservice_handler.notify_interested_services(event) - - app_streams = set() - - for appservice in self.appservice_to_user_streams: - # TODO (kegan): Redundant appservice listener checks? - # App services will already be in the room_to_user_streams set, but - # that isn't enough. They need to be checked here in order to - # receive *invites* for users they are interested in. Does this - # make the room_to_user_streams check somewhat obselete? - if appservice.is_interested(event): - app_user_streams = self.appservice_to_user_streams.get( - appservice, set() - ) - app_streams |= app_user_streams + self.appservice_handler.notify_interested_services(room_stream_id) if event.type == EventTypes.Member and event.membership == Membership.JOIN: self._user_joined_room(event.state_key, event.room_id) @@ -251,35 +227,36 @@ class Notifier(object): "room_key", room_stream_id, users=extra_users, rooms=[event.room_id], - extra_streams=app_streams, ) - def on_new_event(self, stream_key, new_token, users=[], rooms=[], - extra_streams=set()): + @preserve_fn + def on_new_event(self, stream_key, new_token, users=[], rooms=[]): """ Used to inform listeners that something has happend event wise. Will wake up all listeners for the given users and rooms. """ with PreserveLoggingContext(): - user_streams = set() + with Measure(self.clock, "on_new_event"): + user_streams = set() - for user in users: - user_stream = self.user_to_user_stream.get(str(user)) - if user_stream is not None: - user_streams.add(user_stream) + for user in users: + user_stream = self.user_to_user_stream.get(str(user)) + if user_stream is not None: + user_streams.add(user_stream) - for room in rooms: - user_streams |= self.room_to_user_streams.get(room, set()) + for room in rooms: + user_streams |= self.room_to_user_streams.get(room, set()) - time_now_ms = self.clock.time_msec() - for user_stream in user_streams: - try: - user_stream.notify(stream_key, new_token, time_now_ms) - except: - logger.exception("Failed to notify listener") + time_now_ms = self.clock.time_msec() + for user_stream in user_streams: + try: + user_stream.notify(stream_key, new_token, time_now_ms) + except: + logger.exception("Failed to notify listener") - self.notify_replication() + self.notify_replication() + @preserve_fn def on_new_replication_data(self): """Used to inform replication listeners that something has happend without waking up any of the normal user event streams""" @@ -294,7 +271,6 @@ class Notifier(object): """ user_stream = self.user_to_user_stream.get(user_id) if user_stream is None: - appservice = yield self.store.get_app_service_by_user_id(user_id) current_token = yield self.event_sources.get_current_token() if room_ids is None: rooms = yield self.store.get_rooms_for_user(user_id) @@ -302,7 +278,6 @@ class Notifier(object): user_stream = _NotifierUserStream( user_id=user_id, rooms=room_ids, - appservice=appservice, current_token=current_token, time_now_ms=self.clock.time_msec(), ) @@ -477,11 +452,6 @@ class Notifier(object): s = self.room_to_user_streams.setdefault(room, set()) s.add(user_stream) - if user_stream.appservice: - self.appservice_to_user_stream.setdefault( - user_stream.appservice, set() - ).add(user_stream) - def _user_joined_room(self, user_id, room_id): new_user_stream = self.user_to_user_stream.get(user_id) if new_user_stream is not None: diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index 46e768e35..ed2ccc4df 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -38,15 +38,16 @@ class ActionGenerator: @defer.inlineCallbacks def handle_push_actions_for_event(self, event, context): - with Measure(self.clock, "handle_push_actions_for_event"): + with Measure(self.clock, "evaluator_for_event"): bulk_evaluator = yield evaluator_for_event( - event, self.hs, self.store, context.current_state + event, self.hs, self.store, context.state_group, context.current_state ) + with Measure(self.clock, "action_for_event_by_user"): actions_by_user = yield bulk_evaluator.action_for_event_by_user( event, context.current_state ) - context.push_actions = [ - (uid, actions) for uid, actions in actions_by_user.items() - ] + context.push_actions = [ + (uid, actions) for uid, actions in actions_by_user.items() + ] diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 024c14904..edb00ed20 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -217,6 +217,27 @@ BASE_APPEND_OVERRIDE_RULES = [ 'dont_notify' ] }, + # This was changed from underride to override so it's closer in priority + # to the content rules where the user name highlight rule lives. This + # way a room rule is lower priority than both but a custom override rule + # is higher priority than both. + { + 'rule_id': 'global/override/.m.rule.contains_display_name', + 'conditions': [ + { + 'kind': 'contains_display_name' + } + ], + 'actions': [ + 'notify', + { + 'set_tweak': 'sound', + 'value': 'default' + }, { + 'set_tweak': 'highlight' + } + ] + }, ] @@ -242,23 +263,6 @@ BASE_APPEND_UNDERRIDE_RULES = [ } ] }, - { - 'rule_id': 'global/underride/.m.rule.contains_display_name', - 'conditions': [ - { - 'kind': 'contains_display_name' - } - ], - 'actions': [ - 'notify', - { - 'set_tweak': 'sound', - 'value': 'default' - }, { - 'set_tweak': 'highlight' - } - ] - }, { 'rule_id': 'global/underride/.m.rule.room_one_to_one', 'conditions': [ diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 756e5da51..004eded61 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -36,35 +36,11 @@ def _get_rules(room_id, user_ids, store): @defer.inlineCallbacks -def evaluator_for_event(event, hs, store, current_state): - room_id = event.room_id - # We also will want to generate notifs for other people in the room so - # their unread countss are correct in the event stream, but to avoid - # generating them for bot / AS users etc, we only do so for people who've - # sent a read receipt into the room. - - local_users_in_room = set( - e.state_key for e in current_state.values() - if e.type == EventTypes.Member and e.membership == Membership.JOIN - and hs.is_mine_id(e.state_key) +def evaluator_for_event(event, hs, store, state_group, current_state): + rules_by_user = yield store.bulk_get_push_rules_for_room( + event.room_id, state_group, current_state ) - # users in the room who have pushers need to get push rules run because - # that's how their pushers work - if_users_with_pushers = yield store.get_if_users_have_pushers( - local_users_in_room - ) - user_ids = set( - uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher - ) - - users_with_receipts = yield store.get_users_with_read_receipts_in_room(room_id) - - # any users with pushers must be ours: they have pushers - for uid in users_with_receipts: - if uid in local_users_in_room: - user_ids.add(uid) - # if this event is an invite event, we may need to run rules for the user # who's been invited, otherwise they won't get told they've been invited if event.type == 'm.room.member' and event.content['membership'] == 'invite': @@ -72,12 +48,12 @@ def evaluator_for_event(event, hs, store, current_state): if invited_user and hs.is_mine_id(invited_user): has_pusher = yield store.user_has_pusher(invited_user) if has_pusher: - user_ids.add(invited_user) - - rules_by_user = yield _get_rules(room_id, user_ids, store) + rules_by_user[invited_user] = yield store.get_push_rules_for_user( + invited_user + ) defer.returnValue(BulkPushRuleEvaluator( - room_id, rules_by_user, user_ids, store + event.room_id, rules_by_user, store )) @@ -90,10 +66,9 @@ class BulkPushRuleEvaluator: the same logic to run the actual rules, but could be optimised further (see https://matrix.org/jira/browse/SYN-562) """ - def __init__(self, room_id, rules_by_user, users_in_room, store): + def __init__(self, room_id, rules_by_user, store): self.room_id = room_id self.rules_by_user = rules_by_user - self.users_in_room = users_in_room self.store = store @defer.inlineCallbacks diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index d555a33e9..becb8ef1a 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -17,14 +17,15 @@ from twisted.internet import defer from synapse.util.presentable_names import ( calculate_room_name, name_from_member_event ) +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred @defer.inlineCallbacks def get_badge_count(store, user_id): - invites, joins = yield defer.gatherResults([ - store.get_invited_rooms_for_user(user_id), - store.get_rooms_for_user(user_id), - ], consumeErrors=True) + invites, joins = yield preserve_context_over_deferred(defer.gatherResults([ + preserve_fn(store.get_invited_rooms_for_user)(user_id), + preserve_fn(store.get_rooms_for_user)(user_id), + ], consumeErrors=True)) my_receipts_by_room = yield store.get_receipts_for_user( user_id, "m.read", diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 5853ec36a..3837be523 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -17,7 +17,7 @@ from twisted.internet import defer import pusher -from synapse.util.logcontext import preserve_fn +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.async import run_on_reactor import logging @@ -102,14 +102,14 @@ class PusherPool: yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) @defer.inlineCallbacks - def remove_pushers_by_user(self, user_id, except_token_ids=[]): + def remove_pushers_by_user(self, user_id, except_access_token_id=None): all = yield self.store.get_all_pushers() logger.info( - "Removing all pushers for user %s except access tokens ids %r", - user_id, except_token_ids + "Removing all pushers for user %s except access tokens id %r", + user_id, except_access_token_id ) for p in all: - if p['user_name'] == user_id and p['access_token'] not in except_token_ids: + if p['user_name'] == user_id and p['access_token'] != except_access_token_id: logger.info( "Removing pusher for app id %s, pushkey %s, user %s", p['app_id'], p['pushkey'], p['user_name'] @@ -130,10 +130,12 @@ class PusherPool: if u in self.pushers: for p in self.pushers[u].values(): deferreds.append( - p.on_new_notifications(min_stream_id, max_stream_id) + preserve_fn(p.on_new_notifications)( + min_stream_id, max_stream_id + ) ) - yield defer.gatherResults(deferreds) + yield preserve_context_over_deferred(defer.gatherResults(deferreds)) except: logger.exception("Exception in pusher on_new_notifications") @@ -155,10 +157,10 @@ class PusherPool: if u in self.pushers: for p in self.pushers[u].values(): deferreds.append( - p.on_new_receipts(min_stream_id, max_stream_id) + preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id) ) - yield defer.gatherResults(deferreds) + yield preserve_context_over_deferred(defer.gatherResults(deferreds)) except: logger.exception("Exception in pusher on_new_receipts") diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index 8c2d487ff..84993b33b 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -41,6 +41,7 @@ STREAM_NAMES = ( ("push_rules",), ("pushers",), ("state",), + ("caches",), ) @@ -70,6 +71,7 @@ class ReplicationResource(Resource): * "backfill": Old events that have been backfilled from other servers. * "push_rules": Per user changes to push rules. * "pushers": Per user changes to their pushers. + * "caches": Cache invalidations. The API takes two additional query parameters: @@ -129,6 +131,7 @@ class ReplicationResource(Resource): push_rules_token, room_stream_token = self.store.get_push_rules_stream_token() pushers_token = self.store.get_pushers_stream_token() state_token = self.store.get_state_stream_token() + caches_token = self.store.get_cache_stream_token() defer.returnValue(_ReplicationToken( room_stream_token, @@ -140,6 +143,7 @@ class ReplicationResource(Resource): push_rules_token, pushers_token, state_token, + caches_token, )) @request_handler() @@ -188,6 +192,7 @@ class ReplicationResource(Resource): yield self.push_rules(writer, current_token, limit, request_streams) yield self.pushers(writer, current_token, limit, request_streams) yield self.state(writer, current_token, limit, request_streams) + yield self.caches(writer, current_token, limit, request_streams) self.streams(writer, current_token, request_streams) logger.info("Replicated %d rows", writer.total) @@ -379,6 +384,20 @@ class ReplicationResource(Resource): "position", "type", "state_key", "event_id" )) + @defer.inlineCallbacks + def caches(self, writer, current_token, limit, request_streams): + current_position = current_token.caches + + caches = request_streams.get("caches") + + if caches is not None: + updated_caches = yield self.store.get_all_updated_caches( + caches, current_position, limit + ) + writer.write_header_and_rows("caches", updated_caches, ( + "position", "cache_func", "keys", "invalidation_ts" + )) + class _Writer(object): """Writes the streams as a JSON object as the response to the request""" @@ -407,7 +426,7 @@ class _Writer(object): class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( "events", "presence", "typing", "receipts", "account_data", "backfill", - "push_rules", "pushers", "state" + "push_rules", "pushers", "state", "caches", ))): __slots__ = [] diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 46e43ce1c..f19540d6b 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -14,15 +14,43 @@ # limitations under the License. from synapse.storage._base import SQLBaseStore +from synapse.storage.engines import PostgresEngine from twisted.internet import defer +from ._slaved_id_tracker import SlavedIdTracker + +import logging + +logger = logging.getLogger(__name__) + class BaseSlavedStore(SQLBaseStore): def __init__(self, db_conn, hs): super(BaseSlavedStore, self).__init__(hs) + if isinstance(self.database_engine, PostgresEngine): + self._cache_id_gen = SlavedIdTracker( + db_conn, "cache_invalidation_stream", "stream_id", + ) + else: + self._cache_id_gen = None def stream_positions(self): - return {} + pos = {} + if self._cache_id_gen: + pos["caches"] = self._cache_id_gen.get_current_token() + return pos def process_replication(self, result): + stream = result.get("caches") + if stream: + for row in stream["rows"]: + ( + position, cache_func, keys, invalidation_ts, + ) = row + + try: + getattr(self, cache_func).invalidate(tuple(keys)) + except AttributeError: + logger.info("Got unexpected cache_func: %r", cache_func) + self._cache_id_gen.advance(int(stream["position"])) return defer.succeed(None) diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py index 25792d942..a374f2f1a 100644 --- a/synapse/replication/slave/storage/appservice.py +++ b/synapse/replication/slave/storage/appservice.py @@ -28,3 +28,13 @@ class SlavedApplicationServiceStore(BaseSlavedStore): get_app_service_by_token = DataStore.get_app_service_by_token.__func__ get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__ + get_app_services = DataStore.get_app_services.__func__ + get_new_events_for_appservice = DataStore.get_new_events_for_appservice.__func__ + create_appservice_txn = DataStore.create_appservice_txn.__func__ + get_appservices_by_state = DataStore.get_appservices_by_state.__func__ + get_oldest_unsent_txn = DataStore.get_oldest_unsent_txn.__func__ + _get_last_txn = DataStore._get_last_txn.__func__ + complete_appservice_txn = DataStore.complete_appservice_txn.__func__ + get_appservice_state = DataStore.get_appservice_state.__func__ + set_appservice_last_pos = DataStore.set_appservice_last_pos.__func__ + set_appservice_state = DataStore.set_appservice_state.__func__ diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py index 5fbe3a303..7301d885f 100644 --- a/synapse/replication/slave/storage/directory.py +++ b/synapse/replication/slave/storage/directory.py @@ -20,4 +20,4 @@ from synapse.storage.directory import DirectoryStore class DirectoryStore(BaseSlavedStore): get_aliases_for_room = DirectoryStore.__dict__[ "get_aliases_for_room" - ].orig + ] diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py index 307833f9e..e27c7332d 100644 --- a/synapse/replication/slave/storage/registration.py +++ b/synapse/replication/slave/storage/registration.py @@ -25,6 +25,9 @@ class SlavedRegistrationStore(BaseSlavedStore): # TODO: use the cached version and invalidate deleted tokens get_user_by_access_token = RegistrationStore.__dict__[ "get_user_by_access_token" - ].orig + ] _query_for_auth = DataStore._query_for_auth.__func__ + get_user_by_id = RegistrationStore.__dict__[ + "get_user_by_id" + ] diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 14227f1cd..326780405 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -46,7 +46,9 @@ from synapse.rest.client.v2_alpha import ( account_data, report_event, openid, + notifications, devices, + thirdparty, ) from synapse.http.server import JsonResource @@ -91,4 +93,6 @@ class ClientRestResource(JsonResource): account_data.register_servlets(hs, client_resource) report_event.register_servlets(hs, client_resource) openid.register_servlets(hs, client_resource) + notifications.register_servlets(hs, client_resource) devices.register_servlets(hs, client_resource) + thirdparty.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index b0cb31a44..af21661d7 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -28,6 +28,10 @@ logger = logging.getLogger(__name__) class WhoisRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/admin/whois/(?P[^/]*)") + def __init__(self, hs): + super(WhoisRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, user_id): target_user = UserID.from_string(user_id) @@ -82,6 +86,10 @@ class PurgeHistoryRestServlet(ClientV1RestServlet): "/admin/purge_history/(?P[^/]*)/(?P[^/]*)" ) + def __init__(self, hs): + super(PurgeHistoryRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_POST(self, request, room_id, event_id): requester = yield self.auth.get_user_by_req(request) diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py index 96b49b01f..c2a844786 100644 --- a/synapse/rest/client/v1/base.py +++ b/synapse/rest/client/v1/base.py @@ -57,7 +57,6 @@ class ClientV1RestServlet(RestServlet): hs (synapse.server.HomeServer): """ self.hs = hs - self.handlers = hs.get_handlers() self.builder_factory = hs.get_event_builder_factory() self.auth = hs.get_v1auth() self.txns = HttpTransactionStore() diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 8ac09419d..09d083159 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -36,6 +36,10 @@ def register_servlets(hs, http_server): class ClientDirectoryServer(ClientV1RestServlet): PATTERNS = client_path_patterns("/directory/room/(?P[^/]*)$") + def __init__(self, hs): + super(ClientDirectoryServer, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, room_alias): room_alias = RoomAlias.from_string(room_alias) @@ -146,6 +150,7 @@ class ClientDirectoryListServer(ClientV1RestServlet): def __init__(self, hs): super(ClientDirectoryListServer, self).__init__(hs) self.store = hs.get_datastore() + self.handlers = hs.get_handlers() @defer.inlineCallbacks def on_GET(self, request, room_id): diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 498bb9e18..701b6f549 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -32,6 +32,10 @@ class EventStreamRestServlet(ClientV1RestServlet): DEFAULT_LONGPOLL_TIME_MS = 30000 + def __init__(self, hs): + super(EventStreamRestServlet, self).__init__(hs) + self.event_stream_handler = hs.get_event_stream_handler() + @defer.inlineCallbacks def on_GET(self, request): requester = yield self.auth.get_user_by_req( @@ -46,7 +50,6 @@ class EventStreamRestServlet(ClientV1RestServlet): if "room_id" in request.args: room_id = request.args["room_id"][0] - handler = self.handlers.event_stream_handler pagin_config = PaginationConfig.from_request(request) timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS if "timeout" in request.args: @@ -57,7 +60,7 @@ class EventStreamRestServlet(ClientV1RestServlet): as_client_event = "raw" not in request.args - chunk = yield handler.get_stream( + chunk = yield self.event_stream_handler.get_stream( requester.user.to_string(), pagin_config, timeout=timeout, @@ -80,12 +83,12 @@ class EventRestServlet(ClientV1RestServlet): def __init__(self, hs): super(EventRestServlet, self).__init__(hs) self.clock = hs.get_clock() + self.event_handler = hs.get_event_handler() @defer.inlineCallbacks def on_GET(self, request, event_id): requester = yield self.auth.get_user_by_req(request) - handler = self.handlers.event_handler - event = yield handler.get_event(requester.user, event_id) + event = yield self.event_handler.get_event(requester.user, event_id) time_now = self.clock.time_msec() if event: diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 36c352056..113a49e53 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -23,6 +23,10 @@ from .base import ClientV1RestServlet, client_path_patterns class InitialSyncRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/initialSync$") + def __init__(self, hs): + super(InitialSyncRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request): requester = yield self.auth.get_user_by_req(request) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 92fcae674..6c0eec8fb 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -54,12 +54,9 @@ class LoginRestServlet(ClientV1RestServlet): self.jwt_secret = hs.config.jwt_secret self.jwt_algorithm = hs.config.jwt_algorithm self.cas_enabled = hs.config.cas_enabled - self.cas_server_url = hs.config.cas_server_url - self.cas_required_attributes = hs.config.cas_required_attributes - self.servername = hs.config.server_name - self.http_client = hs.get_simple_http_client() self.auth_handler = self.hs.get_auth_handler() self.device_handler = self.hs.get_device_handler() + self.handlers = hs.get_handlers() def on_GET(self, request): flows = [] @@ -110,17 +107,6 @@ class LoginRestServlet(ClientV1RestServlet): LoginRestServlet.JWT_TYPE): result = yield self.do_jwt_login(login_submission) defer.returnValue(result) - # TODO Delete this after all CAS clients switch to token login instead - elif self.cas_enabled and (login_submission["type"] == - LoginRestServlet.CAS_TYPE): - uri = "%s/proxyValidate" % (self.cas_server_url,) - args = { - "ticket": login_submission["ticket"], - "service": login_submission["service"] - } - body = yield self.http_client.get_raw(uri, args) - result = yield self.do_cas_login(body) - defer.returnValue(result) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: result = yield self.do_token_login(login_submission) defer.returnValue(result) @@ -191,51 +177,6 @@ class LoginRestServlet(ClientV1RestServlet): defer.returnValue((200, result)) - # TODO Delete this after all CAS clients switch to token login instead - @defer.inlineCallbacks - def do_cas_login(self, cas_response_body): - user, attributes = self.parse_cas_response(cas_response_body) - - for required_attribute, required_value in self.cas_required_attributes.items(): - # If required attribute was not in CAS Response - Forbidden - if required_attribute not in attributes: - raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) - - # Also need to check value - if required_value is not None: - actual_value = attributes[required_attribute] - # If required attribute value does not match expected - Forbidden - if required_value != actual_value: - raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) - - user_id = UserID.create(user, self.hs.hostname).to_string() - auth_handler = self.auth_handler - registered_user_id = yield auth_handler.check_user_exists(user_id) - if registered_user_id: - access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id( - registered_user_id - ) - ) - result = { - "user_id": registered_user_id, # may have changed - "access_token": access_token, - "refresh_token": refresh_token, - "home_server": self.hs.hostname, - } - - else: - user_id, access_token = ( - yield self.handlers.registration_handler.register(localpart=user) - ) - result = { - "user_id": user_id, # may have changed - "access_token": access_token, - "home_server": self.hs.hostname, - } - - defer.returnValue((200, result)) - @defer.inlineCallbacks def do_jwt_login(self, login_submission): token = login_submission.get("token", None) @@ -293,33 +234,6 @@ class LoginRestServlet(ClientV1RestServlet): defer.returnValue((200, result)) - # TODO Delete this after all CAS clients switch to token login instead - def parse_cas_response(self, cas_response_body): - root = ET.fromstring(cas_response_body) - if not root.tag.endswith("serviceResponse"): - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - if not root[0].tag.endswith("authenticationSuccess"): - raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED) - for child in root[0]: - if child.tag.endswith("user"): - user = child.text - if child.tag.endswith("attributes"): - attributes = {} - for attribute in child: - # ElementTree library expands the namespace in attribute tags - # to the full URL of the namespace. - # See (https://docs.python.org/2/library/xml.etree.elementtree.html) - # We don't care about namespace here and it will always be encased in - # curly braces, so we remove them. - if "}" in attribute.tag: - attributes[attribute.tag.split("}")[1]] = attribute.text - else: - attributes[attribute.tag] = attribute.text - if user is None or attributes is None: - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - - return (user, attributes) - def _register_device(self, user_id, login_submission): """Register a device for a user. @@ -347,6 +261,7 @@ class SAML2RestServlet(ClientV1RestServlet): def __init__(self, hs): super(SAML2RestServlet, self).__init__(hs) self.sp_config = hs.config.saml2_config_path + self.handlers = hs.get_handlers() @defer.inlineCallbacks def on_POST(self, request): @@ -384,18 +299,6 @@ class SAML2RestServlet(ClientV1RestServlet): defer.returnValue((200, {"status": "not_authenticated"})) -# TODO Delete this after all CAS clients switch to token login instead -class CasRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/login/cas", releases=()) - - def __init__(self, hs): - super(CasRestServlet, self).__init__(hs) - self.cas_server_url = hs.config.cas_server_url - - def on_GET(self, request): - return (200, {"serverUrl": self.cas_server_url}) - - class CasRedirectServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/login/cas/redirect", releases=()) @@ -427,6 +330,8 @@ class CasTicketServlet(ClientV1RestServlet): self.cas_server_url = hs.config.cas_server_url self.cas_service_url = hs.config.cas_service_url self.cas_required_attributes = hs.config.cas_required_attributes + self.auth_handler = hs.get_auth_handler() + self.handlers = hs.get_handlers() @defer.inlineCallbacks def on_GET(self, request): @@ -479,30 +384,39 @@ class CasTicketServlet(ClientV1RestServlet): return urlparse.urlunparse(url_parts) def parse_cas_response(self, cas_response_body): - root = ET.fromstring(cas_response_body) - if not root.tag.endswith("serviceResponse"): - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - if not root[0].tag.endswith("authenticationSuccess"): - raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED) - for child in root[0]: - if child.tag.endswith("user"): - user = child.text - if child.tag.endswith("attributes"): - attributes = {} - for attribute in child: - # ElementTree library expands the namespace in attribute tags - # to the full URL of the namespace. - # See (https://docs.python.org/2/library/xml.etree.elementtree.html) - # We don't care about namespace here and it will always be encased in - # curly braces, so we remove them. - if "}" in attribute.tag: - attributes[attribute.tag.split("}")[1]] = attribute.text - else: - attributes[attribute.tag] = attribute.text - if user is None or attributes is None: - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - - return (user, attributes) + user = None + attributes = None + try: + root = ET.fromstring(cas_response_body) + if not root.tag.endswith("serviceResponse"): + raise Exception("root of CAS response is not serviceResponse") + success = (root[0].tag.endswith("authenticationSuccess")) + for child in root[0]: + if child.tag.endswith("user"): + user = child.text + if child.tag.endswith("attributes"): + attributes = {} + for attribute in child: + # ElementTree library expands the namespace in + # attribute tags to the full URL of the namespace. + # We don't care about namespace here and it will always + # be encased in curly braces, so we remove them. + tag = attribute.tag + if "}" in tag: + tag = tag.split("}")[1] + attributes[tag] = attribute.text + if user is None: + raise Exception("CAS response does not contain user") + if attributes is None: + raise Exception("CAS response does not contain attributes") + except Exception: + logger.error("Error parsing CAS response", exc_info=1) + raise LoginError(401, "Invalid CAS response", + errcode=Codes.UNAUTHORIZED) + if not success: + raise LoginError(401, "Unsuccessful CAS response", + errcode=Codes.UNAUTHORIZED) + return user, attributes def register_servlets(hs, http_server): @@ -512,5 +426,3 @@ def register_servlets(hs, http_server): if hs.config.cas_enabled: CasRedirectServlet(hs).register(http_server) CasTicketServlet(hs).register(http_server) - CasRestServlet(hs).register(http_server) - # TODO PasswordResetRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index 65c4e2ebe..355e82474 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -24,6 +24,10 @@ from synapse.http.servlet import parse_json_object_from_request class ProfileDisplaynameRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/profile/(?P[^/]*)/displayname") + def __init__(self, hs): + super(ProfileDisplaynameRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, user_id): user = UserID.from_string(user_id) @@ -62,6 +66,10 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): class ProfileAvatarURLRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/profile/(?P[^/]*)/avatar_url") + def __init__(self, hs): + super(ProfileAvatarURLRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, user_id): user = UserID.from_string(user_id) @@ -99,6 +107,10 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): class ProfileRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/profile/(?P[^/]*)") + def __init__(self, hs): + super(ProfileRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, user_id): user = UserID.from_string(user_id) diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 2383b9df8..71d58c8e8 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -65,6 +65,7 @@ class RegisterRestServlet(ClientV1RestServlet): self.sessions = {} self.enable_registration = hs.config.enable_registration self.auth_handler = hs.get_auth_handler() + self.handlers = hs.get_handlers() def on_GET(self, request): if self.hs.config.enable_registration_captcha: @@ -383,6 +384,7 @@ class CreateUserRestServlet(ClientV1RestServlet): super(CreateUserRestServlet, self).__init__(hs) self.store = hs.get_datastore() self.direct_user_creation_max_duration = hs.config.user_creation_max_duration + self.handlers = hs.get_handlers() @defer.inlineCallbacks def on_POST(self, request): diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 866a1e912..0d8175701 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -35,6 +35,10 @@ logger = logging.getLogger(__name__) class RoomCreateRestServlet(ClientV1RestServlet): # No PATTERN; we have custom dispatch rules here + def __init__(self, hs): + super(RoomCreateRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): PATTERNS = "/createRoom" register_txn_path(self, PATTERNS, http_server) @@ -82,6 +86,10 @@ class RoomCreateRestServlet(ClientV1RestServlet): # TODO: Needs unit testing for generic events class RoomStateEventRestServlet(ClientV1RestServlet): + def __init__(self, hs): + super(RoomStateEventRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): # /room/$roomid/state/$eventtype no_state_key = "/rooms/(?P[^/]*)/state/(?P[^/]*)$" @@ -166,6 +174,10 @@ class RoomStateEventRestServlet(ClientV1RestServlet): # TODO: Needs unit testing for generic events + feedback class RoomSendEventRestServlet(ClientV1RestServlet): + def __init__(self, hs): + super(RoomSendEventRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): # /rooms/$roomid/send/$event_type[/$txn_id] PATTERNS = ("/rooms/(?P[^/]*)/send/(?P[^/]*)") @@ -210,6 +222,9 @@ class RoomSendEventRestServlet(ClientV1RestServlet): # TODO: Needs unit testing for room ID + alias joins class JoinRoomAliasServlet(ClientV1RestServlet): + def __init__(self, hs): + super(JoinRoomAliasServlet, self).__init__(hs) + self.handlers = hs.get_handlers() def register(self, http_server): # /join/$room_identifier[/$txn_id] @@ -253,6 +268,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet): action="join", txn_id=txn_id, remote_room_hosts=remote_room_hosts, + content=content, third_party_signed=content.get("third_party_signed", None), ) @@ -296,6 +312,10 @@ class PublicRoomListRestServlet(ClientV1RestServlet): class RoomMemberListRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/rooms/(?P[^/]*)/members$") + def __init__(self, hs): + super(RoomMemberListRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) @@ -322,6 +342,10 @@ class RoomMemberListRestServlet(ClientV1RestServlet): class RoomMessageListRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/rooms/(?P[^/]*)/messages$") + def __init__(self, hs): + super(RoomMessageListRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, room_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) @@ -351,6 +375,10 @@ class RoomMessageListRestServlet(ClientV1RestServlet): class RoomStateRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/rooms/(?P[^/]*)/state$") + def __init__(self, hs): + super(RoomStateRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, room_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) @@ -368,6 +396,10 @@ class RoomStateRestServlet(ClientV1RestServlet): class RoomInitialSyncRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/rooms/(?P[^/]*)/initialSync$") + def __init__(self, hs): + super(RoomInitialSyncRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, room_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) @@ -388,6 +420,7 @@ class RoomEventContext(ClientV1RestServlet): def __init__(self, hs): super(RoomEventContext, self).__init__(hs) self.clock = hs.get_clock() + self.handlers = hs.get_handlers() @defer.inlineCallbacks def on_GET(self, request, room_id, event_id): @@ -424,6 +457,10 @@ class RoomEventContext(ClientV1RestServlet): class RoomForgetRestServlet(ClientV1RestServlet): + def __init__(self, hs): + super(RoomForgetRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): PATTERNS = ("/rooms/(?P[^/]*)/forget") register_txn_path(self, PATTERNS, http_server) @@ -462,6 +499,10 @@ class RoomForgetRestServlet(ClientV1RestServlet): # TODO: Needs unit testing class RoomMembershipRestServlet(ClientV1RestServlet): + def __init__(self, hs): + super(RoomMembershipRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): # /rooms/$roomid/[invite|join|leave] PATTERNS = ("/rooms/(?P[^/]*)/" @@ -542,6 +583,10 @@ class RoomMembershipRestServlet(ClientV1RestServlet): class RoomRedactEventRestServlet(ClientV1RestServlet): + def __init__(self, hs): + super(RoomRedactEventRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): PATTERNS = ("/rooms/(?P[^/]*)/redact/(?P[^/]*)") register_txn_path(self, PATTERNS, http_server) @@ -624,6 +669,10 @@ class SearchRestServlet(ClientV1RestServlet): "/search$" ) + def __init__(self, hs): + super(SearchRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_POST(self, request): requester = yield self.auth.get_user_by_req(request) diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py new file mode 100644 index 000000000..f1a48acf0 --- /dev/null +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -0,0 +1,99 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.http.servlet import ( + RestServlet, parse_string, parse_integer +) +from synapse.events.utils import ( + serialize_event, format_event_for_client_v2_without_room_id, +) + +from ._base import client_v2_patterns + +import logging + +logger = logging.getLogger(__name__) + + +class NotificationsServlet(RestServlet): + PATTERNS = client_v2_patterns("/notifications$", releases=()) + + def __init__(self, hs): + super(NotificationsServlet, self).__init__() + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + + @defer.inlineCallbacks + def on_GET(self, request): + requester = yield self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + + from_token = parse_string(request, "from", required=False) + limit = parse_integer(request, "limit", default=50) + + limit = min(limit, 500) + + push_actions = yield self.store.get_push_actions_for_user( + user_id, from_token, limit + ) + + receipts_by_room = yield self.store.get_receipts_for_user_with_orderings( + user_id, 'm.read' + ) + + notif_event_ids = [pa["event_id"] for pa in push_actions] + notif_events = yield self.store.get_events(notif_event_ids) + + returned_push_actions = [] + + next_token = None + + for pa in push_actions: + returned_pa = { + "room_id": pa["room_id"], + "profile_tag": pa["profile_tag"], + "actions": pa["actions"], + "ts": pa["received_ts"], + "event": serialize_event( + notif_events[pa["event_id"]], + self.clock.time_msec(), + event_format=format_event_for_client_v2_without_room_id, + ), + } + + if pa["room_id"] not in receipts_by_room: + returned_pa["read"] = False + else: + receipt = receipts_by_room[pa["room_id"]] + + returned_pa["read"] = ( + receipt["topological_ordering"], receipt["stream_ordering"] + ) >= ( + pa["topological_ordering"], pa["stream_ordering"] + ) + returned_push_actions.append(returned_pa) + next_token = pa["stream_ordering"] + + defer.returnValue((200, { + "notifications": returned_push_actions, + "next_token": next_token, + })) + + +def register_servlets(hs, http_server): + NotificationsServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 943f5676a..2121bd75e 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -403,10 +403,9 @@ class RegisterRestServlet(RestServlet): # register the user's device device_id = params.get("device_id") initial_display_name = params.get("initial_device_display_name") - device_id = self.device_handler.check_device_registered( + return self.device_handler.check_device_registered( user_id, device_id, initial_display_name ) - return device_id @defer.inlineCallbacks def _do_guest_registration(self): diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 43d8e0bf3..b11acdbea 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -146,7 +146,7 @@ class SyncRestServlet(RestServlet): affect_presence = set_presence != PresenceState.OFFLINE if affect_presence: - yield self.presence_handler.set_state(user, {"presence": set_presence}) + yield self.presence_handler.set_state(user, {"presence": set_presence}, True) context = yield self.presence_handler.user_syncing( user.to_string(), affect_presence=affect_presence, diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py new file mode 100644 index 000000000..9abca3a8a --- /dev/null +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +from twisted.internet import defer + +from synapse.http.servlet import RestServlet +from synapse.types import ThirdPartyEntityKind +from ._base import client_v2_patterns + +logger = logging.getLogger(__name__) + + +class ThirdPartyUserServlet(RestServlet): + PATTERNS = client_v2_patterns("/3pu(/(?P[^/]+))?$", + releases=()) + + def __init__(self, hs): + super(ThirdPartyUserServlet, self).__init__() + + self.auth = hs.get_auth() + self.appservice_handler = hs.get_application_service_handler() + + @defer.inlineCallbacks + def on_GET(self, request, protocol): + yield self.auth.get_user_by_req(request) + + fields = request.args + del fields["access_token"] + + results = yield self.appservice_handler.query_3pe( + ThirdPartyEntityKind.USER, protocol, fields + ) + + defer.returnValue((200, results)) + + +class ThirdPartyLocationServlet(RestServlet): + PATTERNS = client_v2_patterns("/3pl(/(?P[^/]+))?$", + releases=()) + + def __init__(self, hs): + super(ThirdPartyLocationServlet, self).__init__() + + self.auth = hs.get_auth() + self.appservice_handler = hs.get_application_service_handler() + + @defer.inlineCallbacks + def on_GET(self, request, protocol): + yield self.auth.get_user_by_req(request) + + fields = request.args + del fields["access_token"] + + results = yield self.appservice_handler.query_3pe( + ThirdPartyEntityKind.LOCATION, protocol, fields + ) + + defer.returnValue((200, results)) + + +def register_servlets(hs, http_server): + ThirdPartyUserServlet(hs).register(http_server) + ThirdPartyLocationServlet(hs).register(http_server) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 7209d5a37..9fe201365 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -15,6 +15,7 @@ from synapse.http.server import request_handler, respond_with_json_bytes from synapse.http.servlet import parse_integer, parse_json_object_from_request from synapse.api.errors import SynapseError, Codes +from synapse.crypto.keyring import KeyLookupError from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET @@ -210,9 +211,10 @@ class RemoteKey(Resource): yield self.keyring.get_server_verify_key_v2_direct( server_name, key_ids ) + except KeyLookupError as e: + logger.info("Failed to fetch key: %s", e) except: logger.exception("Failed to get key for %r", server_name) - pass yield self.query_keys( request, query, query_remote_on_cache_miss=False ) diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index 9f6962077..9f0625a82 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -45,6 +45,7 @@ class DownloadResource(Resource): @request_handler() @defer.inlineCallbacks def _async_render_GET(self, request): + request.setHeader("Content-Security-Policy", "sandbox") server_name, media_id, name = parse_media_id(request) if server_name == self.server_name: yield self._respond_local_file(request, media_id, name) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index bdd0e60c5..33f35fb44 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -29,14 +29,13 @@ from synapse.http.server import ( from synapse.util.async import ObservableDeferred from synapse.util.stringutils import is_ascii -from copy import deepcopy - import os import re import fnmatch import cgi import ujson as json import urlparse +import itertools import logging logger = logging.getLogger(__name__) @@ -163,7 +162,7 @@ class PreviewUrlResource(Resource): logger.debug("got media_info of '%s'" % media_info) - if self._is_media(media_info['media_type']): + if _is_media(media_info['media_type']): dims = yield self.media_repo._generate_local_thumbnails( media_info['filesystem_id'], media_info ) @@ -184,11 +183,9 @@ class PreviewUrlResource(Resource): logger.warn("Couldn't get dims for %s" % url) # define our OG response for this media - elif self._is_html(media_info['media_type']): + elif _is_html(media_info['media_type']): # TODO: somehow stop a big HTML tree from exploding synapse's RAM - from lxml import etree - file = open(media_info['filename']) body = file.read() file.close() @@ -199,17 +196,35 @@ class PreviewUrlResource(Resource): match = re.match(r'.*; *charset=(.*?)(;|$)', media_info['media_type'], re.I) encoding = match.group(1) if match else "utf-8" - try: - parser = etree.HTMLParser(recover=True, encoding=encoding) - tree = etree.fromstring(body, parser) - og = yield self._calc_og(tree, media_info, requester) - except UnicodeDecodeError: - # blindly try decoding the body as utf-8, which seems to fix - # the charset mismatches on https://google.com - parser = etree.HTMLParser(recover=True, encoding=encoding) - tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser) - og = yield self._calc_og(tree, media_info, requester) + og = decode_and_calc_og(body, media_info['uri'], encoding) + # pre-cache the image for posterity + # FIXME: it might be cleaner to use the same flow as the main /preview_url + # request itself and benefit from the same caching etc. But for now we + # just rely on the caching on the master request to speed things up. + if 'og:image' in og and og['og:image']: + image_info = yield self._download_url( + _rebase_url(og['og:image'], media_info['uri']), requester.user + ) + + if _is_media(image_info['media_type']): + # TODO: make sure we don't choke on white-on-transparent images + dims = yield self.media_repo._generate_local_thumbnails( + image_info['filesystem_id'], image_info + ) + if dims: + og["og:image:width"] = dims['width'] + og["og:image:height"] = dims['height'] + else: + logger.warn("Couldn't get dims for %s" % og["og:image"]) + + og["og:image"] = "mxc://%s/%s" % ( + self.server_name, image_info['filesystem_id'] + ) + og["og:image:type"] = image_info['media_type'] + og["matrix:image:size"] = image_info['media_length'] + else: + del og["og:image"] else: logger.warn("Failed to find any OG data in %s", url) og = {} @@ -232,139 +247,6 @@ class PreviewUrlResource(Resource): respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True) - @defer.inlineCallbacks - def _calc_og(self, tree, media_info, requester): - # suck our tree into lxml and define our OG response. - - # if we see any image URLs in the OG response, then spider them - # (although the client could choose to do this by asking for previews of those - # URLs to avoid DoSing the server) - - # "og:type" : "video", - # "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw", - # "og:site_name" : "YouTube", - # "og:video:type" : "application/x-shockwave-flash", - # "og:description" : "Fun stuff happening here", - # "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon", - # "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg", - # "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1", - # "og:video:width" : "1280" - # "og:video:height" : "720", - # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3", - - og = {} - for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): - if 'content' in tag.attrib: - og[tag.attrib['property']] = tag.attrib['content'] - - # TODO: grab article: meta tags too, e.g.: - - # "article:publisher" : "https://www.facebook.com/thethudonline" /> - # "article:author" content="https://www.facebook.com/thethudonline" /> - # "article:tag" content="baby" /> - # "article:section" content="Breaking News" /> - # "article:published_time" content="2016-03-31T19:58:24+00:00" /> - # "article:modified_time" content="2016-04-01T18:31:53+00:00" /> - - if 'og:title' not in og: - # do some basic spidering of the HTML - title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]") - og['og:title'] = title[0].text.strip() if title else None - - if 'og:image' not in og: - # TODO: extract a favicon failing all else - meta_image = tree.xpath( - "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content" - ) - if meta_image: - og['og:image'] = self._rebase_url(meta_image[0], media_info['uri']) - else: - # TODO: consider inlined CSS styles as well as width & height attribs - images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]") - images = sorted(images, key=lambda i: ( - -1 * float(i.attrib['width']) * float(i.attrib['height']) - )) - if not images: - images = tree.xpath("//img[@src]") - if images: - og['og:image'] = images[0].attrib['src'] - - # pre-cache the image for posterity - # FIXME: it might be cleaner to use the same flow as the main /preview_url - # request itself and benefit from the same caching etc. But for now we - # just rely on the caching on the master request to speed things up. - if 'og:image' in og and og['og:image']: - image_info = yield self._download_url( - self._rebase_url(og['og:image'], media_info['uri']), requester.user - ) - - if self._is_media(image_info['media_type']): - # TODO: make sure we don't choke on white-on-transparent images - dims = yield self.media_repo._generate_local_thumbnails( - image_info['filesystem_id'], image_info - ) - if dims: - og["og:image:width"] = dims['width'] - og["og:image:height"] = dims['height'] - else: - logger.warn("Couldn't get dims for %s" % og["og:image"]) - - og["og:image"] = "mxc://%s/%s" % ( - self.server_name, image_info['filesystem_id'] - ) - og["og:image:type"] = image_info['media_type'] - og["matrix:image:size"] = image_info['media_length'] - else: - del og["og:image"] - - if 'og:description' not in og: - meta_description = tree.xpath( - "//*/meta" - "[translate(@name, 'DESCRIPTION', 'description')='description']" - "/@content") - if meta_description: - og['og:description'] = meta_description[0] - else: - # grab any text nodes which are inside the tag... - # unless they are within an HTML5 semantic markup tag... - #
,