Merge branch 'release-v0.17.1' of github.com:matrix-org/synapse

This commit is contained in:
Erik Johnston 2016-08-24 14:39:35 +01:00
commit 37638c06c5
98 changed files with 3277 additions and 1430 deletions

View File

@ -1,3 +1,50 @@
Changes in synapse v0.17.1 (2016-08-24)
=======================================
Changes:
* Delete old received_transactions rows (PR #1038)
* Pass through user-supplied content in /join/$room_id (PR #1039)
Bug fixes:
* Fix bug with backfill (PR #1040)
Changes in synapse v0.17.1-rc1 (2016-08-22)
===========================================
Features:
* Add notification API (PR #1028)
Changes:
* Don't print stack traces when failing to get remote keys (PR #996)
* Various federation /event/ perf improvements (PR #998)
* Only process one local membership event per room at a time (PR #1005)
* Move default display name push rule (PR #1011, #1023)
* Fix up preview URL API. Add tests. (PR #1015)
* Set ``Content-Security-Policy`` on media repo (PR #1021)
* Make notify_interested_services faster (PR #1022)
* Add usage stats to prometheus monitoring (PR #1037)
Bug fixes:
* Fix token login (PR #993)
* Fix CAS login (PR #994, #995)
* Fix /sync to not clobber status_msg (PR #997)
* Fix redacted state events to include prev_content (PR #1003)
* Fix some bugs in the auth/ldap handler (PR #1007)
* Fix backfill request to limit URI length, so that remotes don't reject the
requests due to path length limits (PR #1012)
* Fix AS push code to not send duplicate events (PR #1025)
Changes in synapse v0.17.0 (2016-08-08) Changes in synapse v0.17.0 (2016-08-08)
======================================= =======================================

View File

@ -95,7 +95,7 @@ Synapse is the reference python/twisted Matrix homeserver implementation.
System requirements: System requirements:
- POSIX-compliant system (tested on Linux & OS X) - POSIX-compliant system (tested on Linux & OS X)
- Python 2.7 - Python 2.7
- At least 512 MB RAM. - At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org
Synapse is written in python but some of the libraries is uses are written in Synapse is written in python but some of the libraries is uses are written in
C. So before we can install synapse itself we need a working C compiler and the C. So before we can install synapse itself we need a working C compiler and the

97
docs/workers.rst Normal file
View File

@ -0,0 +1,97 @@
Scaling synapse via workers
---------------------------
Synapse has experimental support for splitting out functionality into
multiple separate python processes, helping greatly with scalability. These
processes are called 'workers', and are (eventually) intended to scale
horizontally independently.
All processes continue to share the same database instance, and as such, workers
only work with postgres based synapse deployments (sharing a single sqlite
across multiple processes is a recipe for disaster, plus you should be using
postgres anyway if you care about scalability).
The workers communicate with the master synapse process via a synapse-specific
HTTP protocol called 'replication' - analogous to MySQL or Postgres style
database replication; feeding a stream of relevant data to the workers so they
can be kept in sync with the main synapse process and database state.
To enable workers, you need to add a replication listener to the master synapse, e.g.::
listeners:
- port: 9092
bind_address: '127.0.0.1'
type: http
tls: false
x_forwarded: false
resources:
- names: [replication]
compress: false
Under **no circumstances** should this replication API listener be exposed to the
public internet; it currently implements no authentication whatsoever and is
unencrypted HTTP.
You then create a set of configs for the various worker processes. These should be
worker configuration files should be stored in a dedicated subdirectory, to allow
synctl to manipulate them.
The current available worker applications are:
* synapse.app.pusher - handles sending push notifications to sygnal and email
* synapse.app.synchrotron - handles /sync endpoints. can scales horizontally through multiple instances.
* synapse.app.appservice - handles output traffic to Application Services
* synapse.app.federation_reader - handles receiving federation traffic (including public_rooms API)
* synapse.app.media_repository - handles the media repository.
Each worker configuration file inherits the configuration of the main homeserver
configuration file. You can then override configuration specific to that worker,
e.g. the HTTP listener that it provides (if any); logging configuration; etc.
You should minimise the number of overrides though to maintain a usable config.
You must specify the type of worker application (worker_app) and the replication
endpoint that it's talking to on the main synapse process (worker_replication_url).
For instance::
worker_app: synapse.app.synchrotron
# The replication listener on the synapse to talk to.
worker_replication_url: http://127.0.0.1:9092/_synapse/replication
worker_listeners:
- type: http
port: 8083
resources:
- names:
- client
worker_daemonize: True
worker_pid_file: /home/matrix/synapse/synchrotron.pid
worker_log_config: /home/matrix/synapse/config/synchrotron_log_config.yaml
...is a full configuration for a synchrotron worker instance, which will expose a
plain HTTP /sync endpoint on port 8083 separately from the /sync endpoint provided
by the main synapse.
Obviously you should configure your loadbalancer to route the /sync endpoint to
the synchrotron instance(s) in this instance.
Finally, to actually run your worker-based synapse, you must pass synctl the -a
commandline option to tell it to operate on all the worker configurations found
in the given directory, e.g.::
synctl -a $CONFIG/workers start
Currently one should always restart all workers when restarting or upgrading
synapse, unless you explicitly know it's safe not to. For instance, restarting
synapse without restarting all the synchrotrons may result in broken typing
notifications.
To manipulate a specific worker, you pass the -w option to synctl::
synctl -w $CONFIG/workers/synchrotron.yaml restart
All of the above is highly experimental and subject to change as Synapse evolves,
but documenting it here to help folks needing highly scalable Synapses similar
to the one running matrix.org!

View File

@ -25,5 +25,6 @@ rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27 tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install lxml
tox -e py27 tox -e py27

View File

@ -14,6 +14,7 @@ fi
tox -e py27 --notest -v tox -e py27 --notest -v
TOX_BIN=$TOX_DIR/py27/bin TOX_BIN=$TOX_DIR/py27/bin
$TOX_BIN/pip install setuptools
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install lxml $TOX_BIN/pip install lxml
$TOX_BIN/pip install psycopg2 $TOX_BIN/pip install psycopg2

View File

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.17.0" __version__ = "0.17.1"

View File

@ -675,27 +675,18 @@ class Auth(object):
try: try:
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str) macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
user_prefix = "user_id = " user_id = self.get_user_id_from_macaroon(macaroon)
user = None user = UserID.from_string(user_id)
user_id = None
guest = False
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
user_id = caveat.caveat_id[len(user_prefix):]
user = UserID.from_string(user_id)
elif caveat.caveat_id == "guest = true":
guest = True
self.validate_macaroon( self.validate_macaroon(
macaroon, rights, self.hs.config.expire_access_token, macaroon, rights, self.hs.config.expire_access_token,
user_id=user_id, user_id=user_id,
) )
if user is None: guest = False
raise AuthError( for caveat in macaroon.caveats:
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon", if caveat.caveat_id == "guest = true":
errcode=Codes.UNKNOWN_TOKEN guest = True
)
if guest: if guest:
ret = { ret = {
@ -743,6 +734,29 @@ class Auth(object):
errcode=Codes.UNKNOWN_TOKEN errcode=Codes.UNKNOWN_TOKEN
) )
def get_user_id_from_macaroon(self, macaroon):
"""Retrieve the user_id given by the caveats on the macaroon.
Does *not* validate the macaroon.
Args:
macaroon (pymacaroons.Macaroon): The macaroon to validate
Returns:
(str) user id
Raises:
AuthError if there is no user_id caveat in the macaroon
"""
user_prefix = "user_id = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
return caveat.caveat_id[len(user_prefix):]
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
errcode=Codes.UNKNOWN_TOKEN
)
def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id): def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id):
""" """
validate that a Macaroon is understood by and was signed by this server. validate that a Macaroon is understood by and was signed by this server.
@ -754,6 +768,7 @@ class Auth(object):
verify_expiry(bool): Whether to verify whether the macaroon has expired. verify_expiry(bool): Whether to verify whether the macaroon has expired.
This should really always be True, but no clients currently implement This should really always be True, but no clients currently implement
token refresh, so we can't enforce expiry yet. token refresh, so we can't enforce expiry yet.
user_id (str): The user_id required
""" """
v = pymacaroons.Verifier() v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1") v.satisfy_exact("gen = 1")

209
synapse/app/appservice.py Normal file
View File

@ -0,0 +1,209 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.server import HomeServer
from synapse.config._base import ConfigError
from synapse.config.logger import setup_logging
from synapse.config.homeserver import HomeServerConfig
from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.storage.engines import create_engine
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import gc
logger = logging.getLogger("synapse.app.appservice")
class AppserviceSlaveStore(
DirectoryStore, SlavedEventStore, SlavedApplicationServiceStore,
SlavedRegistrationStore,
):
pass
class AppserviceServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse appservice now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
appservice_handler = self.get_application_service_handler()
@defer.inlineCallbacks
def replicate(results):
stream = results.get("events")
if stream:
max_stream_id = stream["position"]
yield appservice_handler.notify_interested_services(max_stream_id)
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
replicate(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(30)
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse appservice", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.appservice"
setup_logging(config.worker_log_config, config.worker_log_file)
database_engine = create_engine(config.database_config)
if config.notify_appservices:
sys.stderr.write(
"\nThe appservices must be disabled in the main synapse process"
"\nbefore they can be run in a separate worker."
"\nPlease add ``notify_appservices: false`` to the main config"
"\n"
)
sys.exit(1)
# Force the pushers to start since they will be disabled in the main config
config.notify_appservices = True
ps = AppserviceServer(
config.server_name,
db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ps.setup()
ps.start_listening(config.worker_listeners)
def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ps.replicate()
ps.get_datastore().start_profiling()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-appservice",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View File

@ -51,7 +51,7 @@ from synapse.api.urls import (
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory from synapse.crypto import context_factory
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from synapse.metrics import register_memory_metrics from synapse.metrics import register_memory_metrics, get_metrics_for
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
from synapse.federation.transport.server import TransportLayerServer from synapse.federation.transport.server import TransportLayerServer
@ -385,6 +385,8 @@ def run(hs):
start_time = hs.get_clock().time() start_time = hs.get_clock().time()
stats = {}
@defer.inlineCallbacks @defer.inlineCallbacks
def phone_stats_home(): def phone_stats_home():
logger.info("Gathering stats for reporting") logger.info("Gathering stats for reporting")
@ -393,7 +395,10 @@ def run(hs):
if uptime < 0: if uptime < 0:
uptime = 0 uptime = 0
stats = {} # If the stats directory is empty then this is the first time we've
# reported stats.
first_time = not stats
stats["homeserver"] = hs.config.server_name stats["homeserver"] = hs.config.server_name
stats["timestamp"] = now stats["timestamp"] = now
stats["uptime_seconds"] = uptime stats["uptime_seconds"] = uptime
@ -406,6 +411,25 @@ def run(hs):
daily_messages = yield hs.get_datastore().count_daily_messages() daily_messages = yield hs.get_datastore().count_daily_messages()
if daily_messages is not None: if daily_messages is not None:
stats["daily_messages"] = daily_messages stats["daily_messages"] = daily_messages
else:
stats.pop("daily_messages", None)
if first_time:
# Add callbacks to report the synapse stats as metrics whenever
# prometheus requests them, typically every 30s.
# As some of the stats are expensive to calculate we only update
# them when synapse phones home to matrix.org every 24 hours.
metrics = get_metrics_for("synapse.usage")
metrics.add_callback("timestamp", lambda: stats["timestamp"])
metrics.add_callback("uptime_seconds", lambda: stats["uptime_seconds"])
metrics.add_callback("total_users", lambda: stats["total_users"])
metrics.add_callback("total_room_count", lambda: stats["total_room_count"])
metrics.add_callback(
"daily_active_users", lambda: stats["daily_active_users"]
)
metrics.add_callback(
"daily_messages", lambda: stats.get("daily_messages", 0)
)
logger.info("Reporting stats to matrix.org: %s" % (stats,)) logger.info("Reporting stats to matrix.org: %s" % (stats,))
try: try:

View File

@ -0,0 +1,212 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine
from synapse.storage.media_repository import MediaRepositoryStore
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from synapse.api.urls import (
CONTENT_REPO_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX
)
from synapse.crypto import context_factory
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import gc
logger = logging.getLogger("synapse.app.media_repository")
class MediaRepositorySlavedStore(
SlavedApplicationServiceStore,
SlavedRegistrationStore,
BaseSlavedStore,
MediaRepositoryStore,
ClientIpStore,
):
pass
class MediaRepositoryServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
elif name == "media":
media_repo = MediaRepositoryResource(self)
resources.update({
MEDIA_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo,
CONTENT_REPO_PREFIX: ContentRepoResource(
self, self.config.uploads_path
),
})
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse media repository now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(5)
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse media repository", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.media_repository"
setup_logging(config.worker_log_config, config.worker_log_file)
database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
ss = MediaRepositoryServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners)
def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ss.get_datastore().start_profiling()
ss.replicate()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-media-repository",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View File

@ -80,11 +80,6 @@ class PusherSlaveStore(
DataStore.get_profile_displayname.__func__ DataStore.get_profile_displayname.__func__
) )
# XXX: This is a bit broken because we don't persist forgotten rooms
# in a way that they can be streamed. This means that we don't have a
# way to invalidate the forgotten rooms cache correctly.
# For now we expire the cache every 10 minutes.
BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
who_forgot_in_room = ( who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"] RoomMemberStore.__dict__["who_forgot_in_room"]
) )
@ -168,7 +163,6 @@ class PusherServer(HomeServer):
store = self.get_datastore() store = self.get_datastore()
replication_url = self.config.worker_replication_url replication_url = self.config.worker_replication_url
pusher_pool = self.get_pusherpool() pusher_pool = self.get_pusherpool()
clock = self.get_clock()
def stop_pusher(user_id, app_id, pushkey): def stop_pusher(user_id, app_id, pushkey):
key = "%s:%s" % (app_id, pushkey) key = "%s:%s" % (app_id, pushkey)
@ -220,21 +214,11 @@ class PusherServer(HomeServer):
min_stream_id, max_stream_id, affected_room_ids min_stream_id, max_stream_id, affected_room_ids
) )
def expire_broken_caches():
store.who_forgot_in_room.invalidate_all()
next_expire_broken_caches_ms = 0
while True: while True:
try: try:
args = store.stream_positions() args = store.stream_positions()
args["timeout"] = 30000 args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args) result = yield http_client.get_json(replication_url, args=args)
now_ms = clock.time_msec()
if now_ms > next_expire_broken_caches_ms:
expire_broken_caches()
next_expire_broken_caches_ms = (
now_ms + store.BROKEN_CACHE_EXPIRY_MS
)
yield store.process_replication(result) yield store.process_replication(result)
poke_pushers(result) poke_pushers(result)
except: except:

View File

@ -26,6 +26,7 @@ from synapse.http.site import SynapseSite
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.rest.client.v2_alpha import sync from synapse.rest.client.v2_alpha import sync
from synapse.rest.client.v1 import events
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
@ -74,11 +75,6 @@ class SynchrotronSlavedStore(
BaseSlavedStore, BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different ClientIpStore, # After BaseSlavedStore because the constructor is different
): ):
# XXX: This is a bit broken because we don't persist forgotten rooms
# in a way that they can be streamed. This means that we don't have a
# way to invalidate the forgotten rooms cache correctly.
# For now we expire the cache every 10 minutes.
BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
who_forgot_in_room = ( who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"] RoomMemberStore.__dict__["who_forgot_in_room"]
) )
@ -89,17 +85,23 @@ class SynchrotronSlavedStore(
get_presence_list_accepted = PresenceStore.__dict__[ get_presence_list_accepted = PresenceStore.__dict__[
"get_presence_list_accepted" "get_presence_list_accepted"
] ]
get_presence_list_observers_accepted = PresenceStore.__dict__[
"get_presence_list_observers_accepted"
]
UPDATE_SYNCING_USERS_MS = 10 * 1000 UPDATE_SYNCING_USERS_MS = 10 * 1000
class SynchrotronPresence(object): class SynchrotronPresence(object):
def __init__(self, hs): def __init__(self, hs):
self.is_mine_id = hs.is_mine_id
self.http_client = hs.get_simple_http_client() self.http_client = hs.get_simple_http_client()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.user_to_num_current_syncs = {} self.user_to_num_current_syncs = {}
self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users" self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users"
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
active_presence = self.store.take_presence_startup_info() active_presence = self.store.take_presence_startup_info()
self.user_to_current_state = { self.user_to_current_state = {
@ -119,11 +121,13 @@ class SynchrotronPresence(object):
reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown) reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown)
def set_state(self, user, state): def set_state(self, user, state, ignore_status_msg=False):
# TODO Hows this supposed to work? # TODO Hows this supposed to work?
pass pass
get_states = PresenceHandler.get_states.__func__ get_states = PresenceHandler.get_states.__func__
get_state = PresenceHandler.get_state.__func__
_get_interested_parties = PresenceHandler._get_interested_parties.__func__
current_state_for_users = PresenceHandler.current_state_for_users.__func__ current_state_for_users = PresenceHandler.current_state_for_users.__func__
@defer.inlineCallbacks @defer.inlineCallbacks
@ -194,19 +198,39 @@ class SynchrotronPresence(object):
self._need_to_send_sync = False self._need_to_send_sync = False
yield self._send_syncing_users_now() yield self._send_syncing_users_now()
@defer.inlineCallbacks
def notify_from_replication(self, states, stream_id):
parties = yield self._get_interested_parties(
states, calculate_remote_hosts=False
)
room_ids_to_states, users_to_states, _ = parties
self.notifier.on_new_event(
"presence_key", stream_id, rooms=room_ids_to_states.keys(),
users=users_to_states.keys()
)
@defer.inlineCallbacks
def process_replication(self, result): def process_replication(self, result):
stream = result.get("presence", {"rows": []}) stream = result.get("presence", {"rows": []})
states = []
for row in stream["rows"]: for row in stream["rows"]:
( (
position, user_id, state, last_active_ts, position, user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts, status_msg, last_federation_update_ts, last_user_sync_ts, status_msg,
currently_active currently_active
) = row ) = row
self.user_to_current_state[user_id] = UserPresenceState( state = UserPresenceState(
user_id, state, last_active_ts, user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts, status_msg, last_federation_update_ts, last_user_sync_ts, status_msg,
currently_active currently_active
) )
self.user_to_current_state[user_id] = state
states.append(state)
if states and "position" in stream:
stream_id = int(stream["position"])
yield self.notify_from_replication(states, stream_id)
class SynchrotronTyping(object): class SynchrotronTyping(object):
@ -266,10 +290,12 @@ class SynchrotronServer(HomeServer):
elif name == "client": elif name == "client":
resource = JsonResource(self, canonical_json=False) resource = JsonResource(self, canonical_json=False)
sync.register_servlets(self, resource) sync.register_servlets(self, resource)
events.register_servlets(self, resource)
resources.update({ resources.update({
"/_matrix/client/r0": resource, "/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource, "/_matrix/client/unstable": resource,
"/_matrix/client/v2_alpha": resource, "/_matrix/client/v2_alpha": resource,
"/_matrix/client/api/v1": resource,
}) })
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, Resource())
@ -307,15 +333,10 @@ class SynchrotronServer(HomeServer):
http_client = self.get_simple_http_client() http_client = self.get_simple_http_client()
store = self.get_datastore() store = self.get_datastore()
replication_url = self.config.worker_replication_url replication_url = self.config.worker_replication_url
clock = self.get_clock()
notifier = self.get_notifier() notifier = self.get_notifier()
presence_handler = self.get_presence_handler() presence_handler = self.get_presence_handler()
typing_handler = self.get_typing_handler() typing_handler = self.get_typing_handler()
def expire_broken_caches():
store.who_forgot_in_room.invalidate_all()
store.get_presence_list_accepted.invalidate_all()
def notify_from_stream( def notify_from_stream(
result, stream_name, stream_key, room=None, user=None result, stream_name, stream_key, room=None, user=None
): ):
@ -377,22 +398,15 @@ class SynchrotronServer(HomeServer):
result, "typing", "typing_key", room="room_id" result, "typing", "typing_key", room="room_id"
) )
next_expire_broken_caches_ms = 0
while True: while True:
try: try:
args = store.stream_positions() args = store.stream_positions()
args.update(typing_handler.stream_positions()) args.update(typing_handler.stream_positions())
args["timeout"] = 30000 args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args) result = yield http_client.get_json(replication_url, args=args)
now_ms = clock.time_msec()
if now_ms > next_expire_broken_caches_ms:
expire_broken_caches()
next_expire_broken_caches_ms = (
now_ms + store.BROKEN_CACHE_EXPIRY_MS
)
yield store.process_replication(result) yield store.process_replication(result)
typing_handler.process_replication(result) typing_handler.process_replication(result)
presence_handler.process_replication(result) yield presence_handler.process_replication(result)
notify(result) notify(result)
except: except:
logger.exception("Error replicating from %r", replication_url) logger.exception("Error replicating from %r", replication_url)

View File

@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from twisted.internet import defer
import logging import logging
import re import re
@ -79,13 +81,17 @@ class ApplicationService(object):
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS] NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
def __init__(self, token, url=None, namespaces=None, hs_token=None, def __init__(self, token, url=None, namespaces=None, hs_token=None,
sender=None, id=None): sender=None, id=None, protocols=None):
self.token = token self.token = token
self.url = url self.url = url
self.hs_token = hs_token self.hs_token = hs_token
self.sender = sender self.sender = sender
self.namespaces = self._check_namespaces(namespaces) self.namespaces = self._check_namespaces(namespaces)
self.id = id self.id = id
if protocols:
self.protocols = set(protocols)
else:
self.protocols = set()
def _check_namespaces(self, namespaces): def _check_namespaces(self, namespaces):
# Sanity check that it is of the form: # Sanity check that it is of the form:
@ -138,65 +144,66 @@ class ApplicationService(object):
return regex_obj["exclusive"] return regex_obj["exclusive"]
return False return False
def _matches_user(self, event, member_list): @defer.inlineCallbacks
if (hasattr(event, "sender") and def _matches_user(self, event, store):
self.is_interested_in_user(event.sender)): if not event:
return True defer.returnValue(False)
if self.is_interested_in_user(event.sender):
defer.returnValue(True)
# also check m.room.member state key # also check m.room.member state key
if (hasattr(event, "type") and event.type == EventTypes.Member if (event.type == EventTypes.Member and
and hasattr(event, "state_key") self.is_interested_in_user(event.state_key)):
and self.is_interested_in_user(event.state_key)): defer.returnValue(True)
return True
if not store:
defer.returnValue(False)
member_list = yield store.get_users_in_room(event.room_id)
# check joined member events # check joined member events
for user_id in member_list: for user_id in member_list:
if self.is_interested_in_user(user_id): if self.is_interested_in_user(user_id):
return True defer.returnValue(True)
return False defer.returnValue(False)
def _matches_room_id(self, event): def _matches_room_id(self, event):
if hasattr(event, "room_id"): if hasattr(event, "room_id"):
return self.is_interested_in_room(event.room_id) return self.is_interested_in_room(event.room_id)
return False return False
def _matches_aliases(self, event, alias_list): @defer.inlineCallbacks
def _matches_aliases(self, event, store):
if not store or not event:
defer.returnValue(False)
alias_list = yield store.get_aliases_for_room(event.room_id)
for alias in alias_list: for alias in alias_list:
if self.is_interested_in_alias(alias): if self.is_interested_in_alias(alias):
return True defer.returnValue(True)
return False defer.returnValue(False)
def is_interested(self, event, restrict_to=None, aliases_for_event=None, @defer.inlineCallbacks
member_list=None): def is_interested(self, event, store=None):
"""Check if this service is interested in this event. """Check if this service is interested in this event.
Args: Args:
event(Event): The event to check. event(Event): The event to check.
restrict_to(str): The namespace to restrict regex tests to. store(DataStore)
aliases_for_event(list): A list of all the known room aliases for
this event.
member_list(list): A list of all joined user_ids in this room.
Returns: Returns:
bool: True if this service would like to know about this event. bool: True if this service would like to know about this event.
""" """
if aliases_for_event is None: # Do cheap checks first
aliases_for_event = [] if self._matches_room_id(event):
if member_list is None: defer.returnValue(True)
member_list = []
if restrict_to and restrict_to not in ApplicationService.NS_LIST: if (yield self._matches_aliases(event, store)):
# this is a programming error, so fail early and raise a general defer.returnValue(True)
# exception
raise Exception("Unexpected restrict_to value: %s". restrict_to)
if not restrict_to: if (yield self._matches_user(event, store)):
return (self._matches_user(event, member_list) defer.returnValue(True)
or self._matches_aliases(event, aliases_for_event)
or self._matches_room_id(event)) defer.returnValue(False)
elif restrict_to == ApplicationService.NS_ALIASES:
return self._matches_aliases(event, aliases_for_event)
elif restrict_to == ApplicationService.NS_ROOMS:
return self._matches_room_id(event)
elif restrict_to == ApplicationService.NS_USERS:
return self._matches_user(event, member_list)
def is_interested_in_user(self, user_id): def is_interested_in_user(self, user_id):
return ( return (
@ -216,6 +223,9 @@ class ApplicationService(object):
or user_id == self.sender or user_id == self.sender
) )
def is_interested_in_protocol(self, protocol):
return protocol in self.protocols
def is_exclusive_alias(self, alias): def is_exclusive_alias(self, alias):
return self._is_exclusive(ApplicationService.NS_ALIASES, alias) return self._is_exclusive(ApplicationService.NS_ALIASES, alias)

View File

@ -17,6 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import CodeMessageException from synapse.api.errors import CodeMessageException
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.types import ThirdPartyEntityKind
import logging import logging
import urllib import urllib
@ -24,6 +25,28 @@ import urllib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _is_valid_3pe_result(r, field):
if not isinstance(r, dict):
return False
for k in (field, "protocol"):
if k not in r:
return False
if not isinstance(r[k], str):
return False
if "fields" not in r:
return False
fields = r["fields"]
if not isinstance(fields, dict):
return False
for k in fields.keys():
if not isinstance(fields[k], str):
return False
return True
class ApplicationServiceApi(SimpleHttpClient): class ApplicationServiceApi(SimpleHttpClient):
"""This class manages HS -> AS communications, including querying and """This class manages HS -> AS communications, including querying and
pushing. pushing.
@ -71,6 +94,43 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("query_alias to %s threw exception %s", uri, ex) logger.warning("query_alias to %s threw exception %s", uri, ex)
defer.returnValue(False) defer.returnValue(False)
@defer.inlineCallbacks
def query_3pe(self, service, kind, protocol, fields):
if kind == ThirdPartyEntityKind.USER:
uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
required_field = "userid"
elif kind == ThirdPartyEntityKind.LOCATION:
uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
required_field = "alias"
else:
raise ValueError(
"Unrecognised 'kind' argument %r to query_3pe()", kind
)
try:
response = yield self.get_json(uri, fields)
if not isinstance(response, list):
logger.warning(
"query_3pe to %s returned an invalid response %r",
uri, response
)
defer.returnValue([])
ret = []
for r in response:
if _is_valid_3pe_result(r, field=required_field):
ret.append(r)
else:
logger.warning(
"query_3pe to %s returned an invalid result %r",
uri, r
)
defer.returnValue(ret)
except Exception as ex:
logger.warning("query_3pe to %s threw exception %s", uri, ex)
defer.returnValue([])
@defer.inlineCallbacks @defer.inlineCallbacks
def push_bulk(self, service, events, txn_id=None): def push_bulk(self, service, events, txn_id=None):
events = self._serialize(events) events = self._serialize(events)

View File

@ -48,9 +48,12 @@ UP & quit +---------- YES SUCCESS
This is all tied together by the AppServiceScheduler which DIs the required This is all tied together by the AppServiceScheduler which DIs the required
components. components.
""" """
from twisted.internet import defer
from synapse.appservice import ApplicationServiceState from synapse.appservice import ApplicationServiceState
from twisted.internet import defer from synapse.util.logcontext import preserve_fn
from synapse.util.metrics import Measure
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -73,7 +76,7 @@ class ApplicationServiceScheduler(object):
self.txn_ctrl = _TransactionController( self.txn_ctrl = _TransactionController(
self.clock, self.store, self.as_api, create_recoverer self.clock, self.store, self.as_api, create_recoverer
) )
self.queuer = _ServiceQueuer(self.txn_ctrl) self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock)
@defer.inlineCallbacks @defer.inlineCallbacks
def start(self): def start(self):
@ -94,38 +97,36 @@ class _ServiceQueuer(object):
this schedules any other events in the queue to run. this schedules any other events in the queue to run.
""" """
def __init__(self, txn_ctrl): def __init__(self, txn_ctrl, clock):
self.queued_events = {} # dict of {service_id: [events]} self.queued_events = {} # dict of {service_id: [events]}
self.pending_requests = {} # dict of {service_id: Deferred} self.requests_in_flight = set()
self.txn_ctrl = txn_ctrl self.txn_ctrl = txn_ctrl
self.clock = clock
def enqueue(self, service, event): def enqueue(self, service, event):
# if this service isn't being sent something # if this service isn't being sent something
if not self.pending_requests.get(service.id): self.queued_events.setdefault(service.id, []).append(event)
self._send_request(service, [event]) preserve_fn(self._send_request)(service)
else:
# add to queue for this service
if service.id not in self.queued_events:
self.queued_events[service.id] = []
self.queued_events[service.id].append(event)
def _send_request(self, service, events): @defer.inlineCallbacks
# send request and add callbacks def _send_request(self, service):
d = self.txn_ctrl.send(service, events) if service.id in self.requests_in_flight:
d.addBoth(self._on_request_finish) return
d.addErrback(self._on_request_fail)
self.pending_requests[service.id] = d
def _on_request_finish(self, service): self.requests_in_flight.add(service.id)
self.pending_requests[service.id] = None try:
# if there are queued events, then send them. while True:
if (service.id in self.queued_events events = self.queued_events.pop(service.id, [])
and len(self.queued_events[service.id]) > 0): if not events:
self._send_request(service, self.queued_events[service.id]) return
self.queued_events[service.id] = []
def _on_request_fail(self, err): with Measure(self.clock, "servicequeuer.send"):
logger.error("AS request failed: %s", err) try:
yield self.txn_ctrl.send(service, events)
except:
logger.exception("AS request failed")
finally:
self.requests_in_flight.discard(service.id)
class _TransactionController(object): class _TransactionController(object):
@ -149,14 +150,12 @@ class _TransactionController(object):
if service_is_up: if service_is_up:
sent = yield txn.send(self.as_api) sent = yield txn.send(self.as_api)
if sent: if sent:
txn.complete(self.store) yield txn.complete(self.store)
else: else:
self._start_recoverer(service) preserve_fn(self._start_recoverer)(service)
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
self._start_recoverer(service) preserve_fn(self._start_recoverer)(service)
# request has finished
defer.returnValue(service)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_recovered(self, recoverer): def on_recovered(self, recoverer):

View File

@ -28,6 +28,7 @@ class AppServiceConfig(Config):
def read_config(self, config): def read_config(self, config):
self.app_service_config_files = config.get("app_service_config_files", []) self.app_service_config_files = config.get("app_service_config_files", [])
self.notify_appservices = config.get("notify_appservices", True)
def default_config(cls, **kwargs): def default_config(cls, **kwargs):
return """\ return """\
@ -122,6 +123,15 @@ def _load_appservice(hostname, as_info, config_filename):
raise ValueError( raise ValueError(
"Missing/bad type 'exclusive' key in %s", regex_obj "Missing/bad type 'exclusive' key in %s", regex_obj
) )
# protocols check
protocols = as_info.get("protocols")
if protocols:
# Because strings are lists in python
if isinstance(protocols, str) or not isinstance(protocols, list):
raise KeyError("Optional 'protocols' must be a list if present.")
for p in protocols:
if not isinstance(p, str):
raise KeyError("Bad value for 'protocols' item")
return ApplicationService( return ApplicationService(
token=as_info["as_token"], token=as_info["as_token"],
url=as_info["url"], url=as_info["url"],
@ -129,4 +139,5 @@ def _load_appservice(hostname, as_info, config_filename):
hs_token=as_info["hs_token"], hs_token=as_info["hs_token"],
sender=user_id, sender=user_id,
id=as_info["id"], id=as_info["id"],
protocols=protocols,
) )

View File

@ -22,6 +22,7 @@ from synapse.util.logcontext import (
preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
preserve_fn preserve_fn
) )
from synapse.util.metrics import Measure
from twisted.internet import defer from twisted.internet import defer
@ -61,6 +62,10 @@ Attributes:
""" """
class KeyLookupError(ValueError):
pass
class Keyring(object): class Keyring(object):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -239,59 +244,60 @@ class Keyring(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def do_iterations(): def do_iterations():
merged_results = {} with Measure(self.clock, "get_server_verify_keys"):
merged_results = {}
missing_keys = {}
for verify_request in verify_requests:
missing_keys.setdefault(verify_request.server_name, set()).update(
verify_request.key_ids
)
for fn in key_fetch_fns:
results = yield fn(missing_keys.items())
merged_results.update(results)
# We now need to figure out which verify requests we have keys
# for and which we don't
missing_keys = {} missing_keys = {}
requests_missing_keys = []
for verify_request in verify_requests: for verify_request in verify_requests:
server_name = verify_request.server_name missing_keys.setdefault(verify_request.server_name, set()).update(
result_keys = merged_results[server_name] verify_request.key_ids
)
if verify_request.deferred.called: for fn in key_fetch_fns:
# We've already called this deferred, which probably results = yield fn(missing_keys.items())
# means that we've already found a key for it. merged_results.update(results)
continue
for key_id in verify_request.key_ids: # We now need to figure out which verify requests we have keys
if key_id in result_keys: # for and which we don't
with PreserveLoggingContext(): missing_keys = {}
verify_request.deferred.callback(( requests_missing_keys = []
server_name, for verify_request in verify_requests:
key_id, server_name = verify_request.server_name
result_keys[key_id], result_keys = merged_results[server_name]
))
break
else:
# The else block is only reached if the loop above
# doesn't break.
missing_keys.setdefault(server_name, set()).update(
verify_request.key_ids
)
requests_missing_keys.append(verify_request)
if not missing_keys: if verify_request.deferred.called:
break # We've already called this deferred, which probably
# means that we've already found a key for it.
continue
for verify_request in requests_missing_keys.values(): for key_id in verify_request.key_ids:
verify_request.deferred.errback(SynapseError( if key_id in result_keys:
401, with PreserveLoggingContext():
"No key for %s with id %s" % ( verify_request.deferred.callback((
verify_request.server_name, verify_request.key_ids, server_name,
), key_id,
Codes.UNAUTHORIZED, result_keys[key_id],
)) ))
break
else:
# The else block is only reached if the loop above
# doesn't break.
missing_keys.setdefault(server_name, set()).update(
verify_request.key_ids
)
requests_missing_keys.append(verify_request)
if not missing_keys:
break
for verify_request in requests_missing_keys.values():
verify_request.deferred.errback(SynapseError(
401,
"No key for %s with id %s" % (
verify_request.server_name, verify_request.key_ids,
),
Codes.UNAUTHORIZED,
))
def on_err(err): def on_err(err):
for verify_request in verify_requests: for verify_request in verify_requests:
@ -302,15 +308,15 @@ class Keyring(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids): def get_keys_from_store(self, server_name_and_key_ids):
res = yield defer.gatherResults( res = yield preserve_context_over_deferred(defer.gatherResults(
[ [
self.store.get_server_verify_keys( preserve_fn(self.store.get_server_verify_keys)(
server_name, key_ids server_name, key_ids
).addCallback(lambda ks, server: (server, ks), server_name) ).addCallback(lambda ks, server: (server, ks), server_name)
for server_name, key_ids in server_name_and_key_ids for server_name, key_ids in server_name_and_key_ids
], ],
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
defer.returnValue(dict(res)) defer.returnValue(dict(res))
@ -331,13 +337,13 @@ class Keyring(object):
) )
defer.returnValue({}) defer.returnValue({})
results = yield defer.gatherResults( results = yield preserve_context_over_deferred(defer.gatherResults(
[ [
get_key(p_name, p_keys) preserve_fn(get_key)(p_name, p_keys)
for p_name, p_keys in self.perspective_servers.items() for p_name, p_keys in self.perspective_servers.items()
], ],
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
union_of_keys = {} union_of_keys = {}
for result in results: for result in results:
@ -363,7 +369,7 @@ class Keyring(object):
) )
except Exception as e: except Exception as e:
logger.info( logger.info(
"Unable to getting key %r for %r directly: %s %s", "Unable to get key %r for %r directly: %s %s",
key_ids, server_name, key_ids, server_name,
type(e).__name__, str(e.message), type(e).__name__, str(e.message),
) )
@ -377,13 +383,13 @@ class Keyring(object):
defer.returnValue(keys) defer.returnValue(keys)
results = yield defer.gatherResults( results = yield preserve_context_over_deferred(defer.gatherResults(
[ [
get_key(server_name, key_ids) preserve_fn(get_key)(server_name, key_ids)
for server_name, key_ids in server_name_and_key_ids for server_name, key_ids in server_name_and_key_ids
], ],
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
merged = {} merged = {}
for result in results: for result in results:
@ -425,7 +431,7 @@ class Keyring(object):
for response in responses: for response in responses:
if (u"signatures" not in response if (u"signatures" not in response
or perspective_name not in response[u"signatures"]): or perspective_name not in response[u"signatures"]):
raise ValueError( raise KeyLookupError(
"Key response not signed by perspective server" "Key response not signed by perspective server"
" %r" % (perspective_name,) " %r" % (perspective_name,)
) )
@ -448,7 +454,7 @@ class Keyring(object):
list(response[u"signatures"][perspective_name]), list(response[u"signatures"][perspective_name]),
list(perspective_keys) list(perspective_keys)
) )
raise ValueError( raise KeyLookupError(
"Response not signed with a known key for perspective" "Response not signed with a known key for perspective"
" server %r" % (perspective_name,) " server %r" % (perspective_name,)
) )
@ -460,9 +466,9 @@ class Keyring(object):
for server_name, response_keys in processed_response.items(): for server_name, response_keys in processed_response.items():
keys.setdefault(server_name, {}).update(response_keys) keys.setdefault(server_name, {}).update(response_keys)
yield defer.gatherResults( yield preserve_context_over_deferred(defer.gatherResults(
[ [
self.store_keys( preserve_fn(self.store_keys)(
server_name=server_name, server_name=server_name,
from_server=perspective_name, from_server=perspective_name,
verify_keys=response_keys, verify_keys=response_keys,
@ -470,7 +476,7 @@ class Keyring(object):
for server_name, response_keys in keys.items() for server_name, response_keys in keys.items()
], ],
consumeErrors=True consumeErrors=True
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
defer.returnValue(keys) defer.returnValue(keys)
@ -491,10 +497,10 @@ class Keyring(object):
if (u"signatures" not in response if (u"signatures" not in response
or server_name not in response[u"signatures"]): or server_name not in response[u"signatures"]):
raise ValueError("Key response not signed by remote server") raise KeyLookupError("Key response not signed by remote server")
if "tls_fingerprints" not in response: if "tls_fingerprints" not in response:
raise ValueError("Key response missing TLS fingerprints") raise KeyLookupError("Key response missing TLS fingerprints")
certificate_bytes = crypto.dump_certificate( certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1, tls_certificate crypto.FILETYPE_ASN1, tls_certificate
@ -508,7 +514,7 @@ class Keyring(object):
response_sha256_fingerprints.add(fingerprint[u"sha256"]) response_sha256_fingerprints.add(fingerprint[u"sha256"])
if sha256_fingerprint_b64 not in response_sha256_fingerprints: if sha256_fingerprint_b64 not in response_sha256_fingerprints:
raise ValueError("TLS certificate not allowed by fingerprints") raise KeyLookupError("TLS certificate not allowed by fingerprints")
response_keys = yield self.process_v2_response( response_keys = yield self.process_v2_response(
from_server=server_name, from_server=server_name,
@ -518,7 +524,7 @@ class Keyring(object):
keys.update(response_keys) keys.update(response_keys)
yield defer.gatherResults( yield preserve_context_over_deferred(defer.gatherResults(
[ [
preserve_fn(self.store_keys)( preserve_fn(self.store_keys)(
server_name=key_server_name, server_name=key_server_name,
@ -528,7 +534,7 @@ class Keyring(object):
for key_server_name, verify_keys in keys.items() for key_server_name, verify_keys in keys.items()
], ],
consumeErrors=True consumeErrors=True
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
defer.returnValue(keys) defer.returnValue(keys)
@ -560,14 +566,14 @@ class Keyring(object):
server_name = response_json["server_name"] server_name = response_json["server_name"]
if only_from_server: if only_from_server:
if server_name != from_server: if server_name != from_server:
raise ValueError( raise KeyLookupError(
"Expected a response for server %r not %r" % ( "Expected a response for server %r not %r" % (
from_server, server_name from_server, server_name
) )
) )
for key_id in response_json["signatures"].get(server_name, {}): for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]: if key_id not in response_json["verify_keys"]:
raise ValueError( raise KeyLookupError(
"Key response must include verification keys for all" "Key response must include verification keys for all"
" signatures" " signatures"
) )
@ -594,7 +600,7 @@ class Keyring(object):
response_keys.update(verify_keys) response_keys.update(verify_keys)
response_keys.update(old_verify_keys) response_keys.update(old_verify_keys)
yield defer.gatherResults( yield preserve_context_over_deferred(defer.gatherResults(
[ [
preserve_fn(self.store.store_server_keys_json)( preserve_fn(self.store.store_server_keys_json)(
server_name=server_name, server_name=server_name,
@ -607,7 +613,7 @@ class Keyring(object):
for key_id in updated_key_ids for key_id in updated_key_ids
], ],
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
results[server_name] = response_keys results[server_name] = response_keys
@ -635,15 +641,15 @@ class Keyring(object):
if ("signatures" not in response if ("signatures" not in response
or server_name not in response["signatures"]): or server_name not in response["signatures"]):
raise ValueError("Key response not signed by remote server") raise KeyLookupError("Key response not signed by remote server")
if "tls_certificate" not in response: if "tls_certificate" not in response:
raise ValueError("Key response missing TLS certificate") raise KeyLookupError("Key response missing TLS certificate")
tls_certificate_b64 = response["tls_certificate"] tls_certificate_b64 = response["tls_certificate"]
if encode_base64(x509_certificate_bytes) != tls_certificate_b64: if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
raise ValueError("TLS certificate doesn't match") raise KeyLookupError("TLS certificate doesn't match")
# Cache the result in the datastore. # Cache the result in the datastore.
@ -659,7 +665,7 @@ class Keyring(object):
for key_id in response["signatures"][server_name]: for key_id in response["signatures"][server_name]:
if key_id not in response["verify_keys"]: if key_id not in response["verify_keys"]:
raise ValueError( raise KeyLookupError(
"Key response must include verification keys for all" "Key response must include verification keys for all"
" signatures" " signatures"
) )
@ -696,7 +702,7 @@ class Keyring(object):
A deferred that completes when the keys are stored. A deferred that completes when the keys are stored.
""" """
# TODO(markjh): Store whether the keys have expired. # TODO(markjh): Store whether the keys have expired.
yield defer.gatherResults( yield preserve_context_over_deferred(defer.gatherResults(
[ [
preserve_fn(self.store.store_server_verify_key)( preserve_fn(self.store.store_server_verify_key)(
server_name, server_name, key.time_added, key server_name, server_name, key.time_added, key
@ -704,4 +710,4 @@ class Keyring(object):
for key_id, key in verify_keys.items() for key_id, key in verify_keys.items()
], ],
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)

View File

@ -88,6 +88,8 @@ def prune_event(event):
if "age_ts" in event.unsigned: if "age_ts" in event.unsigned:
allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"] allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"]
if "replaces_state" in event.unsigned:
allowed_fields["unsigned"]["replaces_state"] = event.unsigned["replaces_state"]
return type(event)( return type(event)(
allowed_fields, allowed_fields,

View File

@ -23,6 +23,7 @@ from synapse.crypto.event_signing import check_event_content_hash
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
import logging import logging
@ -102,10 +103,10 @@ class FederationBase(object):
warn, pdu warn, pdu
) )
valid_pdus = yield defer.gatherResults( valid_pdus = yield preserve_context_over_deferred(defer.gatherResults(
deferreds, deferreds,
consumeErrors=True consumeErrors=True
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
if include_none: if include_none:
defer.returnValue(valid_pdus) defer.returnValue(valid_pdus)
@ -129,7 +130,7 @@ class FederationBase(object):
for pdu in pdus for pdu in pdus
] ]
deferreds = self.keyring.verify_json_objects_for_server([ deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([
(p.origin, p.get_pdu_json()) (p.origin, p.get_pdu_json())
for p in redacted_pdus for p in redacted_pdus
]) ])

View File

@ -27,6 +27,7 @@ from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute from synapse.util.async import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
import synapse.metrics import synapse.metrics
@ -51,10 +52,34 @@ sent_edus_counter = metrics.register_counter("sent_edus")
sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"]) sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
PDU_RETRY_TIME_MS = 1 * 60 * 1000
class FederationClient(FederationBase): class FederationClient(FederationBase):
def __init__(self, hs): def __init__(self, hs):
super(FederationClient, self).__init__(hs) super(FederationClient, self).__init__(hs)
self.pdu_destination_tried = {}
self._clock.looping_call(
self._clear_tried_cache, 60 * 1000,
)
def _clear_tried_cache(self):
"""Clear pdu_destination_tried cache"""
now = self._clock.time_msec()
old_dict = self.pdu_destination_tried
self.pdu_destination_tried = {}
for event_id, destination_dict in old_dict.items():
destination_dict = {
dest: time
for dest, time in destination_dict.items()
if time + PDU_RETRY_TIME_MS > now
}
if destination_dict:
self.pdu_destination_tried[event_id] = destination_dict
def start_get_pdu_cache(self): def start_get_pdu_cache(self):
self._get_pdu_cache = ExpiringCache( self._get_pdu_cache = ExpiringCache(
cache_name="get_pdu_cache", cache_name="get_pdu_cache",
@ -201,10 +226,10 @@ class FederationClient(FederationBase):
] ]
# FIXME: We should handle signature failures more gracefully. # FIXME: We should handle signature failures more gracefully.
pdus[:] = yield defer.gatherResults( pdus[:] = yield preserve_context_over_deferred(defer.gatherResults(
self._check_sigs_and_hashes(pdus), self._check_sigs_and_hashes(pdus),
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
defer.returnValue(pdus) defer.returnValue(pdus)
@ -240,8 +265,15 @@ class FederationClient(FederationBase):
if ev: if ev:
defer.returnValue(ev) defer.returnValue(ev)
pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
pdu = None pdu = None
for destination in destinations: for destination in destinations:
now = self._clock.time_msec()
last_attempt = pdu_attempts.get(destination, 0)
if last_attempt + PDU_RETRY_TIME_MS > now:
continue
try: try:
limiter = yield get_retry_limiter( limiter = yield get_retry_limiter(
destination, destination,
@ -269,25 +301,19 @@ class FederationClient(FederationBase):
break break
pdu_attempts[destination] = now
except SynapseError as e: except SynapseError as e:
logger.info( logger.info(
"Failed to get PDU %s from %s because %s", "Failed to get PDU %s from %s because %s",
event_id, destination, e, event_id, destination, e,
) )
continue
except CodeMessageException as e:
if 400 <= e.code < 500:
raise
logger.info(
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
)
continue
except NotRetryingDestination as e: except NotRetryingDestination as e:
logger.info(e.message) logger.info(e.message)
continue continue
except Exception as e: except Exception as e:
pdu_attempts[destination] = now
logger.info( logger.info(
"Failed to get PDU %s from %s because %s", "Failed to get PDU %s from %s because %s",
event_id, destination, e, event_id, destination, e,
@ -406,7 +432,7 @@ class FederationClient(FederationBase):
events and the second is a list of event ids that we failed to fetch. events and the second is a list of event ids that we failed to fetch.
""" """
if return_local: if return_local:
seen_events = yield self.store.get_events(event_ids) seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
signed_events = seen_events.values() signed_events = seen_events.values()
else: else:
seen_events = yield self.store.have_events(event_ids) seen_events = yield self.store.have_events(event_ids)
@ -432,14 +458,16 @@ class FederationClient(FederationBase):
batch = set(missing_events[i:i + batch_size]) batch = set(missing_events[i:i + batch_size])
deferreds = [ deferreds = [
self.get_pdu( preserve_fn(self.get_pdu)(
destinations=random_server_list(), destinations=random_server_list(),
event_id=e_id, event_id=e_id,
) )
for e_id in batch for e_id in batch
] ]
res = yield defer.DeferredList(deferreds, consumeErrors=True) res = yield preserve_context_over_deferred(
defer.DeferredList(deferreds, consumeErrors=True)
)
for success, result in res: for success, result in res:
if success: if success:
signed_events.append(result) signed_events.append(result)
@ -828,14 +856,16 @@ class FederationClient(FederationBase):
return srvs return srvs
deferreds = [ deferreds = [
self.get_pdu( preserve_fn(self.get_pdu)(
destinations=random_server_list(), destinations=random_server_list(),
event_id=e_id, event_id=e_id,
) )
for e_id, depth in ordered_missing[:limit - len(signed_events)] for e_id, depth in ordered_missing[:limit - len(signed_events)]
] ]
res = yield defer.DeferredList(deferreds, consumeErrors=True) res = yield preserve_context_over_deferred(
defer.DeferredList(deferreds, consumeErrors=True)
)
for (result, val), (e_id, _) in zip(res, ordered_missing): for (result, val), (e_id, _) in zip(res, ordered_missing):
if result and val: if result and val:
signed_events.append(val) signed_events.append(val)

View File

@ -21,11 +21,11 @@ from .units import Transaction
from synapse.api.errors import HttpResponseException from synapse.api.errors import HttpResponseException
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.retryutils import ( from synapse.util.retryutils import (
get_retry_limiter, NotRetryingDestination, get_retry_limiter, NotRetryingDestination,
) )
from synapse.util.metrics import measure_func
import synapse.metrics import synapse.metrics
import logging import logging
@ -51,7 +51,7 @@ class TransactionQueue(object):
self.transport_layer = transport_layer self.transport_layer = transport_layer
self._clock = hs.get_clock() self.clock = hs.get_clock()
# Is a mapping from destinations -> deferreds. Used to keep track # Is a mapping from destinations -> deferreds. Used to keep track
# of which destinations have transactions in flight and when they are # of which destinations have transactions in flight and when they are
@ -82,7 +82,7 @@ class TransactionQueue(object):
self.pending_failures_by_dest = {} self.pending_failures_by_dest = {}
# HACK to get unique tx id # HACK to get unique tx id
self._next_txn_id = int(self._clock.time_msec()) self._next_txn_id = int(self.clock.time_msec())
def can_send_to(self, destination): def can_send_to(self, destination):
"""Can we send messages to the given server? """Can we send messages to the given server?
@ -119,266 +119,215 @@ class TransactionQueue(object):
if not destinations: if not destinations:
return return
deferreds = []
for destination in destinations: for destination in destinations:
deferred = defer.Deferred()
self.pending_pdus_by_dest.setdefault(destination, []).append( self.pending_pdus_by_dest.setdefault(destination, []).append(
(pdu, deferred, order) (pdu, order)
) )
def chain(failure): preserve_context_over_fn(
if not deferred.called: self._attempt_new_transaction, destination
deferred.errback(failure) )
def log_failure(f):
logger.warn("Failed to send pdu to %s: %s", destination, f.value)
deferred.addErrback(log_failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(chain)
deferreds.append(deferred)
# NO inlineCallbacks
def enqueue_edu(self, edu): def enqueue_edu(self, edu):
destination = edu.destination destination = edu.destination
if not self.can_send_to(destination): if not self.can_send_to(destination):
return return
deferred = defer.Deferred() self.pending_edus_by_dest.setdefault(destination, []).append(edu)
self.pending_edus_by_dest.setdefault(destination, []).append(
(edu, deferred) preserve_context_over_fn(
self._attempt_new_transaction, destination
) )
def chain(failure):
if not deferred.called:
deferred.errback(failure)
def log_failure(f):
logger.warn("Failed to send edu to %s: %s", destination, f.value)
deferred.addErrback(log_failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(chain)
return deferred
@defer.inlineCallbacks
def enqueue_failure(self, failure, destination): def enqueue_failure(self, failure, destination):
if destination == self.server_name or destination == "localhost": if destination == self.server_name or destination == "localhost":
return return
deferred = defer.Deferred()
if not self.can_send_to(destination): if not self.can_send_to(destination):
return return
self.pending_failures_by_dest.setdefault( self.pending_failures_by_dest.setdefault(
destination, [] destination, []
).append( ).append(failure)
(failure, deferred)
preserve_context_over_fn(
self._attempt_new_transaction, destination
) )
def chain(f):
if not deferred.called:
deferred.errback(f)
def log_failure(f):
logger.warn("Failed to send failure to %s: %s", destination, f.value)
deferred.addErrback(log_failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(chain)
yield deferred
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function
def _attempt_new_transaction(self, destination): def _attempt_new_transaction(self, destination):
yield run_on_reactor() yield run_on_reactor()
while True:
# list of (pending_pdu, deferred, order)
if destination in self.pending_transactions:
# XXX: pending_transactions can get stuck on by a never-ending
# request at which point pending_pdus_by_dest just keeps growing.
# we need application-layer timeouts of some flavour of these
# requests
logger.debug(
"TX [%s] Transaction already in progress",
destination
)
return
# list of (pending_pdu, deferred, order) pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
if destination in self.pending_transactions: pending_edus = self.pending_edus_by_dest.pop(destination, [])
# XXX: pending_transactions can get stuck on by a never-ending pending_failures = self.pending_failures_by_dest.pop(destination, [])
# request at which point pending_pdus_by_dest just keeps growing.
# we need application-layer timeouts of some flavour of these if pending_pdus:
# requests logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
logger.debug( destination, len(pending_pdus))
"TX [%s] Transaction already in progress",
destination if not pending_pdus and not pending_edus and not pending_failures:
logger.debug("TX [%s] Nothing to send", destination)
return
yield self._send_new_transaction(
destination, pending_pdus, pending_edus, pending_failures
) )
return
pending_pdus = self.pending_pdus_by_dest.pop(destination, []) @measure_func("_send_new_transaction")
pending_edus = self.pending_edus_by_dest.pop(destination, []) @defer.inlineCallbacks
pending_failures = self.pending_failures_by_dest.pop(destination, []) def _send_new_transaction(self, destination, pending_pdus, pending_edus,
pending_failures):
if pending_pdus:
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
if not pending_pdus and not pending_edus and not pending_failures:
logger.debug("TX [%s] Nothing to send", destination)
return
try:
self.pending_transactions[destination] = 1
logger.debug("TX [%s] _attempt_new_transaction", destination)
# Sort based on the order field # Sort based on the order field
pending_pdus.sort(key=lambda t: t[2]) pending_pdus.sort(key=lambda t: t[1])
pdus = [x[0] for x in pending_pdus] pdus = [x[0] for x in pending_pdus]
edus = [x[0] for x in pending_edus] edus = pending_edus
failures = [x[0].get_dict() for x in pending_failures] failures = [x.get_dict() for x in pending_failures]
deferreds = [
x[1]
for x in pending_pdus + pending_edus + pending_failures
]
txn_id = str(self._next_txn_id) try:
self.pending_transactions[destination] = 1
limiter = yield get_retry_limiter( logger.debug("TX [%s] _attempt_new_transaction", destination)
destination,
self._clock,
self.store,
)
logger.debug( txn_id = str(self._next_txn_id)
"TX [%s] {%s} Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)",
destination, txn_id,
len(pending_pdus),
len(pending_edus),
len(pending_failures)
)
logger.debug("TX [%s] Persisting transaction...", destination) limiter = yield get_retry_limiter(
destination,
transaction = Transaction.create_new( self.clock,
origin_server_ts=int(self._clock.time_msec()), self.store,
transaction_id=txn_id,
origin=self.server_name,
destination=destination,
pdus=pdus,
edus=edus,
pdu_failures=failures,
)
self._next_txn_id += 1
yield self.transaction_actions.prepare_to_send(transaction)
logger.debug("TX [%s] Persisted transaction", destination)
logger.info(
"TX [%s] {%s} Sending transaction [%s],"
" (PDUs: %d, EDUs: %d, failures: %d)",
destination, txn_id,
transaction.transaction_id,
len(pending_pdus),
len(pending_edus),
len(pending_failures),
)
with limiter:
# Actually send the transaction
# FIXME (erikj): This is a bit of a hack to make the Pdu age
# keys work
def json_data_cb():
data = transaction.get_dict()
now = int(self._clock.time_msec())
if "pdus" in data:
for p in data["pdus"]:
if "age_ts" in p:
unsigned = p.setdefault("unsigned", {})
unsigned["age"] = now - int(p["age_ts"])
del p["age_ts"]
return data
try:
response = yield self.transport_layer.send_transaction(
transaction, json_data_cb
)
code = 200
if response:
for e_id, r in response.get("pdus", {}).items():
if "error" in r:
logger.warn(
"Transaction returned error for %s: %s",
e_id, r,
)
except HttpResponseException as e:
code = e.code
response = e.response
logger.info(
"TX [%s] {%s} got %d response",
destination, txn_id, code
) )
logger.debug("TX [%s] Sent transaction", destination) logger.debug(
logger.debug("TX [%s] Marking as delivered...", destination) "TX [%s] {%s} Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)",
destination, txn_id,
len(pending_pdus),
len(pending_edus),
len(pending_failures)
)
yield self.transaction_actions.delivered( logger.debug("TX [%s] Persisting transaction...", destination)
transaction, code, response
)
logger.debug("TX [%s] Marked as delivered", destination) transaction = Transaction.create_new(
origin_server_ts=int(self.clock.time_msec()),
transaction_id=txn_id,
origin=self.server_name,
destination=destination,
pdus=pdus,
edus=edus,
pdu_failures=failures,
)
logger.debug("TX [%s] Yielding to callbacks...", destination) self._next_txn_id += 1
for deferred in deferreds: yield self.transaction_actions.prepare_to_send(transaction)
if code == 200:
deferred.callback(None)
else:
deferred.errback(RuntimeError("Got status %d" % code))
# Ensures we don't continue until all callbacks on that logger.debug("TX [%s] Persisted transaction", destination)
# deferred have fired logger.info(
try: "TX [%s] {%s} Sending transaction [%s],"
yield deferred " (PDUs: %d, EDUs: %d, failures: %d)",
except: destination, txn_id,
pass transaction.transaction_id,
len(pending_pdus),
len(pending_edus),
len(pending_failures),
)
logger.debug("TX [%s] Yielded to callbacks", destination) with limiter:
except NotRetryingDestination: # Actually send the transaction
logger.info(
"TX [%s] not ready for retry yet - "
"dropping transaction for now",
destination,
)
except RuntimeError as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
except Exception as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
for deferred in deferreds: # FIXME (erikj): This is a bit of a hack to make the Pdu age
if not deferred.called: # keys work
deferred.errback(e) def json_data_cb():
data = transaction.get_dict()
now = int(self.clock.time_msec())
if "pdus" in data:
for p in data["pdus"]:
if "age_ts" in p:
unsigned = p.setdefault("unsigned", {})
unsigned["age"] = now - int(p["age_ts"])
del p["age_ts"]
return data
finally: try:
# We want to be *very* sure we delete this after we stop processing response = yield self.transport_layer.send_transaction(
self.pending_transactions.pop(destination, None) transaction, json_data_cb
)
code = 200
# Check to see if there is anything else to send. if response:
self._attempt_new_transaction(destination) for e_id, r in response.get("pdus", {}).items():
if "error" in r:
logger.warn(
"Transaction returned error for %s: %s",
e_id, r,
)
except HttpResponseException as e:
code = e.code
response = e.response
logger.info(
"TX [%s] {%s} got %d response",
destination, txn_id, code
)
logger.debug("TX [%s] Sent transaction", destination)
logger.debug("TX [%s] Marking as delivered...", destination)
yield self.transaction_actions.delivered(
transaction, code, response
)
logger.debug("TX [%s] Marked as delivered", destination)
if code != 200:
for p in pdus:
logger.info(
"Failed to send event %s to %s", p.event_id, destination
)
except NotRetryingDestination:
logger.info(
"TX [%s] not ready for retry yet - "
"dropping transaction for now",
destination,
)
except RuntimeError as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
for p in pdus:
logger.info("Failed to send event %s to %s", p.event_id, destination)
except Exception as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
for p in pdus:
logger.info("Failed to send event %s to %s", p.event_id, destination)
finally:
# We want to be *very* sure we delete this after we stop processing
self.pending_transactions.pop(destination, None)

View File

@ -19,7 +19,6 @@ from .room import (
) )
from .room_member import RoomMemberHandler from .room_member import RoomMemberHandler
from .message import MessageHandler from .message import MessageHandler
from .events import EventStreamHandler, EventHandler
from .federation import FederationHandler from .federation import FederationHandler
from .profile import ProfileHandler from .profile import ProfileHandler
from .directory import DirectoryHandler from .directory import DirectoryHandler
@ -53,8 +52,6 @@ class Handlers(object):
self.message_handler = MessageHandler(hs) self.message_handler = MessageHandler(hs)
self.room_creation_handler = RoomCreationHandler(hs) self.room_creation_handler = RoomCreationHandler(hs)
self.room_member_handler = RoomMemberHandler(hs) self.room_member_handler = RoomMemberHandler(hs)
self.event_stream_handler = EventStreamHandler(hs)
self.event_handler = EventHandler(hs)
self.federation_handler = FederationHandler(hs) self.federation_handler = FederationHandler(hs)
self.profile_handler = ProfileHandler(hs) self.profile_handler = ProfileHandler(hs)
self.directory_handler = DirectoryHandler(hs) self.directory_handler = DirectoryHandler(hs)

View File

@ -16,7 +16,8 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.appservice import ApplicationService from synapse.util.metrics import Measure
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
import logging import logging
@ -42,36 +43,73 @@ class ApplicationServicesHandler(object):
self.appservice_api = hs.get_application_service_api() self.appservice_api = hs.get_application_service_api()
self.scheduler = hs.get_application_service_scheduler() self.scheduler = hs.get_application_service_scheduler()
self.started_scheduler = False self.started_scheduler = False
self.clock = hs.get_clock()
self.notify_appservices = hs.config.notify_appservices
self.current_max = 0
self.is_processing = False
@defer.inlineCallbacks @defer.inlineCallbacks
def notify_interested_services(self, event): def notify_interested_services(self, current_id):
"""Notifies (pushes) all application services interested in this event. """Notifies (pushes) all application services interested in this event.
Pushing is done asynchronously, so this method won't block for any Pushing is done asynchronously, so this method won't block for any
prolonged length of time. prolonged length of time.
Args: Args:
event(Event): The event to push out to interested services. current_id(int): The current maximum ID.
""" """
# Gather interested services services = yield self.store.get_app_services()
services = yield self._get_services_for_event(event) if not services or not self.notify_appservices:
if len(services) == 0: return
return # no services need notifying
# Do we know this user exists? If not, poke the user query API for self.current_max = max(self.current_max, current_id)
# all services which match that user regex. This needs to block as these if self.is_processing:
# user queries need to be made BEFORE pushing the event. return
yield self._check_user_exists(event.sender)
if event.type == EventTypes.Member:
yield self._check_user_exists(event.state_key)
if not self.started_scheduler: with Measure(self.clock, "notify_interested_services"):
self.scheduler.start().addErrback(log_failure) self.is_processing = True
self.started_scheduler = True try:
upper_bound = self.current_max
limit = 100
while True:
upper_bound, events = yield self.store.get_new_events_for_appservice(
upper_bound, limit
)
# Fork off pushes to these services if not events:
for service in services: break
self.scheduler.submit_event_for_as(service, event)
for event in events:
# Gather interested services
services = yield self._get_services_for_event(event)
if len(services) == 0:
continue # no services need notifying
# Do we know this user exists? If not, poke the user
# query API for all services which match that user regex.
# This needs to block as these user queries need to be
# made BEFORE pushing the event.
yield self._check_user_exists(event.sender)
if event.type == EventTypes.Member:
yield self._check_user_exists(event.state_key)
if not self.started_scheduler:
self.scheduler.start().addErrback(log_failure)
self.started_scheduler = True
# Fork off pushes to these services
for service in services:
preserve_fn(self.scheduler.submit_event_for_as)(
service, event
)
yield self.store.set_appservice_last_pos(upper_bound)
if len(events) < limit:
break
finally:
self.is_processing = False
@defer.inlineCallbacks @defer.inlineCallbacks
def query_user_exists(self, user_id): def query_user_exists(self, user_id):
@ -104,11 +142,12 @@ class ApplicationServicesHandler(object):
association can be found. association can be found.
""" """
room_alias_str = room_alias.to_string() room_alias_str = room_alias.to_string()
alias_query_services = yield self._get_services_for_event( services = yield self.store.get_app_services()
event=None, alias_query_services = [
restrict_to=ApplicationService.NS_ALIASES, s for s in services if (
alias_list=[room_alias_str] s.is_interested_in_alias(room_alias_str)
) )
]
for alias_service in alias_query_services: for alias_service in alias_query_services:
is_known_alias = yield self.appservice_api.query_alias( is_known_alias = yield self.appservice_api.query_alias(
alias_service, room_alias_str alias_service, room_alias_str
@ -121,34 +160,35 @@ class ApplicationServicesHandler(object):
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_services_for_event(self, event, restrict_to="", alias_list=None): def query_3pe(self, kind, protocol, fields):
services = yield self._get_services_for_3pn(protocol)
results = yield preserve_context_over_deferred(defer.DeferredList([
preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields)
for service in services
], consumeErrors=True))
ret = []
for (success, result) in results:
if success:
ret.extend(result)
defer.returnValue(ret)
@defer.inlineCallbacks
def _get_services_for_event(self, event):
"""Retrieve a list of application services interested in this event. """Retrieve a list of application services interested in this event.
Args: Args:
event(Event): The event to check. Can be None if alias_list is not. event(Event): The event to check. Can be None if alias_list is not.
restrict_to(str): The namespace to restrict regex tests to.
alias_list: A list of aliases to get services for. If None, this
list is obtained from the database.
Returns: Returns:
list<ApplicationService>: A list of services interested in this list<ApplicationService>: A list of services interested in this
event based on the service regex. event based on the service regex.
""" """
member_list = None
if hasattr(event, "room_id"):
# We need to know the aliases associated with this event.room_id,
# if any.
if not alias_list:
alias_list = yield self.store.get_aliases_for_room(
event.room_id
)
# We need to know the members associated with this event.room_id,
# if any.
member_list = yield self.store.get_users_in_room(event.room_id)
services = yield self.store.get_app_services() services = yield self.store.get_app_services()
interested_list = [ interested_list = [
s for s in services if ( s for s in services if (
s.is_interested(event, restrict_to, alias_list, member_list) yield s.is_interested(event, self.store)
) )
] ]
defer.returnValue(interested_list) defer.returnValue(interested_list)
@ -163,6 +203,14 @@ class ApplicationServicesHandler(object):
] ]
defer.returnValue(interested_list) defer.returnValue(interested_list)
@defer.inlineCallbacks
def _get_services_for_3pn(self, protocol):
services = yield self.store.get_app_services()
interested_list = [
s for s in services if s.is_interested_in_protocol(protocol)
]
defer.returnValue(interested_list)
@defer.inlineCallbacks @defer.inlineCallbacks
def _is_unknown_user(self, user_id): def _is_unknown_user(self, user_id):
if not self.is_mine_id(user_id): if not self.is_mine_id(user_id):

View File

@ -70,11 +70,11 @@ class AuthHandler(BaseHandler):
self.ldap_uri = hs.config.ldap_uri self.ldap_uri = hs.config.ldap_uri
self.ldap_start_tls = hs.config.ldap_start_tls self.ldap_start_tls = hs.config.ldap_start_tls
self.ldap_base = hs.config.ldap_base self.ldap_base = hs.config.ldap_base
self.ldap_filter = hs.config.ldap_filter
self.ldap_attributes = hs.config.ldap_attributes self.ldap_attributes = hs.config.ldap_attributes
if self.ldap_mode == LDAPMode.SEARCH: if self.ldap_mode == LDAPMode.SEARCH:
self.ldap_bind_dn = hs.config.ldap_bind_dn self.ldap_bind_dn = hs.config.ldap_bind_dn
self.ldap_bind_password = hs.config.ldap_bind_password self.ldap_bind_password = hs.config.ldap_bind_password
self.ldap_filter = hs.config.ldap_filter
self.hs = hs # FIXME better possibility to access registrationHandler later? self.hs = hs # FIXME better possibility to access registrationHandler later?
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
@ -660,7 +660,7 @@ class AuthHandler(BaseHandler):
else: else:
logger.warn( logger.warn(
"ldap registration failed: unexpected (%d!=1) amount of results", "ldap registration failed: unexpected (%d!=1) amount of results",
len(result) len(conn.response)
) )
defer.returnValue(False) defer.returnValue(False)
@ -719,13 +719,14 @@ class AuthHandler(BaseHandler):
return macaroon.serialize() return macaroon.serialize()
def validate_short_term_login_token_and_get_user_id(self, login_token): def validate_short_term_login_token_and_get_user_id(self, login_token):
auth_api = self.hs.get_auth()
try: try:
macaroon = pymacaroons.Macaroon.deserialize(login_token) macaroon = pymacaroons.Macaroon.deserialize(login_token)
auth_api = self.hs.get_auth() user_id = auth_api.get_user_id_from_macaroon(macaroon)
auth_api.validate_macaroon(macaroon, "login", True) auth_api.validate_macaroon(macaroon, "login", True, user_id)
return self.get_user_from_macaroon(macaroon) return user_id
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): except Exception:
raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN) raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
def _generate_base_macaroon(self, user_id): def _generate_base_macaroon(self, user_id):
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
@ -736,21 +737,11 @@ class AuthHandler(BaseHandler):
macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon return macaroon
def get_user_from_macaroon(self, macaroon):
user_prefix = "user_id = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
return caveat.caveat_id[len(user_prefix):]
raise AuthError(
self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token",
errcode=Codes.UNKNOWN_TOKEN
)
@defer.inlineCallbacks @defer.inlineCallbacks
def set_password(self, user_id, newpassword, requester=None): def set_password(self, user_id, newpassword, requester=None):
password_hash = self.hash(newpassword) password_hash = self.hash(newpassword)
except_access_token_ids = [requester.access_token_id] if requester else [] except_access_token_id = requester.access_token_id if requester else None
try: try:
yield self.store.user_set_password_hash(user_id, password_hash) yield self.store.user_set_password_hash(user_id, password_hash)
@ -759,10 +750,10 @@ class AuthHandler(BaseHandler):
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
raise e raise e
yield self.store.user_delete_access_tokens( yield self.store.user_delete_access_tokens(
user_id, except_access_token_ids user_id, except_access_token_id
) )
yield self.hs.get_pusherpool().remove_pushers_by_user( yield self.hs.get_pusherpool().remove_pushers_by_user(
user_id, except_access_token_ids user_id, except_access_token_id
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -26,7 +26,9 @@ from synapse.api.errors import (
from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext, preserve_fn from synapse.util.logcontext import (
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred
)
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.frozenutils import unfreeze from synapse.util.frozenutils import unfreeze
@ -249,7 +251,7 @@ class FederationHandler(BaseHandler):
if ev.type != EventTypes.Member: if ev.type != EventTypes.Member:
continue continue
try: try:
domain = UserID.from_string(ev.state_key).domain domain = get_domain_from_id(ev.state_key)
except: except:
continue continue
@ -274,7 +276,7 @@ class FederationHandler(BaseHandler):
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def backfill(self, dest, room_id, limit, extremities=[]): def backfill(self, dest, room_id, limit, extremities):
""" Trigger a backfill request to `dest` for the given `room_id` """ Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. This may return This will attempt to get more events from the remote. This may return
@ -284,9 +286,6 @@ class FederationHandler(BaseHandler):
if dest == self.server_name: if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.") raise SynapseError(400, "Can't backfill from self.")
if not extremities:
extremities = yield self.store.get_oldest_events_in_room(room_id)
events = yield self.replication_layer.backfill( events = yield self.replication_layer.backfill(
dest, dest,
room_id, room_id,
@ -364,9 +363,9 @@ class FederationHandler(BaseHandler):
missing_auth - failed_to_fetch missing_auth - failed_to_fetch
) )
results = yield defer.gatherResults( results = yield preserve_context_over_deferred(defer.gatherResults(
[ [
self.replication_layer.get_pdu( preserve_fn(self.replication_layer.get_pdu)(
[dest], [dest],
event_id, event_id,
outlier=True, outlier=True,
@ -375,10 +374,10 @@ class FederationHandler(BaseHandler):
for event_id in missing_auth - failed_to_fetch for event_id in missing_auth - failed_to_fetch
], ],
consumeErrors=True consumeErrors=True
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results}) auth_events.update({a.event_id: a for a in results if a})
required_auth.update( required_auth.update(
a_id for event in results for a_id, _ in event.auth_events a_id for event in results for a_id, _ in event.auth_events if event
) )
missing_auth = required_auth - set(auth_events) missing_auth = required_auth - set(auth_events)
@ -455,6 +454,10 @@ class FederationHandler(BaseHandler):
) )
max_depth = sorted_extremeties_tuple[0][1] max_depth = sorted_extremeties_tuple[0][1]
# We don't want to specify too many extremities as it causes the backfill
# request URI to be too long.
extremities = dict(sorted_extremeties_tuple[:5])
if current_depth > max_depth: if current_depth > max_depth:
logger.debug( logger.debug(
"Not backfilling as we don't need to. %d < %d", "Not backfilling as we don't need to. %d < %d",
@ -551,10 +554,10 @@ class FederationHandler(BaseHandler):
event_ids = list(extremities.keys()) event_ids = list(extremities.keys())
states = yield defer.gatherResults([ states = yield preserve_context_over_deferred(defer.gatherResults([
self.state_handler.resolve_state_groups(room_id, [e]) preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e])
for e in event_ids for e in event_ids
]) ]))
states = dict(zip(event_ids, [s[1] for s in states])) states = dict(zip(event_ids, [s[1] for s in states]))
for e_id, _ in sorted_extremeties_tuple: for e_id, _ in sorted_extremeties_tuple:
@ -1093,16 +1096,17 @@ class FederationHandler(BaseHandler):
) )
if event: if event:
# FIXME: This is a temporary work around where we occasionally if self.hs.is_mine_id(event.event_id):
# return events slightly differently than when they were # FIXME: This is a temporary work around where we occasionally
# originally signed # return events slightly differently than when they were
event.signatures.update( # originally signed
compute_event_signature( event.signatures.update(
event, compute_event_signature(
self.hs.hostname, event,
self.hs.config.signing_key[0] self.hs.hostname,
self.hs.config.signing_key[0]
)
) )
)
if do_auth: if do_auth:
in_room = yield self.auth.check_host_in_room( in_room = yield self.auth.check_host_in_room(
@ -1112,6 +1116,12 @@ class FederationHandler(BaseHandler):
if not in_room: if not in_room:
raise AuthError(403, "Host not in room.") raise AuthError(403, "Host not in room.")
events = yield self._filter_events_for_server(
origin, event.room_id, [event]
)
event = events[0]
defer.returnValue(event) defer.returnValue(event)
else: else:
defer.returnValue(None) defer.returnValue(None)
@ -1158,9 +1168,9 @@ class FederationHandler(BaseHandler):
a bunch of outliers, but not a chunk of individual events that depend a bunch of outliers, but not a chunk of individual events that depend
on each other for state calculations. on each other for state calculations.
""" """
contexts = yield defer.gatherResults( contexts = yield preserve_context_over_deferred(defer.gatherResults(
[ [
self._prep_event( preserve_fn(self._prep_event)(
origin, origin,
ev_info["event"], ev_info["event"],
state=ev_info.get("state"), state=ev_info.get("state"),
@ -1168,7 +1178,7 @@ class FederationHandler(BaseHandler):
) )
for ev_info in event_infos for ev_info in event_infos
] ]
) ))
yield self.store.persist_events( yield self.store.persist_events(
[ [
@ -1452,9 +1462,9 @@ class FederationHandler(BaseHandler):
# Do auth conflict res. # Do auth conflict res.
logger.info("Different auth: %s", different_auth) logger.info("Different auth: %s", different_auth)
different_events = yield defer.gatherResults( different_events = yield preserve_context_over_deferred(defer.gatherResults(
[ [
self.store.get_event( preserve_fn(self.store.get_event)(
d, d,
allow_none=True, allow_none=True,
allow_rejected=False, allow_rejected=False,
@ -1463,7 +1473,7 @@ class FederationHandler(BaseHandler):
if d in have_events and not have_events[d] if d in have_events and not have_events[d]
], ],
consumeErrors=True consumeErrors=True
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
if different_events: if different_events:
local_view = dict(auth_events) local_view = dict(auth_events)

View File

@ -28,7 +28,8 @@ from synapse.types import (
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock
from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from ._base import BaseHandler from ._base import BaseHandler
@ -502,15 +503,17 @@ class MessageHandler(BaseHandler):
lambda states: states[event.event_id] lambda states: states[event.event_id]
) )
(messages, token), current_state = yield defer.gatherResults( (messages, token), current_state = yield preserve_context_over_deferred(
[ defer.gatherResults(
self.store.get_recent_events_for_room( [
event.room_id, preserve_fn(self.store.get_recent_events_for_room)(
limit=limit, event.room_id,
end_token=room_end_token, limit=limit,
), end_token=room_end_token,
deferred_room_state, ),
] deferred_room_state,
]
)
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
messages = yield filter_events_for_client( messages = yield filter_events_for_client(
@ -719,9 +722,9 @@ class MessageHandler(BaseHandler):
presence, receipts, (messages, token) = yield defer.gatherResults( presence, receipts, (messages, token) = yield defer.gatherResults(
[ [
get_presence(), preserve_fn(get_presence)(),
get_receipts(), preserve_fn(get_receipts)(),
self.store.get_recent_events_for_room( preserve_fn(self.store.get_recent_events_for_room)(
room_id, room_id,
limit=limit, limit=limit,
end_token=now_token.room_key, end_token=now_token.room_key,
@ -755,6 +758,7 @@ class MessageHandler(BaseHandler):
defer.returnValue(ret) defer.returnValue(ret)
@measure_func("_create_new_client_event")
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_new_client_event(self, builder, prev_event_ids=None): def _create_new_client_event(self, builder, prev_event_ids=None):
if prev_event_ids: if prev_event_ids:
@ -806,6 +810,7 @@ class MessageHandler(BaseHandler):
(event, context,) (event, context,)
) )
@measure_func("handle_new_client_event")
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_new_client_event( def handle_new_client_event(
self, self,
@ -934,7 +939,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _notify(): def _notify():
yield run_on_reactor() yield run_on_reactor()
self.notifier.on_new_room_event( yield self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, event, event_stream_id, max_stream_id,
extra_users=extra_users extra_users=extra_users
) )
@ -944,6 +949,6 @@ class MessageHandler(BaseHandler):
# If invite, remove room_state from unsigned before sending. # If invite, remove room_state from unsigned before sending.
event.unsigned.pop("invite_room_state", None) event.unsigned.pop("invite_room_state", None)
federation_handler.handle_new_event( preserve_fn(federation_handler.handle_new_event)(
event, destinations=destinations, event, destinations=destinations,
) )

View File

@ -503,7 +503,7 @@ class PresenceHandler(object):
defer.returnValue(states) defer.returnValue(states)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_interested_parties(self, states): def _get_interested_parties(self, states, calculate_remote_hosts=True):
"""Given a list of states return which entities (rooms, users, servers) """Given a list of states return which entities (rooms, users, servers)
are interested in the given states. are interested in the given states.
@ -526,14 +526,15 @@ class PresenceHandler(object):
users_to_states.setdefault(state.user_id, []).append(state) users_to_states.setdefault(state.user_id, []).append(state)
hosts_to_states = {} hosts_to_states = {}
for room_id, states in room_ids_to_states.items(): if calculate_remote_hosts:
local_states = filter(lambda s: self.is_mine_id(s.user_id), states) for room_id, states in room_ids_to_states.items():
if not local_states: local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
continue if not local_states:
continue
hosts = yield self.store.get_joined_hosts_for_room(room_id) hosts = yield self.store.get_joined_hosts_for_room(room_id)
for host in hosts: for host in hosts:
hosts_to_states.setdefault(host, []).extend(local_states) hosts_to_states.setdefault(host, []).extend(local_states)
for user_id, states in users_to_states.items(): for user_id, states in users_to_states.items():
local_states = filter(lambda s: self.is_mine_id(s.user_id), states) local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
@ -565,6 +566,16 @@ class PresenceHandler(object):
self._push_to_remotes(hosts_to_states) self._push_to_remotes(hosts_to_states)
@defer.inlineCallbacks
def notify_for_states(self, state, stream_id):
parties = yield self._get_interested_parties([state])
room_ids_to_states, users_to_states, hosts_to_states = parties
self.notifier.on_new_event(
"presence_key", stream_id, rooms=room_ids_to_states.keys(),
users=[UserID.from_string(u) for u in users_to_states.keys()]
)
def _push_to_remotes(self, hosts_to_states): def _push_to_remotes(self, hosts_to_states):
"""Sends state updates to remote servers. """Sends state updates to remote servers.
@ -672,7 +683,7 @@ class PresenceHandler(object):
]) ])
@defer.inlineCallbacks @defer.inlineCallbacks
def set_state(self, target_user, state): def set_state(self, target_user, state, ignore_status_msg=False):
"""Set the presence state of the user. """Set the presence state of the user.
""" """
status_msg = state.get("status_msg", None) status_msg = state.get("status_msg", None)
@ -689,10 +700,13 @@ class PresenceHandler(object):
prev_state = yield self.current_state_for_user(user_id) prev_state = yield self.current_state_for_user(user_id)
new_fields = { new_fields = {
"state": presence, "state": presence
"status_msg": status_msg if presence != PresenceState.OFFLINE else None
} }
if not ignore_status_msg:
msg = status_msg if presence != PresenceState.OFFLINE else None
new_fields["status_msg"] = msg
if presence == PresenceState.ONLINE: if presence == PresenceState.ONLINE:
new_fields["last_active_ts"] = self.clock.time_msec() new_fields["last_active_ts"] = self.clock.time_msec()

View File

@ -59,10 +59,13 @@ class RoomMemberHandler(BaseHandler):
prev_event_ids, prev_event_ids,
txn_id=None, txn_id=None,
ratelimit=True, ratelimit=True,
content=None,
): ):
if content is None:
content = {}
msg_handler = self.hs.get_handlers().message_handler msg_handler = self.hs.get_handlers().message_handler
content = {"membership": membership} content["membership"] = membership
if requester.is_guest: if requester.is_guest:
content["kind"] = "guest" content["kind"] = "guest"
@ -140,8 +143,9 @@ class RoomMemberHandler(BaseHandler):
remote_room_hosts=None, remote_room_hosts=None,
third_party_signed=None, third_party_signed=None,
ratelimit=True, ratelimit=True,
content=None,
): ):
key = (target, room_id,) key = (room_id,)
with (yield self.member_linearizer.queue(key)): with (yield self.member_linearizer.queue(key)):
result = yield self._update_membership( result = yield self._update_membership(
@ -153,6 +157,7 @@ class RoomMemberHandler(BaseHandler):
remote_room_hosts=remote_room_hosts, remote_room_hosts=remote_room_hosts,
third_party_signed=third_party_signed, third_party_signed=third_party_signed,
ratelimit=ratelimit, ratelimit=ratelimit,
content=content,
) )
defer.returnValue(result) defer.returnValue(result)
@ -168,7 +173,11 @@ class RoomMemberHandler(BaseHandler):
remote_room_hosts=None, remote_room_hosts=None,
third_party_signed=None, third_party_signed=None,
ratelimit=True, ratelimit=True,
content=None,
): ):
if content is None:
content = {}
effective_membership_state = action effective_membership_state = action
if action in ["kick", "unban"]: if action in ["kick", "unban"]:
effective_membership_state = "leave" effective_membership_state = "leave"
@ -218,7 +227,7 @@ class RoomMemberHandler(BaseHandler):
if inviter and not self.hs.is_mine(inviter): if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain) remote_room_hosts.append(inviter.domain)
content = {"membership": Membership.JOIN} content["membership"] = Membership.JOIN
profile = self.hs.get_handlers().profile_handler profile = self.hs.get_handlers().profile_handler
content["displayname"] = yield profile.get_displayname(target) content["displayname"] = yield profile.get_displayname(target)
@ -272,6 +281,7 @@ class RoomMemberHandler(BaseHandler):
txn_id=txn_id, txn_id=txn_id,
ratelimit=ratelimit, ratelimit=ratelimit,
prev_event_ids=latest_event_ids, prev_event_ids=latest_event_ids,
content=content,
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -464,10 +464,10 @@ class SyncHandler(object):
else: else:
state = {} state = {}
defer.returnValue({ defer.returnValue({
(e.type, e.state_key): e (e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(state.values()) for e in sync_config.filter_collection.filter_room_state(state.values())
}) })
@defer.inlineCallbacks @defer.inlineCallbacks
def unread_notifs_for_room_id(self, room_id, sync_config): def unread_notifs_for_room_id(self, room_id, sync_config):
@ -485,9 +485,9 @@ class SyncHandler(object):
) )
defer.returnValue(notifs) defer.returnValue(notifs)
# There is no new information in this period, so your notification # There is no new information in this period, so your notification
# count is whatever it was last time. # count is whatever it was last time.
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def generate_sync_result(self, sync_config, since_token=None, full_state=False): def generate_sync_result(self, sync_config, since_token=None, full_state=False):

View File

@ -16,7 +16,9 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError from synapse.api.errors import SynapseError, AuthError
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import (
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
)
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.types import UserID from synapse.types import UserID
@ -169,13 +171,13 @@ class TypingHandler(object):
deferreds = [] deferreds = []
for domain in domains: for domain in domains:
if domain == self.server_name: if domain == self.server_name:
self._push_update_local( preserve_fn(self._push_update_local)(
room_id=room_id, room_id=room_id,
user_id=user_id, user_id=user_id,
typing=typing typing=typing
) )
else: else:
deferreds.append(self.federation.send_edu( deferreds.append(preserve_fn(self.federation.send_edu)(
destination=domain, destination=domain,
edu_type="m.typing", edu_type="m.typing",
content={ content={
@ -185,7 +187,9 @@ class TypingHandler(object):
}, },
)) ))
yield defer.DeferredList(deferreds, consumeErrors=True) yield preserve_context_over_deferred(
defer.DeferredList(deferreds, consumeErrors=True)
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _recv_edu(self, origin, content): def _recv_edu(self, origin, content):

View File

@ -155,9 +155,7 @@ class MatrixFederationHttpClient(object):
time_out=timeout / 1000. if timeout else 60, time_out=timeout / 1000. if timeout else 60,
) )
response = yield preserve_context_over_fn( response = yield preserve_context_over_fn(send_request)
send_request,
)
log_result = "%d %s" % (response.code, response.phrase,) log_result = "%d %s" % (response.code, response.phrase,)
break break

View File

@ -19,6 +19,7 @@ from synapse.api.errors import (
) )
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches import intern_dict from synapse.util.caches import intern_dict
from synapse.util.metrics import Measure
import synapse.metrics import synapse.metrics
import synapse.events import synapse.events
@ -74,12 +75,12 @@ response_db_txn_duration = metrics.register_distribution(
_next_request_id = 0 _next_request_id = 0
def request_handler(report_metrics=True): def request_handler(include_metrics=False):
"""Decorator for ``wrap_request_handler``""" """Decorator for ``wrap_request_handler``"""
return lambda request_handler: wrap_request_handler(request_handler, report_metrics) return lambda request_handler: wrap_request_handler(request_handler, include_metrics)
def wrap_request_handler(request_handler, report_metrics): def wrap_request_handler(request_handler, include_metrics=False):
"""Wraps a method that acts as a request handler with the necessary logging """Wraps a method that acts as a request handler with the necessary logging
and exception handling. and exception handling.
@ -103,54 +104,56 @@ def wrap_request_handler(request_handler, report_metrics):
_next_request_id += 1 _next_request_id += 1
with LoggingContext(request_id) as request_context: with LoggingContext(request_id) as request_context:
if report_metrics: with Measure(self.clock, "wrapped_request_handler"):
request_metrics = RequestMetrics() request_metrics = RequestMetrics()
request_metrics.start(self.clock) request_metrics.start(self.clock, name=self.__class__.__name__)
request_context.request = request_id request_context.request = request_id
with request.processing(): with request.processing():
try:
with PreserveLoggingContext(request_context):
yield request_handler(self, request)
except CodeMessageException as e:
code = e.code
if isinstance(e, SynapseError):
logger.info(
"%s SynapseError: %s - %s", request, code, e.msg
)
else:
logger.exception(e)
outgoing_responses_counter.inc(request.method, str(code))
respond_with_json(
request, code, cs_exception(e), send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
version_string=self.version_string,
)
except:
logger.exception(
"Failed handle request %s.%s on %r: %r",
request_handler.__module__,
request_handler.__name__,
self,
request
)
respond_with_json(
request,
500,
{
"error": "Internal server error",
"errcode": Codes.UNKNOWN,
},
send_cors=True
)
finally:
try: try:
if report_metrics: with PreserveLoggingContext(request_context):
request_metrics.stop( if include_metrics:
self.clock, request, self.__class__.__name__ yield request_handler(self, request, request_metrics)
else:
yield request_handler(self, request)
except CodeMessageException as e:
code = e.code
if isinstance(e, SynapseError):
logger.info(
"%s SynapseError: %s - %s", request, code, e.msg
) )
else:
logger.exception(e)
outgoing_responses_counter.inc(request.method, str(code))
respond_with_json(
request, code, cs_exception(e), send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
version_string=self.version_string,
)
except: except:
pass logger.exception(
"Failed handle request %s.%s on %r: %r",
request_handler.__module__,
request_handler.__name__,
self,
request
)
respond_with_json(
request,
500,
{
"error": "Internal server error",
"errcode": Codes.UNKNOWN,
},
send_cors=True
)
finally:
try:
request_metrics.stop(
self.clock, request
)
except Exception as e:
logger.warn("Failed to stop metrics: %r", e)
return wrapped_request_handler return wrapped_request_handler
@ -220,9 +223,9 @@ class JsonResource(HttpServer, resource.Resource):
# It does its own metric reporting because _async_render dispatches to # It does its own metric reporting because _async_render dispatches to
# a callback and it's the class name of that callback we want to report # a callback and it's the class name of that callback we want to report
# against rather than the JsonResource itself. # against rather than the JsonResource itself.
@request_handler(report_metrics=False) @request_handler(include_metrics=True)
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render(self, request): def _async_render(self, request, request_metrics):
""" This gets called from render() every time someone sends us a request. """ This gets called from render() every time someone sends us a request.
This checks if anyone has registered a callback for that method and This checks if anyone has registered a callback for that method and
path. path.
@ -231,9 +234,6 @@ class JsonResource(HttpServer, resource.Resource):
self._send_response(request, 200, {}) self._send_response(request, 200, {})
return return
request_metrics = RequestMetrics()
request_metrics.start(self.clock)
# Loop through all the registered callbacks to check if the method # Loop through all the registered callbacks to check if the method
# and path regex match # and path regex match
for path_entry in self.path_regexs.get(request.method, []): for path_entry in self.path_regexs.get(request.method, []):
@ -247,12 +247,6 @@ class JsonResource(HttpServer, resource.Resource):
callback = path_entry.callback callback = path_entry.callback
servlet_instance = getattr(callback, "__self__", None)
if servlet_instance is not None:
servlet_classname = servlet_instance.__class__.__name__
else:
servlet_classname = "%r" % callback
kwargs = intern_dict({ kwargs = intern_dict({
name: urllib.unquote(value).decode("UTF-8") if value else value name: urllib.unquote(value).decode("UTF-8") if value else value
for name, value in m.groupdict().items() for name, value in m.groupdict().items()
@ -263,10 +257,13 @@ class JsonResource(HttpServer, resource.Resource):
code, response = callback_return code, response = callback_return
self._send_response(request, code, response) self._send_response(request, code, response)
try: servlet_instance = getattr(callback, "__self__", None)
request_metrics.stop(self.clock, request, servlet_classname) if servlet_instance is not None:
except: servlet_classname = servlet_instance.__class__.__name__
pass else:
servlet_classname = "%r" % callback
request_metrics.name = servlet_classname
return return
@ -298,11 +295,12 @@ class JsonResource(HttpServer, resource.Resource):
class RequestMetrics(object): class RequestMetrics(object):
def start(self, clock): def start(self, clock, name):
self.start = clock.time_msec() self.start = clock.time_msec()
self.start_context = LoggingContext.current_context() self.start_context = LoggingContext.current_context()
self.name = name
def stop(self, clock, request, servlet_classname): def stop(self, clock, request):
context = LoggingContext.current_context() context = LoggingContext.current_context()
tag = "" tag = ""
@ -316,26 +314,26 @@ class RequestMetrics(object):
) )
return return
incoming_requests_counter.inc(request.method, servlet_classname, tag) incoming_requests_counter.inc(request.method, self.name, tag)
response_timer.inc_by( response_timer.inc_by(
clock.time_msec() - self.start, request.method, clock.time_msec() - self.start, request.method,
servlet_classname, tag self.name, tag
) )
ru_utime, ru_stime = context.get_resource_usage() ru_utime, ru_stime = context.get_resource_usage()
response_ru_utime.inc_by( response_ru_utime.inc_by(
ru_utime, request.method, servlet_classname, tag ru_utime, request.method, self.name, tag
) )
response_ru_stime.inc_by( response_ru_stime.inc_by(
ru_stime, request.method, servlet_classname, tag ru_stime, request.method, self.name, tag
) )
response_db_txn_count.inc_by( response_db_txn_count.inc_by(
context.db_txn_count, request.method, servlet_classname, tag context.db_txn_count, request.method, self.name, tag
) )
response_db_txn_duration.inc_by( response_db_txn_duration.inc_by(
context.db_txn_duration, request.method, servlet_classname, tag context.db_txn_duration, request.method, self.name, tag
) )

View File

@ -19,7 +19,8 @@ from synapse.api.errors import AuthError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
from synapse.util.metrics import Measure
from synapse.types import StreamToken from synapse.types import StreamToken
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
import synapse.metrics import synapse.metrics
@ -67,10 +68,8 @@ class _NotifierUserStream(object):
so that it can remove itself from the indexes in the Notifier class. so that it can remove itself from the indexes in the Notifier class.
""" """
def __init__(self, user_id, rooms, current_token, time_now_ms, def __init__(self, user_id, rooms, current_token, time_now_ms):
appservice=None):
self.user_id = user_id self.user_id = user_id
self.appservice = appservice
self.rooms = set(rooms) self.rooms = set(rooms)
self.current_token = current_token self.current_token = current_token
self.last_notified_ms = time_now_ms self.last_notified_ms = time_now_ms
@ -107,11 +106,6 @@ class _NotifierUserStream(object):
notifier.user_to_user_stream.pop(self.user_id) notifier.user_to_user_stream.pop(self.user_id)
if self.appservice:
notifier.appservice_to_user_streams.get(
self.appservice, set()
).discard(self)
def count_listeners(self): def count_listeners(self):
return len(self.notify_deferred.observers()) return len(self.notify_deferred.observers())
@ -142,7 +136,6 @@ class Notifier(object):
def __init__(self, hs): def __init__(self, hs):
self.user_to_user_stream = {} self.user_to_user_stream = {}
self.room_to_user_streams = {} self.room_to_user_streams = {}
self.appservice_to_user_streams = {}
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -168,8 +161,6 @@ class Notifier(object):
all_user_streams |= x all_user_streams |= x
for x in self.user_to_user_stream.values(): for x in self.user_to_user_stream.values():
all_user_streams.add(x) all_user_streams.add(x)
for x in self.appservice_to_user_streams.values():
all_user_streams |= x
return sum(stream.count_listeners() for stream in all_user_streams) return sum(stream.count_listeners() for stream in all_user_streams)
metrics.register_callback("listeners", count_listeners) metrics.register_callback("listeners", count_listeners)
@ -182,11 +173,8 @@ class Notifier(object):
"users", "users",
lambda: len(self.user_to_user_stream), lambda: len(self.user_to_user_stream),
) )
metrics.register_callback(
"appservices",
lambda: count(bool, self.appservice_to_user_streams.values()),
)
@preserve_fn
def on_new_room_event(self, event, room_stream_id, max_room_stream_id, def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
extra_users=[]): extra_users=[]):
""" Used by handlers to inform the notifier something has happened """ Used by handlers to inform the notifier something has happened
@ -208,6 +196,7 @@ class Notifier(object):
self.notify_replication() self.notify_replication()
@preserve_fn
def _notify_pending_new_room_events(self, max_room_stream_id): def _notify_pending_new_room_events(self, max_room_stream_id):
"""Notify for the room events that were queued waiting for a previous """Notify for the room events that were queued waiting for a previous
event to be persisted. event to be persisted.
@ -225,24 +214,11 @@ class Notifier(object):
else: else:
self._on_new_room_event(event, room_stream_id, extra_users) self._on_new_room_event(event, room_stream_id, extra_users)
@preserve_fn
def _on_new_room_event(self, event, room_stream_id, extra_users=[]): def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
"""Notify any user streams that are interested in this room event""" """Notify any user streams that are interested in this room event"""
# poke any interested application service. # poke any interested application service.
self.appservice_handler.notify_interested_services(event) self.appservice_handler.notify_interested_services(room_stream_id)
app_streams = set()
for appservice in self.appservice_to_user_streams:
# TODO (kegan): Redundant appservice listener checks?
# App services will already be in the room_to_user_streams set, but
# that isn't enough. They need to be checked here in order to
# receive *invites* for users they are interested in. Does this
# make the room_to_user_streams check somewhat obselete?
if appservice.is_interested(event):
app_user_streams = self.appservice_to_user_streams.get(
appservice, set()
)
app_streams |= app_user_streams
if event.type == EventTypes.Member and event.membership == Membership.JOIN: if event.type == EventTypes.Member and event.membership == Membership.JOIN:
self._user_joined_room(event.state_key, event.room_id) self._user_joined_room(event.state_key, event.room_id)
@ -251,35 +227,36 @@ class Notifier(object):
"room_key", room_stream_id, "room_key", room_stream_id,
users=extra_users, users=extra_users,
rooms=[event.room_id], rooms=[event.room_id],
extra_streams=app_streams,
) )
def on_new_event(self, stream_key, new_token, users=[], rooms=[], @preserve_fn
extra_streams=set()): def on_new_event(self, stream_key, new_token, users=[], rooms=[]):
""" Used to inform listeners that something has happend event wise. """ Used to inform listeners that something has happend event wise.
Will wake up all listeners for the given users and rooms. Will wake up all listeners for the given users and rooms.
""" """
with PreserveLoggingContext(): with PreserveLoggingContext():
user_streams = set() with Measure(self.clock, "on_new_event"):
user_streams = set()
for user in users: for user in users:
user_stream = self.user_to_user_stream.get(str(user)) user_stream = self.user_to_user_stream.get(str(user))
if user_stream is not None: if user_stream is not None:
user_streams.add(user_stream) user_streams.add(user_stream)
for room in rooms: for room in rooms:
user_streams |= self.room_to_user_streams.get(room, set()) user_streams |= self.room_to_user_streams.get(room, set())
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
for user_stream in user_streams: for user_stream in user_streams:
try: try:
user_stream.notify(stream_key, new_token, time_now_ms) user_stream.notify(stream_key, new_token, time_now_ms)
except: except:
logger.exception("Failed to notify listener") logger.exception("Failed to notify listener")
self.notify_replication() self.notify_replication()
@preserve_fn
def on_new_replication_data(self): def on_new_replication_data(self):
"""Used to inform replication listeners that something has happend """Used to inform replication listeners that something has happend
without waking up any of the normal user event streams""" without waking up any of the normal user event streams"""
@ -294,7 +271,6 @@ class Notifier(object):
""" """
user_stream = self.user_to_user_stream.get(user_id) user_stream = self.user_to_user_stream.get(user_id)
if user_stream is None: if user_stream is None:
appservice = yield self.store.get_app_service_by_user_id(user_id)
current_token = yield self.event_sources.get_current_token() current_token = yield self.event_sources.get_current_token()
if room_ids is None: if room_ids is None:
rooms = yield self.store.get_rooms_for_user(user_id) rooms = yield self.store.get_rooms_for_user(user_id)
@ -302,7 +278,6 @@ class Notifier(object):
user_stream = _NotifierUserStream( user_stream = _NotifierUserStream(
user_id=user_id, user_id=user_id,
rooms=room_ids, rooms=room_ids,
appservice=appservice,
current_token=current_token, current_token=current_token,
time_now_ms=self.clock.time_msec(), time_now_ms=self.clock.time_msec(),
) )
@ -477,11 +452,6 @@ class Notifier(object):
s = self.room_to_user_streams.setdefault(room, set()) s = self.room_to_user_streams.setdefault(room, set())
s.add(user_stream) s.add(user_stream)
if user_stream.appservice:
self.appservice_to_user_stream.setdefault(
user_stream.appservice, set()
).add(user_stream)
def _user_joined_room(self, user_id, room_id): def _user_joined_room(self, user_id, room_id):
new_user_stream = self.user_to_user_stream.get(user_id) new_user_stream = self.user_to_user_stream.get(user_id)
if new_user_stream is not None: if new_user_stream is not None:

View File

@ -38,15 +38,16 @@ class ActionGenerator:
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_push_actions_for_event(self, event, context): def handle_push_actions_for_event(self, event, context):
with Measure(self.clock, "handle_push_actions_for_event"): with Measure(self.clock, "evaluator_for_event"):
bulk_evaluator = yield evaluator_for_event( bulk_evaluator = yield evaluator_for_event(
event, self.hs, self.store, context.current_state event, self.hs, self.store, context.state_group, context.current_state
) )
with Measure(self.clock, "action_for_event_by_user"):
actions_by_user = yield bulk_evaluator.action_for_event_by_user( actions_by_user = yield bulk_evaluator.action_for_event_by_user(
event, context.current_state event, context.current_state
) )
context.push_actions = [ context.push_actions = [
(uid, actions) for uid, actions in actions_by_user.items() (uid, actions) for uid, actions in actions_by_user.items()
] ]

View File

@ -217,6 +217,27 @@ BASE_APPEND_OVERRIDE_RULES = [
'dont_notify' 'dont_notify'
] ]
}, },
# This was changed from underride to override so it's closer in priority
# to the content rules where the user name highlight rule lives. This
# way a room rule is lower priority than both but a custom override rule
# is higher priority than both.
{
'rule_id': 'global/override/.m.rule.contains_display_name',
'conditions': [
{
'kind': 'contains_display_name'
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}, {
'set_tweak': 'highlight'
}
]
},
] ]
@ -242,23 +263,6 @@ BASE_APPEND_UNDERRIDE_RULES = [
} }
] ]
}, },
{
'rule_id': 'global/underride/.m.rule.contains_display_name',
'conditions': [
{
'kind': 'contains_display_name'
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}, {
'set_tweak': 'highlight'
}
]
},
{ {
'rule_id': 'global/underride/.m.rule.room_one_to_one', 'rule_id': 'global/underride/.m.rule.room_one_to_one',
'conditions': [ 'conditions': [

View File

@ -36,35 +36,11 @@ def _get_rules(room_id, user_ids, store):
@defer.inlineCallbacks @defer.inlineCallbacks
def evaluator_for_event(event, hs, store, current_state): def evaluator_for_event(event, hs, store, state_group, current_state):
room_id = event.room_id rules_by_user = yield store.bulk_get_push_rules_for_room(
# We also will want to generate notifs for other people in the room so event.room_id, state_group, current_state
# their unread countss are correct in the event stream, but to avoid
# generating them for bot / AS users etc, we only do so for people who've
# sent a read receipt into the room.
local_users_in_room = set(
e.state_key for e in current_state.values()
if e.type == EventTypes.Member and e.membership == Membership.JOIN
and hs.is_mine_id(e.state_key)
) )
# users in the room who have pushers need to get push rules run because
# that's how their pushers work
if_users_with_pushers = yield store.get_if_users_have_pushers(
local_users_in_room
)
user_ids = set(
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
)
users_with_receipts = yield store.get_users_with_read_receipts_in_room(room_id)
# any users with pushers must be ours: they have pushers
for uid in users_with_receipts:
if uid in local_users_in_room:
user_ids.add(uid)
# if this event is an invite event, we may need to run rules for the user # if this event is an invite event, we may need to run rules for the user
# who's been invited, otherwise they won't get told they've been invited # who's been invited, otherwise they won't get told they've been invited
if event.type == 'm.room.member' and event.content['membership'] == 'invite': if event.type == 'm.room.member' and event.content['membership'] == 'invite':
@ -72,12 +48,12 @@ def evaluator_for_event(event, hs, store, current_state):
if invited_user and hs.is_mine_id(invited_user): if invited_user and hs.is_mine_id(invited_user):
has_pusher = yield store.user_has_pusher(invited_user) has_pusher = yield store.user_has_pusher(invited_user)
if has_pusher: if has_pusher:
user_ids.add(invited_user) rules_by_user[invited_user] = yield store.get_push_rules_for_user(
invited_user
rules_by_user = yield _get_rules(room_id, user_ids, store) )
defer.returnValue(BulkPushRuleEvaluator( defer.returnValue(BulkPushRuleEvaluator(
room_id, rules_by_user, user_ids, store event.room_id, rules_by_user, store
)) ))
@ -90,10 +66,9 @@ class BulkPushRuleEvaluator:
the same logic to run the actual rules, but could be optimised further the same logic to run the actual rules, but could be optimised further
(see https://matrix.org/jira/browse/SYN-562) (see https://matrix.org/jira/browse/SYN-562)
""" """
def __init__(self, room_id, rules_by_user, users_in_room, store): def __init__(self, room_id, rules_by_user, store):
self.room_id = room_id self.room_id = room_id
self.rules_by_user = rules_by_user self.rules_by_user = rules_by_user
self.users_in_room = users_in_room
self.store = store self.store = store
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -17,14 +17,15 @@ from twisted.internet import defer
from synapse.util.presentable_names import ( from synapse.util.presentable_names import (
calculate_room_name, name_from_member_event calculate_room_name, name_from_member_event
) )
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
@defer.inlineCallbacks @defer.inlineCallbacks
def get_badge_count(store, user_id): def get_badge_count(store, user_id):
invites, joins = yield defer.gatherResults([ invites, joins = yield preserve_context_over_deferred(defer.gatherResults([
store.get_invited_rooms_for_user(user_id), preserve_fn(store.get_invited_rooms_for_user)(user_id),
store.get_rooms_for_user(user_id), preserve_fn(store.get_rooms_for_user)(user_id),
], consumeErrors=True) ], consumeErrors=True))
my_receipts_by_room = yield store.get_receipts_for_user( my_receipts_by_room = yield store.get_receipts_for_user(
user_id, "m.read", user_id, "m.read",

View File

@ -17,7 +17,7 @@
from twisted.internet import defer from twisted.internet import defer
import pusher import pusher
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
import logging import logging
@ -102,14 +102,14 @@ class PusherPool:
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks @defer.inlineCallbacks
def remove_pushers_by_user(self, user_id, except_token_ids=[]): def remove_pushers_by_user(self, user_id, except_access_token_id=None):
all = yield self.store.get_all_pushers() all = yield self.store.get_all_pushers()
logger.info( logger.info(
"Removing all pushers for user %s except access tokens ids %r", "Removing all pushers for user %s except access tokens id %r",
user_id, except_token_ids user_id, except_access_token_id
) )
for p in all: for p in all:
if p['user_name'] == user_id and p['access_token'] not in except_token_ids: if p['user_name'] == user_id and p['access_token'] != except_access_token_id:
logger.info( logger.info(
"Removing pusher for app id %s, pushkey %s, user %s", "Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name'] p['app_id'], p['pushkey'], p['user_name']
@ -130,10 +130,12 @@ class PusherPool:
if u in self.pushers: if u in self.pushers:
for p in self.pushers[u].values(): for p in self.pushers[u].values():
deferreds.append( deferreds.append(
p.on_new_notifications(min_stream_id, max_stream_id) preserve_fn(p.on_new_notifications)(
min_stream_id, max_stream_id
)
) )
yield defer.gatherResults(deferreds) yield preserve_context_over_deferred(defer.gatherResults(deferreds))
except: except:
logger.exception("Exception in pusher on_new_notifications") logger.exception("Exception in pusher on_new_notifications")
@ -155,10 +157,10 @@ class PusherPool:
if u in self.pushers: if u in self.pushers:
for p in self.pushers[u].values(): for p in self.pushers[u].values():
deferreds.append( deferreds.append(
p.on_new_receipts(min_stream_id, max_stream_id) preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id)
) )
yield defer.gatherResults(deferreds) yield preserve_context_over_deferred(defer.gatherResults(deferreds))
except: except:
logger.exception("Exception in pusher on_new_receipts") logger.exception("Exception in pusher on_new_receipts")

View File

@ -41,6 +41,7 @@ STREAM_NAMES = (
("push_rules",), ("push_rules",),
("pushers",), ("pushers",),
("state",), ("state",),
("caches",),
) )
@ -70,6 +71,7 @@ class ReplicationResource(Resource):
* "backfill": Old events that have been backfilled from other servers. * "backfill": Old events that have been backfilled from other servers.
* "push_rules": Per user changes to push rules. * "push_rules": Per user changes to push rules.
* "pushers": Per user changes to their pushers. * "pushers": Per user changes to their pushers.
* "caches": Cache invalidations.
The API takes two additional query parameters: The API takes two additional query parameters:
@ -129,6 +131,7 @@ class ReplicationResource(Resource):
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token() push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
pushers_token = self.store.get_pushers_stream_token() pushers_token = self.store.get_pushers_stream_token()
state_token = self.store.get_state_stream_token() state_token = self.store.get_state_stream_token()
caches_token = self.store.get_cache_stream_token()
defer.returnValue(_ReplicationToken( defer.returnValue(_ReplicationToken(
room_stream_token, room_stream_token,
@ -140,6 +143,7 @@ class ReplicationResource(Resource):
push_rules_token, push_rules_token,
pushers_token, pushers_token,
state_token, state_token,
caches_token,
)) ))
@request_handler() @request_handler()
@ -188,6 +192,7 @@ class ReplicationResource(Resource):
yield self.push_rules(writer, current_token, limit, request_streams) yield self.push_rules(writer, current_token, limit, request_streams)
yield self.pushers(writer, current_token, limit, request_streams) yield self.pushers(writer, current_token, limit, request_streams)
yield self.state(writer, current_token, limit, request_streams) yield self.state(writer, current_token, limit, request_streams)
yield self.caches(writer, current_token, limit, request_streams)
self.streams(writer, current_token, request_streams) self.streams(writer, current_token, request_streams)
logger.info("Replicated %d rows", writer.total) logger.info("Replicated %d rows", writer.total)
@ -379,6 +384,20 @@ class ReplicationResource(Resource):
"position", "type", "state_key", "event_id" "position", "type", "state_key", "event_id"
)) ))
@defer.inlineCallbacks
def caches(self, writer, current_token, limit, request_streams):
current_position = current_token.caches
caches = request_streams.get("caches")
if caches is not None:
updated_caches = yield self.store.get_all_updated_caches(
caches, current_position, limit
)
writer.write_header_and_rows("caches", updated_caches, (
"position", "cache_func", "keys", "invalidation_ts"
))
class _Writer(object): class _Writer(object):
"""Writes the streams as a JSON object as the response to the request""" """Writes the streams as a JSON object as the response to the request"""
@ -407,7 +426,7 @@ class _Writer(object):
class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill", "events", "presence", "typing", "receipts", "account_data", "backfill",
"push_rules", "pushers", "state" "push_rules", "pushers", "state", "caches",
))): ))):
__slots__ = [] __slots__ = []

View File

@ -14,15 +14,43 @@
# limitations under the License. # limitations under the License.
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.engines import PostgresEngine
from twisted.internet import defer from twisted.internet import defer
from ._slaved_id_tracker import SlavedIdTracker
import logging
logger = logging.getLogger(__name__)
class BaseSlavedStore(SQLBaseStore): class BaseSlavedStore(SQLBaseStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(BaseSlavedStore, self).__init__(hs) super(BaseSlavedStore, self).__init__(hs)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = SlavedIdTracker(
db_conn, "cache_invalidation_stream", "stream_id",
)
else:
self._cache_id_gen = None
def stream_positions(self): def stream_positions(self):
return {} pos = {}
if self._cache_id_gen:
pos["caches"] = self._cache_id_gen.get_current_token()
return pos
def process_replication(self, result): def process_replication(self, result):
stream = result.get("caches")
if stream:
for row in stream["rows"]:
(
position, cache_func, keys, invalidation_ts,
) = row
try:
getattr(self, cache_func).invalidate(tuple(keys))
except AttributeError:
logger.info("Got unexpected cache_func: %r", cache_func)
self._cache_id_gen.advance(int(stream["position"]))
return defer.succeed(None) return defer.succeed(None)

View File

@ -28,3 +28,13 @@ class SlavedApplicationServiceStore(BaseSlavedStore):
get_app_service_by_token = DataStore.get_app_service_by_token.__func__ get_app_service_by_token = DataStore.get_app_service_by_token.__func__
get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__ get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__
get_app_services = DataStore.get_app_services.__func__
get_new_events_for_appservice = DataStore.get_new_events_for_appservice.__func__
create_appservice_txn = DataStore.create_appservice_txn.__func__
get_appservices_by_state = DataStore.get_appservices_by_state.__func__
get_oldest_unsent_txn = DataStore.get_oldest_unsent_txn.__func__
_get_last_txn = DataStore._get_last_txn.__func__
complete_appservice_txn = DataStore.complete_appservice_txn.__func__
get_appservice_state = DataStore.get_appservice_state.__func__
set_appservice_last_pos = DataStore.set_appservice_last_pos.__func__
set_appservice_state = DataStore.set_appservice_state.__func__

View File

@ -20,4 +20,4 @@ from synapse.storage.directory import DirectoryStore
class DirectoryStore(BaseSlavedStore): class DirectoryStore(BaseSlavedStore):
get_aliases_for_room = DirectoryStore.__dict__[ get_aliases_for_room = DirectoryStore.__dict__[
"get_aliases_for_room" "get_aliases_for_room"
].orig ]

View File

@ -25,6 +25,9 @@ class SlavedRegistrationStore(BaseSlavedStore):
# TODO: use the cached version and invalidate deleted tokens # TODO: use the cached version and invalidate deleted tokens
get_user_by_access_token = RegistrationStore.__dict__[ get_user_by_access_token = RegistrationStore.__dict__[
"get_user_by_access_token" "get_user_by_access_token"
].orig ]
_query_for_auth = DataStore._query_for_auth.__func__ _query_for_auth = DataStore._query_for_auth.__func__
get_user_by_id = RegistrationStore.__dict__[
"get_user_by_id"
]

View File

@ -46,7 +46,9 @@ from synapse.rest.client.v2_alpha import (
account_data, account_data,
report_event, report_event,
openid, openid,
notifications,
devices, devices,
thirdparty,
) )
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
@ -91,4 +93,6 @@ class ClientRestResource(JsonResource):
account_data.register_servlets(hs, client_resource) account_data.register_servlets(hs, client_resource)
report_event.register_servlets(hs, client_resource) report_event.register_servlets(hs, client_resource)
openid.register_servlets(hs, client_resource) openid.register_servlets(hs, client_resource)
notifications.register_servlets(hs, client_resource)
devices.register_servlets(hs, client_resource) devices.register_servlets(hs, client_resource)
thirdparty.register_servlets(hs, client_resource)

View File

@ -28,6 +28,10 @@ logger = logging.getLogger(__name__)
class WhoisRestServlet(ClientV1RestServlet): class WhoisRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)") PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)")
def __init__(self, hs):
super(WhoisRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
@ -82,6 +86,10 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
"/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" "/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
) )
def __init__(self, hs):
super(PurgeHistoryRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, event_id): def on_POST(self, request, room_id, event_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)

View File

@ -57,7 +57,6 @@ class ClientV1RestServlet(RestServlet):
hs (synapse.server.HomeServer): hs (synapse.server.HomeServer):
""" """
self.hs = hs self.hs = hs
self.handlers = hs.get_handlers()
self.builder_factory = hs.get_event_builder_factory() self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_v1auth() self.auth = hs.get_v1auth()
self.txns = HttpTransactionStore() self.txns = HttpTransactionStore()

View File

@ -36,6 +36,10 @@ def register_servlets(hs, http_server):
class ClientDirectoryServer(ClientV1RestServlet): class ClientDirectoryServer(ClientV1RestServlet):
PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$") PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$")
def __init__(self, hs):
super(ClientDirectoryServer, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_alias): def on_GET(self, request, room_alias):
room_alias = RoomAlias.from_string(room_alias) room_alias = RoomAlias.from_string(room_alias)
@ -146,6 +150,7 @@ class ClientDirectoryListServer(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(ClientDirectoryListServer, self).__init__(hs) super(ClientDirectoryListServer, self).__init__(hs)
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):

View File

@ -32,6 +32,10 @@ class EventStreamRestServlet(ClientV1RestServlet):
DEFAULT_LONGPOLL_TIME_MS = 30000 DEFAULT_LONGPOLL_TIME_MS = 30000
def __init__(self, hs):
super(EventStreamRestServlet, self).__init__(hs)
self.event_stream_handler = hs.get_event_stream_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
requester = yield self.auth.get_user_by_req( requester = yield self.auth.get_user_by_req(
@ -46,7 +50,6 @@ class EventStreamRestServlet(ClientV1RestServlet):
if "room_id" in request.args: if "room_id" in request.args:
room_id = request.args["room_id"][0] room_id = request.args["room_id"][0]
handler = self.handlers.event_stream_handler
pagin_config = PaginationConfig.from_request(request) pagin_config = PaginationConfig.from_request(request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
if "timeout" in request.args: if "timeout" in request.args:
@ -57,7 +60,7 @@ class EventStreamRestServlet(ClientV1RestServlet):
as_client_event = "raw" not in request.args as_client_event = "raw" not in request.args
chunk = yield handler.get_stream( chunk = yield self.event_stream_handler.get_stream(
requester.user.to_string(), requester.user.to_string(),
pagin_config, pagin_config,
timeout=timeout, timeout=timeout,
@ -80,12 +83,12 @@ class EventRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(EventRestServlet, self).__init__(hs) super(EventRestServlet, self).__init__(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, event_id): def on_GET(self, request, event_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
handler = self.handlers.event_handler event = yield self.event_handler.get_event(requester.user, event_id)
event = yield handler.get_event(requester.user, event_id)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
if event: if event:

View File

@ -23,6 +23,10 @@ from .base import ClientV1RestServlet, client_path_patterns
class InitialSyncRestServlet(ClientV1RestServlet): class InitialSyncRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/initialSync$") PATTERNS = client_path_patterns("/initialSync$")
def __init__(self, hs):
super(InitialSyncRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)

View File

@ -54,12 +54,9 @@ class LoginRestServlet(ClientV1RestServlet):
self.jwt_secret = hs.config.jwt_secret self.jwt_secret = hs.config.jwt_secret
self.jwt_algorithm = hs.config.jwt_algorithm self.jwt_algorithm = hs.config.jwt_algorithm
self.cas_enabled = hs.config.cas_enabled self.cas_enabled = hs.config.cas_enabled
self.cas_server_url = hs.config.cas_server_url
self.cas_required_attributes = hs.config.cas_required_attributes
self.servername = hs.config.server_name
self.http_client = hs.get_simple_http_client()
self.auth_handler = self.hs.get_auth_handler() self.auth_handler = self.hs.get_auth_handler()
self.device_handler = self.hs.get_device_handler() self.device_handler = self.hs.get_device_handler()
self.handlers = hs.get_handlers()
def on_GET(self, request): def on_GET(self, request):
flows = [] flows = []
@ -110,17 +107,6 @@ class LoginRestServlet(ClientV1RestServlet):
LoginRestServlet.JWT_TYPE): LoginRestServlet.JWT_TYPE):
result = yield self.do_jwt_login(login_submission) result = yield self.do_jwt_login(login_submission)
defer.returnValue(result) defer.returnValue(result)
# TODO Delete this after all CAS clients switch to token login instead
elif self.cas_enabled and (login_submission["type"] ==
LoginRestServlet.CAS_TYPE):
uri = "%s/proxyValidate" % (self.cas_server_url,)
args = {
"ticket": login_submission["ticket"],
"service": login_submission["service"]
}
body = yield self.http_client.get_raw(uri, args)
result = yield self.do_cas_login(body)
defer.returnValue(result)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
result = yield self.do_token_login(login_submission) result = yield self.do_token_login(login_submission)
defer.returnValue(result) defer.returnValue(result)
@ -191,51 +177,6 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result)) defer.returnValue((200, result))
# TODO Delete this after all CAS clients switch to token login instead
@defer.inlineCallbacks
def do_cas_login(self, cas_response_body):
user, attributes = self.parse_cas_response(cas_response_body)
for required_attribute, required_value in self.cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in attributes:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
# Also need to check value
if required_value is not None:
actual_value = attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler
registered_user_id = yield auth_handler.check_user_exists(user_id)
if registered_user_id:
access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(
registered_user_id
)
)
result = {
"user_id": registered_user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
}
else:
user_id, access_token = (
yield self.handlers.registration_handler.register(localpart=user)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"home_server": self.hs.hostname,
}
defer.returnValue((200, result))
@defer.inlineCallbacks @defer.inlineCallbacks
def do_jwt_login(self, login_submission): def do_jwt_login(self, login_submission):
token = login_submission.get("token", None) token = login_submission.get("token", None)
@ -293,33 +234,6 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result)) defer.returnValue((200, result))
# TODO Delete this after all CAS clients switch to token login instead
def parse_cas_response(self, cas_response_body):
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not root[0].tag.endswith("authenticationSuccess"):
raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
if child.tag.endswith("attributes"):
attributes = {}
for attribute in child:
# ElementTree library expands the namespace in attribute tags
# to the full URL of the namespace.
# See (https://docs.python.org/2/library/xml.etree.elementtree.html)
# We don't care about namespace here and it will always be encased in
# curly braces, so we remove them.
if "}" in attribute.tag:
attributes[attribute.tag.split("}")[1]] = attribute.text
else:
attributes[attribute.tag] = attribute.text
if user is None or attributes is None:
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
return (user, attributes)
def _register_device(self, user_id, login_submission): def _register_device(self, user_id, login_submission):
"""Register a device for a user. """Register a device for a user.
@ -347,6 +261,7 @@ class SAML2RestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(SAML2RestServlet, self).__init__(hs) super(SAML2RestServlet, self).__init__(hs)
self.sp_config = hs.config.saml2_config_path self.sp_config = hs.config.saml2_config_path
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -384,18 +299,6 @@ class SAML2RestServlet(ClientV1RestServlet):
defer.returnValue((200, {"status": "not_authenticated"})) defer.returnValue((200, {"status": "not_authenticated"}))
# TODO Delete this after all CAS clients switch to token login instead
class CasRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login/cas", releases=())
def __init__(self, hs):
super(CasRestServlet, self).__init__(hs)
self.cas_server_url = hs.config.cas_server_url
def on_GET(self, request):
return (200, {"serverUrl": self.cas_server_url})
class CasRedirectServlet(ClientV1RestServlet): class CasRedirectServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login/cas/redirect", releases=()) PATTERNS = client_path_patterns("/login/cas/redirect", releases=())
@ -427,6 +330,8 @@ class CasTicketServlet(ClientV1RestServlet):
self.cas_server_url = hs.config.cas_server_url self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url self.cas_service_url = hs.config.cas_service_url
self.cas_required_attributes = hs.config.cas_required_attributes self.cas_required_attributes = hs.config.cas_required_attributes
self.auth_handler = hs.get_auth_handler()
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
@ -479,30 +384,39 @@ class CasTicketServlet(ClientV1RestServlet):
return urlparse.urlunparse(url_parts) return urlparse.urlunparse(url_parts)
def parse_cas_response(self, cas_response_body): def parse_cas_response(self, cas_response_body):
root = ET.fromstring(cas_response_body) user = None
if not root.tag.endswith("serviceResponse"): attributes = None
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) try:
if not root[0].tag.endswith("authenticationSuccess"): root = ET.fromstring(cas_response_body)
raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED) if not root.tag.endswith("serviceResponse"):
for child in root[0]: raise Exception("root of CAS response is not serviceResponse")
if child.tag.endswith("user"): success = (root[0].tag.endswith("authenticationSuccess"))
user = child.text for child in root[0]:
if child.tag.endswith("attributes"): if child.tag.endswith("user"):
attributes = {} user = child.text
for attribute in child: if child.tag.endswith("attributes"):
# ElementTree library expands the namespace in attribute tags attributes = {}
# to the full URL of the namespace. for attribute in child:
# See (https://docs.python.org/2/library/xml.etree.elementtree.html) # ElementTree library expands the namespace in
# We don't care about namespace here and it will always be encased in # attribute tags to the full URL of the namespace.
# curly braces, so we remove them. # We don't care about namespace here and it will always
if "}" in attribute.tag: # be encased in curly braces, so we remove them.
attributes[attribute.tag.split("}")[1]] = attribute.text tag = attribute.tag
else: if "}" in tag:
attributes[attribute.tag] = attribute.text tag = tag.split("}")[1]
if user is None or attributes is None: attributes[tag] = attribute.text
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) if user is None:
raise Exception("CAS response does not contain user")
return (user, attributes) if attributes is None:
raise Exception("CAS response does not contain attributes")
except Exception:
logger.error("Error parsing CAS response", exc_info=1)
raise LoginError(401, "Invalid CAS response",
errcode=Codes.UNAUTHORIZED)
if not success:
raise LoginError(401, "Unsuccessful CAS response",
errcode=Codes.UNAUTHORIZED)
return user, attributes
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
@ -512,5 +426,3 @@ def register_servlets(hs, http_server):
if hs.config.cas_enabled: if hs.config.cas_enabled:
CasRedirectServlet(hs).register(http_server) CasRedirectServlet(hs).register(http_server)
CasTicketServlet(hs).register(http_server) CasTicketServlet(hs).register(http_server)
CasRestServlet(hs).register(http_server)
# TODO PasswordResetRestServlet(hs).register(http_server)

View File

@ -24,6 +24,10 @@ from synapse.http.servlet import parse_json_object_from_request
class ProfileDisplaynameRestServlet(ClientV1RestServlet): class ProfileDisplaynameRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname") PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname")
def __init__(self, hs):
super(ProfileDisplaynameRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
@ -62,6 +66,10 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
class ProfileAvatarURLRestServlet(ClientV1RestServlet): class ProfileAvatarURLRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url") PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url")
def __init__(self, hs):
super(ProfileAvatarURLRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
@ -99,6 +107,10 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
class ProfileRestServlet(ClientV1RestServlet): class ProfileRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)") PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)")
def __init__(self, hs):
super(ProfileRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
user = UserID.from_string(user_id) user = UserID.from_string(user_id)

View File

@ -65,6 +65,7 @@ class RegisterRestServlet(ClientV1RestServlet):
self.sessions = {} self.sessions = {}
self.enable_registration = hs.config.enable_registration self.enable_registration = hs.config.enable_registration
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.handlers = hs.get_handlers()
def on_GET(self, request): def on_GET(self, request):
if self.hs.config.enable_registration_captcha: if self.hs.config.enable_registration_captcha:
@ -383,6 +384,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
super(CreateUserRestServlet, self).__init__(hs) super(CreateUserRestServlet, self).__init__(hs)
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.direct_user_creation_max_duration = hs.config.user_creation_max_duration self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):

View File

@ -35,6 +35,10 @@ logger = logging.getLogger(__name__)
class RoomCreateRestServlet(ClientV1RestServlet): class RoomCreateRestServlet(ClientV1RestServlet):
# No PATTERN; we have custom dispatch rules here # No PATTERN; we have custom dispatch rules here
def __init__(self, hs):
super(RoomCreateRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server): def register(self, http_server):
PATTERNS = "/createRoom" PATTERNS = "/createRoom"
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
@ -82,6 +86,10 @@ class RoomCreateRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for generic events # TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(ClientV1RestServlet): class RoomStateEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomStateEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server): def register(self, http_server):
# /room/$roomid/state/$eventtype # /room/$roomid/state/$eventtype
no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$" no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$"
@ -166,6 +174,10 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for generic events + feedback # TODO: Needs unit testing for generic events + feedback
class RoomSendEventRestServlet(ClientV1RestServlet): class RoomSendEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomSendEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server): def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id] # /rooms/$roomid/send/$event_type[/$txn_id]
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)") PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)")
@ -210,6 +222,9 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for room ID + alias joins # TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(ClientV1RestServlet): class JoinRoomAliasServlet(ClientV1RestServlet):
def __init__(self, hs):
super(JoinRoomAliasServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server): def register(self, http_server):
# /join/$room_identifier[/$txn_id] # /join/$room_identifier[/$txn_id]
@ -253,6 +268,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
action="join", action="join",
txn_id=txn_id, txn_id=txn_id,
remote_room_hosts=remote_room_hosts, remote_room_hosts=remote_room_hosts,
content=content,
third_party_signed=content.get("third_party_signed", None), third_party_signed=content.get("third_party_signed", None),
) )
@ -296,6 +312,10 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
class RoomMemberListRestServlet(ClientV1RestServlet): class RoomMemberListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$") PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$")
def __init__(self, hs):
super(RoomMemberListRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens) # TODO support Pagination stream API (limit/tokens)
@ -322,6 +342,10 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
class RoomMessageListRestServlet(ClientV1RestServlet): class RoomMessageListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$") PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$")
def __init__(self, hs):
super(RoomMessageListRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
@ -351,6 +375,10 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
class RoomStateRestServlet(ClientV1RestServlet): class RoomStateRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$") PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$")
def __init__(self, hs):
super(RoomStateRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
@ -368,6 +396,10 @@ class RoomStateRestServlet(ClientV1RestServlet):
class RoomInitialSyncRestServlet(ClientV1RestServlet): class RoomInitialSyncRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$") PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$")
def __init__(self, hs):
super(RoomInitialSyncRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
@ -388,6 +420,7 @@ class RoomEventContext(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomEventContext, self).__init__(hs) super(RoomEventContext, self).__init__(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, event_id): def on_GET(self, request, room_id, event_id):
@ -424,6 +457,10 @@ class RoomEventContext(ClientV1RestServlet):
class RoomForgetRestServlet(ClientV1RestServlet): class RoomForgetRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomForgetRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server): def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget") PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
@ -462,6 +499,10 @@ class RoomForgetRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class RoomMembershipRestServlet(ClientV1RestServlet): class RoomMembershipRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomMembershipRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server): def register(self, http_server):
# /rooms/$roomid/[invite|join|leave] # /rooms/$roomid/[invite|join|leave]
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/" PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
@ -542,6 +583,10 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
class RoomRedactEventRestServlet(ClientV1RestServlet): class RoomRedactEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomRedactEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server): def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)") PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
@ -624,6 +669,10 @@ class SearchRestServlet(ClientV1RestServlet):
"/search$" "/search$"
) )
def __init__(self, hs):
super(SearchRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)

View File

@ -0,0 +1,99 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.http.servlet import (
RestServlet, parse_string, parse_integer
)
from synapse.events.utils import (
serialize_event, format_event_for_client_v2_without_room_id,
)
from ._base import client_v2_patterns
import logging
logger = logging.getLogger(__name__)
class NotificationsServlet(RestServlet):
PATTERNS = client_v2_patterns("/notifications$", releases=())
def __init__(self, hs):
super(NotificationsServlet, self).__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
from_token = parse_string(request, "from", required=False)
limit = parse_integer(request, "limit", default=50)
limit = min(limit, 500)
push_actions = yield self.store.get_push_actions_for_user(
user_id, from_token, limit
)
receipts_by_room = yield self.store.get_receipts_for_user_with_orderings(
user_id, 'm.read'
)
notif_event_ids = [pa["event_id"] for pa in push_actions]
notif_events = yield self.store.get_events(notif_event_ids)
returned_push_actions = []
next_token = None
for pa in push_actions:
returned_pa = {
"room_id": pa["room_id"],
"profile_tag": pa["profile_tag"],
"actions": pa["actions"],
"ts": pa["received_ts"],
"event": serialize_event(
notif_events[pa["event_id"]],
self.clock.time_msec(),
event_format=format_event_for_client_v2_without_room_id,
),
}
if pa["room_id"] not in receipts_by_room:
returned_pa["read"] = False
else:
receipt = receipts_by_room[pa["room_id"]]
returned_pa["read"] = (
receipt["topological_ordering"], receipt["stream_ordering"]
) >= (
pa["topological_ordering"], pa["stream_ordering"]
)
returned_push_actions.append(returned_pa)
next_token = pa["stream_ordering"]
defer.returnValue((200, {
"notifications": returned_push_actions,
"next_token": next_token,
}))
def register_servlets(hs, http_server):
NotificationsServlet(hs).register(http_server)

View File

@ -403,10 +403,9 @@ class RegisterRestServlet(RestServlet):
# register the user's device # register the user's device
device_id = params.get("device_id") device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name") initial_display_name = params.get("initial_device_display_name")
device_id = self.device_handler.check_device_registered( return self.device_handler.check_device_registered(
user_id, device_id, initial_display_name user_id, device_id, initial_display_name
) )
return device_id
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_guest_registration(self): def _do_guest_registration(self):

View File

@ -146,7 +146,7 @@ class SyncRestServlet(RestServlet):
affect_presence = set_presence != PresenceState.OFFLINE affect_presence = set_presence != PresenceState.OFFLINE
if affect_presence: if affect_presence:
yield self.presence_handler.set_state(user, {"presence": set_presence}) yield self.presence_handler.set_state(user, {"presence": set_presence}, True)
context = yield self.presence_handler.user_syncing( context = yield self.presence_handler.user_syncing(
user.to_string(), affect_presence=affect_presence, user.to_string(), affect_presence=affect_presence,

View File

@ -0,0 +1,78 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.http.servlet import RestServlet
from synapse.types import ThirdPartyEntityKind
from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
class ThirdPartyUserServlet(RestServlet):
PATTERNS = client_v2_patterns("/3pu(/(?P<protocol>[^/]+))?$",
releases=())
def __init__(self, hs):
super(ThirdPartyUserServlet, self).__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
@defer.inlineCallbacks
def on_GET(self, request, protocol):
yield self.auth.get_user_by_req(request)
fields = request.args
del fields["access_token"]
results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.USER, protocol, fields
)
defer.returnValue((200, results))
class ThirdPartyLocationServlet(RestServlet):
PATTERNS = client_v2_patterns("/3pl(/(?P<protocol>[^/]+))?$",
releases=())
def __init__(self, hs):
super(ThirdPartyLocationServlet, self).__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
@defer.inlineCallbacks
def on_GET(self, request, protocol):
yield self.auth.get_user_by_req(request)
fields = request.args
del fields["access_token"]
results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.LOCATION, protocol, fields
)
defer.returnValue((200, results))
def register_servlets(hs, http_server):
ThirdPartyUserServlet(hs).register(http_server)
ThirdPartyLocationServlet(hs).register(http_server)

View File

@ -15,6 +15,7 @@
from synapse.http.server import request_handler, respond_with_json_bytes from synapse.http.server import request_handler, respond_with_json_bytes
from synapse.http.servlet import parse_integer, parse_json_object_from_request from synapse.http.servlet import parse_integer, parse_json_object_from_request
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.crypto.keyring import KeyLookupError
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
@ -210,9 +211,10 @@ class RemoteKey(Resource):
yield self.keyring.get_server_verify_key_v2_direct( yield self.keyring.get_server_verify_key_v2_direct(
server_name, key_ids server_name, key_ids
) )
except KeyLookupError as e:
logger.info("Failed to fetch key: %s", e)
except: except:
logger.exception("Failed to get key for %r", server_name) logger.exception("Failed to get key for %r", server_name)
pass
yield self.query_keys( yield self.query_keys(
request, query, query_remote_on_cache_miss=False request, query, query_remote_on_cache_miss=False
) )

View File

@ -45,6 +45,7 @@ class DownloadResource(Resource):
@request_handler() @request_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render_GET(self, request): def _async_render_GET(self, request):
request.setHeader("Content-Security-Policy", "sandbox")
server_name, media_id, name = parse_media_id(request) server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name: if server_name == self.server_name:
yield self._respond_local_file(request, media_id, name) yield self._respond_local_file(request, media_id, name)

View File

@ -29,14 +29,13 @@ from synapse.http.server import (
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.stringutils import is_ascii from synapse.util.stringutils import is_ascii
from copy import deepcopy
import os import os
import re import re
import fnmatch import fnmatch
import cgi import cgi
import ujson as json import ujson as json
import urlparse import urlparse
import itertools
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -163,7 +162,7 @@ class PreviewUrlResource(Resource):
logger.debug("got media_info of '%s'" % media_info) logger.debug("got media_info of '%s'" % media_info)
if self._is_media(media_info['media_type']): if _is_media(media_info['media_type']):
dims = yield self.media_repo._generate_local_thumbnails( dims = yield self.media_repo._generate_local_thumbnails(
media_info['filesystem_id'], media_info media_info['filesystem_id'], media_info
) )
@ -184,11 +183,9 @@ class PreviewUrlResource(Resource):
logger.warn("Couldn't get dims for %s" % url) logger.warn("Couldn't get dims for %s" % url)
# define our OG response for this media # define our OG response for this media
elif self._is_html(media_info['media_type']): elif _is_html(media_info['media_type']):
# TODO: somehow stop a big HTML tree from exploding synapse's RAM # TODO: somehow stop a big HTML tree from exploding synapse's RAM
from lxml import etree
file = open(media_info['filename']) file = open(media_info['filename'])
body = file.read() body = file.read()
file.close() file.close()
@ -199,17 +196,35 @@ class PreviewUrlResource(Resource):
match = re.match(r'.*; *charset=(.*?)(;|$)', media_info['media_type'], re.I) match = re.match(r'.*; *charset=(.*?)(;|$)', media_info['media_type'], re.I)
encoding = match.group(1) if match else "utf-8" encoding = match.group(1) if match else "utf-8"
try: og = decode_and_calc_og(body, media_info['uri'], encoding)
parser = etree.HTMLParser(recover=True, encoding=encoding)
tree = etree.fromstring(body, parser)
og = yield self._calc_og(tree, media_info, requester)
except UnicodeDecodeError:
# blindly try decoding the body as utf-8, which seems to fix
# the charset mismatches on https://google.com
parser = etree.HTMLParser(recover=True, encoding=encoding)
tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser)
og = yield self._calc_og(tree, media_info, requester)
# pre-cache the image for posterity
# FIXME: it might be cleaner to use the same flow as the main /preview_url
# request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up.
if 'og:image' in og and og['og:image']:
image_info = yield self._download_url(
_rebase_url(og['og:image'], media_info['uri']), requester.user
)
if _is_media(image_info['media_type']):
# TODO: make sure we don't choke on white-on-transparent images
dims = yield self.media_repo._generate_local_thumbnails(
image_info['filesystem_id'], image_info
)
if dims:
og["og:image:width"] = dims['width']
og["og:image:height"] = dims['height']
else:
logger.warn("Couldn't get dims for %s" % og["og:image"])
og["og:image"] = "mxc://%s/%s" % (
self.server_name, image_info['filesystem_id']
)
og["og:image:type"] = image_info['media_type']
og["matrix:image:size"] = image_info['media_length']
else:
del og["og:image"]
else: else:
logger.warn("Failed to find any OG data in %s", url) logger.warn("Failed to find any OG data in %s", url)
og = {} og = {}
@ -232,139 +247,6 @@ class PreviewUrlResource(Resource):
respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True) respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
@defer.inlineCallbacks
def _calc_og(self, tree, media_info, requester):
# suck our tree into lxml and define our OG response.
# if we see any image URLs in the OG response, then spider them
# (although the client could choose to do this by asking for previews of those
# URLs to avoid DoSing the server)
# "og:type" : "video",
# "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
# "og:site_name" : "YouTube",
# "og:video:type" : "application/x-shockwave-flash",
# "og:description" : "Fun stuff happening here",
# "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
# "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
# "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
# "og:video:width" : "1280"
# "og:video:height" : "720",
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
og = {}
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
if 'content' in tag.attrib:
og[tag.attrib['property']] = tag.attrib['content']
# TODO: grab article: meta tags too, e.g.:
# "article:publisher" : "https://www.facebook.com/thethudonline" />
# "article:author" content="https://www.facebook.com/thethudonline" />
# "article:tag" content="baby" />
# "article:section" content="Breaking News" />
# "article:published_time" content="2016-03-31T19:58:24+00:00" />
# "article:modified_time" content="2016-04-01T18:31:53+00:00" />
if 'og:title' not in og:
# do some basic spidering of the HTML
title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
og['og:title'] = title[0].text.strip() if title else None
if 'og:image' not in og:
# TODO: extract a favicon failing all else
meta_image = tree.xpath(
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
)
if meta_image:
og['og:image'] = self._rebase_url(meta_image[0], media_info['uri'])
else:
# TODO: consider inlined CSS styles as well as width & height attribs
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
images = sorted(images, key=lambda i: (
-1 * float(i.attrib['width']) * float(i.attrib['height'])
))
if not images:
images = tree.xpath("//img[@src]")
if images:
og['og:image'] = images[0].attrib['src']
# pre-cache the image for posterity
# FIXME: it might be cleaner to use the same flow as the main /preview_url
# request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up.
if 'og:image' in og and og['og:image']:
image_info = yield self._download_url(
self._rebase_url(og['og:image'], media_info['uri']), requester.user
)
if self._is_media(image_info['media_type']):
# TODO: make sure we don't choke on white-on-transparent images
dims = yield self.media_repo._generate_local_thumbnails(
image_info['filesystem_id'], image_info
)
if dims:
og["og:image:width"] = dims['width']
og["og:image:height"] = dims['height']
else:
logger.warn("Couldn't get dims for %s" % og["og:image"])
og["og:image"] = "mxc://%s/%s" % (
self.server_name, image_info['filesystem_id']
)
og["og:image:type"] = image_info['media_type']
og["matrix:image:size"] = image_info['media_length']
else:
del og["og:image"]
if 'og:description' not in og:
meta_description = tree.xpath(
"//*/meta"
"[translate(@name, 'DESCRIPTION', 'description')='description']"
"/@content")
if meta_description:
og['og:description'] = meta_description[0]
else:
# grab any text nodes which are inside the <body/> tag...
# unless they are within an HTML5 semantic markup tag...
# <header/>, <nav/>, <aside/>, <footer/>
# ...or if they are within a <script/> or <style/> tag.
# This is a very very very coarse approximation to a plain text
# render of the page.
# We don't just use XPATH here as that is slow on some machines.
# We clone `tree` as we modify it.
cloned_tree = deepcopy(tree.find("body"))
TAGS_TO_REMOVE = ("header", "nav", "aside", "footer", "script", "style",)
for el in cloned_tree.iter(TAGS_TO_REMOVE):
el.getparent().remove(el)
# Split all the text nodes into paragraphs (by splitting on new
# lines)
text_nodes = (
re.sub(r'\s+', '\n', el.text).strip()
for el in cloned_tree.iter()
if el.text and isinstance(el.tag, basestring) # Removes comments
)
og['og:description'] = summarize_paragraphs(text_nodes)
# TODO: delete the url downloads to stop diskfilling,
# as we only ever cared about its OG
defer.returnValue(og)
def _rebase_url(self, url, base):
base = list(urlparse.urlparse(base))
url = list(urlparse.urlparse(url))
if not url[0]: # fix up schema
url[0] = base[0] or "http"
if not url[1]: # fix up hostname
url[1] = base[1]
if not url[2].startswith('/'):
url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2]
return urlparse.urlunparse(url)
@defer.inlineCallbacks @defer.inlineCallbacks
def _download_url(self, url, user): def _download_url(self, url, user):
# TODO: we should probably honour robots.txt... except in practice # TODO: we should probably honour robots.txt... except in practice
@ -445,17 +327,171 @@ class PreviewUrlResource(Resource):
"etag": headers["ETag"][0] if "ETag" in headers else None, "etag": headers["ETag"][0] if "ETag" in headers else None,
}) })
def _is_media(self, content_type):
if content_type.lower().startswith("image/"):
return True
def _is_html(self, content_type): def decode_and_calc_og(body, media_uri, request_encoding=None):
content_type = content_type.lower() from lxml import etree
if (
content_type.startswith("text/html") or try:
content_type.startswith("application/xhtml") parser = etree.HTMLParser(recover=True, encoding=request_encoding)
): tree = etree.fromstring(body, parser)
return True og = _calc_og(tree, media_uri)
except UnicodeDecodeError:
# blindly try decoding the body as utf-8, which seems to fix
# the charset mismatches on https://google.com
parser = etree.HTMLParser(recover=True, encoding=request_encoding)
tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser)
og = _calc_og(tree, media_uri)
return og
def _calc_og(tree, media_uri):
# suck our tree into lxml and define our OG response.
# if we see any image URLs in the OG response, then spider them
# (although the client could choose to do this by asking for previews of those
# URLs to avoid DoSing the server)
# "og:type" : "video",
# "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
# "og:site_name" : "YouTube",
# "og:video:type" : "application/x-shockwave-flash",
# "og:description" : "Fun stuff happening here",
# "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
# "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
# "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
# "og:video:width" : "1280"
# "og:video:height" : "720",
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
og = {}
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
if 'content' in tag.attrib:
og[tag.attrib['property']] = tag.attrib['content']
# TODO: grab article: meta tags too, e.g.:
# "article:publisher" : "https://www.facebook.com/thethudonline" />
# "article:author" content="https://www.facebook.com/thethudonline" />
# "article:tag" content="baby" />
# "article:section" content="Breaking News" />
# "article:published_time" content="2016-03-31T19:58:24+00:00" />
# "article:modified_time" content="2016-04-01T18:31:53+00:00" />
if 'og:title' not in og:
# do some basic spidering of the HTML
title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
og['og:title'] = title[0].text.strip() if title else None
if 'og:image' not in og:
# TODO: extract a favicon failing all else
meta_image = tree.xpath(
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
)
if meta_image:
og['og:image'] = _rebase_url(meta_image[0], media_uri)
else:
# TODO: consider inlined CSS styles as well as width & height attribs
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
images = sorted(images, key=lambda i: (
-1 * float(i.attrib['width']) * float(i.attrib['height'])
))
if not images:
images = tree.xpath("//img[@src]")
if images:
og['og:image'] = images[0].attrib['src']
if 'og:description' not in og:
meta_description = tree.xpath(
"//*/meta"
"[translate(@name, 'DESCRIPTION', 'description')='description']"
"/@content")
if meta_description:
og['og:description'] = meta_description[0]
else:
# grab any text nodes which are inside the <body/> tag...
# unless they are within an HTML5 semantic markup tag...
# <header/>, <nav/>, <aside/>, <footer/>
# ...or if they are within a <script/> or <style/> tag.
# This is a very very very coarse approximation to a plain text
# render of the page.
# We don't just use XPATH here as that is slow on some machines.
from lxml import etree
TAGS_TO_REMOVE = (
"header", "nav", "aside", "footer", "script", "style", etree.Comment
)
# Split all the text nodes into paragraphs (by splitting on new
# lines)
text_nodes = (
re.sub(r'\s+', '\n', el).strip()
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
)
og['og:description'] = summarize_paragraphs(text_nodes)
# TODO: delete the url downloads to stop diskfilling,
# as we only ever cared about its OG
return og
def _iterate_over_text(tree, *tags_to_ignore):
"""Iterate over the tree returning text nodes in a depth first fashion,
skipping text nodes inside certain tags.
"""
# This is basically a stack that we extend using itertools.chain.
# This will either consist of an element to iterate over *or* a string
# to be returned.
elements = iter([tree])
while True:
el = elements.next()
if isinstance(el, basestring):
yield el
elif el is not None and el.tag not in tags_to_ignore:
# el.text is the text before the first child, so we can immediately
# return it if the text exists.
if el.text:
yield el.text
# We add to the stack all the elements children, interspersed with
# each child's tail text (if it exists). The tail text of a node
# is text that comes *after* the node, so we always include it even
# if we ignore the child node.
elements = itertools.chain(
itertools.chain.from_iterable( # Basically a flatmap
[child, child.tail] if child.tail else [child]
for child in el.iterchildren()
),
elements
)
def _rebase_url(url, base):
base = list(urlparse.urlparse(base))
url = list(urlparse.urlparse(url))
if not url[0]: # fix up schema
url[0] = base[0] or "http"
if not url[1]: # fix up hostname
url[1] = base[1]
if not url[2].startswith('/'):
url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2]
return urlparse.urlunparse(url)
def _is_media(content_type):
if content_type.lower().startswith("image/"):
return True
def _is_html(content_type):
content_type = content_type.lower()
if (
content_type.startswith("text/html") or
content_type.startswith("application/xhtml")
):
return True
def summarize_paragraphs(text_nodes, min_size=200, max_size=500): def summarize_paragraphs(text_nodes, min_size=200, max_size=500):

View File

@ -41,6 +41,7 @@ from synapse.handlers.presence import PresenceHandler
from synapse.handlers.room import RoomListHandler from synapse.handlers.room import RoomListHandler
from synapse.handlers.sync import SyncHandler from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler from synapse.handlers.typing import TypingHandler
from synapse.handlers.events import EventHandler, EventStreamHandler
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.notifier import Notifier from synapse.notifier import Notifier
@ -94,6 +95,8 @@ class HomeServer(object):
'auth_handler', 'auth_handler',
'device_handler', 'device_handler',
'e2e_keys_handler', 'e2e_keys_handler',
'event_handler',
'event_stream_handler',
'application_service_api', 'application_service_api',
'application_service_scheduler', 'application_service_scheduler',
'application_service_handler', 'application_service_handler',
@ -214,6 +217,12 @@ class HomeServer(object):
def build_application_service_handler(self): def build_application_service_handler(self):
return ApplicationServicesHandler(self) return ApplicationServicesHandler(self)
def build_event_handler(self):
return EventHandler(self)
def build_event_stream_handler(self):
return EventStreamHandler(self)
def build_event_sources(self): def build_event_sources(self):
return EventSources(self) return EventSources(self)

View File

@ -1,3 +1,4 @@
import synapse.api.auth
import synapse.handlers import synapse.handlers
import synapse.handlers.auth import synapse.handlers.auth
import synapse.handlers.device import synapse.handlers.device
@ -6,6 +7,9 @@ import synapse.storage
import synapse.state import synapse.state
class HomeServer(object): class HomeServer(object):
def get_auth(self) -> synapse.api.auth.Auth:
pass
def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler: def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler:
pass pass

View File

@ -50,6 +50,7 @@ from .openid import OpenIdStore
from .client_ips import ClientIpStore from .client_ips import ClientIpStore
from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
from .engines import PostgresEngine
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -123,6 +124,13 @@ class DataStore(RoomMemberStore, RoomStore,
extra_tables=[("deleted_pushers", "stream_id")], extra_tables=[("deleted_pushers", "stream_id")],
) )
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = StreamIdGenerator(
db_conn, "cache_invalidation_stream", "stream_id",
)
else:
self._cache_id_gen = None
events_max = self._stream_id_gen.get_current_token() events_max = self._stream_id_gen.get_current_token()
event_cache_prefill, min_event_val = self._get_cache_dict( event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events", db_conn, "events",

View File

@ -19,6 +19,7 @@ from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache from synapse.util.caches.descriptors import Cache
from synapse.util.caches import intern_dict from synapse.util.caches import intern_dict
from synapse.storage.engines import PostgresEngine
import synapse.metrics import synapse.metrics
@ -165,7 +166,7 @@ class SQLBaseStore(object):
self._txn_perf_counters = PerformanceCounters() self._txn_perf_counters = PerformanceCounters()
self._get_event_counters = PerformanceCounters() self._get_event_counters = PerformanceCounters()
self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, self._get_event_cache = Cache("*getEvent*", keylen=3,
max_entries=hs.config.event_cache_size) max_entries=hs.config.event_cache_size)
self._state_group_cache = DictionaryCache( self._state_group_cache = DictionaryCache(
@ -305,13 +306,14 @@ class SQLBaseStore(object):
func, *args, **kwargs func, *args, **kwargs
) )
with PreserveLoggingContext(): try:
result = yield self._db_pool.runWithConnection( with PreserveLoggingContext():
inner_func, *args, **kwargs result = yield self._db_pool.runWithConnection(
) inner_func, *args, **kwargs
)
for after_callback, after_args in after_callbacks: finally:
after_callback(*after_args) for after_callback, after_args in after_callbacks:
after_callback(*after_args)
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -860,6 +862,62 @@ class SQLBaseStore(object):
return cache, min_val return cache, min_val
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
This should only be used to invalidate caches where slaves won't
otherwise know from other replication streams that the cache should
be invalidated.
"""
txn.call_after(cache_func.invalidate, keys)
if isinstance(self.database_engine, PostgresEngine):
# get_next() returns a context manager which is designed to wrap
# the transaction. However, we want to only get an ID when we want
# to use it, here, so we need to call __enter__ manually, and have
# __exit__ called after the transaction finishes.
ctx = self._cache_id_gen.get_next()
stream_id = ctx.__enter__()
txn.call_after(ctx.__exit__, None, None, None)
txn.call_after(self.hs.get_notifier().on_new_replication_data)
self._simple_insert_txn(
txn,
table="cache_invalidation_stream",
values={
"stream_id": stream_id,
"cache_func": cache_func.__name__,
"keys": list(keys),
"invalidation_ts": self.clock.time_msec(),
}
)
def get_all_updated_caches(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed([])
def get_all_updated_caches_txn(txn):
# We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine.
sql = (
"SELECT stream_id, cache_func, keys, invalidation_ts"
" FROM cache_invalidation_stream"
" WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, limit,))
return txn.fetchall()
return self.runInteraction(
"get_all_updated_caches", get_all_updated_caches_txn
)
def get_cache_stream_token(self):
if self._cache_id_gen:
return self._cache_id_gen.get_current_token()
else:
return 0
class _RollbackButIsFineException(Exception): class _RollbackButIsFineException(Exception):
""" This exception is used to rollback a transaction without implying """ This exception is used to rollback a transaction without implying

View File

@ -218,38 +218,37 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
Returns: Returns:
AppServiceTransaction: A new transaction. AppServiceTransaction: A new transaction.
""" """
def _create_appservice_txn(txn):
# work out new txn id (highest txn id for this service += 1)
# The highest id may be the last one sent (in which case it is last_txn)
# or it may be the highest in the txns list (which are waiting to be/are
# being sent)
last_txn_id = self._get_last_txn(txn, service.id)
txn.execute(
"SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?",
(service.id,)
)
highest_txn_id = txn.fetchone()[0]
if highest_txn_id is None:
highest_txn_id = 0
new_txn_id = max(highest_txn_id, last_txn_id) + 1
# Insert new txn into txn table
event_ids = json.dumps([e.event_id for e in events])
txn.execute(
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
"VALUES(?,?,?)",
(service.id, new_txn_id, event_ids)
)
return AppServiceTransaction(
service=service, id=new_txn_id, events=events
)
return self.runInteraction( return self.runInteraction(
"create_appservice_txn", "create_appservice_txn",
self._create_appservice_txn, _create_appservice_txn,
service, events
)
def _create_appservice_txn(self, txn, service, events):
# work out new txn id (highest txn id for this service += 1)
# The highest id may be the last one sent (in which case it is last_txn)
# or it may be the highest in the txns list (which are waiting to be/are
# being sent)
last_txn_id = self._get_last_txn(txn, service.id)
txn.execute(
"SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?",
(service.id,)
)
highest_txn_id = txn.fetchone()[0]
if highest_txn_id is None:
highest_txn_id = 0
new_txn_id = max(highest_txn_id, last_txn_id) + 1
# Insert new txn into txn table
event_ids = json.dumps([e.event_id for e in events])
txn.execute(
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
"VALUES(?,?,?)",
(service.id, new_txn_id, event_ids)
)
return AppServiceTransaction(
service=service, id=new_txn_id, events=events
) )
def complete_appservice_txn(self, txn_id, service): def complete_appservice_txn(self, txn_id, service):
@ -263,39 +262,38 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
A Deferred which resolves if this transaction was stored A Deferred which resolves if this transaction was stored
successfully. successfully.
""" """
return self.runInteraction(
"complete_appservice_txn",
self._complete_appservice_txn,
txn_id, service
)
def _complete_appservice_txn(self, txn, txn_id, service):
txn_id = int(txn_id) txn_id = int(txn_id)
# Debugging query: Make sure the txn being completed is EXACTLY +1 from def _complete_appservice_txn(txn):
# what was there before. If it isn't, we've got problems (e.g. the AS # Debugging query: Make sure the txn being completed is EXACTLY +1 from
# has probably missed some events), so whine loudly but still continue, # what was there before. If it isn't, we've got problems (e.g. the AS
# since it shouldn't fail completion of the transaction. # has probably missed some events), so whine loudly but still continue,
last_txn_id = self._get_last_txn(txn, service.id) # since it shouldn't fail completion of the transaction.
if (last_txn_id + 1) != txn_id: last_txn_id = self._get_last_txn(txn, service.id)
logger.error( if (last_txn_id + 1) != txn_id:
"appservice: Completing a transaction which has an ID > 1 from " logger.error(
"the last ID sent to this AS. We've either dropped events or " "appservice: Completing a transaction which has an ID > 1 from "
"sent it to the AS out of order. FIX ME. last_txn=%s " "the last ID sent to this AS. We've either dropped events or "
"completing_txn=%s service_id=%s", last_txn_id, txn_id, "sent it to the AS out of order. FIX ME. last_txn=%s "
service.id "completing_txn=%s service_id=%s", last_txn_id, txn_id,
service.id
)
# Set current txn_id for AS to 'txn_id'
self._simple_upsert_txn(
txn, "application_services_state", dict(as_id=service.id),
dict(last_txn=txn_id)
) )
# Set current txn_id for AS to 'txn_id' # Delete txn
self._simple_upsert_txn( self._simple_delete_txn(
txn, "application_services_state", dict(as_id=service.id), txn, "application_services_txns",
dict(last_txn=txn_id) dict(txn_id=txn_id, as_id=service.id)
) )
# Delete txn return self.runInteraction(
self._simple_delete_txn( "complete_appservice_txn",
txn, "application_services_txns", _complete_appservice_txn,
dict(txn_id=txn_id, as_id=service.id)
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -309,10 +307,25 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
A Deferred which resolves to an AppServiceTransaction or A Deferred which resolves to an AppServiceTransaction or
None. None.
""" """
def _get_oldest_unsent_txn(txn):
# Monotonically increasing txn ids, so just select the smallest
# one in the txns table (we delete them when they are sent)
txn.execute(
"SELECT * FROM application_services_txns WHERE as_id=?"
" ORDER BY txn_id ASC LIMIT 1",
(service.id,)
)
rows = self.cursor_to_dict(txn)
if not rows:
return None
entry = rows[0]
return entry
entry = yield self.runInteraction( entry = yield self.runInteraction(
"get_oldest_unsent_appservice_txn", "get_oldest_unsent_appservice_txn",
self._get_oldest_unsent_txn, _get_oldest_unsent_txn,
service
) )
if not entry: if not entry:
@ -326,22 +339,6 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
service=service, id=entry["txn_id"], events=events service=service, id=entry["txn_id"], events=events
)) ))
def _get_oldest_unsent_txn(self, txn, service):
# Monotonically increasing txn ids, so just select the smallest
# one in the txns table (we delete them when they are sent)
txn.execute(
"SELECT * FROM application_services_txns WHERE as_id=?"
" ORDER BY txn_id ASC LIMIT 1",
(service.id,)
)
rows = self.cursor_to_dict(txn)
if not rows:
return None
entry = rows[0]
return entry
def _get_last_txn(self, txn, service_id): def _get_last_txn(self, txn, service_id):
txn.execute( txn.execute(
"SELECT last_txn FROM application_services_state WHERE as_id=?", "SELECT last_txn FROM application_services_state WHERE as_id=?",
@ -352,3 +349,45 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
return 0 return 0
else: else:
return int(last_txn_id[0]) # select 'last_txn' col return int(last_txn_id[0]) # select 'last_txn' col
def set_appservice_last_pos(self, pos):
def set_appservice_last_pos_txn(txn):
txn.execute(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
)
return self.runInteraction(
"set_appservice_last_pos", set_appservice_last_pos_txn
)
@defer.inlineCallbacks
def get_new_events_for_appservice(self, current_id, limit):
"""Get all new evnets"""
def get_new_events_for_appservice_txn(txn):
sql = (
"SELECT e.stream_ordering, e.event_id"
" FROM events AS e"
" WHERE"
" (SELECT stream_ordering FROM appservice_stream_position)"
" < e.stream_ordering"
" AND e.stream_ordering <= ?"
" ORDER BY e.stream_ordering ASC"
" LIMIT ?"
)
txn.execute(sql, (current_id, limit))
rows = txn.fetchall()
upper_bound = current_id
if len(rows) == limit:
upper_bound = rows[-1][0]
return upper_bound, [row[1] for row in rows]
upper_bound, event_ids = yield self.runInteraction(
"get_new_events_for_appservice", get_new_events_for_appservice_txn,
)
events = yield self._get_events(event_ids)
defer.returnValue((upper_bound, events))

View File

@ -82,32 +82,39 @@ class DirectoryStore(SQLBaseStore):
Returns: Returns:
Deferred Deferred
""" """
try: def alias_txn(txn):
yield self._simple_insert( self._simple_insert_txn(
txn,
"room_aliases", "room_aliases",
{ {
"room_alias": room_alias.to_string(), "room_alias": room_alias.to_string(),
"room_id": room_id, "room_id": room_id,
"creator": creator, "creator": creator,
}, },
desc="create_room_alias_association", )
self._simple_insert_many_txn(
txn,
table="room_alias_servers",
values=[{
"room_alias": room_alias.to_string(),
"server": server,
} for server in servers],
)
self._invalidate_cache_and_stream(
txn, self.get_aliases_for_room, (room_id,)
)
try:
ret = yield self.runInteraction(
"create_room_alias_association", alias_txn
) )
except self.database_engine.module.IntegrityError: except self.database_engine.module.IntegrityError:
raise SynapseError( raise SynapseError(
409, "Room alias %s already exists" % room_alias.to_string() 409, "Room alias %s already exists" % room_alias.to_string()
) )
defer.returnValue(ret)
for server in servers:
# TODO(erikj): Fix this to bulk insert
yield self._simple_insert(
"room_alias_servers",
{
"room_alias": room_alias.to_string(),
"server": server,
},
desc="create_room_alias_association",
)
self.get_aliases_for_room.invalidate((room_id,))
def get_room_alias_creator(self, room_alias): def get_room_alias_creator(self, room_alias):
return self._simple_select_one_onecol( return self._simple_select_one_onecol(

View File

@ -56,7 +56,7 @@ class EventPushActionsStore(SQLBaseStore):
) )
self._simple_insert_many_txn(txn, "event_push_actions", values) self._simple_insert_many_txn(txn, "event_push_actions", values)
@cachedInlineCallbacks(num_args=3, lru=True, tree=True, max_entries=5000) @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
def get_unread_event_push_actions_by_room_for_user( def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id self, room_id, user_id, last_read_event_id
): ):
@ -337,6 +337,36 @@ class EventPushActionsStore(SQLBaseStore):
# Now return the first `limit` # Now return the first `limit`
defer.returnValue(notifs[:limit]) defer.returnValue(notifs[:limit])
@defer.inlineCallbacks
def get_push_actions_for_user(self, user_id, before=None, limit=50):
def f(txn):
before_clause = ""
if before:
before_clause = "AND stream_ordering < ?"
args = [user_id, before, limit]
else:
args = [user_id, limit]
sql = (
"SELECT epa.event_id, epa.room_id,"
" epa.stream_ordering, epa.topological_ordering,"
" epa.actions, epa.profile_tag, e.received_ts"
" FROM event_push_actions epa, events e"
" WHERE epa.room_id = e.room_id AND epa.event_id = e.event_id"
" AND epa.user_id = ? %s"
" ORDER BY epa.stream_ordering DESC"
" LIMIT ?"
% (before_clause,)
)
txn.execute(sql, args)
return self.cursor_to_dict(txn)
push_actions = yield self.runInteraction(
"get_push_actions_for_user", f
)
for pa in push_actions:
pa["actions"] = json.loads(pa["actions"])
defer.returnValue(push_actions)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_time_of_last_push_action_before(self, stream_ordering): def get_time_of_last_push_action_before(self, stream_ordering):
def f(txn): def f(txn):

View File

@ -20,8 +20,11 @@ from synapse.events import FrozenEvent, USE_FROZEN_DICTS
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import preserve_fn, PreserveLoggingContext from synapse.util.logcontext import (
preserve_fn, PreserveLoggingContext, preserve_context_over_deferred
)
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
@ -201,7 +204,7 @@ class EventsStore(SQLBaseStore):
deferreds = [] deferreds = []
for room_id, evs_ctxs in partitioned.items(): for room_id, evs_ctxs in partitioned.items():
d = self._event_persist_queue.add_to_queue( d = preserve_fn(self._event_persist_queue.add_to_queue)(
room_id, evs_ctxs, room_id, evs_ctxs,
backfilled=backfilled, backfilled=backfilled,
current_state=None, current_state=None,
@ -211,7 +214,9 @@ class EventsStore(SQLBaseStore):
for room_id in partitioned.keys(): for room_id in partitioned.keys():
self._maybe_start_persisting(room_id) self._maybe_start_persisting(room_id)
return defer.gatherResults(deferreds, consumeErrors=True) return preserve_context_over_deferred(
defer.gatherResults(deferreds, consumeErrors=True)
)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -224,7 +229,7 @@ class EventsStore(SQLBaseStore):
self._maybe_start_persisting(event.room_id) self._maybe_start_persisting(event.room_id)
yield deferred yield preserve_context_over_deferred(deferred)
max_persisted_id = yield self._stream_id_gen.get_current_token() max_persisted_id = yield self._stream_id_gen.get_current_token()
defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id)) defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id))
@ -600,7 +605,8 @@ class EventsStore(SQLBaseStore):
"rejections", "rejections",
"redactions", "redactions",
"room_memberships", "room_memberships",
"state_events" "state_events",
"topics"
): ):
txn.executemany( txn.executemany(
"DELETE FROM %s WHERE event_id = ?" % (table,), "DELETE FROM %s WHERE event_id = ?" % (table,),
@ -1086,7 +1092,7 @@ class EventsStore(SQLBaseStore):
if not allow_rejected: if not allow_rejected:
rows[:] = [r for r in rows if not r["rejects"]] rows[:] = [r for r in rows if not r["rejects"]]
res = yield defer.gatherResults( res = yield preserve_context_over_deferred(defer.gatherResults(
[ [
preserve_fn(self._get_event_from_row)( preserve_fn(self._get_event_from_row)(
row["internal_metadata"], row["json"], row["redacts"], row["internal_metadata"], row["json"], row["redacts"],
@ -1095,7 +1101,7 @@ class EventsStore(SQLBaseStore):
for row in rows for row in rows
], ],
consumeErrors=True consumeErrors=True
) ))
defer.returnValue({ defer.returnValue({
e.event.event_id: e e.event.event_id: e
@ -1131,54 +1137,55 @@ class EventsStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_event_from_row(self, internal_metadata, js, redacted, def _get_event_from_row(self, internal_metadata, js, redacted,
rejected_reason=None): rejected_reason=None):
d = json.loads(js) with Measure(self._clock, "_get_event_from_row"):
internal_metadata = json.loads(internal_metadata) d = json.loads(js)
internal_metadata = json.loads(internal_metadata)
if rejected_reason: if rejected_reason:
rejected_reason = yield self._simple_select_one_onecol( rejected_reason = yield self._simple_select_one_onecol(
table="rejections", table="rejections",
keyvalues={"event_id": rejected_reason}, keyvalues={"event_id": rejected_reason},
retcol="reason", retcol="reason",
desc="_get_event_from_row_rejected_reason", desc="_get_event_from_row_rejected_reason",
)
original_ev = FrozenEvent(
d,
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
) )
original_ev = FrozenEvent( redacted_event = None
d, if redacted:
internal_metadata_dict=internal_metadata, redacted_event = prune_event(original_ev)
rejected_reason=rejected_reason,
)
redacted_event = None redaction_id = yield self._simple_select_one_onecol(
if redacted: table="redactions",
redacted_event = prune_event(original_ev) keyvalues={"redacts": redacted_event.event_id},
retcol="event_id",
desc="_get_event_from_row_redactions",
)
redaction_id = yield self._simple_select_one_onecol( redacted_event.unsigned["redacted_by"] = redaction_id
table="redactions", # Get the redaction event.
keyvalues={"redacts": redacted_event.event_id},
retcol="event_id", because = yield self.get_event(
desc="_get_event_from_row_redactions", redaction_id,
check_redacted=False,
allow_none=True,
)
if because:
# It's fine to do add the event directly, since get_pdu_json
# will serialise this field correctly
redacted_event.unsigned["redacted_because"] = because
cache_entry = _EventCacheEntry(
event=original_ev,
redacted_event=redacted_event,
) )
redacted_event.unsigned["redacted_by"] = redaction_id self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
# Get the redaction event.
because = yield self.get_event(
redaction_id,
check_redacted=False,
allow_none=True,
)
if because:
# It's fine to do add the event directly, since get_pdu_json
# will serialise this field correctly
redacted_event.unsigned["redacted_because"] = because
cache_entry = _EventCacheEntry(
event=original_ev,
redacted_event=redacted_event,
)
self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
defer.returnValue(cache_entry) defer.returnValue(cache_entry)

View File

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 33 SCHEMA_VERSION = 34
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View File

@ -189,18 +189,30 @@ class PresenceStore(SQLBaseStore):
desc="add_presence_list_pending", desc="add_presence_list_pending",
) )
@defer.inlineCallbacks
def set_presence_list_accepted(self, observer_localpart, observed_userid): def set_presence_list_accepted(self, observer_localpart, observed_userid):
result = yield self._simple_update_one( def update_presence_list_txn(txn):
table="presence_list", result = self._simple_update_one_txn(
keyvalues={"user_id": observer_localpart, txn,
"observed_user_id": observed_userid}, table="presence_list",
updatevalues={"accepted": True}, keyvalues={
desc="set_presence_list_accepted", "user_id": observer_localpart,
"observed_user_id": observed_userid
},
updatevalues={"accepted": True},
)
self._invalidate_cache_and_stream(
txn, self.get_presence_list_accepted, (observer_localpart,)
)
self._invalidate_cache_and_stream(
txn, self.get_presence_list_observers_accepted, (observed_userid,)
)
return result
return self.runInteraction(
"set_presence_list_accepted", update_presence_list_txn,
) )
self.get_presence_list_accepted.invalidate((observer_localpart,))
self.get_presence_list_observers_accepted.invalidate((observed_userid,))
defer.returnValue(result)
def get_presence_list(self, observer_localpart, accepted=None): def get_presence_list(self, observer_localpart, accepted=None):
if accepted: if accepted:

View File

@ -16,6 +16,7 @@
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.push.baserules import list_with_base_rules from synapse.push.baserules import list_with_base_rules
from synapse.api.constants import EventTypes, Membership
from twisted.internet import defer from twisted.internet import defer
import logging import logging
@ -48,7 +49,7 @@ def _load_rules(rawrules, enabled_map):
class PushRuleStore(SQLBaseStore): class PushRuleStore(SQLBaseStore):
@cachedInlineCallbacks(lru=True) @cachedInlineCallbacks()
def get_push_rules_for_user(self, user_id): def get_push_rules_for_user(self, user_id):
rows = yield self._simple_select_list( rows = yield self._simple_select_list(
table="push_rules", table="push_rules",
@ -72,7 +73,7 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(rules) defer.returnValue(rules)
@cachedInlineCallbacks(lru=True) @cachedInlineCallbacks()
def get_push_rules_enabled_for_user(self, user_id): def get_push_rules_enabled_for_user(self, user_id):
results = yield self._simple_select_list( results = yield self._simple_select_list(
table="push_rules_enable", table="push_rules_enable",
@ -123,6 +124,61 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(results) defer.returnValue(results)
def bulk_get_push_rules_for_room(self, room_id, state_group, current_state):
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
# of None don't hit previous cached calls with a None state_group.
# To do this we set the state_group to a new object as object() != object()
state_group = object()
return self._bulk_get_push_rules_for_room(room_id, state_group, current_state)
@cachedInlineCallbacks(num_args=2, cache_context=True)
def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state,
cache_context):
# We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different.
# See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None
# We also will want to generate notifs for other people in the room so
# their unread countss are correct in the event stream, but to avoid
# generating them for bot / AS users etc, we only do so for people who've
# sent a read receipt into the room.
local_users_in_room = set(
e.state_key for e in current_state.values()
if e.type == EventTypes.Member and e.membership == Membership.JOIN
and self.hs.is_mine_id(e.state_key)
)
# users in the room who have pushers need to get push rules run because
# that's how their pushers work
if_users_with_pushers = yield self.get_if_users_have_pushers(
local_users_in_room, on_invalidate=cache_context.invalidate,
)
user_ids = set(
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
)
users_with_receipts = yield self.get_users_with_read_receipts_in_room(
room_id, on_invalidate=cache_context.invalidate,
)
# any users with pushers must be ours: they have pushers
for uid in users_with_receipts:
if uid in local_users_in_room:
user_ids.add(uid)
rules_by_user = yield self.bulk_get_push_rules(
user_ids, on_invalidate=cache_context.invalidate,
)
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
defer.returnValue(rules_by_user)
@cachedList(cached_method_name="get_push_rules_enabled_for_user", @cachedList(cached_method_name="get_push_rules_enabled_for_user",
list_name="user_ids", num_args=1, inlineCallbacks=True) list_name="user_ids", num_args=1, inlineCallbacks=True)
def bulk_get_push_rules_enabled(self, user_ids): def bulk_get_push_rules_enabled(self, user_ids):

View File

@ -135,7 +135,7 @@ class PusherStore(SQLBaseStore):
"get_all_updated_pushers", get_all_updated_pushers_txn "get_all_updated_pushers", get_all_updated_pushers_txn
) )
@cachedInlineCallbacks(lru=True, num_args=1, max_entries=15000) @cachedInlineCallbacks(num_args=1, max_entries=15000)
def get_if_user_has_pusher(self, user_id): def get_if_user_has_pusher(self, user_id):
result = yield self._simple_select_many_batch( result = yield self._simple_select_many_batch(
table='pushers', table='pushers',

View File

@ -94,6 +94,31 @@ class ReceiptsStore(SQLBaseStore):
defer.returnValue({row["room_id"]: row["event_id"] for row in rows}) defer.returnValue({row["room_id"]: row["event_id"] for row in rows})
@defer.inlineCallbacks
def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
def f(txn):
sql = (
"SELECT rl.room_id, rl.event_id,"
" e.topological_ordering, e.stream_ordering"
" FROM receipts_linearized AS rl"
" INNER JOIN events AS e USING (room_id, event_id)"
" WHERE rl.room_id = e.room_id"
" AND rl.event_id = e.event_id"
" AND user_id = ?"
)
txn.execute(sql, (user_id,))
return txn.fetchall()
rows = yield self.runInteraction(
"get_receipts_for_user_with_orderings", f
)
defer.returnValue({
row[0]: {
"event_id": row[1],
"topological_ordering": row[2],
"stream_ordering": row[3],
} for row in rows
})
@defer.inlineCallbacks @defer.inlineCallbacks
def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
"""Get receipts for multiple rooms for sending to clients. """Get receipts for multiple rooms for sending to clients.
@ -120,7 +145,7 @@ class ReceiptsStore(SQLBaseStore):
defer.returnValue([ev for res in results.values() for ev in res]) defer.returnValue([ev for res in results.values() for ev in res])
@cachedInlineCallbacks(num_args=3, max_entries=5000, lru=True, tree=True) @cachedInlineCallbacks(num_args=3, max_entries=5000, tree=True)
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""Get receipts for a single room for sending to clients. """Get receipts for a single room for sending to clients.

View File

@ -93,7 +93,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
desc="add_refresh_token_to_user", desc="add_refresh_token_to_user",
) )
@defer.inlineCallbacks
def register(self, user_id, token=None, password_hash=None, def register(self, user_id, token=None, password_hash=None,
was_guest=False, make_guest=False, appservice_id=None, was_guest=False, make_guest=False, appservice_id=None,
create_profile_with_localpart=None, admin=False): create_profile_with_localpart=None, admin=False):
@ -115,7 +114,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
Raises: Raises:
StoreError if the user_id could not be registered. StoreError if the user_id could not be registered.
""" """
yield self.runInteraction( return self.runInteraction(
"register", "register",
self._register, self._register,
user_id, user_id,
@ -127,8 +126,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
create_profile_with_localpart, create_profile_with_localpart,
admin admin
) )
self.get_user_by_id.invalidate((user_id,))
self.is_guest.invalidate((user_id,))
def _register( def _register(
self, self,
@ -210,6 +207,11 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
(create_profile_with_localpart,) (create_profile_with_localpart,)
) )
self._invalidate_cache_and_stream(
txn, self.get_user_by_id, (user_id,)
)
txn.call_after(self.is_guest.invalidate, (user_id,))
@cached() @cached()
def get_user_by_id(self, user_id): def get_user_by_id(self, user_id):
return self._simple_select_one( return self._simple_select_one(
@ -236,22 +238,31 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
return self.runInteraction("get_users_by_id_case_insensitive", f) return self.runInteraction("get_users_by_id_case_insensitive", f)
@defer.inlineCallbacks
def user_set_password_hash(self, user_id, password_hash): def user_set_password_hash(self, user_id, password_hash):
""" """
NB. This does *not* evict any cache because the one use for this NB. This does *not* evict any cache because the one use for this
removes most of the entries subsequently anyway so it would be removes most of the entries subsequently anyway so it would be
pointless. Use flush_user separately. pointless. Use flush_user separately.
""" """
yield self._simple_update_one('users', { def user_set_password_hash_txn(txn):
'name': user_id self._simple_update_one_txn(
}, { txn,
'password_hash': password_hash 'users', {
}) 'name': user_id
self.get_user_by_id.invalidate((user_id,)) },
{
'password_hash': password_hash
}
)
self._invalidate_cache_and_stream(
txn, self.get_user_by_id, (user_id,)
)
return self.runInteraction(
"user_set_password_hash", user_set_password_hash_txn
)
@defer.inlineCallbacks @defer.inlineCallbacks
def user_delete_access_tokens(self, user_id, except_token_ids=[], def user_delete_access_tokens(self, user_id, except_token_id=None,
device_id=None, device_id=None,
delete_refresh_tokens=False): delete_refresh_tokens=False):
""" """
@ -259,7 +270,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
Args: Args:
user_id (str): ID of user the tokens belong to user_id (str): ID of user the tokens belong to
except_token_ids (list[str]): list of access_tokens which should except_token_id (str): list of access_tokens IDs which should
*not* be deleted *not* be deleted
device_id (str|None): ID of device the tokens are associated with. device_id (str|None): ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will If None, tokens associated with any device (or no device) will
@ -269,53 +280,45 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
Returns: Returns:
defer.Deferred: defer.Deferred:
""" """
def f(txn, table, except_tokens, call_after_delete): def f(txn):
sql = "SELECT token FROM %s WHERE user_id = ?" % table keyvalues = {
clauses = [user_id] "user_id": user_id,
}
if device_id is not None: if device_id is not None:
sql += " AND device_id = ?" keyvalues["device_id"] = device_id
clauses.append(device_id)
if except_tokens: if delete_refresh_tokens:
sql += " AND id NOT IN (%s)" % ( self._simple_delete_txn(
",".join(["?" for _ in except_tokens]), txn,
) table="refresh_tokens",
clauses += except_tokens keyvalues=keyvalues,
txn.execute(sql, clauses)
rows = txn.fetchall()
n = 100
chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)]
for chunk in chunks:
if call_after_delete:
for row in chunk:
txn.call_after(call_after_delete, (row[0],))
txn.execute(
"DELETE FROM %s WHERE token in (%s)" % (
table,
",".join(["?" for _ in chunk]),
), [r[0] for r in chunk]
) )
# delete refresh tokens first, to stop new access tokens being items = keyvalues.items()
# allocated while our backs are turned where_clause = " AND ".join(k + " = ?" for k, _ in items)
if delete_refresh_tokens: values = [v for _, v in items]
yield self.runInteraction( if except_token_id:
"user_delete_access_tokens", f, where_clause += " AND id != ?"
table="refresh_tokens", values.append(except_token_id)
except_tokens=[],
call_after_delete=None, txn.execute(
"SELECT token FROM access_tokens WHERE %s" % where_clause,
values
)
rows = self.cursor_to_dict(txn)
for row in rows:
self._invalidate_cache_and_stream(
txn, self.get_user_by_access_token, (row["token"],)
)
txn.execute(
"DELETE FROM access_tokens WHERE %s" % where_clause,
values
) )
yield self.runInteraction( yield self.runInteraction(
"user_delete_access_tokens", f, "user_delete_access_tokens", f,
table="access_tokens",
except_tokens=except_token_ids,
call_after_delete=self.get_user_by_access_token.invalidate,
) )
def delete_access_token(self, access_token): def delete_access_token(self, access_token):
@ -328,7 +331,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
}, },
) )
txn.call_after(self.get_user_by_access_token.invalidate, (access_token,)) self._invalidate_cache_and_stream(
txn, self.get_user_by_access_token, (access_token,)
)
return self.runInteraction("delete_access_token", f) return self.runInteraction("delete_access_token", f)

View File

@ -277,7 +277,6 @@ class RoomMemberStore(SQLBaseStore):
user_id, membership_list=[Membership.JOIN], user_id, membership_list=[Membership.JOIN],
) )
@defer.inlineCallbacks
def forget(self, user_id, room_id): def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id.""" """Indicate that user_id wishes to discard history for room_id."""
def f(txn): def f(txn):
@ -292,10 +291,13 @@ class RoomMemberStore(SQLBaseStore):
" room_id = ?" " room_id = ?"
) )
txn.execute(sql, (user_id, room_id)) txn.execute(sql, (user_id, room_id))
yield self.runInteraction("forget_membership", f)
self.was_forgotten_at.invalidate_all() txn.call_after(self.was_forgotten_at.invalidate_all)
self.who_forgot_in_room.invalidate_all() txn.call_after(self.did_forget.invalidate, (user_id, room_id))
self.did_forget.invalidate((user_id, room_id)) self._invalidate_cache_and_stream(
txn, self.who_forgot_in_room, (room_id,)
)
return self.runInteraction("forget_membership", f)
@cachedInlineCallbacks(num_args=2) @cachedInlineCallbacks(num_args=2)
def did_forget(self, user_id, room_id): def did_forget(self, user_id, room_id):

View File

@ -0,0 +1,23 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS appservice_stream_position(
Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
stream_ordering BIGINT,
CHECK (Lock='X')
);
INSERT INTO appservice_stream_position (stream_ordering)
SELECT COALESCE(MAX(stream_ordering), 0) FROM events;

View File

@ -0,0 +1,46 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.prepare_database import get_statements
from synapse.storage.engines import PostgresEngine
import logging
logger = logging.getLogger(__name__)
# This stream is used to notify replication slaves that some caches have
# been invalidated that they cannot infer from the other streams.
CREATE_TABLE = """
CREATE TABLE cache_invalidation_stream (
stream_id BIGINT,
cache_func TEXT,
keys TEXT[],
invalidation_ts BIGINT
);
CREATE INDEX cache_invalidation_stream_id ON cache_invalidation_stream(stream_id);
"""
def run_create(cur, database_engine, *args, **kwargs):
if not isinstance(database_engine, PostgresEngine):
return
for statement in get_statements(CREATE_TABLE.splitlines()):
cur.execute(statement)
def run_upgrade(cur, database_engine, *args, **kwargs):
pass

View File

@ -0,0 +1,20 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
DELETE FROM push_rules WHERE rule_id = 'global/override/.m.rule.contains_display_name';
UPDATE push_rules SET rule_id = 'global/override/.m.rule.contains_display_name' WHERE rule_id = 'global/underride/.m.rule.contains_display_name';
DELETE FROM push_rules_enable WHERE rule_id = 'global/override/.m.rule.contains_display_name';
UPDATE push_rules_enable SET rule_id = 'global/override/.m.rule.contains_display_name' WHERE rule_id = 'global/underride/.m.rule.contains_display_name';

View File

@ -0,0 +1,32 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.engines import PostgresEngine
import logging
logger = logging.getLogger(__name__)
def run_create(cur, database_engine, *args, **kwargs):
if isinstance(database_engine, PostgresEngine):
cur.execute("TRUNCATE received_transactions")
else:
cur.execute("DELETE FROM received_transactions")
cur.execute("CREATE INDEX received_transactions_ts ON received_transactions(ts)")
def run_upgrade(cur, database_engine, *args, **kwargs):
pass

View File

@ -25,7 +25,7 @@ from synapse.util.caches.descriptors import cached, cachedList
class SignatureStore(SQLBaseStore): class SignatureStore(SQLBaseStore):
"""Persistence for event signatures and hashes""" """Persistence for event signatures and hashes"""
@cached(lru=True) @cached()
def get_event_reference_hash(self, event_id): def get_event_reference_hash(self, event_id):
return self._get_event_reference_hashes_txn(event_id) return self._get_event_reference_hashes_txn(event_id)

View File

@ -174,7 +174,7 @@ class StateStore(SQLBaseStore):
return [r[0] for r in results] return [r[0] for r in results]
return self.runInteraction("get_current_state_for_key", f) return self.runInteraction("get_current_state_for_key", f)
@cached(num_args=2, lru=True, max_entries=1000) @cached(num_args=2, max_entries=1000)
def _get_state_group_from_group(self, group, types): def _get_state_group_from_group(self, group, types):
raise NotImplementedError() raise NotImplementedError()
@ -272,7 +272,7 @@ class StateStore(SQLBaseStore):
state_map = yield self.get_state_for_events([event_id], types) state_map = yield self.get_state_for_events([event_id], types)
defer.returnValue(state_map[event_id]) defer.returnValue(state_map[event_id])
@cached(num_args=2, lru=True, max_entries=10000) @cached(num_args=2, max_entries=10000)
def _get_state_group_for_event(self, room_id, event_id): def _get_state_group_for_event(self, room_id, event_id):
return self._simple_select_one_onecol( return self._simple_select_one_onecol(
table="event_to_state_groups", table="event_to_state_groups",

View File

@ -39,7 +39,7 @@ from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import PostgresEngine, Sqlite3Engine
import logging import logging
@ -234,12 +234,12 @@ class StreamStore(SQLBaseStore):
results = {} results = {}
room_ids = list(room_ids) room_ids = list(room_ids)
for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)): for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
res = yield defer.gatherResults([ res = yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(self.get_room_events_stream_for_room)( preserve_fn(self.get_room_events_stream_for_room)(
room_id, from_key, to_key, limit, order=order, room_id, from_key, to_key, limit, order=order,
) )
for room_id in rm_ids for room_id in rm_ids
]) ]))
results.update(dict(zip(rm_ids, res))) results.update(dict(zip(rm_ids, res)))
defer.returnValue(results) defer.returnValue(results)

View File

@ -62,10 +62,9 @@ class TransactionStore(SQLBaseStore):
self.last_transaction = {} self.last_transaction = {}
reactor.addSystemEventTrigger("before", "shutdown", self._persist_in_mem_txns) reactor.addSystemEventTrigger("before", "shutdown", self._persist_in_mem_txns)
hs.get_clock().looping_call( self._clock.looping_call(self._persist_in_mem_txns, 1000)
self._persist_in_mem_txns,
1000, self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000)
)
def get_received_txn_response(self, transaction_id, origin): def get_received_txn_response(self, transaction_id, origin):
"""For an incoming transaction from a given origin, check if we have """For an incoming transaction from a given origin, check if we have
@ -127,6 +126,7 @@ class TransactionStore(SQLBaseStore):
"origin": origin, "origin": origin,
"response_code": code, "response_code": code,
"response_json": buffer(encode_canonical_json(response_dict)), "response_json": buffer(encode_canonical_json(response_dict)),
"ts": self._clock.time_msec(),
}, },
or_ignore=True, or_ignore=True,
desc="set_received_txn_response", desc="set_received_txn_response",
@ -383,3 +383,12 @@ class TransactionStore(SQLBaseStore):
yield self.runInteraction("_persist_in_mem_txns", f) yield self.runInteraction("_persist_in_mem_txns", f)
except: except:
logger.exception("Failed to persist transactions!") logger.exception("Failed to persist transactions!")
def _cleanup_transactions(self):
now = self._clock.time_msec()
month_ago = now - 30 * 24 * 60 * 60 * 1000
def _cleanup_transactions_txn(txn):
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
return self.runInteraction("_persist_in_mem_txns", _cleanup_transactions_txn)

View File

@ -269,3 +269,10 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
return "t%d-%d" % (self.topological, self.stream) return "t%d-%d" % (self.topological, self.stream)
else: else:
return "s%d" % (self.stream,) return "s%d" % (self.stream,)
# Some arbitrary constants used for internal API enumerations. Don't rely on
# exact values; always pass or compare symbolically
class ThirdPartyEntityKind(object):
USER = 'user'
LOCATION = 'location'

View File

@ -146,10 +146,10 @@ def concurrently_execute(func, args, limit):
except StopIteration: except StopIteration:
pass pass
return defer.gatherResults([ return preserve_context_over_deferred(defer.gatherResults([
preserve_fn(_concurrently_execute_inner)() preserve_fn(_concurrently_execute_inner)()
for _ in xrange(limit) for _ in xrange(limit)
], consumeErrors=True).addErrback(unwrapFirstError) ], consumeErrors=True)).addErrback(unwrapFirstError)
class Linearizer(object): class Linearizer(object):
@ -181,7 +181,8 @@ class Linearizer(object):
self.key_to_defer[key] = new_defer self.key_to_defer[key] = new_defer
if current_defer: if current_defer:
yield preserve_context_over_deferred(current_defer) with PreserveLoggingContext():
yield current_defer
@contextmanager @contextmanager
def _ctx_manager(): def _ctx_manager():
@ -264,7 +265,7 @@ class ReadWriteLock(object):
curr_readers.clear() curr_readers.clear()
self.key_to_current_writer[key] = new_defer self.key_to_current_writer[key] = new_defer
yield defer.gatherResults(to_wait_on) yield preserve_context_over_deferred(defer.gatherResults(to_wait_on))
@contextmanager @contextmanager
def _ctx_manager(): def _ctx_manager():

View File

@ -25,8 +25,7 @@ from synapse.util.logcontext import (
from . import DEBUG_CACHES, register_cache from . import DEBUG_CACHES, register_cache
from twisted.internet import defer from twisted.internet import defer
from collections import namedtuple
from collections import OrderedDict
import os import os
import functools import functools
@ -54,16 +53,11 @@ class Cache(object):
"metrics", "metrics",
) )
def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False): def __init__(self, name, max_entries=1000, keylen=1, tree=False):
if lru: cache_type = TreeCache if tree else dict
cache_type = TreeCache if tree else dict self.cache = LruCache(
self.cache = LruCache( max_size=max_entries, keylen=keylen, cache_type=cache_type
max_size=max_entries, keylen=keylen, cache_type=cache_type )
)
self.max_entries = None
else:
self.cache = OrderedDict()
self.max_entries = max_entries
self.name = name self.name = name
self.keylen = keylen self.keylen = keylen
@ -81,8 +75,8 @@ class Cache(object):
"Cache objects can only be accessed from the main thread" "Cache objects can only be accessed from the main thread"
) )
def get(self, key, default=_CacheSentinel): def get(self, key, default=_CacheSentinel, callback=None):
val = self.cache.get(key, _CacheSentinel) val = self.cache.get(key, _CacheSentinel, callback=callback)
if val is not _CacheSentinel: if val is not _CacheSentinel:
self.metrics.inc_hits() self.metrics.inc_hits()
return val return val
@ -94,19 +88,15 @@ class Cache(object):
else: else:
return default return default
def update(self, sequence, key, value): def update(self, sequence, key, value, callback=None):
self.check_thread() self.check_thread()
if self.sequence == sequence: if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the # Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369) # number that the cache had before the SELECT was started (SYN-369)
self.prefill(key, value) self.prefill(key, value, callback=callback)
def prefill(self, key, value): def prefill(self, key, value, callback=None):
if self.max_entries is not None: self.cache.set(key, value, callback=callback)
while len(self.cache) >= self.max_entries:
self.cache.popitem(last=False)
self.cache[key] = value
def invalidate(self, key): def invalidate(self, key):
self.check_thread() self.check_thread()
@ -151,9 +141,21 @@ class CacheDescriptor(object):
The wrapped function has another additional callable, called "prefill", The wrapped function has another additional callable, called "prefill",
which can be used to insert values into the cache specifically, without which can be used to insert values into the cache specifically, without
calling the calculation function. calling the calculation function.
Cached functions can be "chained" (i.e. a cached function can call other cached
functions and get appropriately invalidated when they called caches are
invalidated) by adding a special "cache_context" argument to the function
and passing that as a kwarg to all caches called. For example::
@cachedInlineCallbacks(cache_context=True)
def foo(self, key, cache_context):
r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate)
r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
defer.returnValue(r1 + r2)
""" """
def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False, def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
inlineCallbacks=False): inlineCallbacks=False, cache_context=False):
max_entries = int(max_entries * CACHE_SIZE_FACTOR) max_entries = int(max_entries * CACHE_SIZE_FACTOR)
self.orig = orig self.orig = orig
@ -165,15 +167,33 @@ class CacheDescriptor(object):
self.max_entries = max_entries self.max_entries = max_entries
self.num_args = num_args self.num_args = num_args
self.lru = lru
self.tree = tree self.tree = tree
self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] all_args = inspect.getargspec(orig)
self.arg_names = all_args.args[1:num_args + 1]
if "cache_context" in all_args.args:
if not cache_context:
raise ValueError(
"Cannot have a 'cache_context' arg without setting"
" cache_context=True"
)
try:
self.arg_names.remove("cache_context")
except ValueError:
pass
elif cache_context:
raise ValueError(
"Cannot have cache_context=True without having an arg"
" named `cache_context`"
)
self.add_cache_context = cache_context
if len(self.arg_names) < self.num_args: if len(self.arg_names) < self.num_args:
raise Exception( raise Exception(
"Not enough explicit positional arguments to key off of for %r." "Not enough explicit positional arguments to key off of for %r."
" (@cached cannot key off of *args or **kwars)" " (@cached cannot key off of *args or **kwargs)"
% (orig.__name__,) % (orig.__name__,)
) )
@ -182,16 +202,29 @@ class CacheDescriptor(object):
name=self.orig.__name__, name=self.orig.__name__,
max_entries=self.max_entries, max_entries=self.max_entries,
keylen=self.num_args, keylen=self.num_args,
lru=self.lru,
tree=self.tree, tree=self.tree,
) )
@functools.wraps(self.orig) @functools.wraps(self.orig)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
# Add temp cache_context so inspect.getcallargs doesn't explode
if self.add_cache_context:
kwargs["cache_context"] = None
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
# Add our own `cache_context` to argument list if the wrapped function
# has asked for one
if self.add_cache_context:
kwargs["cache_context"] = _CacheContext(cache, cache_key)
try: try:
cached_result_d = cache.get(cache_key) cached_result_d = cache.get(cache_key, callback=invalidate_callback)
observer = cached_result_d.observe() observer = cached_result_d.observe()
if DEBUG_CACHES: if DEBUG_CACHES:
@ -228,7 +261,7 @@ class CacheDescriptor(object):
ret.addErrback(onErr) ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True) ret = ObservableDeferred(ret, consumeErrors=True)
cache.update(sequence, cache_key, ret) cache.update(sequence, cache_key, ret, callback=invalidate_callback)
return preserve_context_over_deferred(ret.observe()) return preserve_context_over_deferred(ret.observe())
@ -297,6 +330,10 @@ class CacheListDescriptor(object):
@functools.wraps(self.orig) @functools.wraps(self.orig)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name] list_args = arg_dict[self.list_name]
@ -311,7 +348,7 @@ class CacheListDescriptor(object):
key[self.list_pos] = arg key[self.list_pos] = arg
try: try:
res = cache.get(tuple(key)) res = cache.get(tuple(key), callback=invalidate_callback)
if not res.has_succeeded(): if not res.has_succeeded():
res = res.observe() res = res.observe()
res.addCallback(lambda r, arg: (arg, r), arg) res.addCallback(lambda r, arg: (arg, r), arg)
@ -345,7 +382,10 @@ class CacheListDescriptor(object):
key = list(keyargs) key = list(keyargs)
key[self.list_pos] = arg key[self.list_pos] = arg
cache.update(sequence, tuple(key), observer) cache.update(
sequence, tuple(key), observer,
callback=invalidate_callback
)
def invalidate(f, key): def invalidate(f, key):
cache.invalidate(key) cache.invalidate(key)
@ -376,24 +416,29 @@ class CacheListDescriptor(object):
return wrapped return wrapped
def cached(max_entries=1000, num_args=1, lru=True, tree=False): class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
def invalidate(self):
self.cache.invalidate(self.key)
def cached(max_entries=1000, num_args=1, tree=False, cache_context=False):
return lambda orig: CacheDescriptor( return lambda orig: CacheDescriptor(
orig, orig,
max_entries=max_entries, max_entries=max_entries,
num_args=num_args, num_args=num_args,
lru=lru,
tree=tree, tree=tree,
cache_context=cache_context,
) )
def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False): def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False):
return lambda orig: CacheDescriptor( return lambda orig: CacheDescriptor(
orig, orig,
max_entries=max_entries, max_entries=max_entries,
num_args=num_args, num_args=num_args,
lru=lru,
tree=tree, tree=tree,
inlineCallbacks=True, inlineCallbacks=True,
cache_context=cache_context,
) )

View File

@ -30,13 +30,14 @@ def enumerate_leaves(node, depth):
class _Node(object): class _Node(object):
__slots__ = ["prev_node", "next_node", "key", "value"] __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
def __init__(self, prev_node, next_node, key, value): def __init__(self, prev_node, next_node, key, value, callbacks=set()):
self.prev_node = prev_node self.prev_node = prev_node
self.next_node = next_node self.next_node = next_node
self.key = key self.key = key
self.value = value self.value = value
self.callbacks = callbacks
class LruCache(object): class LruCache(object):
@ -44,6 +45,9 @@ class LruCache(object):
Least-recently-used cache. Least-recently-used cache.
Supports del_multi only if cache_type=TreeCache Supports del_multi only if cache_type=TreeCache
If cache_type=TreeCache, all keys must be tuples. If cache_type=TreeCache, all keys must be tuples.
Can also set callbacks on objects when getting/setting which are fired
when that key gets invalidated/evicted.
""" """
def __init__(self, max_size, keylen=1, cache_type=dict): def __init__(self, max_size, keylen=1, cache_type=dict):
cache = cache_type() cache = cache_type()
@ -62,10 +66,10 @@ class LruCache(object):
return inner return inner
def add_node(key, value): def add_node(key, value, callbacks=set()):
prev_node = list_root prev_node = list_root
next_node = prev_node.next_node next_node = prev_node.next_node
node = _Node(prev_node, next_node, key, value) node = _Node(prev_node, next_node, key, value, callbacks)
prev_node.next_node = node prev_node.next_node = node
next_node.prev_node = node next_node.prev_node = node
cache[key] = node cache[key] = node
@ -88,23 +92,41 @@ class LruCache(object):
prev_node.next_node = next_node prev_node.next_node = next_node
next_node.prev_node = prev_node next_node.prev_node = prev_node
for cb in node.callbacks:
cb()
node.callbacks.clear()
@synchronized @synchronized
def cache_get(key, default=None): def cache_get(key, default=None, callback=None):
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
move_node_to_front(node) move_node_to_front(node)
if callback:
node.callbacks.add(callback)
return node.value return node.value
else: else:
return default return default
@synchronized @synchronized
def cache_set(key, value): def cache_set(key, value, callback=None):
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
if value != node.value:
for cb in node.callbacks:
cb()
node.callbacks.clear()
if callback:
node.callbacks.add(callback)
move_node_to_front(node) move_node_to_front(node)
node.value = value node.value = value
else: else:
add_node(key, value) if callback:
callbacks = set([callback])
else:
callbacks = set()
add_node(key, value, callbacks)
if len(cache) > max_size: if len(cache) > max_size:
todelete = list_root.prev_node todelete = list_root.prev_node
delete_node(todelete) delete_node(todelete)
@ -148,6 +170,9 @@ class LruCache(object):
def cache_clear(): def cache_clear():
list_root.next_node = list_root list_root.next_node = list_root
list_root.prev_node = list_root list_root.prev_node = list_root
for node in cache.values():
for cb in node.callbacks:
cb()
cache.clear() cache.clear()
@synchronized @synchronized

View File

@ -64,6 +64,9 @@ class TreeCache(object):
self.size -= cnt self.size -= cnt
return popped return popped
def values(self):
return [e.value for e in self.root.values()]
def __len__(self): def __len__(self):
return self.size return self.size

View File

@ -297,12 +297,13 @@ def preserve_context_over_fn(fn, *args, **kwargs):
return res return res
def preserve_context_over_deferred(deferred): def preserve_context_over_deferred(deferred, context=None):
"""Given a deferred wrap it such that any callbacks added later to it will """Given a deferred wrap it such that any callbacks added later to it will
be invoked with the current context. be invoked with the current context.
""" """
current_context = LoggingContext.current_context() if context is None:
d = _PreservingContextDeferred(current_context) context = LoggingContext.current_context()
d = _PreservingContextDeferred(context)
deferred.chainDeferred(d) deferred.chainDeferred(d)
return d return d
@ -316,8 +317,13 @@ def preserve_fn(f):
def g(*args, **kwargs): def g(*args, **kwargs):
with PreserveLoggingContext(current): with PreserveLoggingContext(current):
return f(*args, **kwargs) res = f(*args, **kwargs)
if isinstance(res, defer.Deferred):
return preserve_context_over_deferred(
res, context=LoggingContext.sentinel
)
else:
return res
return g return g

View File

@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
import synapse.metrics import synapse.metrics
from functools import wraps
import logging import logging
@ -47,6 +49,18 @@ block_db_txn_duration = metrics.register_distribution(
) )
def measure_func(name):
def wrapper(func):
@wraps(func)
@defer.inlineCallbacks
def measured_func(self, *args, **kwargs):
with Measure(self.clock, name):
r = yield func(self, *args, **kwargs)
defer.returnValue(r)
return measured_func
return wrapper
class Measure(object): class Measure(object):
__slots__ = [ __slots__ = [
"clock", "name", "start_context", "start", "new_context", "ru_utime", "clock", "name", "start_context", "start", "new_context", "ru_utime",
@ -64,7 +78,6 @@ class Measure(object):
self.start = self.clock.time_msec() self.start = self.clock.time_msec()
self.start_context = LoggingContext.current_context() self.start_context = LoggingContext.current_context()
if not self.start_context: if not self.start_context:
logger.warn("Entered Measure without log context: %s", self.name)
self.start_context = LoggingContext("Measure") self.start_context = LoggingContext("Measure")
self.start_context.__enter__() self.start_context.__enter__()
self.created_context = True self.created_context = True
@ -74,7 +87,7 @@ class Measure(object):
self.db_txn_duration = self.start_context.db_txn_duration self.db_txn_duration = self.start_context.db_txn_duration
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None or not self.start_context: if isinstance(exc_type, Exception) or not self.start_context:
return return
duration = self.clock.time_msec() - self.start duration = self.clock.time_msec() - self.start
@ -85,7 +98,7 @@ class Measure(object):
if context != self.start_context: if context != self.start_context:
logger.warn( logger.warn(
"Context has unexpectedly changed from '%s' to '%s'. (%r)", "Context has unexpectedly changed from '%s' to '%s'. (%r)",
context, self.start_context, self.name self.start_context, context, self.name
) )
return return

View File

@ -17,7 +17,7 @@ from twisted.internet import defer
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
import logging import logging
@ -55,12 +55,12 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
given events given events
events ([synapse.events.EventBase]): list of events to filter events ([synapse.events.EventBase]): list of events to filter
""" """
forgotten = yield defer.gatherResults([ forgotten = yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(store.who_forgot_in_room)( preserve_fn(store.who_forgot_in_room)(
room_id, room_id,
) )
for room_id in frozenset(e.room_id for e in events) for room_id in frozenset(e.room_id for e in events)
], consumeErrors=True) ], consumeErrors=True))
# Set of membership event_ids that have been forgotten # Set of membership event_ids that have been forgotten
event_id_forgotten = frozenset( event_id_forgotten = frozenset(

View File

@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from twisted.internet import defer
from mock import Mock from mock import Mock
from tests import unittest from tests import unittest
@ -42,20 +44,25 @@ class ApplicationServiceTestCase(unittest.TestCase):
type="m.something", room_id="!foo:bar", sender="@someone:somewhere" type="m.something", room_id="!foo:bar", sender="@someone:somewhere"
) )
self.store = Mock()
@defer.inlineCallbacks
def test_regex_user_id_prefix_match(self): def test_regex_user_id_prefix_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*") _regex("@irc_.*")
) )
self.event.sender = "@irc_foobar:matrix.org" self.event.sender = "@irc_foobar:matrix.org"
self.assertTrue(self.service.is_interested(self.event)) self.assertTrue((yield self.service.is_interested(self.event)))
@defer.inlineCallbacks
def test_regex_user_id_prefix_no_match(self): def test_regex_user_id_prefix_no_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*") _regex("@irc_.*")
) )
self.event.sender = "@someone_else:matrix.org" self.event.sender = "@someone_else:matrix.org"
self.assertFalse(self.service.is_interested(self.event)) self.assertFalse((yield self.service.is_interested(self.event)))
@defer.inlineCallbacks
def test_regex_room_member_is_checked(self): def test_regex_room_member_is_checked(self):
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*") _regex("@irc_.*")
@ -63,30 +70,36 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.event.sender = "@someone_else:matrix.org" self.event.sender = "@someone_else:matrix.org"
self.event.type = "m.room.member" self.event.type = "m.room.member"
self.event.state_key = "@irc_foobar:matrix.org" self.event.state_key = "@irc_foobar:matrix.org"
self.assertTrue(self.service.is_interested(self.event)) self.assertTrue((yield self.service.is_interested(self.event)))
@defer.inlineCallbacks
def test_regex_room_id_match(self): def test_regex_room_id_match(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append( self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!some_prefix.*some_suffix:matrix.org") _regex("!some_prefix.*some_suffix:matrix.org")
) )
self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org" self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
self.assertTrue(self.service.is_interested(self.event)) self.assertTrue((yield self.service.is_interested(self.event)))
@defer.inlineCallbacks
def test_regex_room_id_no_match(self): def test_regex_room_id_no_match(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append( self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!some_prefix.*some_suffix:matrix.org") _regex("!some_prefix.*some_suffix:matrix.org")
) )
self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org" self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
self.assertFalse(self.service.is_interested(self.event)) self.assertFalse((yield self.service.is_interested(self.event)))
@defer.inlineCallbacks
def test_regex_alias_match(self): def test_regex_alias_match(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org") _regex("#irc_.*:matrix.org")
) )
self.assertTrue(self.service.is_interested( self.store.get_aliases_for_room.return_value = [
self.event, "#irc_foobar:matrix.org", "#athing:matrix.org"
aliases_for_event=["#irc_foobar:matrix.org", "#athing:matrix.org"] ]
)) self.store.get_users_in_room.return_value = []
self.assertTrue((yield self.service.is_interested(
self.event, self.store
)))
def test_non_exclusive_alias(self): def test_non_exclusive_alias(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
@ -136,15 +149,20 @@ class ApplicationServiceTestCase(unittest.TestCase):
"!irc_foobar:matrix.org" "!irc_foobar:matrix.org"
)) ))
@defer.inlineCallbacks
def test_regex_alias_no_match(self): def test_regex_alias_no_match(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org") _regex("#irc_.*:matrix.org")
) )
self.assertFalse(self.service.is_interested( self.store.get_aliases_for_room.return_value = [
self.event, "#xmpp_foobar:matrix.org", "#athing:matrix.org"
aliases_for_event=["#xmpp_foobar:matrix.org", "#athing:matrix.org"] ]
)) self.store.get_users_in_room.return_value = []
self.assertFalse((yield self.service.is_interested(
self.event, self.store
)))
@defer.inlineCallbacks
def test_regex_multiple_matches(self): def test_regex_multiple_matches(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org") _regex("#irc_.*:matrix.org")
@ -153,53 +171,13 @@ class ApplicationServiceTestCase(unittest.TestCase):
_regex("@irc_.*") _regex("@irc_.*")
) )
self.event.sender = "@irc_foobar:matrix.org" self.event.sender = "@irc_foobar:matrix.org"
self.assertTrue(self.service.is_interested( self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"]
self.event, self.store.get_users_in_room.return_value = []
aliases_for_event=["#irc_barfoo:matrix.org"] self.assertTrue((yield self.service.is_interested(
)) self.event, self.store
)))
def test_restrict_to_rooms(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!flibble_.*:matrix.org")
)
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*")
)
self.event.sender = "@irc_foobar:matrix.org"
self.event.room_id = "!wibblewoo:matrix.org"
self.assertFalse(self.service.is_interested(
self.event,
restrict_to=ApplicationService.NS_ROOMS
))
def test_restrict_to_aliases(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#xmpp_.*:matrix.org")
)
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*")
)
self.event.sender = "@irc_foobar:matrix.org"
self.assertFalse(self.service.is_interested(
self.event,
restrict_to=ApplicationService.NS_ALIASES,
aliases_for_event=["#irc_barfoo:matrix.org"]
))
def test_restrict_to_senders(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#xmpp_.*:matrix.org")
)
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*")
)
self.event.sender = "@xmpp_foobar:matrix.org"
self.assertFalse(self.service.is_interested(
self.event,
restrict_to=ApplicationService.NS_USERS,
aliases_for_event=["#xmpp_barfoo:matrix.org"]
))
@defer.inlineCallbacks
def test_interested_in_self(self): def test_interested_in_self(self):
# make sure invites get through # make sure invites get through
self.service.sender = "@appservice:name" self.service.sender = "@appservice:name"
@ -211,20 +189,21 @@ class ApplicationServiceTestCase(unittest.TestCase):
"membership": "invite" "membership": "invite"
} }
self.event.state_key = self.service.sender self.event.state_key = self.service.sender
self.assertTrue(self.service.is_interested(self.event)) self.assertTrue((yield self.service.is_interested(self.event)))
@defer.inlineCallbacks
def test_member_list_match(self): def test_member_list_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*") _regex("@irc_.*")
) )
join_list = [ self.store.get_users_in_room.return_value = [
"@alice:here", "@alice:here",
"@irc_fo:here", # AS user "@irc_fo:here", # AS user
"@bob:here", "@bob:here",
] ]
self.store.get_aliases_for_room.return_value = []
self.event.sender = "@xmpp_foobar:matrix.org" self.event.sender = "@xmpp_foobar:matrix.org"
self.assertTrue(self.service.is_interested( self.assertTrue((yield self.service.is_interested(
event=self.event, event=self.event, store=self.store
member_list=join_list )))
))

View File

@ -193,7 +193,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.txn_ctrl = Mock() self.txn_ctrl = Mock()
self.queuer = _ServiceQueuer(self.txn_ctrl) self.queuer = _ServiceQueuer(self.txn_ctrl, MockClock())
def test_send_single_event_no_queue(self): def test_send_single_event_no_queue(self):
# Expect the event to be sent immediately. # Expect the event to be sent immediately.

View File

@ -15,6 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
from .. import unittest from .. import unittest
from tests.utils import MockClock
from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.appservice import ApplicationServicesHandler
@ -32,6 +33,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
hs.get_datastore = Mock(return_value=self.mock_store) hs.get_datastore = Mock(return_value=self.mock_store)
hs.get_application_service_api = Mock(return_value=self.mock_as_api) hs.get_application_service_api = Mock(return_value=self.mock_as_api)
hs.get_application_service_scheduler = Mock(return_value=self.mock_scheduler) hs.get_application_service_scheduler = Mock(return_value=self.mock_scheduler)
hs.get_clock.return_value = MockClock()
self.handler = ApplicationServicesHandler(hs) self.handler = ApplicationServicesHandler(hs)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -51,8 +53,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
type="m.room.message", type="m.room.message",
room_id="!foo:bar" room_id="!foo:bar"
) )
self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
self.mock_as_api.push = Mock() self.mock_as_api.push = Mock()
yield self.handler.notify_interested_services(event) yield self.handler.notify_interested_services(0)
self.mock_scheduler.submit_event_for_as.assert_called_once_with( self.mock_scheduler.submit_event_for_as.assert_called_once_with(
interested_service, event interested_service, event
) )
@ -72,7 +75,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
) )
self.mock_as_api.push = Mock() self.mock_as_api.push = Mock()
self.mock_as_api.query_user = Mock() self.mock_as_api.query_user = Mock()
yield self.handler.notify_interested_services(event) self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
yield self.handler.notify_interested_services(0)
self.mock_as_api.query_user.assert_called_once_with( self.mock_as_api.query_user.assert_called_once_with(
services[0], user_id services[0], user_id
) )
@ -94,7 +98,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
) )
self.mock_as_api.push = Mock() self.mock_as_api.push = Mock()
self.mock_as_api.query_user = Mock() self.mock_as_api.query_user = Mock()
yield self.handler.notify_interested_services(event) self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
yield self.handler.notify_interested_services(0)
self.assertFalse( self.assertFalse(
self.mock_as_api.query_user.called, self.mock_as_api.query_user.called,
"query_user called when it shouldn't have been." "query_user called when it shouldn't have been."
@ -108,11 +113,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
room_id = "!alpha:bet" room_id = "!alpha:bet"
servers = ["aperture"] servers = ["aperture"]
interested_service = self._mkservice(is_interested=True) interested_service = self._mkservice_alias(is_interested_in_alias=True)
services = [ services = [
self._mkservice(is_interested=False), self._mkservice_alias(is_interested_in_alias=False),
interested_service, interested_service,
self._mkservice(is_interested=False) self._mkservice_alias(is_interested_in_alias=False)
] ]
self.mock_store.get_app_services = Mock(return_value=services) self.mock_store.get_app_services = Mock(return_value=services)
@ -135,3 +140,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
service.token = "mock_service_token" service.token = "mock_service_token"
service.url = "mock_service_url" service.url = "mock_service_url"
return service return service
def _mkservice_alias(self, is_interested_in_alias):
service = Mock()
service.is_interested_in_alias = Mock(return_value=is_interested_in_alias)
service.token = "mock_service_token"
service.url = "mock_service_url"
return service

View File

@ -14,11 +14,13 @@
# limitations under the License. # limitations under the License.
import pymacaroons import pymacaroons
from twisted.internet import defer
import synapse
import synapse.api.errors
from synapse.handlers.auth import AuthHandler from synapse.handlers.auth import AuthHandler
from tests import unittest from tests import unittest
from tests.utils import setup_test_homeserver from tests.utils import setup_test_homeserver
from twisted.internet import defer
class AuthHandlers(object): class AuthHandlers(object):
@ -31,11 +33,12 @@ class AuthTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.hs = yield setup_test_homeserver(handlers=None) self.hs = yield setup_test_homeserver(handlers=None)
self.hs.handlers = AuthHandlers(self.hs) self.hs.handlers = AuthHandlers(self.hs)
self.auth_handler = self.hs.handlers.auth_handler
def test_token_is_a_macaroon(self): def test_token_is_a_macaroon(self):
self.hs.config.macaroon_secret_key = "this key is a huge secret" self.hs.config.macaroon_secret_key = "this key is a huge secret"
token = self.hs.handlers.auth_handler.generate_access_token("some_user") token = self.auth_handler.generate_access_token("some_user")
# Check that we can parse the thing with pymacaroons # Check that we can parse the thing with pymacaroons
macaroon = pymacaroons.Macaroon.deserialize(token) macaroon = pymacaroons.Macaroon.deserialize(token)
# The most basic of sanity checks # The most basic of sanity checks
@ -46,7 +49,7 @@ class AuthTestCase(unittest.TestCase):
self.hs.config.macaroon_secret_key = "this key is a massive secret" self.hs.config.macaroon_secret_key = "this key is a massive secret"
self.hs.clock.now = 5000 self.hs.clock.now = 5000
token = self.hs.handlers.auth_handler.generate_access_token("a_user") token = self.auth_handler.generate_access_token("a_user")
macaroon = pymacaroons.Macaroon.deserialize(token) macaroon = pymacaroons.Macaroon.deserialize(token)
def verify_gen(caveat): def verify_gen(caveat):
@ -67,3 +70,46 @@ class AuthTestCase(unittest.TestCase):
v.satisfy_general(verify_type) v.satisfy_general(verify_type)
v.satisfy_general(verify_expiry) v.satisfy_general(verify_expiry)
v.verify(macaroon, self.hs.config.macaroon_secret_key) v.verify(macaroon, self.hs.config.macaroon_secret_key)
def test_short_term_login_token_gives_user_id(self):
self.hs.clock.now = 1000
token = self.auth_handler.generate_short_term_login_token(
"a_user", 5000
)
self.assertEqual(
"a_user",
self.auth_handler.validate_short_term_login_token_and_get_user_id(
token
)
)
# when we advance the clock, the token should be rejected
self.hs.clock.now = 6000
with self.assertRaises(synapse.api.errors.AuthError):
self.auth_handler.validate_short_term_login_token_and_get_user_id(
token
)
def test_short_term_login_token_cannot_replace_user_id(self):
token = self.auth_handler.generate_short_term_login_token(
"a_user", 5000
)
macaroon = pymacaroons.Macaroon.deserialize(token)
self.assertEqual(
"a_user",
self.auth_handler.validate_short_term_login_token_and_get_user_id(
macaroon.serialize()
)
)
# add another "user_id" caveat, which might allow us to override the
# user_id.
macaroon.add_first_party_caveat("user_id = b_user")
with self.assertRaises(synapse.api.errors.AuthError):
self.auth_handler.validate_short_term_login_token_and_get_user_id(
macaroon.serialize()
)

View File

@ -17,6 +17,8 @@
from tests import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from mock import Mock
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.caches.descriptors import Cache, cached from synapse.util.caches.descriptors import Cache, cached
@ -72,7 +74,7 @@ class CacheTestCase(unittest.TestCase):
cache.get(3) cache.get(3)
def test_eviction_lru(self): def test_eviction_lru(self):
cache = Cache("test", max_entries=2, lru=True) cache = Cache("test", max_entries=2)
cache.prefill(1, "one") cache.prefill(1, "one")
cache.prefill(2, "two") cache.prefill(2, "two")
@ -199,3 +201,115 @@ class CacheDecoratorTestCase(unittest.TestCase):
self.assertEquals(a.func("foo").result, d.result) self.assertEquals(a.func("foo").result, d.result)
self.assertEquals(callcount[0], 0) self.assertEquals(callcount[0], 0)
@defer.inlineCallbacks
def test_invalidate_context(self):
callcount = [0]
callcount2 = [0]
class A(object):
@cached()
def func(self, key):
callcount[0] += 1
return key
@cached(cache_context=True)
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate)
a = A()
yield a.func2("foo")
self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 1)
a.func.invalidate(("foo",))
yield a.func("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 1)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
@defer.inlineCallbacks
def test_eviction_context(self):
callcount = [0]
callcount2 = [0]
class A(object):
@cached(max_entries=2)
def func(self, key):
callcount[0] += 1
return key
@cached(cache_context=True)
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate)
a = A()
yield a.func2("foo")
yield a.func2("foo2")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func("foo3")
self.assertEquals(callcount[0], 3)
self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 4)
self.assertEquals(callcount2[0], 3)
@defer.inlineCallbacks
def test_double_get(self):
callcount = [0]
callcount2 = [0]
class A(object):
@cached()
def func(self, key):
callcount[0] += 1
return key
@cached(cache_context=True)
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate)
a = A()
a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
yield a.func2("foo")
self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 1)
a.func2.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
yield a.func2("foo")
a.func2.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 2)
a.func.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
yield a.func("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 3)

View File

@ -15,7 +15,9 @@
from . import unittest from . import unittest
from synapse.rest.media.v1.preview_url_resource import summarize_paragraphs from synapse.rest.media.v1.preview_url_resource import (
summarize_paragraphs, decode_and_calc_og
)
class PreviewTestCase(unittest.TestCase): class PreviewTestCase(unittest.TestCase):
@ -137,3 +139,79 @@ class PreviewTestCase(unittest.TestCase):
" of old wooden houses in Northern Norway, the oldest house dating from" " of old wooden houses in Northern Norway, the oldest house dating from"
" 1789. The Arctic Cathedral, a modern church…" " 1789. The Arctic Cathedral, a modern church…"
) )
class PreviewUrlTestCase(unittest.TestCase):
def test_simple(self):
html = """
<html>
<head><title>Foo</title></head>
<body>
Some text.
</body>
</html>
"""
og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEquals(og, {
"og:title": "Foo",
"og:description": "Some text."
})
def test_comment(self):
html = """
<html>
<head><title>Foo</title></head>
<body>
<!-- HTML comment -->
Some text.
</body>
</html>
"""
og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEquals(og, {
"og:title": "Foo",
"og:description": "Some text."
})
def test_comment2(self):
html = """
<html>
<head><title>Foo</title></head>
<body>
Some text.
<!-- HTML comment -->
Some more text.
<p>Text</p>
More text
</body>
</html>
"""
og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEquals(og, {
"og:title": "Foo",
"og:description": "Some text.\n\nSome more text.\n\nText\n\nMore text"
})
def test_script(self):
html = """
<html>
<head><title>Foo</title></head>
<body>
<script> (function() {})() </script>
Some text.
</body>
</html>
"""
og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEquals(og, {
"og:title": "Foo",
"og:description": "Some text."
})

View File

@ -19,6 +19,8 @@ from .. import unittest
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache from synapse.util.caches.treecache import TreeCache
from mock import Mock
class LruCacheTestCase(unittest.TestCase): class LruCacheTestCase(unittest.TestCase):
@ -48,6 +50,8 @@ class LruCacheTestCase(unittest.TestCase):
self.assertEquals(cache.get("key"), 1) self.assertEquals(cache.get("key"), 1)
self.assertEquals(cache.setdefault("key", 2), 1) self.assertEquals(cache.setdefault("key", 2), 1)
self.assertEquals(cache.get("key"), 1) self.assertEquals(cache.get("key"), 1)
cache["key"] = 2 # Make sure overriding works.
self.assertEquals(cache.get("key"), 2)
def test_pop(self): def test_pop(self):
cache = LruCache(1) cache = LruCache(1)
@ -79,3 +83,152 @@ class LruCacheTestCase(unittest.TestCase):
cache["key"] = 1 cache["key"] = 1
cache.clear() cache.clear()
self.assertEquals(len(cache), 0) self.assertEquals(len(cache), 0)
class LruCacheCallbacksTestCase(unittest.TestCase):
def test_get(self):
m = Mock()
cache = LruCache(1)
cache.set("key", "value")
self.assertFalse(m.called)
cache.get("key", callback=m)
self.assertFalse(m.called)
cache.get("key", "value")
self.assertFalse(m.called)
cache.set("key", "value2")
self.assertEquals(m.call_count, 1)
cache.set("key", "value")
self.assertEquals(m.call_count, 1)
def test_multi_get(self):
m = Mock()
cache = LruCache(1)
cache.set("key", "value")
self.assertFalse(m.called)
cache.get("key", callback=m)
self.assertFalse(m.called)
cache.get("key", callback=m)
self.assertFalse(m.called)
cache.set("key", "value2")
self.assertEquals(m.call_count, 1)
cache.set("key", "value")
self.assertEquals(m.call_count, 1)
def test_set(self):
m = Mock()
cache = LruCache(1)
cache.set("key", "value", m)
self.assertFalse(m.called)
cache.set("key", "value")
self.assertFalse(m.called)
cache.set("key", "value2")
self.assertEquals(m.call_count, 1)
cache.set("key", "value")
self.assertEquals(m.call_count, 1)
def test_pop(self):
m = Mock()
cache = LruCache(1)
cache.set("key", "value", m)
self.assertFalse(m.called)
cache.pop("key")
self.assertEquals(m.call_count, 1)
cache.set("key", "value")
self.assertEquals(m.call_count, 1)
cache.pop("key")
self.assertEquals(m.call_count, 1)
def test_del_multi(self):
m1 = Mock()
m2 = Mock()
m3 = Mock()
m4 = Mock()
cache = LruCache(4, 2, cache_type=TreeCache)
cache.set(("a", "1"), "value", m1)
cache.set(("a", "2"), "value", m2)
cache.set(("b", "1"), "value", m3)
cache.set(("b", "2"), "value", m4)
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
self.assertEquals(m4.call_count, 0)
cache.del_multi(("a",))
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 1)
self.assertEquals(m3.call_count, 0)
self.assertEquals(m4.call_count, 0)
def test_clear(self):
m1 = Mock()
m2 = Mock()
cache = LruCache(5)
cache.set("key1", "value", m1)
cache.set("key2", "value", m2)
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
cache.clear()
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 1)
def test_eviction(self):
m1 = Mock(name="m1")
m2 = Mock(name="m2")
m3 = Mock(name="m3")
cache = LruCache(2)
cache.set("key1", "value", m1)
cache.set("key2", "value", m2)
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.set("key3", "value", m3)
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.set("key3", "value")
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.get("key2")
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.set("key1", "value", m1)
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 1)