mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-01-26 00:25:58 -05:00
Merge branch 'release-v0.17.1' of github.com:matrix-org/synapse
This commit is contained in:
commit
37638c06c5
47
CHANGES.rst
47
CHANGES.rst
@ -1,3 +1,50 @@
|
|||||||
|
Changes in synapse v0.17.1 (2016-08-24)
|
||||||
|
=======================================
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
|
||||||
|
* Delete old received_transactions rows (PR #1038)
|
||||||
|
* Pass through user-supplied content in /join/$room_id (PR #1039)
|
||||||
|
|
||||||
|
|
||||||
|
Bug fixes:
|
||||||
|
|
||||||
|
* Fix bug with backfill (PR #1040)
|
||||||
|
|
||||||
|
|
||||||
|
Changes in synapse v0.17.1-rc1 (2016-08-22)
|
||||||
|
===========================================
|
||||||
|
|
||||||
|
Features:
|
||||||
|
|
||||||
|
* Add notification API (PR #1028)
|
||||||
|
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
|
||||||
|
* Don't print stack traces when failing to get remote keys (PR #996)
|
||||||
|
* Various federation /event/ perf improvements (PR #998)
|
||||||
|
* Only process one local membership event per room at a time (PR #1005)
|
||||||
|
* Move default display name push rule (PR #1011, #1023)
|
||||||
|
* Fix up preview URL API. Add tests. (PR #1015)
|
||||||
|
* Set ``Content-Security-Policy`` on media repo (PR #1021)
|
||||||
|
* Make notify_interested_services faster (PR #1022)
|
||||||
|
* Add usage stats to prometheus monitoring (PR #1037)
|
||||||
|
|
||||||
|
|
||||||
|
Bug fixes:
|
||||||
|
|
||||||
|
* Fix token login (PR #993)
|
||||||
|
* Fix CAS login (PR #994, #995)
|
||||||
|
* Fix /sync to not clobber status_msg (PR #997)
|
||||||
|
* Fix redacted state events to include prev_content (PR #1003)
|
||||||
|
* Fix some bugs in the auth/ldap handler (PR #1007)
|
||||||
|
* Fix backfill request to limit URI length, so that remotes don't reject the
|
||||||
|
requests due to path length limits (PR #1012)
|
||||||
|
* Fix AS push code to not send duplicate events (PR #1025)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Changes in synapse v0.17.0 (2016-08-08)
|
Changes in synapse v0.17.0 (2016-08-08)
|
||||||
=======================================
|
=======================================
|
||||||
|
|
||||||
|
@ -95,7 +95,7 @@ Synapse is the reference python/twisted Matrix homeserver implementation.
|
|||||||
System requirements:
|
System requirements:
|
||||||
- POSIX-compliant system (tested on Linux & OS X)
|
- POSIX-compliant system (tested on Linux & OS X)
|
||||||
- Python 2.7
|
- Python 2.7
|
||||||
- At least 512 MB RAM.
|
- At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org
|
||||||
|
|
||||||
Synapse is written in python but some of the libraries is uses are written in
|
Synapse is written in python but some of the libraries is uses are written in
|
||||||
C. So before we can install synapse itself we need a working C compiler and the
|
C. So before we can install synapse itself we need a working C compiler and the
|
||||||
|
97
docs/workers.rst
Normal file
97
docs/workers.rst
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
Scaling synapse via workers
|
||||||
|
---------------------------
|
||||||
|
|
||||||
|
Synapse has experimental support for splitting out functionality into
|
||||||
|
multiple separate python processes, helping greatly with scalability. These
|
||||||
|
processes are called 'workers', and are (eventually) intended to scale
|
||||||
|
horizontally independently.
|
||||||
|
|
||||||
|
All processes continue to share the same database instance, and as such, workers
|
||||||
|
only work with postgres based synapse deployments (sharing a single sqlite
|
||||||
|
across multiple processes is a recipe for disaster, plus you should be using
|
||||||
|
postgres anyway if you care about scalability).
|
||||||
|
|
||||||
|
The workers communicate with the master synapse process via a synapse-specific
|
||||||
|
HTTP protocol called 'replication' - analogous to MySQL or Postgres style
|
||||||
|
database replication; feeding a stream of relevant data to the workers so they
|
||||||
|
can be kept in sync with the main synapse process and database state.
|
||||||
|
|
||||||
|
To enable workers, you need to add a replication listener to the master synapse, e.g.::
|
||||||
|
|
||||||
|
listeners:
|
||||||
|
- port: 9092
|
||||||
|
bind_address: '127.0.0.1'
|
||||||
|
type: http
|
||||||
|
tls: false
|
||||||
|
x_forwarded: false
|
||||||
|
resources:
|
||||||
|
- names: [replication]
|
||||||
|
compress: false
|
||||||
|
|
||||||
|
Under **no circumstances** should this replication API listener be exposed to the
|
||||||
|
public internet; it currently implements no authentication whatsoever and is
|
||||||
|
unencrypted HTTP.
|
||||||
|
|
||||||
|
You then create a set of configs for the various worker processes. These should be
|
||||||
|
worker configuration files should be stored in a dedicated subdirectory, to allow
|
||||||
|
synctl to manipulate them.
|
||||||
|
|
||||||
|
The current available worker applications are:
|
||||||
|
* synapse.app.pusher - handles sending push notifications to sygnal and email
|
||||||
|
* synapse.app.synchrotron - handles /sync endpoints. can scales horizontally through multiple instances.
|
||||||
|
* synapse.app.appservice - handles output traffic to Application Services
|
||||||
|
* synapse.app.federation_reader - handles receiving federation traffic (including public_rooms API)
|
||||||
|
* synapse.app.media_repository - handles the media repository.
|
||||||
|
|
||||||
|
Each worker configuration file inherits the configuration of the main homeserver
|
||||||
|
configuration file. You can then override configuration specific to that worker,
|
||||||
|
e.g. the HTTP listener that it provides (if any); logging configuration; etc.
|
||||||
|
You should minimise the number of overrides though to maintain a usable config.
|
||||||
|
|
||||||
|
You must specify the type of worker application (worker_app) and the replication
|
||||||
|
endpoint that it's talking to on the main synapse process (worker_replication_url).
|
||||||
|
|
||||||
|
For instance::
|
||||||
|
|
||||||
|
worker_app: synapse.app.synchrotron
|
||||||
|
|
||||||
|
# The replication listener on the synapse to talk to.
|
||||||
|
worker_replication_url: http://127.0.0.1:9092/_synapse/replication
|
||||||
|
|
||||||
|
worker_listeners:
|
||||||
|
- type: http
|
||||||
|
port: 8083
|
||||||
|
resources:
|
||||||
|
- names:
|
||||||
|
- client
|
||||||
|
|
||||||
|
worker_daemonize: True
|
||||||
|
worker_pid_file: /home/matrix/synapse/synchrotron.pid
|
||||||
|
worker_log_config: /home/matrix/synapse/config/synchrotron_log_config.yaml
|
||||||
|
|
||||||
|
...is a full configuration for a synchrotron worker instance, which will expose a
|
||||||
|
plain HTTP /sync endpoint on port 8083 separately from the /sync endpoint provided
|
||||||
|
by the main synapse.
|
||||||
|
|
||||||
|
Obviously you should configure your loadbalancer to route the /sync endpoint to
|
||||||
|
the synchrotron instance(s) in this instance.
|
||||||
|
|
||||||
|
Finally, to actually run your worker-based synapse, you must pass synctl the -a
|
||||||
|
commandline option to tell it to operate on all the worker configurations found
|
||||||
|
in the given directory, e.g.::
|
||||||
|
|
||||||
|
synctl -a $CONFIG/workers start
|
||||||
|
|
||||||
|
Currently one should always restart all workers when restarting or upgrading
|
||||||
|
synapse, unless you explicitly know it's safe not to. For instance, restarting
|
||||||
|
synapse without restarting all the synchrotrons may result in broken typing
|
||||||
|
notifications.
|
||||||
|
|
||||||
|
To manipulate a specific worker, you pass the -w option to synctl::
|
||||||
|
|
||||||
|
synctl -w $CONFIG/workers/synchrotron.yaml restart
|
||||||
|
|
||||||
|
All of the above is highly experimental and subject to change as Synapse evolves,
|
||||||
|
but documenting it here to help folks needing highly scalable Synapses similar
|
||||||
|
to the one running matrix.org!
|
||||||
|
|
@ -25,5 +25,6 @@ rm .coverage* || echo "No coverage files to remove"
|
|||||||
tox --notest -e py27
|
tox --notest -e py27
|
||||||
TOX_BIN=$WORKSPACE/.tox/py27/bin
|
TOX_BIN=$WORKSPACE/.tox/py27/bin
|
||||||
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
|
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
|
||||||
|
$TOX_BIN/pip install lxml
|
||||||
|
|
||||||
tox -e py27
|
tox -e py27
|
||||||
|
@ -14,6 +14,7 @@ fi
|
|||||||
tox -e py27 --notest -v
|
tox -e py27 --notest -v
|
||||||
|
|
||||||
TOX_BIN=$TOX_DIR/py27/bin
|
TOX_BIN=$TOX_DIR/py27/bin
|
||||||
|
$TOX_BIN/pip install setuptools
|
||||||
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
|
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
|
||||||
$TOX_BIN/pip install lxml
|
$TOX_BIN/pip install lxml
|
||||||
$TOX_BIN/pip install psycopg2
|
$TOX_BIN/pip install psycopg2
|
||||||
|
@ -16,4 +16,4 @@
|
|||||||
""" This is a reference implementation of a Matrix home server.
|
""" This is a reference implementation of a Matrix home server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.17.0"
|
__version__ = "0.17.1"
|
||||||
|
@ -675,27 +675,18 @@ class Auth(object):
|
|||||||
try:
|
try:
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
|
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
|
||||||
|
|
||||||
user_prefix = "user_id = "
|
user_id = self.get_user_id_from_macaroon(macaroon)
|
||||||
user = None
|
user = UserID.from_string(user_id)
|
||||||
user_id = None
|
|
||||||
guest = False
|
|
||||||
for caveat in macaroon.caveats:
|
|
||||||
if caveat.caveat_id.startswith(user_prefix):
|
|
||||||
user_id = caveat.caveat_id[len(user_prefix):]
|
|
||||||
user = UserID.from_string(user_id)
|
|
||||||
elif caveat.caveat_id == "guest = true":
|
|
||||||
guest = True
|
|
||||||
|
|
||||||
self.validate_macaroon(
|
self.validate_macaroon(
|
||||||
macaroon, rights, self.hs.config.expire_access_token,
|
macaroon, rights, self.hs.config.expire_access_token,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if user is None:
|
guest = False
|
||||||
raise AuthError(
|
for caveat in macaroon.caveats:
|
||||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
|
if caveat.caveat_id == "guest = true":
|
||||||
errcode=Codes.UNKNOWN_TOKEN
|
guest = True
|
||||||
)
|
|
||||||
|
|
||||||
if guest:
|
if guest:
|
||||||
ret = {
|
ret = {
|
||||||
@ -743,6 +734,29 @@ class Auth(object):
|
|||||||
errcode=Codes.UNKNOWN_TOKEN
|
errcode=Codes.UNKNOWN_TOKEN
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_user_id_from_macaroon(self, macaroon):
|
||||||
|
"""Retrieve the user_id given by the caveats on the macaroon.
|
||||||
|
|
||||||
|
Does *not* validate the macaroon.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
macaroon (pymacaroons.Macaroon): The macaroon to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(str) user id
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AuthError if there is no user_id caveat in the macaroon
|
||||||
|
"""
|
||||||
|
user_prefix = "user_id = "
|
||||||
|
for caveat in macaroon.caveats:
|
||||||
|
if caveat.caveat_id.startswith(user_prefix):
|
||||||
|
return caveat.caveat_id[len(user_prefix):]
|
||||||
|
raise AuthError(
|
||||||
|
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
|
||||||
|
errcode=Codes.UNKNOWN_TOKEN
|
||||||
|
)
|
||||||
|
|
||||||
def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id):
|
def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id):
|
||||||
"""
|
"""
|
||||||
validate that a Macaroon is understood by and was signed by this server.
|
validate that a Macaroon is understood by and was signed by this server.
|
||||||
@ -754,6 +768,7 @@ class Auth(object):
|
|||||||
verify_expiry(bool): Whether to verify whether the macaroon has expired.
|
verify_expiry(bool): Whether to verify whether the macaroon has expired.
|
||||||
This should really always be True, but no clients currently implement
|
This should really always be True, but no clients currently implement
|
||||||
token refresh, so we can't enforce expiry yet.
|
token refresh, so we can't enforce expiry yet.
|
||||||
|
user_id (str): The user_id required
|
||||||
"""
|
"""
|
||||||
v = pymacaroons.Verifier()
|
v = pymacaroons.Verifier()
|
||||||
v.satisfy_exact("gen = 1")
|
v.satisfy_exact("gen = 1")
|
||||||
|
209
synapse/app/appservice.py
Normal file
209
synapse/app/appservice.py
Normal file
@ -0,0 +1,209 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import synapse
|
||||||
|
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.config._base import ConfigError
|
||||||
|
from synapse.config.logger import setup_logging
|
||||||
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
|
from synapse.http.site import SynapseSite
|
||||||
|
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||||
|
from synapse.replication.slave.storage.directory import DirectoryStore
|
||||||
|
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||||
|
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
||||||
|
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||||
|
from synapse.storage.engines import create_engine
|
||||||
|
from synapse.util.async import sleep
|
||||||
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
|
from synapse.util.logcontext import LoggingContext
|
||||||
|
from synapse.util.manhole import manhole
|
||||||
|
from synapse.util.rlimit import change_resource_limit
|
||||||
|
from synapse.util.versionstring import get_version_string
|
||||||
|
|
||||||
|
from twisted.internet import reactor, defer
|
||||||
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
|
from daemonize import Daemonize
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import gc
|
||||||
|
|
||||||
|
logger = logging.getLogger("synapse.app.appservice")
|
||||||
|
|
||||||
|
|
||||||
|
class AppserviceSlaveStore(
|
||||||
|
DirectoryStore, SlavedEventStore, SlavedApplicationServiceStore,
|
||||||
|
SlavedRegistrationStore,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AppserviceServer(HomeServer):
|
||||||
|
def get_db_conn(self, run_new_connection=True):
|
||||||
|
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||||
|
# not be passed to the database engine.
|
||||||
|
db_params = {
|
||||||
|
k: v for k, v in self.db_config.get("args", {}).items()
|
||||||
|
if not k.startswith("cp_")
|
||||||
|
}
|
||||||
|
db_conn = self.database_engine.module.connect(**db_params)
|
||||||
|
|
||||||
|
if run_new_connection:
|
||||||
|
self.database_engine.on_new_connection(db_conn)
|
||||||
|
return db_conn
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
logger.info("Setting up.")
|
||||||
|
self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
|
||||||
|
logger.info("Finished setting up.")
|
||||||
|
|
||||||
|
def _listen_http(self, listener_config):
|
||||||
|
port = listener_config["port"]
|
||||||
|
bind_address = listener_config.get("bind_address", "")
|
||||||
|
site_tag = listener_config.get("tag", port)
|
||||||
|
resources = {}
|
||||||
|
for res in listener_config["resources"]:
|
||||||
|
for name in res["names"]:
|
||||||
|
if name == "metrics":
|
||||||
|
resources[METRICS_PREFIX] = MetricsResource(self)
|
||||||
|
|
||||||
|
root_resource = create_resource_tree(resources, Resource())
|
||||||
|
reactor.listenTCP(
|
||||||
|
port,
|
||||||
|
SynapseSite(
|
||||||
|
"synapse.access.http.%s" % (site_tag,),
|
||||||
|
site_tag,
|
||||||
|
listener_config,
|
||||||
|
root_resource,
|
||||||
|
),
|
||||||
|
interface=bind_address
|
||||||
|
)
|
||||||
|
logger.info("Synapse appservice now listening on port %d", port)
|
||||||
|
|
||||||
|
def start_listening(self, listeners):
|
||||||
|
for listener in listeners:
|
||||||
|
if listener["type"] == "http":
|
||||||
|
self._listen_http(listener)
|
||||||
|
elif listener["type"] == "manhole":
|
||||||
|
reactor.listenTCP(
|
||||||
|
listener["port"],
|
||||||
|
manhole(
|
||||||
|
username="matrix",
|
||||||
|
password="rabbithole",
|
||||||
|
globals={"hs": self},
|
||||||
|
),
|
||||||
|
interface=listener.get("bind_address", '127.0.0.1')
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def replicate(self):
|
||||||
|
http_client = self.get_simple_http_client()
|
||||||
|
store = self.get_datastore()
|
||||||
|
replication_url = self.config.worker_replication_url
|
||||||
|
appservice_handler = self.get_application_service_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def replicate(results):
|
||||||
|
stream = results.get("events")
|
||||||
|
if stream:
|
||||||
|
max_stream_id = stream["position"]
|
||||||
|
yield appservice_handler.notify_interested_services(max_stream_id)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
args = store.stream_positions()
|
||||||
|
args["timeout"] = 30000
|
||||||
|
result = yield http_client.get_json(replication_url, args=args)
|
||||||
|
yield store.process_replication(result)
|
||||||
|
replicate(result)
|
||||||
|
except:
|
||||||
|
logger.exception("Error replicating from %r", replication_url)
|
||||||
|
yield sleep(30)
|
||||||
|
|
||||||
|
|
||||||
|
def start(config_options):
|
||||||
|
try:
|
||||||
|
config = HomeServerConfig.load_config(
|
||||||
|
"Synapse appservice", config_options
|
||||||
|
)
|
||||||
|
except ConfigError as e:
|
||||||
|
sys.stderr.write("\n" + e.message + "\n")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
assert config.worker_app == "synapse.app.appservice"
|
||||||
|
|
||||||
|
setup_logging(config.worker_log_config, config.worker_log_file)
|
||||||
|
|
||||||
|
database_engine = create_engine(config.database_config)
|
||||||
|
|
||||||
|
if config.notify_appservices:
|
||||||
|
sys.stderr.write(
|
||||||
|
"\nThe appservices must be disabled in the main synapse process"
|
||||||
|
"\nbefore they can be run in a separate worker."
|
||||||
|
"\nPlease add ``notify_appservices: false`` to the main config"
|
||||||
|
"\n"
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Force the pushers to start since they will be disabled in the main config
|
||||||
|
config.notify_appservices = True
|
||||||
|
|
||||||
|
ps = AppserviceServer(
|
||||||
|
config.server_name,
|
||||||
|
db_config=config.database_config,
|
||||||
|
config=config,
|
||||||
|
version_string="Synapse/" + get_version_string(synapse),
|
||||||
|
database_engine=database_engine,
|
||||||
|
)
|
||||||
|
|
||||||
|
ps.setup()
|
||||||
|
ps.start_listening(config.worker_listeners)
|
||||||
|
|
||||||
|
def run():
|
||||||
|
with LoggingContext("run"):
|
||||||
|
logger.info("Running")
|
||||||
|
change_resource_limit(config.soft_file_limit)
|
||||||
|
if config.gc_thresholds:
|
||||||
|
gc.set_threshold(*config.gc_thresholds)
|
||||||
|
reactor.run()
|
||||||
|
|
||||||
|
def start():
|
||||||
|
ps.replicate()
|
||||||
|
ps.get_datastore().start_profiling()
|
||||||
|
|
||||||
|
reactor.callWhenRunning(start)
|
||||||
|
|
||||||
|
if config.worker_daemonize:
|
||||||
|
daemon = Daemonize(
|
||||||
|
app="synapse-appservice",
|
||||||
|
pid=config.worker_pid_file,
|
||||||
|
action=run,
|
||||||
|
auto_close_fds=False,
|
||||||
|
verbose=True,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
daemon.start()
|
||||||
|
else:
|
||||||
|
run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
with LoggingContext("main"):
|
||||||
|
start(sys.argv[1:])
|
@ -51,7 +51,7 @@ from synapse.api.urls import (
|
|||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.crypto import context_factory
|
from synapse.crypto import context_factory
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext
|
||||||
from synapse.metrics import register_memory_metrics
|
from synapse.metrics import register_memory_metrics, get_metrics_for
|
||||||
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||||
from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
|
from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
|
||||||
from synapse.federation.transport.server import TransportLayerServer
|
from synapse.federation.transport.server import TransportLayerServer
|
||||||
@ -385,6 +385,8 @@ def run(hs):
|
|||||||
|
|
||||||
start_time = hs.get_clock().time()
|
start_time = hs.get_clock().time()
|
||||||
|
|
||||||
|
stats = {}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def phone_stats_home():
|
def phone_stats_home():
|
||||||
logger.info("Gathering stats for reporting")
|
logger.info("Gathering stats for reporting")
|
||||||
@ -393,7 +395,10 @@ def run(hs):
|
|||||||
if uptime < 0:
|
if uptime < 0:
|
||||||
uptime = 0
|
uptime = 0
|
||||||
|
|
||||||
stats = {}
|
# If the stats directory is empty then this is the first time we've
|
||||||
|
# reported stats.
|
||||||
|
first_time = not stats
|
||||||
|
|
||||||
stats["homeserver"] = hs.config.server_name
|
stats["homeserver"] = hs.config.server_name
|
||||||
stats["timestamp"] = now
|
stats["timestamp"] = now
|
||||||
stats["uptime_seconds"] = uptime
|
stats["uptime_seconds"] = uptime
|
||||||
@ -406,6 +411,25 @@ def run(hs):
|
|||||||
daily_messages = yield hs.get_datastore().count_daily_messages()
|
daily_messages = yield hs.get_datastore().count_daily_messages()
|
||||||
if daily_messages is not None:
|
if daily_messages is not None:
|
||||||
stats["daily_messages"] = daily_messages
|
stats["daily_messages"] = daily_messages
|
||||||
|
else:
|
||||||
|
stats.pop("daily_messages", None)
|
||||||
|
|
||||||
|
if first_time:
|
||||||
|
# Add callbacks to report the synapse stats as metrics whenever
|
||||||
|
# prometheus requests them, typically every 30s.
|
||||||
|
# As some of the stats are expensive to calculate we only update
|
||||||
|
# them when synapse phones home to matrix.org every 24 hours.
|
||||||
|
metrics = get_metrics_for("synapse.usage")
|
||||||
|
metrics.add_callback("timestamp", lambda: stats["timestamp"])
|
||||||
|
metrics.add_callback("uptime_seconds", lambda: stats["uptime_seconds"])
|
||||||
|
metrics.add_callback("total_users", lambda: stats["total_users"])
|
||||||
|
metrics.add_callback("total_room_count", lambda: stats["total_room_count"])
|
||||||
|
metrics.add_callback(
|
||||||
|
"daily_active_users", lambda: stats["daily_active_users"]
|
||||||
|
)
|
||||||
|
metrics.add_callback(
|
||||||
|
"daily_messages", lambda: stats.get("daily_messages", 0)
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Reporting stats to matrix.org: %s" % (stats,))
|
logger.info("Reporting stats to matrix.org: %s" % (stats,))
|
||||||
try:
|
try:
|
||||||
|
212
synapse/app/media_repository.py
Normal file
212
synapse/app/media_repository.py
Normal file
@ -0,0 +1,212 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import synapse
|
||||||
|
|
||||||
|
from synapse.config._base import ConfigError
|
||||||
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
|
from synapse.config.logger import setup_logging
|
||||||
|
from synapse.http.site import SynapseSite
|
||||||
|
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||||
|
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||||
|
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
||||||
|
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||||
|
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
||||||
|
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.storage.client_ips import ClientIpStore
|
||||||
|
from synapse.storage.engines import create_engine
|
||||||
|
from synapse.storage.media_repository import MediaRepositoryStore
|
||||||
|
from synapse.util.async import sleep
|
||||||
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
|
from synapse.util.logcontext import LoggingContext
|
||||||
|
from synapse.util.manhole import manhole
|
||||||
|
from synapse.util.rlimit import change_resource_limit
|
||||||
|
from synapse.util.versionstring import get_version_string
|
||||||
|
from synapse.api.urls import (
|
||||||
|
CONTENT_REPO_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX
|
||||||
|
)
|
||||||
|
from synapse.crypto import context_factory
|
||||||
|
|
||||||
|
|
||||||
|
from twisted.internet import reactor, defer
|
||||||
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
|
from daemonize import Daemonize
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import gc
|
||||||
|
|
||||||
|
logger = logging.getLogger("synapse.app.media_repository")
|
||||||
|
|
||||||
|
|
||||||
|
class MediaRepositorySlavedStore(
|
||||||
|
SlavedApplicationServiceStore,
|
||||||
|
SlavedRegistrationStore,
|
||||||
|
BaseSlavedStore,
|
||||||
|
MediaRepositoryStore,
|
||||||
|
ClientIpStore,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MediaRepositoryServer(HomeServer):
|
||||||
|
def get_db_conn(self, run_new_connection=True):
|
||||||
|
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||||
|
# not be passed to the database engine.
|
||||||
|
db_params = {
|
||||||
|
k: v for k, v in self.db_config.get("args", {}).items()
|
||||||
|
if not k.startswith("cp_")
|
||||||
|
}
|
||||||
|
db_conn = self.database_engine.module.connect(**db_params)
|
||||||
|
|
||||||
|
if run_new_connection:
|
||||||
|
self.database_engine.on_new_connection(db_conn)
|
||||||
|
return db_conn
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
logger.info("Setting up.")
|
||||||
|
self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
|
||||||
|
logger.info("Finished setting up.")
|
||||||
|
|
||||||
|
def _listen_http(self, listener_config):
|
||||||
|
port = listener_config["port"]
|
||||||
|
bind_address = listener_config.get("bind_address", "")
|
||||||
|
site_tag = listener_config.get("tag", port)
|
||||||
|
resources = {}
|
||||||
|
for res in listener_config["resources"]:
|
||||||
|
for name in res["names"]:
|
||||||
|
if name == "metrics":
|
||||||
|
resources[METRICS_PREFIX] = MetricsResource(self)
|
||||||
|
elif name == "media":
|
||||||
|
media_repo = MediaRepositoryResource(self)
|
||||||
|
resources.update({
|
||||||
|
MEDIA_PREFIX: media_repo,
|
||||||
|
LEGACY_MEDIA_PREFIX: media_repo,
|
||||||
|
CONTENT_REPO_PREFIX: ContentRepoResource(
|
||||||
|
self, self.config.uploads_path
|
||||||
|
),
|
||||||
|
})
|
||||||
|
|
||||||
|
root_resource = create_resource_tree(resources, Resource())
|
||||||
|
reactor.listenTCP(
|
||||||
|
port,
|
||||||
|
SynapseSite(
|
||||||
|
"synapse.access.http.%s" % (site_tag,),
|
||||||
|
site_tag,
|
||||||
|
listener_config,
|
||||||
|
root_resource,
|
||||||
|
),
|
||||||
|
interface=bind_address
|
||||||
|
)
|
||||||
|
logger.info("Synapse media repository now listening on port %d", port)
|
||||||
|
|
||||||
|
def start_listening(self, listeners):
|
||||||
|
for listener in listeners:
|
||||||
|
if listener["type"] == "http":
|
||||||
|
self._listen_http(listener)
|
||||||
|
elif listener["type"] == "manhole":
|
||||||
|
reactor.listenTCP(
|
||||||
|
listener["port"],
|
||||||
|
manhole(
|
||||||
|
username="matrix",
|
||||||
|
password="rabbithole",
|
||||||
|
globals={"hs": self},
|
||||||
|
),
|
||||||
|
interface=listener.get("bind_address", '127.0.0.1')
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def replicate(self):
|
||||||
|
http_client = self.get_simple_http_client()
|
||||||
|
store = self.get_datastore()
|
||||||
|
replication_url = self.config.worker_replication_url
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
args = store.stream_positions()
|
||||||
|
args["timeout"] = 30000
|
||||||
|
result = yield http_client.get_json(replication_url, args=args)
|
||||||
|
yield store.process_replication(result)
|
||||||
|
except:
|
||||||
|
logger.exception("Error replicating from %r", replication_url)
|
||||||
|
yield sleep(5)
|
||||||
|
|
||||||
|
|
||||||
|
def start(config_options):
|
||||||
|
try:
|
||||||
|
config = HomeServerConfig.load_config(
|
||||||
|
"Synapse media repository", config_options
|
||||||
|
)
|
||||||
|
except ConfigError as e:
|
||||||
|
sys.stderr.write("\n" + e.message + "\n")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
assert config.worker_app == "synapse.app.media_repository"
|
||||||
|
|
||||||
|
setup_logging(config.worker_log_config, config.worker_log_file)
|
||||||
|
|
||||||
|
database_engine = create_engine(config.database_config)
|
||||||
|
|
||||||
|
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||||
|
|
||||||
|
ss = MediaRepositoryServer(
|
||||||
|
config.server_name,
|
||||||
|
db_config=config.database_config,
|
||||||
|
tls_server_context_factory=tls_server_context_factory,
|
||||||
|
config=config,
|
||||||
|
version_string="Synapse/" + get_version_string(synapse),
|
||||||
|
database_engine=database_engine,
|
||||||
|
)
|
||||||
|
|
||||||
|
ss.setup()
|
||||||
|
ss.get_handlers()
|
||||||
|
ss.start_listening(config.worker_listeners)
|
||||||
|
|
||||||
|
def run():
|
||||||
|
with LoggingContext("run"):
|
||||||
|
logger.info("Running")
|
||||||
|
change_resource_limit(config.soft_file_limit)
|
||||||
|
if config.gc_thresholds:
|
||||||
|
gc.set_threshold(*config.gc_thresholds)
|
||||||
|
reactor.run()
|
||||||
|
|
||||||
|
def start():
|
||||||
|
ss.get_datastore().start_profiling()
|
||||||
|
ss.replicate()
|
||||||
|
|
||||||
|
reactor.callWhenRunning(start)
|
||||||
|
|
||||||
|
if config.worker_daemonize:
|
||||||
|
daemon = Daemonize(
|
||||||
|
app="synapse-media-repository",
|
||||||
|
pid=config.worker_pid_file,
|
||||||
|
action=run,
|
||||||
|
auto_close_fds=False,
|
||||||
|
verbose=True,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
daemon.start()
|
||||||
|
else:
|
||||||
|
run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
with LoggingContext("main"):
|
||||||
|
start(sys.argv[1:])
|
@ -80,11 +80,6 @@ class PusherSlaveStore(
|
|||||||
DataStore.get_profile_displayname.__func__
|
DataStore.get_profile_displayname.__func__
|
||||||
)
|
)
|
||||||
|
|
||||||
# XXX: This is a bit broken because we don't persist forgotten rooms
|
|
||||||
# in a way that they can be streamed. This means that we don't have a
|
|
||||||
# way to invalidate the forgotten rooms cache correctly.
|
|
||||||
# For now we expire the cache every 10 minutes.
|
|
||||||
BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
|
|
||||||
who_forgot_in_room = (
|
who_forgot_in_room = (
|
||||||
RoomMemberStore.__dict__["who_forgot_in_room"]
|
RoomMemberStore.__dict__["who_forgot_in_room"]
|
||||||
)
|
)
|
||||||
@ -168,7 +163,6 @@ class PusherServer(HomeServer):
|
|||||||
store = self.get_datastore()
|
store = self.get_datastore()
|
||||||
replication_url = self.config.worker_replication_url
|
replication_url = self.config.worker_replication_url
|
||||||
pusher_pool = self.get_pusherpool()
|
pusher_pool = self.get_pusherpool()
|
||||||
clock = self.get_clock()
|
|
||||||
|
|
||||||
def stop_pusher(user_id, app_id, pushkey):
|
def stop_pusher(user_id, app_id, pushkey):
|
||||||
key = "%s:%s" % (app_id, pushkey)
|
key = "%s:%s" % (app_id, pushkey)
|
||||||
@ -220,21 +214,11 @@ class PusherServer(HomeServer):
|
|||||||
min_stream_id, max_stream_id, affected_room_ids
|
min_stream_id, max_stream_id, affected_room_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
def expire_broken_caches():
|
|
||||||
store.who_forgot_in_room.invalidate_all()
|
|
||||||
|
|
||||||
next_expire_broken_caches_ms = 0
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
args = store.stream_positions()
|
args = store.stream_positions()
|
||||||
args["timeout"] = 30000
|
args["timeout"] = 30000
|
||||||
result = yield http_client.get_json(replication_url, args=args)
|
result = yield http_client.get_json(replication_url, args=args)
|
||||||
now_ms = clock.time_msec()
|
|
||||||
if now_ms > next_expire_broken_caches_ms:
|
|
||||||
expire_broken_caches()
|
|
||||||
next_expire_broken_caches_ms = (
|
|
||||||
now_ms + store.BROKEN_CACHE_EXPIRY_MS
|
|
||||||
)
|
|
||||||
yield store.process_replication(result)
|
yield store.process_replication(result)
|
||||||
poke_pushers(result)
|
poke_pushers(result)
|
||||||
except:
|
except:
|
||||||
|
@ -26,6 +26,7 @@ from synapse.http.site import SynapseSite
|
|||||||
from synapse.http.server import JsonResource
|
from synapse.http.server import JsonResource
|
||||||
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||||
from synapse.rest.client.v2_alpha import sync
|
from synapse.rest.client.v2_alpha import sync
|
||||||
|
from synapse.rest.client.v1 import events
|
||||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||||
from synapse.replication.slave.storage.events import SlavedEventStore
|
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||||
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||||
@ -74,11 +75,6 @@ class SynchrotronSlavedStore(
|
|||||||
BaseSlavedStore,
|
BaseSlavedStore,
|
||||||
ClientIpStore, # After BaseSlavedStore because the constructor is different
|
ClientIpStore, # After BaseSlavedStore because the constructor is different
|
||||||
):
|
):
|
||||||
# XXX: This is a bit broken because we don't persist forgotten rooms
|
|
||||||
# in a way that they can be streamed. This means that we don't have a
|
|
||||||
# way to invalidate the forgotten rooms cache correctly.
|
|
||||||
# For now we expire the cache every 10 minutes.
|
|
||||||
BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
|
|
||||||
who_forgot_in_room = (
|
who_forgot_in_room = (
|
||||||
RoomMemberStore.__dict__["who_forgot_in_room"]
|
RoomMemberStore.__dict__["who_forgot_in_room"]
|
||||||
)
|
)
|
||||||
@ -89,17 +85,23 @@ class SynchrotronSlavedStore(
|
|||||||
get_presence_list_accepted = PresenceStore.__dict__[
|
get_presence_list_accepted = PresenceStore.__dict__[
|
||||||
"get_presence_list_accepted"
|
"get_presence_list_accepted"
|
||||||
]
|
]
|
||||||
|
get_presence_list_observers_accepted = PresenceStore.__dict__[
|
||||||
|
"get_presence_list_observers_accepted"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
UPDATE_SYNCING_USERS_MS = 10 * 1000
|
UPDATE_SYNCING_USERS_MS = 10 * 1000
|
||||||
|
|
||||||
|
|
||||||
class SynchrotronPresence(object):
|
class SynchrotronPresence(object):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
self.is_mine_id = hs.is_mine_id
|
||||||
self.http_client = hs.get_simple_http_client()
|
self.http_client = hs.get_simple_http_client()
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.user_to_num_current_syncs = {}
|
self.user_to_num_current_syncs = {}
|
||||||
self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users"
|
self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users"
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
self.notifier = hs.get_notifier()
|
||||||
|
|
||||||
active_presence = self.store.take_presence_startup_info()
|
active_presence = self.store.take_presence_startup_info()
|
||||||
self.user_to_current_state = {
|
self.user_to_current_state = {
|
||||||
@ -119,11 +121,13 @@ class SynchrotronPresence(object):
|
|||||||
|
|
||||||
reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown)
|
reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown)
|
||||||
|
|
||||||
def set_state(self, user, state):
|
def set_state(self, user, state, ignore_status_msg=False):
|
||||||
# TODO Hows this supposed to work?
|
# TODO Hows this supposed to work?
|
||||||
pass
|
pass
|
||||||
|
|
||||||
get_states = PresenceHandler.get_states.__func__
|
get_states = PresenceHandler.get_states.__func__
|
||||||
|
get_state = PresenceHandler.get_state.__func__
|
||||||
|
_get_interested_parties = PresenceHandler._get_interested_parties.__func__
|
||||||
current_state_for_users = PresenceHandler.current_state_for_users.__func__
|
current_state_for_users = PresenceHandler.current_state_for_users.__func__
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -194,19 +198,39 @@ class SynchrotronPresence(object):
|
|||||||
self._need_to_send_sync = False
|
self._need_to_send_sync = False
|
||||||
yield self._send_syncing_users_now()
|
yield self._send_syncing_users_now()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def notify_from_replication(self, states, stream_id):
|
||||||
|
parties = yield self._get_interested_parties(
|
||||||
|
states, calculate_remote_hosts=False
|
||||||
|
)
|
||||||
|
room_ids_to_states, users_to_states, _ = parties
|
||||||
|
|
||||||
|
self.notifier.on_new_event(
|
||||||
|
"presence_key", stream_id, rooms=room_ids_to_states.keys(),
|
||||||
|
users=users_to_states.keys()
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def process_replication(self, result):
|
def process_replication(self, result):
|
||||||
stream = result.get("presence", {"rows": []})
|
stream = result.get("presence", {"rows": []})
|
||||||
|
states = []
|
||||||
for row in stream["rows"]:
|
for row in stream["rows"]:
|
||||||
(
|
(
|
||||||
position, user_id, state, last_active_ts,
|
position, user_id, state, last_active_ts,
|
||||||
last_federation_update_ts, last_user_sync_ts, status_msg,
|
last_federation_update_ts, last_user_sync_ts, status_msg,
|
||||||
currently_active
|
currently_active
|
||||||
) = row
|
) = row
|
||||||
self.user_to_current_state[user_id] = UserPresenceState(
|
state = UserPresenceState(
|
||||||
user_id, state, last_active_ts,
|
user_id, state, last_active_ts,
|
||||||
last_federation_update_ts, last_user_sync_ts, status_msg,
|
last_federation_update_ts, last_user_sync_ts, status_msg,
|
||||||
currently_active
|
currently_active
|
||||||
)
|
)
|
||||||
|
self.user_to_current_state[user_id] = state
|
||||||
|
states.append(state)
|
||||||
|
|
||||||
|
if states and "position" in stream:
|
||||||
|
stream_id = int(stream["position"])
|
||||||
|
yield self.notify_from_replication(states, stream_id)
|
||||||
|
|
||||||
|
|
||||||
class SynchrotronTyping(object):
|
class SynchrotronTyping(object):
|
||||||
@ -266,10 +290,12 @@ class SynchrotronServer(HomeServer):
|
|||||||
elif name == "client":
|
elif name == "client":
|
||||||
resource = JsonResource(self, canonical_json=False)
|
resource = JsonResource(self, canonical_json=False)
|
||||||
sync.register_servlets(self, resource)
|
sync.register_servlets(self, resource)
|
||||||
|
events.register_servlets(self, resource)
|
||||||
resources.update({
|
resources.update({
|
||||||
"/_matrix/client/r0": resource,
|
"/_matrix/client/r0": resource,
|
||||||
"/_matrix/client/unstable": resource,
|
"/_matrix/client/unstable": resource,
|
||||||
"/_matrix/client/v2_alpha": resource,
|
"/_matrix/client/v2_alpha": resource,
|
||||||
|
"/_matrix/client/api/v1": resource,
|
||||||
})
|
})
|
||||||
|
|
||||||
root_resource = create_resource_tree(resources, Resource())
|
root_resource = create_resource_tree(resources, Resource())
|
||||||
@ -307,15 +333,10 @@ class SynchrotronServer(HomeServer):
|
|||||||
http_client = self.get_simple_http_client()
|
http_client = self.get_simple_http_client()
|
||||||
store = self.get_datastore()
|
store = self.get_datastore()
|
||||||
replication_url = self.config.worker_replication_url
|
replication_url = self.config.worker_replication_url
|
||||||
clock = self.get_clock()
|
|
||||||
notifier = self.get_notifier()
|
notifier = self.get_notifier()
|
||||||
presence_handler = self.get_presence_handler()
|
presence_handler = self.get_presence_handler()
|
||||||
typing_handler = self.get_typing_handler()
|
typing_handler = self.get_typing_handler()
|
||||||
|
|
||||||
def expire_broken_caches():
|
|
||||||
store.who_forgot_in_room.invalidate_all()
|
|
||||||
store.get_presence_list_accepted.invalidate_all()
|
|
||||||
|
|
||||||
def notify_from_stream(
|
def notify_from_stream(
|
||||||
result, stream_name, stream_key, room=None, user=None
|
result, stream_name, stream_key, room=None, user=None
|
||||||
):
|
):
|
||||||
@ -377,22 +398,15 @@ class SynchrotronServer(HomeServer):
|
|||||||
result, "typing", "typing_key", room="room_id"
|
result, "typing", "typing_key", room="room_id"
|
||||||
)
|
)
|
||||||
|
|
||||||
next_expire_broken_caches_ms = 0
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
args = store.stream_positions()
|
args = store.stream_positions()
|
||||||
args.update(typing_handler.stream_positions())
|
args.update(typing_handler.stream_positions())
|
||||||
args["timeout"] = 30000
|
args["timeout"] = 30000
|
||||||
result = yield http_client.get_json(replication_url, args=args)
|
result = yield http_client.get_json(replication_url, args=args)
|
||||||
now_ms = clock.time_msec()
|
|
||||||
if now_ms > next_expire_broken_caches_ms:
|
|
||||||
expire_broken_caches()
|
|
||||||
next_expire_broken_caches_ms = (
|
|
||||||
now_ms + store.BROKEN_CACHE_EXPIRY_MS
|
|
||||||
)
|
|
||||||
yield store.process_replication(result)
|
yield store.process_replication(result)
|
||||||
typing_handler.process_replication(result)
|
typing_handler.process_replication(result)
|
||||||
presence_handler.process_replication(result)
|
yield presence_handler.process_replication(result)
|
||||||
notify(result)
|
notify(result)
|
||||||
except:
|
except:
|
||||||
logger.exception("Error replicating from %r", replication_url)
|
logger.exception("Error replicating from %r", replication_url)
|
||||||
|
@ -14,6 +14,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@ -79,13 +81,17 @@ class ApplicationService(object):
|
|||||||
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
|
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
|
||||||
|
|
||||||
def __init__(self, token, url=None, namespaces=None, hs_token=None,
|
def __init__(self, token, url=None, namespaces=None, hs_token=None,
|
||||||
sender=None, id=None):
|
sender=None, id=None, protocols=None):
|
||||||
self.token = token
|
self.token = token
|
||||||
self.url = url
|
self.url = url
|
||||||
self.hs_token = hs_token
|
self.hs_token = hs_token
|
||||||
self.sender = sender
|
self.sender = sender
|
||||||
self.namespaces = self._check_namespaces(namespaces)
|
self.namespaces = self._check_namespaces(namespaces)
|
||||||
self.id = id
|
self.id = id
|
||||||
|
if protocols:
|
||||||
|
self.protocols = set(protocols)
|
||||||
|
else:
|
||||||
|
self.protocols = set()
|
||||||
|
|
||||||
def _check_namespaces(self, namespaces):
|
def _check_namespaces(self, namespaces):
|
||||||
# Sanity check that it is of the form:
|
# Sanity check that it is of the form:
|
||||||
@ -138,65 +144,66 @@ class ApplicationService(object):
|
|||||||
return regex_obj["exclusive"]
|
return regex_obj["exclusive"]
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _matches_user(self, event, member_list):
|
@defer.inlineCallbacks
|
||||||
if (hasattr(event, "sender") and
|
def _matches_user(self, event, store):
|
||||||
self.is_interested_in_user(event.sender)):
|
if not event:
|
||||||
return True
|
defer.returnValue(False)
|
||||||
|
|
||||||
|
if self.is_interested_in_user(event.sender):
|
||||||
|
defer.returnValue(True)
|
||||||
# also check m.room.member state key
|
# also check m.room.member state key
|
||||||
if (hasattr(event, "type") and event.type == EventTypes.Member
|
if (event.type == EventTypes.Member and
|
||||||
and hasattr(event, "state_key")
|
self.is_interested_in_user(event.state_key)):
|
||||||
and self.is_interested_in_user(event.state_key)):
|
defer.returnValue(True)
|
||||||
return True
|
|
||||||
|
if not store:
|
||||||
|
defer.returnValue(False)
|
||||||
|
|
||||||
|
member_list = yield store.get_users_in_room(event.room_id)
|
||||||
|
|
||||||
# check joined member events
|
# check joined member events
|
||||||
for user_id in member_list:
|
for user_id in member_list:
|
||||||
if self.is_interested_in_user(user_id):
|
if self.is_interested_in_user(user_id):
|
||||||
return True
|
defer.returnValue(True)
|
||||||
return False
|
defer.returnValue(False)
|
||||||
|
|
||||||
def _matches_room_id(self, event):
|
def _matches_room_id(self, event):
|
||||||
if hasattr(event, "room_id"):
|
if hasattr(event, "room_id"):
|
||||||
return self.is_interested_in_room(event.room_id)
|
return self.is_interested_in_room(event.room_id)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _matches_aliases(self, event, alias_list):
|
@defer.inlineCallbacks
|
||||||
|
def _matches_aliases(self, event, store):
|
||||||
|
if not store or not event:
|
||||||
|
defer.returnValue(False)
|
||||||
|
|
||||||
|
alias_list = yield store.get_aliases_for_room(event.room_id)
|
||||||
for alias in alias_list:
|
for alias in alias_list:
|
||||||
if self.is_interested_in_alias(alias):
|
if self.is_interested_in_alias(alias):
|
||||||
return True
|
defer.returnValue(True)
|
||||||
return False
|
defer.returnValue(False)
|
||||||
|
|
||||||
def is_interested(self, event, restrict_to=None, aliases_for_event=None,
|
@defer.inlineCallbacks
|
||||||
member_list=None):
|
def is_interested(self, event, store=None):
|
||||||
"""Check if this service is interested in this event.
|
"""Check if this service is interested in this event.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event(Event): The event to check.
|
event(Event): The event to check.
|
||||||
restrict_to(str): The namespace to restrict regex tests to.
|
store(DataStore)
|
||||||
aliases_for_event(list): A list of all the known room aliases for
|
|
||||||
this event.
|
|
||||||
member_list(list): A list of all joined user_ids in this room.
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if this service would like to know about this event.
|
bool: True if this service would like to know about this event.
|
||||||
"""
|
"""
|
||||||
if aliases_for_event is None:
|
# Do cheap checks first
|
||||||
aliases_for_event = []
|
if self._matches_room_id(event):
|
||||||
if member_list is None:
|
defer.returnValue(True)
|
||||||
member_list = []
|
|
||||||
|
|
||||||
if restrict_to and restrict_to not in ApplicationService.NS_LIST:
|
if (yield self._matches_aliases(event, store)):
|
||||||
# this is a programming error, so fail early and raise a general
|
defer.returnValue(True)
|
||||||
# exception
|
|
||||||
raise Exception("Unexpected restrict_to value: %s". restrict_to)
|
|
||||||
|
|
||||||
if not restrict_to:
|
if (yield self._matches_user(event, store)):
|
||||||
return (self._matches_user(event, member_list)
|
defer.returnValue(True)
|
||||||
or self._matches_aliases(event, aliases_for_event)
|
|
||||||
or self._matches_room_id(event))
|
defer.returnValue(False)
|
||||||
elif restrict_to == ApplicationService.NS_ALIASES:
|
|
||||||
return self._matches_aliases(event, aliases_for_event)
|
|
||||||
elif restrict_to == ApplicationService.NS_ROOMS:
|
|
||||||
return self._matches_room_id(event)
|
|
||||||
elif restrict_to == ApplicationService.NS_USERS:
|
|
||||||
return self._matches_user(event, member_list)
|
|
||||||
|
|
||||||
def is_interested_in_user(self, user_id):
|
def is_interested_in_user(self, user_id):
|
||||||
return (
|
return (
|
||||||
@ -216,6 +223,9 @@ class ApplicationService(object):
|
|||||||
or user_id == self.sender
|
or user_id == self.sender
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def is_interested_in_protocol(self, protocol):
|
||||||
|
return protocol in self.protocols
|
||||||
|
|
||||||
def is_exclusive_alias(self, alias):
|
def is_exclusive_alias(self, alias):
|
||||||
return self._is_exclusive(ApplicationService.NS_ALIASES, alias)
|
return self._is_exclusive(ApplicationService.NS_ALIASES, alias)
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ from twisted.internet import defer
|
|||||||
from synapse.api.errors import CodeMessageException
|
from synapse.api.errors import CodeMessageException
|
||||||
from synapse.http.client import SimpleHttpClient
|
from synapse.http.client import SimpleHttpClient
|
||||||
from synapse.events.utils import serialize_event
|
from synapse.events.utils import serialize_event
|
||||||
|
from synapse.types import ThirdPartyEntityKind
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import urllib
|
import urllib
|
||||||
@ -24,6 +25,28 @@ import urllib
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_valid_3pe_result(r, field):
|
||||||
|
if not isinstance(r, dict):
|
||||||
|
return False
|
||||||
|
|
||||||
|
for k in (field, "protocol"):
|
||||||
|
if k not in r:
|
||||||
|
return False
|
||||||
|
if not isinstance(r[k], str):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if "fields" not in r:
|
||||||
|
return False
|
||||||
|
fields = r["fields"]
|
||||||
|
if not isinstance(fields, dict):
|
||||||
|
return False
|
||||||
|
for k in fields.keys():
|
||||||
|
if not isinstance(fields[k], str):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class ApplicationServiceApi(SimpleHttpClient):
|
class ApplicationServiceApi(SimpleHttpClient):
|
||||||
"""This class manages HS -> AS communications, including querying and
|
"""This class manages HS -> AS communications, including querying and
|
||||||
pushing.
|
pushing.
|
||||||
@ -71,6 +94,43 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||||||
logger.warning("query_alias to %s threw exception %s", uri, ex)
|
logger.warning("query_alias to %s threw exception %s", uri, ex)
|
||||||
defer.returnValue(False)
|
defer.returnValue(False)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def query_3pe(self, service, kind, protocol, fields):
|
||||||
|
if kind == ThirdPartyEntityKind.USER:
|
||||||
|
uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
|
||||||
|
required_field = "userid"
|
||||||
|
elif kind == ThirdPartyEntityKind.LOCATION:
|
||||||
|
uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
|
||||||
|
required_field = "alias"
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Unrecognised 'kind' argument %r to query_3pe()", kind
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = yield self.get_json(uri, fields)
|
||||||
|
if not isinstance(response, list):
|
||||||
|
logger.warning(
|
||||||
|
"query_3pe to %s returned an invalid response %r",
|
||||||
|
uri, response
|
||||||
|
)
|
||||||
|
defer.returnValue([])
|
||||||
|
|
||||||
|
ret = []
|
||||||
|
for r in response:
|
||||||
|
if _is_valid_3pe_result(r, field=required_field):
|
||||||
|
ret.append(r)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"query_3pe to %s returned an invalid result %r",
|
||||||
|
uri, r
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue(ret)
|
||||||
|
except Exception as ex:
|
||||||
|
logger.warning("query_3pe to %s threw exception %s", uri, ex)
|
||||||
|
defer.returnValue([])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def push_bulk(self, service, events, txn_id=None):
|
def push_bulk(self, service, events, txn_id=None):
|
||||||
events = self._serialize(events)
|
events = self._serialize(events)
|
||||||
|
@ -48,9 +48,12 @@ UP & quit +---------- YES SUCCESS
|
|||||||
This is all tied together by the AppServiceScheduler which DIs the required
|
This is all tied together by the AppServiceScheduler which DIs the required
|
||||||
components.
|
components.
|
||||||
"""
|
"""
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.appservice import ApplicationServiceState
|
from synapse.appservice import ApplicationServiceState
|
||||||
from twisted.internet import defer
|
from synapse.util.logcontext import preserve_fn
|
||||||
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -73,7 +76,7 @@ class ApplicationServiceScheduler(object):
|
|||||||
self.txn_ctrl = _TransactionController(
|
self.txn_ctrl = _TransactionController(
|
||||||
self.clock, self.store, self.as_api, create_recoverer
|
self.clock, self.store, self.as_api, create_recoverer
|
||||||
)
|
)
|
||||||
self.queuer = _ServiceQueuer(self.txn_ctrl)
|
self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def start(self):
|
def start(self):
|
||||||
@ -94,38 +97,36 @@ class _ServiceQueuer(object):
|
|||||||
this schedules any other events in the queue to run.
|
this schedules any other events in the queue to run.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, txn_ctrl):
|
def __init__(self, txn_ctrl, clock):
|
||||||
self.queued_events = {} # dict of {service_id: [events]}
|
self.queued_events = {} # dict of {service_id: [events]}
|
||||||
self.pending_requests = {} # dict of {service_id: Deferred}
|
self.requests_in_flight = set()
|
||||||
self.txn_ctrl = txn_ctrl
|
self.txn_ctrl = txn_ctrl
|
||||||
|
self.clock = clock
|
||||||
|
|
||||||
def enqueue(self, service, event):
|
def enqueue(self, service, event):
|
||||||
# if this service isn't being sent something
|
# if this service isn't being sent something
|
||||||
if not self.pending_requests.get(service.id):
|
self.queued_events.setdefault(service.id, []).append(event)
|
||||||
self._send_request(service, [event])
|
preserve_fn(self._send_request)(service)
|
||||||
else:
|
|
||||||
# add to queue for this service
|
|
||||||
if service.id not in self.queued_events:
|
|
||||||
self.queued_events[service.id] = []
|
|
||||||
self.queued_events[service.id].append(event)
|
|
||||||
|
|
||||||
def _send_request(self, service, events):
|
@defer.inlineCallbacks
|
||||||
# send request and add callbacks
|
def _send_request(self, service):
|
||||||
d = self.txn_ctrl.send(service, events)
|
if service.id in self.requests_in_flight:
|
||||||
d.addBoth(self._on_request_finish)
|
return
|
||||||
d.addErrback(self._on_request_fail)
|
|
||||||
self.pending_requests[service.id] = d
|
|
||||||
|
|
||||||
def _on_request_finish(self, service):
|
self.requests_in_flight.add(service.id)
|
||||||
self.pending_requests[service.id] = None
|
try:
|
||||||
# if there are queued events, then send them.
|
while True:
|
||||||
if (service.id in self.queued_events
|
events = self.queued_events.pop(service.id, [])
|
||||||
and len(self.queued_events[service.id]) > 0):
|
if not events:
|
||||||
self._send_request(service, self.queued_events[service.id])
|
return
|
||||||
self.queued_events[service.id] = []
|
|
||||||
|
|
||||||
def _on_request_fail(self, err):
|
with Measure(self.clock, "servicequeuer.send"):
|
||||||
logger.error("AS request failed: %s", err)
|
try:
|
||||||
|
yield self.txn_ctrl.send(service, events)
|
||||||
|
except:
|
||||||
|
logger.exception("AS request failed")
|
||||||
|
finally:
|
||||||
|
self.requests_in_flight.discard(service.id)
|
||||||
|
|
||||||
|
|
||||||
class _TransactionController(object):
|
class _TransactionController(object):
|
||||||
@ -149,14 +150,12 @@ class _TransactionController(object):
|
|||||||
if service_is_up:
|
if service_is_up:
|
||||||
sent = yield txn.send(self.as_api)
|
sent = yield txn.send(self.as_api)
|
||||||
if sent:
|
if sent:
|
||||||
txn.complete(self.store)
|
yield txn.complete(self.store)
|
||||||
else:
|
else:
|
||||||
self._start_recoverer(service)
|
preserve_fn(self._start_recoverer)(service)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
self._start_recoverer(service)
|
preserve_fn(self._start_recoverer)(service)
|
||||||
# request has finished
|
|
||||||
defer.returnValue(service)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_recovered(self, recoverer):
|
def on_recovered(self, recoverer):
|
||||||
|
@ -28,6 +28,7 @@ class AppServiceConfig(Config):
|
|||||||
|
|
||||||
def read_config(self, config):
|
def read_config(self, config):
|
||||||
self.app_service_config_files = config.get("app_service_config_files", [])
|
self.app_service_config_files = config.get("app_service_config_files", [])
|
||||||
|
self.notify_appservices = config.get("notify_appservices", True)
|
||||||
|
|
||||||
def default_config(cls, **kwargs):
|
def default_config(cls, **kwargs):
|
||||||
return """\
|
return """\
|
||||||
@ -122,6 +123,15 @@ def _load_appservice(hostname, as_info, config_filename):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Missing/bad type 'exclusive' key in %s", regex_obj
|
"Missing/bad type 'exclusive' key in %s", regex_obj
|
||||||
)
|
)
|
||||||
|
# protocols check
|
||||||
|
protocols = as_info.get("protocols")
|
||||||
|
if protocols:
|
||||||
|
# Because strings are lists in python
|
||||||
|
if isinstance(protocols, str) or not isinstance(protocols, list):
|
||||||
|
raise KeyError("Optional 'protocols' must be a list if present.")
|
||||||
|
for p in protocols:
|
||||||
|
if not isinstance(p, str):
|
||||||
|
raise KeyError("Bad value for 'protocols' item")
|
||||||
return ApplicationService(
|
return ApplicationService(
|
||||||
token=as_info["as_token"],
|
token=as_info["as_token"],
|
||||||
url=as_info["url"],
|
url=as_info["url"],
|
||||||
@ -129,4 +139,5 @@ def _load_appservice(hostname, as_info, config_filename):
|
|||||||
hs_token=as_info["hs_token"],
|
hs_token=as_info["hs_token"],
|
||||||
sender=user_id,
|
sender=user_id,
|
||||||
id=as_info["id"],
|
id=as_info["id"],
|
||||||
|
protocols=protocols,
|
||||||
)
|
)
|
||||||
|
@ -22,6 +22,7 @@ from synapse.util.logcontext import (
|
|||||||
preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
|
preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
|
||||||
preserve_fn
|
preserve_fn
|
||||||
)
|
)
|
||||||
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
@ -61,6 +62,10 @@ Attributes:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class KeyLookupError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Keyring(object):
|
class Keyring(object):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
@ -239,59 +244,60 @@ class Keyring(object):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def do_iterations():
|
def do_iterations():
|
||||||
merged_results = {}
|
with Measure(self.clock, "get_server_verify_keys"):
|
||||||
|
merged_results = {}
|
||||||
|
|
||||||
missing_keys = {}
|
|
||||||
for verify_request in verify_requests:
|
|
||||||
missing_keys.setdefault(verify_request.server_name, set()).update(
|
|
||||||
verify_request.key_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
for fn in key_fetch_fns:
|
|
||||||
results = yield fn(missing_keys.items())
|
|
||||||
merged_results.update(results)
|
|
||||||
|
|
||||||
# We now need to figure out which verify requests we have keys
|
|
||||||
# for and which we don't
|
|
||||||
missing_keys = {}
|
missing_keys = {}
|
||||||
requests_missing_keys = []
|
|
||||||
for verify_request in verify_requests:
|
for verify_request in verify_requests:
|
||||||
server_name = verify_request.server_name
|
missing_keys.setdefault(verify_request.server_name, set()).update(
|
||||||
result_keys = merged_results[server_name]
|
verify_request.key_ids
|
||||||
|
)
|
||||||
|
|
||||||
if verify_request.deferred.called:
|
for fn in key_fetch_fns:
|
||||||
# We've already called this deferred, which probably
|
results = yield fn(missing_keys.items())
|
||||||
# means that we've already found a key for it.
|
merged_results.update(results)
|
||||||
continue
|
|
||||||
|
|
||||||
for key_id in verify_request.key_ids:
|
# We now need to figure out which verify requests we have keys
|
||||||
if key_id in result_keys:
|
# for and which we don't
|
||||||
with PreserveLoggingContext():
|
missing_keys = {}
|
||||||
verify_request.deferred.callback((
|
requests_missing_keys = []
|
||||||
server_name,
|
for verify_request in verify_requests:
|
||||||
key_id,
|
server_name = verify_request.server_name
|
||||||
result_keys[key_id],
|
result_keys = merged_results[server_name]
|
||||||
))
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# The else block is only reached if the loop above
|
|
||||||
# doesn't break.
|
|
||||||
missing_keys.setdefault(server_name, set()).update(
|
|
||||||
verify_request.key_ids
|
|
||||||
)
|
|
||||||
requests_missing_keys.append(verify_request)
|
|
||||||
|
|
||||||
if not missing_keys:
|
if verify_request.deferred.called:
|
||||||
break
|
# We've already called this deferred, which probably
|
||||||
|
# means that we've already found a key for it.
|
||||||
|
continue
|
||||||
|
|
||||||
for verify_request in requests_missing_keys.values():
|
for key_id in verify_request.key_ids:
|
||||||
verify_request.deferred.errback(SynapseError(
|
if key_id in result_keys:
|
||||||
401,
|
with PreserveLoggingContext():
|
||||||
"No key for %s with id %s" % (
|
verify_request.deferred.callback((
|
||||||
verify_request.server_name, verify_request.key_ids,
|
server_name,
|
||||||
),
|
key_id,
|
||||||
Codes.UNAUTHORIZED,
|
result_keys[key_id],
|
||||||
))
|
))
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# The else block is only reached if the loop above
|
||||||
|
# doesn't break.
|
||||||
|
missing_keys.setdefault(server_name, set()).update(
|
||||||
|
verify_request.key_ids
|
||||||
|
)
|
||||||
|
requests_missing_keys.append(verify_request)
|
||||||
|
|
||||||
|
if not missing_keys:
|
||||||
|
break
|
||||||
|
|
||||||
|
for verify_request in requests_missing_keys.values():
|
||||||
|
verify_request.deferred.errback(SynapseError(
|
||||||
|
401,
|
||||||
|
"No key for %s with id %s" % (
|
||||||
|
verify_request.server_name, verify_request.key_ids,
|
||||||
|
),
|
||||||
|
Codes.UNAUTHORIZED,
|
||||||
|
))
|
||||||
|
|
||||||
def on_err(err):
|
def on_err(err):
|
||||||
for verify_request in verify_requests:
|
for verify_request in verify_requests:
|
||||||
@ -302,15 +308,15 @@ class Keyring(object):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_keys_from_store(self, server_name_and_key_ids):
|
def get_keys_from_store(self, server_name_and_key_ids):
|
||||||
res = yield defer.gatherResults(
|
res = yield preserve_context_over_deferred(defer.gatherResults(
|
||||||
[
|
[
|
||||||
self.store.get_server_verify_keys(
|
preserve_fn(self.store.get_server_verify_keys)(
|
||||||
server_name, key_ids
|
server_name, key_ids
|
||||||
).addCallback(lambda ks, server: (server, ks), server_name)
|
).addCallback(lambda ks, server: (server, ks), server_name)
|
||||||
for server_name, key_ids in server_name_and_key_ids
|
for server_name, key_ids in server_name_and_key_ids
|
||||||
],
|
],
|
||||||
consumeErrors=True,
|
consumeErrors=True,
|
||||||
).addErrback(unwrapFirstError)
|
)).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
defer.returnValue(dict(res))
|
defer.returnValue(dict(res))
|
||||||
|
|
||||||
@ -331,13 +337,13 @@ class Keyring(object):
|
|||||||
)
|
)
|
||||||
defer.returnValue({})
|
defer.returnValue({})
|
||||||
|
|
||||||
results = yield defer.gatherResults(
|
results = yield preserve_context_over_deferred(defer.gatherResults(
|
||||||
[
|
[
|
||||||
get_key(p_name, p_keys)
|
preserve_fn(get_key)(p_name, p_keys)
|
||||||
for p_name, p_keys in self.perspective_servers.items()
|
for p_name, p_keys in self.perspective_servers.items()
|
||||||
],
|
],
|
||||||
consumeErrors=True,
|
consumeErrors=True,
|
||||||
).addErrback(unwrapFirstError)
|
)).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
union_of_keys = {}
|
union_of_keys = {}
|
||||||
for result in results:
|
for result in results:
|
||||||
@ -363,7 +369,7 @@ class Keyring(object):
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Unable to getting key %r for %r directly: %s %s",
|
"Unable to get key %r for %r directly: %s %s",
|
||||||
key_ids, server_name,
|
key_ids, server_name,
|
||||||
type(e).__name__, str(e.message),
|
type(e).__name__, str(e.message),
|
||||||
)
|
)
|
||||||
@ -377,13 +383,13 @@ class Keyring(object):
|
|||||||
|
|
||||||
defer.returnValue(keys)
|
defer.returnValue(keys)
|
||||||
|
|
||||||
results = yield defer.gatherResults(
|
results = yield preserve_context_over_deferred(defer.gatherResults(
|
||||||
[
|
[
|
||||||
get_key(server_name, key_ids)
|
preserve_fn(get_key)(server_name, key_ids)
|
||||||
for server_name, key_ids in server_name_and_key_ids
|
for server_name, key_ids in server_name_and_key_ids
|
||||||
],
|
],
|
||||||
consumeErrors=True,
|
consumeErrors=True,
|
||||||
).addErrback(unwrapFirstError)
|
)).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
merged = {}
|
merged = {}
|
||||||
for result in results:
|
for result in results:
|
||||||
@ -425,7 +431,7 @@ class Keyring(object):
|
|||||||
for response in responses:
|
for response in responses:
|
||||||
if (u"signatures" not in response
|
if (u"signatures" not in response
|
||||||
or perspective_name not in response[u"signatures"]):
|
or perspective_name not in response[u"signatures"]):
|
||||||
raise ValueError(
|
raise KeyLookupError(
|
||||||
"Key response not signed by perspective server"
|
"Key response not signed by perspective server"
|
||||||
" %r" % (perspective_name,)
|
" %r" % (perspective_name,)
|
||||||
)
|
)
|
||||||
@ -448,7 +454,7 @@ class Keyring(object):
|
|||||||
list(response[u"signatures"][perspective_name]),
|
list(response[u"signatures"][perspective_name]),
|
||||||
list(perspective_keys)
|
list(perspective_keys)
|
||||||
)
|
)
|
||||||
raise ValueError(
|
raise KeyLookupError(
|
||||||
"Response not signed with a known key for perspective"
|
"Response not signed with a known key for perspective"
|
||||||
" server %r" % (perspective_name,)
|
" server %r" % (perspective_name,)
|
||||||
)
|
)
|
||||||
@ -460,9 +466,9 @@ class Keyring(object):
|
|||||||
for server_name, response_keys in processed_response.items():
|
for server_name, response_keys in processed_response.items():
|
||||||
keys.setdefault(server_name, {}).update(response_keys)
|
keys.setdefault(server_name, {}).update(response_keys)
|
||||||
|
|
||||||
yield defer.gatherResults(
|
yield preserve_context_over_deferred(defer.gatherResults(
|
||||||
[
|
[
|
||||||
self.store_keys(
|
preserve_fn(self.store_keys)(
|
||||||
server_name=server_name,
|
server_name=server_name,
|
||||||
from_server=perspective_name,
|
from_server=perspective_name,
|
||||||
verify_keys=response_keys,
|
verify_keys=response_keys,
|
||||||
@ -470,7 +476,7 @@ class Keyring(object):
|
|||||||
for server_name, response_keys in keys.items()
|
for server_name, response_keys in keys.items()
|
||||||
],
|
],
|
||||||
consumeErrors=True
|
consumeErrors=True
|
||||||
).addErrback(unwrapFirstError)
|
)).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
defer.returnValue(keys)
|
defer.returnValue(keys)
|
||||||
|
|
||||||
@ -491,10 +497,10 @@ class Keyring(object):
|
|||||||
|
|
||||||
if (u"signatures" not in response
|
if (u"signatures" not in response
|
||||||
or server_name not in response[u"signatures"]):
|
or server_name not in response[u"signatures"]):
|
||||||
raise ValueError("Key response not signed by remote server")
|
raise KeyLookupError("Key response not signed by remote server")
|
||||||
|
|
||||||
if "tls_fingerprints" not in response:
|
if "tls_fingerprints" not in response:
|
||||||
raise ValueError("Key response missing TLS fingerprints")
|
raise KeyLookupError("Key response missing TLS fingerprints")
|
||||||
|
|
||||||
certificate_bytes = crypto.dump_certificate(
|
certificate_bytes = crypto.dump_certificate(
|
||||||
crypto.FILETYPE_ASN1, tls_certificate
|
crypto.FILETYPE_ASN1, tls_certificate
|
||||||
@ -508,7 +514,7 @@ class Keyring(object):
|
|||||||
response_sha256_fingerprints.add(fingerprint[u"sha256"])
|
response_sha256_fingerprints.add(fingerprint[u"sha256"])
|
||||||
|
|
||||||
if sha256_fingerprint_b64 not in response_sha256_fingerprints:
|
if sha256_fingerprint_b64 not in response_sha256_fingerprints:
|
||||||
raise ValueError("TLS certificate not allowed by fingerprints")
|
raise KeyLookupError("TLS certificate not allowed by fingerprints")
|
||||||
|
|
||||||
response_keys = yield self.process_v2_response(
|
response_keys = yield self.process_v2_response(
|
||||||
from_server=server_name,
|
from_server=server_name,
|
||||||
@ -518,7 +524,7 @@ class Keyring(object):
|
|||||||
|
|
||||||
keys.update(response_keys)
|
keys.update(response_keys)
|
||||||
|
|
||||||
yield defer.gatherResults(
|
yield preserve_context_over_deferred(defer.gatherResults(
|
||||||
[
|
[
|
||||||
preserve_fn(self.store_keys)(
|
preserve_fn(self.store_keys)(
|
||||||
server_name=key_server_name,
|
server_name=key_server_name,
|
||||||
@ -528,7 +534,7 @@ class Keyring(object):
|
|||||||
for key_server_name, verify_keys in keys.items()
|
for key_server_name, verify_keys in keys.items()
|
||||||
],
|
],
|
||||||
consumeErrors=True
|
consumeErrors=True
|
||||||
).addErrback(unwrapFirstError)
|
)).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
defer.returnValue(keys)
|
defer.returnValue(keys)
|
||||||
|
|
||||||
@ -560,14 +566,14 @@ class Keyring(object):
|
|||||||
server_name = response_json["server_name"]
|
server_name = response_json["server_name"]
|
||||||
if only_from_server:
|
if only_from_server:
|
||||||
if server_name != from_server:
|
if server_name != from_server:
|
||||||
raise ValueError(
|
raise KeyLookupError(
|
||||||
"Expected a response for server %r not %r" % (
|
"Expected a response for server %r not %r" % (
|
||||||
from_server, server_name
|
from_server, server_name
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
for key_id in response_json["signatures"].get(server_name, {}):
|
for key_id in response_json["signatures"].get(server_name, {}):
|
||||||
if key_id not in response_json["verify_keys"]:
|
if key_id not in response_json["verify_keys"]:
|
||||||
raise ValueError(
|
raise KeyLookupError(
|
||||||
"Key response must include verification keys for all"
|
"Key response must include verification keys for all"
|
||||||
" signatures"
|
" signatures"
|
||||||
)
|
)
|
||||||
@ -594,7 +600,7 @@ class Keyring(object):
|
|||||||
response_keys.update(verify_keys)
|
response_keys.update(verify_keys)
|
||||||
response_keys.update(old_verify_keys)
|
response_keys.update(old_verify_keys)
|
||||||
|
|
||||||
yield defer.gatherResults(
|
yield preserve_context_over_deferred(defer.gatherResults(
|
||||||
[
|
[
|
||||||
preserve_fn(self.store.store_server_keys_json)(
|
preserve_fn(self.store.store_server_keys_json)(
|
||||||
server_name=server_name,
|
server_name=server_name,
|
||||||
@ -607,7 +613,7 @@ class Keyring(object):
|
|||||||
for key_id in updated_key_ids
|
for key_id in updated_key_ids
|
||||||
],
|
],
|
||||||
consumeErrors=True,
|
consumeErrors=True,
|
||||||
).addErrback(unwrapFirstError)
|
)).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
results[server_name] = response_keys
|
results[server_name] = response_keys
|
||||||
|
|
||||||
@ -635,15 +641,15 @@ class Keyring(object):
|
|||||||
|
|
||||||
if ("signatures" not in response
|
if ("signatures" not in response
|
||||||
or server_name not in response["signatures"]):
|
or server_name not in response["signatures"]):
|
||||||
raise ValueError("Key response not signed by remote server")
|
raise KeyLookupError("Key response not signed by remote server")
|
||||||
|
|
||||||
if "tls_certificate" not in response:
|
if "tls_certificate" not in response:
|
||||||
raise ValueError("Key response missing TLS certificate")
|
raise KeyLookupError("Key response missing TLS certificate")
|
||||||
|
|
||||||
tls_certificate_b64 = response["tls_certificate"]
|
tls_certificate_b64 = response["tls_certificate"]
|
||||||
|
|
||||||
if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
|
if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
|
||||||
raise ValueError("TLS certificate doesn't match")
|
raise KeyLookupError("TLS certificate doesn't match")
|
||||||
|
|
||||||
# Cache the result in the datastore.
|
# Cache the result in the datastore.
|
||||||
|
|
||||||
@ -659,7 +665,7 @@ class Keyring(object):
|
|||||||
|
|
||||||
for key_id in response["signatures"][server_name]:
|
for key_id in response["signatures"][server_name]:
|
||||||
if key_id not in response["verify_keys"]:
|
if key_id not in response["verify_keys"]:
|
||||||
raise ValueError(
|
raise KeyLookupError(
|
||||||
"Key response must include verification keys for all"
|
"Key response must include verification keys for all"
|
||||||
" signatures"
|
" signatures"
|
||||||
)
|
)
|
||||||
@ -696,7 +702,7 @@ class Keyring(object):
|
|||||||
A deferred that completes when the keys are stored.
|
A deferred that completes when the keys are stored.
|
||||||
"""
|
"""
|
||||||
# TODO(markjh): Store whether the keys have expired.
|
# TODO(markjh): Store whether the keys have expired.
|
||||||
yield defer.gatherResults(
|
yield preserve_context_over_deferred(defer.gatherResults(
|
||||||
[
|
[
|
||||||
preserve_fn(self.store.store_server_verify_key)(
|
preserve_fn(self.store.store_server_verify_key)(
|
||||||
server_name, server_name, key.time_added, key
|
server_name, server_name, key.time_added, key
|
||||||
@ -704,4 +710,4 @@ class Keyring(object):
|
|||||||
for key_id, key in verify_keys.items()
|
for key_id, key in verify_keys.items()
|
||||||
],
|
],
|
||||||
consumeErrors=True,
|
consumeErrors=True,
|
||||||
).addErrback(unwrapFirstError)
|
)).addErrback(unwrapFirstError)
|
||||||
|
@ -88,6 +88,8 @@ def prune_event(event):
|
|||||||
|
|
||||||
if "age_ts" in event.unsigned:
|
if "age_ts" in event.unsigned:
|
||||||
allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"]
|
allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"]
|
||||||
|
if "replaces_state" in event.unsigned:
|
||||||
|
allowed_fields["unsigned"]["replaces_state"] = event.unsigned["replaces_state"]
|
||||||
|
|
||||||
return type(event)(
|
return type(event)(
|
||||||
allowed_fields,
|
allowed_fields,
|
||||||
|
@ -23,6 +23,7 @@ from synapse.crypto.event_signing import check_event_content_hash
|
|||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
|
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError
|
||||||
|
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@ -102,10 +103,10 @@ class FederationBase(object):
|
|||||||
warn, pdu
|
warn, pdu
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_pdus = yield defer.gatherResults(
|
valid_pdus = yield preserve_context_over_deferred(defer.gatherResults(
|
||||||
deferreds,
|
deferreds,
|
||||||
consumeErrors=True
|
consumeErrors=True
|
||||||
).addErrback(unwrapFirstError)
|
)).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
if include_none:
|
if include_none:
|
||||||
defer.returnValue(valid_pdus)
|
defer.returnValue(valid_pdus)
|
||||||
@ -129,7 +130,7 @@ class FederationBase(object):
|
|||||||
for pdu in pdus
|
for pdu in pdus
|
||||||
]
|
]
|
||||||
|
|
||||||
deferreds = self.keyring.verify_json_objects_for_server([
|
deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([
|
||||||
(p.origin, p.get_pdu_json())
|
(p.origin, p.get_pdu_json())
|
||||||
for p in redacted_pdus
|
for p in redacted_pdus
|
||||||
])
|
])
|
||||||
|
@ -27,6 +27,7 @@ from synapse.util import unwrapFirstError
|
|||||||
from synapse.util.async import concurrently_execute
|
from synapse.util.async import concurrently_execute
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
|
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||||
from synapse.events import FrozenEvent
|
from synapse.events import FrozenEvent
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
@ -51,10 +52,34 @@ sent_edus_counter = metrics.register_counter("sent_edus")
|
|||||||
sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
|
sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
|
||||||
|
|
||||||
|
|
||||||
|
PDU_RETRY_TIME_MS = 1 * 60 * 1000
|
||||||
|
|
||||||
|
|
||||||
class FederationClient(FederationBase):
|
class FederationClient(FederationBase):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(FederationClient, self).__init__(hs)
|
super(FederationClient, self).__init__(hs)
|
||||||
|
|
||||||
|
self.pdu_destination_tried = {}
|
||||||
|
self._clock.looping_call(
|
||||||
|
self._clear_tried_cache, 60 * 1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _clear_tried_cache(self):
|
||||||
|
"""Clear pdu_destination_tried cache"""
|
||||||
|
now = self._clock.time_msec()
|
||||||
|
|
||||||
|
old_dict = self.pdu_destination_tried
|
||||||
|
self.pdu_destination_tried = {}
|
||||||
|
|
||||||
|
for event_id, destination_dict in old_dict.items():
|
||||||
|
destination_dict = {
|
||||||
|
dest: time
|
||||||
|
for dest, time in destination_dict.items()
|
||||||
|
if time + PDU_RETRY_TIME_MS > now
|
||||||
|
}
|
||||||
|
if destination_dict:
|
||||||
|
self.pdu_destination_tried[event_id] = destination_dict
|
||||||
|
|
||||||
def start_get_pdu_cache(self):
|
def start_get_pdu_cache(self):
|
||||||
self._get_pdu_cache = ExpiringCache(
|
self._get_pdu_cache = ExpiringCache(
|
||||||
cache_name="get_pdu_cache",
|
cache_name="get_pdu_cache",
|
||||||
@ -201,10 +226,10 @@ class FederationClient(FederationBase):
|
|||||||
]
|
]
|
||||||
|
|
||||||
# FIXME: We should handle signature failures more gracefully.
|
# FIXME: We should handle signature failures more gracefully.
|
||||||
pdus[:] = yield defer.gatherResults(
|
pdus[:] = yield preserve_context_over_deferred(defer.gatherResults(
|
||||||
self._check_sigs_and_hashes(pdus),
|
self._check_sigs_and_hashes(pdus),
|
||||||
consumeErrors=True,
|
consumeErrors=True,
|
||||||
).addErrback(unwrapFirstError)
|
)).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
defer.returnValue(pdus)
|
defer.returnValue(pdus)
|
||||||
|
|
||||||
@ -240,8 +265,15 @@ class FederationClient(FederationBase):
|
|||||||
if ev:
|
if ev:
|
||||||
defer.returnValue(ev)
|
defer.returnValue(ev)
|
||||||
|
|
||||||
|
pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
|
||||||
|
|
||||||
pdu = None
|
pdu = None
|
||||||
for destination in destinations:
|
for destination in destinations:
|
||||||
|
now = self._clock.time_msec()
|
||||||
|
last_attempt = pdu_attempts.get(destination, 0)
|
||||||
|
if last_attempt + PDU_RETRY_TIME_MS > now:
|
||||||
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
limiter = yield get_retry_limiter(
|
limiter = yield get_retry_limiter(
|
||||||
destination,
|
destination,
|
||||||
@ -269,25 +301,19 @@ class FederationClient(FederationBase):
|
|||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
|
pdu_attempts[destination] = now
|
||||||
|
|
||||||
except SynapseError as e:
|
except SynapseError as e:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Failed to get PDU %s from %s because %s",
|
"Failed to get PDU %s from %s because %s",
|
||||||
event_id, destination, e,
|
event_id, destination, e,
|
||||||
)
|
)
|
||||||
continue
|
|
||||||
except CodeMessageException as e:
|
|
||||||
if 400 <= e.code < 500:
|
|
||||||
raise
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Failed to get PDU %s from %s because %s",
|
|
||||||
event_id, destination, e,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
except NotRetryingDestination as e:
|
except NotRetryingDestination as e:
|
||||||
logger.info(e.message)
|
logger.info(e.message)
|
||||||
continue
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
pdu_attempts[destination] = now
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Failed to get PDU %s from %s because %s",
|
"Failed to get PDU %s from %s because %s",
|
||||||
event_id, destination, e,
|
event_id, destination, e,
|
||||||
@ -406,7 +432,7 @@ class FederationClient(FederationBase):
|
|||||||
events and the second is a list of event ids that we failed to fetch.
|
events and the second is a list of event ids that we failed to fetch.
|
||||||
"""
|
"""
|
||||||
if return_local:
|
if return_local:
|
||||||
seen_events = yield self.store.get_events(event_ids)
|
seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
|
||||||
signed_events = seen_events.values()
|
signed_events = seen_events.values()
|
||||||
else:
|
else:
|
||||||
seen_events = yield self.store.have_events(event_ids)
|
seen_events = yield self.store.have_events(event_ids)
|
||||||
@ -432,14 +458,16 @@ class FederationClient(FederationBase):
|
|||||||
batch = set(missing_events[i:i + batch_size])
|
batch = set(missing_events[i:i + batch_size])
|
||||||
|
|
||||||
deferreds = [
|
deferreds = [
|
||||||
self.get_pdu(
|
preserve_fn(self.get_pdu)(
|
||||||
destinations=random_server_list(),
|
destinations=random_server_list(),
|
||||||
event_id=e_id,
|
event_id=e_id,
|
||||||
)
|
)
|
||||||
for e_id in batch
|
for e_id in batch
|
||||||
]
|
]
|
||||||
|
|
||||||
res = yield defer.DeferredList(deferreds, consumeErrors=True)
|
res = yield preserve_context_over_deferred(
|
||||||
|
defer.DeferredList(deferreds, consumeErrors=True)
|
||||||
|
)
|
||||||
for success, result in res:
|
for success, result in res:
|
||||||
if success:
|
if success:
|
||||||
signed_events.append(result)
|
signed_events.append(result)
|
||||||
@ -828,14 +856,16 @@ class FederationClient(FederationBase):
|
|||||||
return srvs
|
return srvs
|
||||||
|
|
||||||
deferreds = [
|
deferreds = [
|
||||||
self.get_pdu(
|
preserve_fn(self.get_pdu)(
|
||||||
destinations=random_server_list(),
|
destinations=random_server_list(),
|
||||||
event_id=e_id,
|
event_id=e_id,
|
||||||
)
|
)
|
||||||
for e_id, depth in ordered_missing[:limit - len(signed_events)]
|
for e_id, depth in ordered_missing[:limit - len(signed_events)]
|
||||||
]
|
]
|
||||||
|
|
||||||
res = yield defer.DeferredList(deferreds, consumeErrors=True)
|
res = yield preserve_context_over_deferred(
|
||||||
|
defer.DeferredList(deferreds, consumeErrors=True)
|
||||||
|
)
|
||||||
for (result, val), (e_id, _) in zip(res, ordered_missing):
|
for (result, val), (e_id, _) in zip(res, ordered_missing):
|
||||||
if result and val:
|
if result and val:
|
||||||
signed_events.append(val)
|
signed_events.append(val)
|
||||||
|
@ -21,11 +21,11 @@ from .units import Transaction
|
|||||||
|
|
||||||
from synapse.api.errors import HttpResponseException
|
from synapse.api.errors import HttpResponseException
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logcontext import preserve_context_over_fn
|
||||||
from synapse.util.logcontext import PreserveLoggingContext
|
|
||||||
from synapse.util.retryutils import (
|
from synapse.util.retryutils import (
|
||||||
get_retry_limiter, NotRetryingDestination,
|
get_retry_limiter, NotRetryingDestination,
|
||||||
)
|
)
|
||||||
|
from synapse.util.metrics import measure_func
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@ -51,7 +51,7 @@ class TransactionQueue(object):
|
|||||||
|
|
||||||
self.transport_layer = transport_layer
|
self.transport_layer = transport_layer
|
||||||
|
|
||||||
self._clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
# Is a mapping from destinations -> deferreds. Used to keep track
|
# Is a mapping from destinations -> deferreds. Used to keep track
|
||||||
# of which destinations have transactions in flight and when they are
|
# of which destinations have transactions in flight and when they are
|
||||||
@ -82,7 +82,7 @@ class TransactionQueue(object):
|
|||||||
self.pending_failures_by_dest = {}
|
self.pending_failures_by_dest = {}
|
||||||
|
|
||||||
# HACK to get unique tx id
|
# HACK to get unique tx id
|
||||||
self._next_txn_id = int(self._clock.time_msec())
|
self._next_txn_id = int(self.clock.time_msec())
|
||||||
|
|
||||||
def can_send_to(self, destination):
|
def can_send_to(self, destination):
|
||||||
"""Can we send messages to the given server?
|
"""Can we send messages to the given server?
|
||||||
@ -119,266 +119,215 @@ class TransactionQueue(object):
|
|||||||
if not destinations:
|
if not destinations:
|
||||||
return
|
return
|
||||||
|
|
||||||
deferreds = []
|
|
||||||
|
|
||||||
for destination in destinations:
|
for destination in destinations:
|
||||||
deferred = defer.Deferred()
|
|
||||||
self.pending_pdus_by_dest.setdefault(destination, []).append(
|
self.pending_pdus_by_dest.setdefault(destination, []).append(
|
||||||
(pdu, deferred, order)
|
(pdu, order)
|
||||||
)
|
)
|
||||||
|
|
||||||
def chain(failure):
|
preserve_context_over_fn(
|
||||||
if not deferred.called:
|
self._attempt_new_transaction, destination
|
||||||
deferred.errback(failure)
|
)
|
||||||
|
|
||||||
def log_failure(f):
|
|
||||||
logger.warn("Failed to send pdu to %s: %s", destination, f.value)
|
|
||||||
|
|
||||||
deferred.addErrback(log_failure)
|
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
|
||||||
self._attempt_new_transaction(destination).addErrback(chain)
|
|
||||||
|
|
||||||
deferreds.append(deferred)
|
|
||||||
|
|
||||||
# NO inlineCallbacks
|
|
||||||
def enqueue_edu(self, edu):
|
def enqueue_edu(self, edu):
|
||||||
destination = edu.destination
|
destination = edu.destination
|
||||||
|
|
||||||
if not self.can_send_to(destination):
|
if not self.can_send_to(destination):
|
||||||
return
|
return
|
||||||
|
|
||||||
deferred = defer.Deferred()
|
self.pending_edus_by_dest.setdefault(destination, []).append(edu)
|
||||||
self.pending_edus_by_dest.setdefault(destination, []).append(
|
|
||||||
(edu, deferred)
|
preserve_context_over_fn(
|
||||||
|
self._attempt_new_transaction, destination
|
||||||
)
|
)
|
||||||
|
|
||||||
def chain(failure):
|
|
||||||
if not deferred.called:
|
|
||||||
deferred.errback(failure)
|
|
||||||
|
|
||||||
def log_failure(f):
|
|
||||||
logger.warn("Failed to send edu to %s: %s", destination, f.value)
|
|
||||||
|
|
||||||
deferred.addErrback(log_failure)
|
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
|
||||||
self._attempt_new_transaction(destination).addErrback(chain)
|
|
||||||
|
|
||||||
return deferred
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def enqueue_failure(self, failure, destination):
|
def enqueue_failure(self, failure, destination):
|
||||||
if destination == self.server_name or destination == "localhost":
|
if destination == self.server_name or destination == "localhost":
|
||||||
return
|
return
|
||||||
|
|
||||||
deferred = defer.Deferred()
|
|
||||||
|
|
||||||
if not self.can_send_to(destination):
|
if not self.can_send_to(destination):
|
||||||
return
|
return
|
||||||
|
|
||||||
self.pending_failures_by_dest.setdefault(
|
self.pending_failures_by_dest.setdefault(
|
||||||
destination, []
|
destination, []
|
||||||
).append(
|
).append(failure)
|
||||||
(failure, deferred)
|
|
||||||
|
preserve_context_over_fn(
|
||||||
|
self._attempt_new_transaction, destination
|
||||||
)
|
)
|
||||||
|
|
||||||
def chain(f):
|
|
||||||
if not deferred.called:
|
|
||||||
deferred.errback(f)
|
|
||||||
|
|
||||||
def log_failure(f):
|
|
||||||
logger.warn("Failed to send failure to %s: %s", destination, f.value)
|
|
||||||
|
|
||||||
deferred.addErrback(log_failure)
|
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
|
||||||
self._attempt_new_transaction(destination).addErrback(chain)
|
|
||||||
|
|
||||||
yield deferred
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
|
||||||
def _attempt_new_transaction(self, destination):
|
def _attempt_new_transaction(self, destination):
|
||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
|
while True:
|
||||||
|
# list of (pending_pdu, deferred, order)
|
||||||
|
if destination in self.pending_transactions:
|
||||||
|
# XXX: pending_transactions can get stuck on by a never-ending
|
||||||
|
# request at which point pending_pdus_by_dest just keeps growing.
|
||||||
|
# we need application-layer timeouts of some flavour of these
|
||||||
|
# requests
|
||||||
|
logger.debug(
|
||||||
|
"TX [%s] Transaction already in progress",
|
||||||
|
destination
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
# list of (pending_pdu, deferred, order)
|
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
|
||||||
if destination in self.pending_transactions:
|
pending_edus = self.pending_edus_by_dest.pop(destination, [])
|
||||||
# XXX: pending_transactions can get stuck on by a never-ending
|
pending_failures = self.pending_failures_by_dest.pop(destination, [])
|
||||||
# request at which point pending_pdus_by_dest just keeps growing.
|
|
||||||
# we need application-layer timeouts of some flavour of these
|
if pending_pdus:
|
||||||
# requests
|
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
|
||||||
logger.debug(
|
destination, len(pending_pdus))
|
||||||
"TX [%s] Transaction already in progress",
|
|
||||||
destination
|
if not pending_pdus and not pending_edus and not pending_failures:
|
||||||
|
logger.debug("TX [%s] Nothing to send", destination)
|
||||||
|
return
|
||||||
|
|
||||||
|
yield self._send_new_transaction(
|
||||||
|
destination, pending_pdus, pending_edus, pending_failures
|
||||||
)
|
)
|
||||||
return
|
|
||||||
|
|
||||||
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
|
@measure_func("_send_new_transaction")
|
||||||
pending_edus = self.pending_edus_by_dest.pop(destination, [])
|
@defer.inlineCallbacks
|
||||||
pending_failures = self.pending_failures_by_dest.pop(destination, [])
|
def _send_new_transaction(self, destination, pending_pdus, pending_edus,
|
||||||
|
pending_failures):
|
||||||
if pending_pdus:
|
|
||||||
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
|
|
||||||
destination, len(pending_pdus))
|
|
||||||
|
|
||||||
if not pending_pdus and not pending_edus and not pending_failures:
|
|
||||||
logger.debug("TX [%s] Nothing to send", destination)
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.pending_transactions[destination] = 1
|
|
||||||
|
|
||||||
logger.debug("TX [%s] _attempt_new_transaction", destination)
|
|
||||||
|
|
||||||
# Sort based on the order field
|
# Sort based on the order field
|
||||||
pending_pdus.sort(key=lambda t: t[2])
|
pending_pdus.sort(key=lambda t: t[1])
|
||||||
|
|
||||||
pdus = [x[0] for x in pending_pdus]
|
pdus = [x[0] for x in pending_pdus]
|
||||||
edus = [x[0] for x in pending_edus]
|
edus = pending_edus
|
||||||
failures = [x[0].get_dict() for x in pending_failures]
|
failures = [x.get_dict() for x in pending_failures]
|
||||||
deferreds = [
|
|
||||||
x[1]
|
|
||||||
for x in pending_pdus + pending_edus + pending_failures
|
|
||||||
]
|
|
||||||
|
|
||||||
txn_id = str(self._next_txn_id)
|
try:
|
||||||
|
self.pending_transactions[destination] = 1
|
||||||
|
|
||||||
limiter = yield get_retry_limiter(
|
logger.debug("TX [%s] _attempt_new_transaction", destination)
|
||||||
destination,
|
|
||||||
self._clock,
|
|
||||||
self.store,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(
|
txn_id = str(self._next_txn_id)
|
||||||
"TX [%s] {%s} Attempting new transaction"
|
|
||||||
" (pdus: %d, edus: %d, failures: %d)",
|
|
||||||
destination, txn_id,
|
|
||||||
len(pending_pdus),
|
|
||||||
len(pending_edus),
|
|
||||||
len(pending_failures)
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug("TX [%s] Persisting transaction...", destination)
|
limiter = yield get_retry_limiter(
|
||||||
|
destination,
|
||||||
transaction = Transaction.create_new(
|
self.clock,
|
||||||
origin_server_ts=int(self._clock.time_msec()),
|
self.store,
|
||||||
transaction_id=txn_id,
|
|
||||||
origin=self.server_name,
|
|
||||||
destination=destination,
|
|
||||||
pdus=pdus,
|
|
||||||
edus=edus,
|
|
||||||
pdu_failures=failures,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._next_txn_id += 1
|
|
||||||
|
|
||||||
yield self.transaction_actions.prepare_to_send(transaction)
|
|
||||||
|
|
||||||
logger.debug("TX [%s] Persisted transaction", destination)
|
|
||||||
logger.info(
|
|
||||||
"TX [%s] {%s} Sending transaction [%s],"
|
|
||||||
" (PDUs: %d, EDUs: %d, failures: %d)",
|
|
||||||
destination, txn_id,
|
|
||||||
transaction.transaction_id,
|
|
||||||
len(pending_pdus),
|
|
||||||
len(pending_edus),
|
|
||||||
len(pending_failures),
|
|
||||||
)
|
|
||||||
|
|
||||||
with limiter:
|
|
||||||
# Actually send the transaction
|
|
||||||
|
|
||||||
# FIXME (erikj): This is a bit of a hack to make the Pdu age
|
|
||||||
# keys work
|
|
||||||
def json_data_cb():
|
|
||||||
data = transaction.get_dict()
|
|
||||||
now = int(self._clock.time_msec())
|
|
||||||
if "pdus" in data:
|
|
||||||
for p in data["pdus"]:
|
|
||||||
if "age_ts" in p:
|
|
||||||
unsigned = p.setdefault("unsigned", {})
|
|
||||||
unsigned["age"] = now - int(p["age_ts"])
|
|
||||||
del p["age_ts"]
|
|
||||||
return data
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = yield self.transport_layer.send_transaction(
|
|
||||||
transaction, json_data_cb
|
|
||||||
)
|
|
||||||
code = 200
|
|
||||||
|
|
||||||
if response:
|
|
||||||
for e_id, r in response.get("pdus", {}).items():
|
|
||||||
if "error" in r:
|
|
||||||
logger.warn(
|
|
||||||
"Transaction returned error for %s: %s",
|
|
||||||
e_id, r,
|
|
||||||
)
|
|
||||||
except HttpResponseException as e:
|
|
||||||
code = e.code
|
|
||||||
response = e.response
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"TX [%s] {%s} got %d response",
|
|
||||||
destination, txn_id, code
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("TX [%s] Sent transaction", destination)
|
logger.debug(
|
||||||
logger.debug("TX [%s] Marking as delivered...", destination)
|
"TX [%s] {%s} Attempting new transaction"
|
||||||
|
" (pdus: %d, edus: %d, failures: %d)",
|
||||||
|
destination, txn_id,
|
||||||
|
len(pending_pdus),
|
||||||
|
len(pending_edus),
|
||||||
|
len(pending_failures)
|
||||||
|
)
|
||||||
|
|
||||||
yield self.transaction_actions.delivered(
|
logger.debug("TX [%s] Persisting transaction...", destination)
|
||||||
transaction, code, response
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug("TX [%s] Marked as delivered", destination)
|
transaction = Transaction.create_new(
|
||||||
|
origin_server_ts=int(self.clock.time_msec()),
|
||||||
|
transaction_id=txn_id,
|
||||||
|
origin=self.server_name,
|
||||||
|
destination=destination,
|
||||||
|
pdus=pdus,
|
||||||
|
edus=edus,
|
||||||
|
pdu_failures=failures,
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug("TX [%s] Yielding to callbacks...", destination)
|
self._next_txn_id += 1
|
||||||
|
|
||||||
for deferred in deferreds:
|
yield self.transaction_actions.prepare_to_send(transaction)
|
||||||
if code == 200:
|
|
||||||
deferred.callback(None)
|
|
||||||
else:
|
|
||||||
deferred.errback(RuntimeError("Got status %d" % code))
|
|
||||||
|
|
||||||
# Ensures we don't continue until all callbacks on that
|
logger.debug("TX [%s] Persisted transaction", destination)
|
||||||
# deferred have fired
|
logger.info(
|
||||||
try:
|
"TX [%s] {%s} Sending transaction [%s],"
|
||||||
yield deferred
|
" (PDUs: %d, EDUs: %d, failures: %d)",
|
||||||
except:
|
destination, txn_id,
|
||||||
pass
|
transaction.transaction_id,
|
||||||
|
len(pending_pdus),
|
||||||
|
len(pending_edus),
|
||||||
|
len(pending_failures),
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug("TX [%s] Yielded to callbacks", destination)
|
with limiter:
|
||||||
except NotRetryingDestination:
|
# Actually send the transaction
|
||||||
logger.info(
|
|
||||||
"TX [%s] not ready for retry yet - "
|
|
||||||
"dropping transaction for now",
|
|
||||||
destination,
|
|
||||||
)
|
|
||||||
except RuntimeError as e:
|
|
||||||
# We capture this here as there as nothing actually listens
|
|
||||||
# for this finishing functions deferred.
|
|
||||||
logger.warn(
|
|
||||||
"TX [%s] Problem in _attempt_transaction: %s",
|
|
||||||
destination,
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
# We capture this here as there as nothing actually listens
|
|
||||||
# for this finishing functions deferred.
|
|
||||||
logger.warn(
|
|
||||||
"TX [%s] Problem in _attempt_transaction: %s",
|
|
||||||
destination,
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
|
|
||||||
for deferred in deferreds:
|
# FIXME (erikj): This is a bit of a hack to make the Pdu age
|
||||||
if not deferred.called:
|
# keys work
|
||||||
deferred.errback(e)
|
def json_data_cb():
|
||||||
|
data = transaction.get_dict()
|
||||||
|
now = int(self.clock.time_msec())
|
||||||
|
if "pdus" in data:
|
||||||
|
for p in data["pdus"]:
|
||||||
|
if "age_ts" in p:
|
||||||
|
unsigned = p.setdefault("unsigned", {})
|
||||||
|
unsigned["age"] = now - int(p["age_ts"])
|
||||||
|
del p["age_ts"]
|
||||||
|
return data
|
||||||
|
|
||||||
finally:
|
try:
|
||||||
# We want to be *very* sure we delete this after we stop processing
|
response = yield self.transport_layer.send_transaction(
|
||||||
self.pending_transactions.pop(destination, None)
|
transaction, json_data_cb
|
||||||
|
)
|
||||||
|
code = 200
|
||||||
|
|
||||||
# Check to see if there is anything else to send.
|
if response:
|
||||||
self._attempt_new_transaction(destination)
|
for e_id, r in response.get("pdus", {}).items():
|
||||||
|
if "error" in r:
|
||||||
|
logger.warn(
|
||||||
|
"Transaction returned error for %s: %s",
|
||||||
|
e_id, r,
|
||||||
|
)
|
||||||
|
except HttpResponseException as e:
|
||||||
|
code = e.code
|
||||||
|
response = e.response
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"TX [%s] {%s} got %d response",
|
||||||
|
destination, txn_id, code
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("TX [%s] Sent transaction", destination)
|
||||||
|
logger.debug("TX [%s] Marking as delivered...", destination)
|
||||||
|
|
||||||
|
yield self.transaction_actions.delivered(
|
||||||
|
transaction, code, response
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("TX [%s] Marked as delivered", destination)
|
||||||
|
|
||||||
|
if code != 200:
|
||||||
|
for p in pdus:
|
||||||
|
logger.info(
|
||||||
|
"Failed to send event %s to %s", p.event_id, destination
|
||||||
|
)
|
||||||
|
except NotRetryingDestination:
|
||||||
|
logger.info(
|
||||||
|
"TX [%s] not ready for retry yet - "
|
||||||
|
"dropping transaction for now",
|
||||||
|
destination,
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
# We capture this here as there as nothing actually listens
|
||||||
|
# for this finishing functions deferred.
|
||||||
|
logger.warn(
|
||||||
|
"TX [%s] Problem in _attempt_transaction: %s",
|
||||||
|
destination,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
|
||||||
|
for p in pdus:
|
||||||
|
logger.info("Failed to send event %s to %s", p.event_id, destination)
|
||||||
|
except Exception as e:
|
||||||
|
# We capture this here as there as nothing actually listens
|
||||||
|
# for this finishing functions deferred.
|
||||||
|
logger.warn(
|
||||||
|
"TX [%s] Problem in _attempt_transaction: %s",
|
||||||
|
destination,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
|
||||||
|
for p in pdus:
|
||||||
|
logger.info("Failed to send event %s to %s", p.event_id, destination)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# We want to be *very* sure we delete this after we stop processing
|
||||||
|
self.pending_transactions.pop(destination, None)
|
||||||
|
@ -19,7 +19,6 @@ from .room import (
|
|||||||
)
|
)
|
||||||
from .room_member import RoomMemberHandler
|
from .room_member import RoomMemberHandler
|
||||||
from .message import MessageHandler
|
from .message import MessageHandler
|
||||||
from .events import EventStreamHandler, EventHandler
|
|
||||||
from .federation import FederationHandler
|
from .federation import FederationHandler
|
||||||
from .profile import ProfileHandler
|
from .profile import ProfileHandler
|
||||||
from .directory import DirectoryHandler
|
from .directory import DirectoryHandler
|
||||||
@ -53,8 +52,6 @@ class Handlers(object):
|
|||||||
self.message_handler = MessageHandler(hs)
|
self.message_handler = MessageHandler(hs)
|
||||||
self.room_creation_handler = RoomCreationHandler(hs)
|
self.room_creation_handler = RoomCreationHandler(hs)
|
||||||
self.room_member_handler = RoomMemberHandler(hs)
|
self.room_member_handler = RoomMemberHandler(hs)
|
||||||
self.event_stream_handler = EventStreamHandler(hs)
|
|
||||||
self.event_handler = EventHandler(hs)
|
|
||||||
self.federation_handler = FederationHandler(hs)
|
self.federation_handler = FederationHandler(hs)
|
||||||
self.profile_handler = ProfileHandler(hs)
|
self.profile_handler = ProfileHandler(hs)
|
||||||
self.directory_handler = DirectoryHandler(hs)
|
self.directory_handler = DirectoryHandler(hs)
|
||||||
|
@ -16,7 +16,8 @@
|
|||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.appservice import ApplicationService
|
from synapse.util.metrics import Measure
|
||||||
|
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@ -42,36 +43,73 @@ class ApplicationServicesHandler(object):
|
|||||||
self.appservice_api = hs.get_application_service_api()
|
self.appservice_api = hs.get_application_service_api()
|
||||||
self.scheduler = hs.get_application_service_scheduler()
|
self.scheduler = hs.get_application_service_scheduler()
|
||||||
self.started_scheduler = False
|
self.started_scheduler = False
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.notify_appservices = hs.config.notify_appservices
|
||||||
|
|
||||||
|
self.current_max = 0
|
||||||
|
self.is_processing = False
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def notify_interested_services(self, event):
|
def notify_interested_services(self, current_id):
|
||||||
"""Notifies (pushes) all application services interested in this event.
|
"""Notifies (pushes) all application services interested in this event.
|
||||||
|
|
||||||
Pushing is done asynchronously, so this method won't block for any
|
Pushing is done asynchronously, so this method won't block for any
|
||||||
prolonged length of time.
|
prolonged length of time.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event(Event): The event to push out to interested services.
|
current_id(int): The current maximum ID.
|
||||||
"""
|
"""
|
||||||
# Gather interested services
|
services = yield self.store.get_app_services()
|
||||||
services = yield self._get_services_for_event(event)
|
if not services or not self.notify_appservices:
|
||||||
if len(services) == 0:
|
return
|
||||||
return # no services need notifying
|
|
||||||
|
|
||||||
# Do we know this user exists? If not, poke the user query API for
|
self.current_max = max(self.current_max, current_id)
|
||||||
# all services which match that user regex. This needs to block as these
|
if self.is_processing:
|
||||||
# user queries need to be made BEFORE pushing the event.
|
return
|
||||||
yield self._check_user_exists(event.sender)
|
|
||||||
if event.type == EventTypes.Member:
|
|
||||||
yield self._check_user_exists(event.state_key)
|
|
||||||
|
|
||||||
if not self.started_scheduler:
|
with Measure(self.clock, "notify_interested_services"):
|
||||||
self.scheduler.start().addErrback(log_failure)
|
self.is_processing = True
|
||||||
self.started_scheduler = True
|
try:
|
||||||
|
upper_bound = self.current_max
|
||||||
|
limit = 100
|
||||||
|
while True:
|
||||||
|
upper_bound, events = yield self.store.get_new_events_for_appservice(
|
||||||
|
upper_bound, limit
|
||||||
|
)
|
||||||
|
|
||||||
# Fork off pushes to these services
|
if not events:
|
||||||
for service in services:
|
break
|
||||||
self.scheduler.submit_event_for_as(service, event)
|
|
||||||
|
for event in events:
|
||||||
|
# Gather interested services
|
||||||
|
services = yield self._get_services_for_event(event)
|
||||||
|
if len(services) == 0:
|
||||||
|
continue # no services need notifying
|
||||||
|
|
||||||
|
# Do we know this user exists? If not, poke the user
|
||||||
|
# query API for all services which match that user regex.
|
||||||
|
# This needs to block as these user queries need to be
|
||||||
|
# made BEFORE pushing the event.
|
||||||
|
yield self._check_user_exists(event.sender)
|
||||||
|
if event.type == EventTypes.Member:
|
||||||
|
yield self._check_user_exists(event.state_key)
|
||||||
|
|
||||||
|
if not self.started_scheduler:
|
||||||
|
self.scheduler.start().addErrback(log_failure)
|
||||||
|
self.started_scheduler = True
|
||||||
|
|
||||||
|
# Fork off pushes to these services
|
||||||
|
for service in services:
|
||||||
|
preserve_fn(self.scheduler.submit_event_for_as)(
|
||||||
|
service, event
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self.store.set_appservice_last_pos(upper_bound)
|
||||||
|
|
||||||
|
if len(events) < limit:
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
self.is_processing = False
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def query_user_exists(self, user_id):
|
def query_user_exists(self, user_id):
|
||||||
@ -104,11 +142,12 @@ class ApplicationServicesHandler(object):
|
|||||||
association can be found.
|
association can be found.
|
||||||
"""
|
"""
|
||||||
room_alias_str = room_alias.to_string()
|
room_alias_str = room_alias.to_string()
|
||||||
alias_query_services = yield self._get_services_for_event(
|
services = yield self.store.get_app_services()
|
||||||
event=None,
|
alias_query_services = [
|
||||||
restrict_to=ApplicationService.NS_ALIASES,
|
s for s in services if (
|
||||||
alias_list=[room_alias_str]
|
s.is_interested_in_alias(room_alias_str)
|
||||||
)
|
)
|
||||||
|
]
|
||||||
for alias_service in alias_query_services:
|
for alias_service in alias_query_services:
|
||||||
is_known_alias = yield self.appservice_api.query_alias(
|
is_known_alias = yield self.appservice_api.query_alias(
|
||||||
alias_service, room_alias_str
|
alias_service, room_alias_str
|
||||||
@ -121,34 +160,35 @@ class ApplicationServicesHandler(object):
|
|||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _get_services_for_event(self, event, restrict_to="", alias_list=None):
|
def query_3pe(self, kind, protocol, fields):
|
||||||
|
services = yield self._get_services_for_3pn(protocol)
|
||||||
|
|
||||||
|
results = yield preserve_context_over_deferred(defer.DeferredList([
|
||||||
|
preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields)
|
||||||
|
for service in services
|
||||||
|
], consumeErrors=True))
|
||||||
|
|
||||||
|
ret = []
|
||||||
|
for (success, result) in results:
|
||||||
|
if success:
|
||||||
|
ret.extend(result)
|
||||||
|
|
||||||
|
defer.returnValue(ret)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _get_services_for_event(self, event):
|
||||||
"""Retrieve a list of application services interested in this event.
|
"""Retrieve a list of application services interested in this event.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event(Event): The event to check. Can be None if alias_list is not.
|
event(Event): The event to check. Can be None if alias_list is not.
|
||||||
restrict_to(str): The namespace to restrict regex tests to.
|
|
||||||
alias_list: A list of aliases to get services for. If None, this
|
|
||||||
list is obtained from the database.
|
|
||||||
Returns:
|
Returns:
|
||||||
list<ApplicationService>: A list of services interested in this
|
list<ApplicationService>: A list of services interested in this
|
||||||
event based on the service regex.
|
event based on the service regex.
|
||||||
"""
|
"""
|
||||||
member_list = None
|
|
||||||
if hasattr(event, "room_id"):
|
|
||||||
# We need to know the aliases associated with this event.room_id,
|
|
||||||
# if any.
|
|
||||||
if not alias_list:
|
|
||||||
alias_list = yield self.store.get_aliases_for_room(
|
|
||||||
event.room_id
|
|
||||||
)
|
|
||||||
# We need to know the members associated with this event.room_id,
|
|
||||||
# if any.
|
|
||||||
member_list = yield self.store.get_users_in_room(event.room_id)
|
|
||||||
|
|
||||||
services = yield self.store.get_app_services()
|
services = yield self.store.get_app_services()
|
||||||
interested_list = [
|
interested_list = [
|
||||||
s for s in services if (
|
s for s in services if (
|
||||||
s.is_interested(event, restrict_to, alias_list, member_list)
|
yield s.is_interested(event, self.store)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
defer.returnValue(interested_list)
|
defer.returnValue(interested_list)
|
||||||
@ -163,6 +203,14 @@ class ApplicationServicesHandler(object):
|
|||||||
]
|
]
|
||||||
defer.returnValue(interested_list)
|
defer.returnValue(interested_list)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _get_services_for_3pn(self, protocol):
|
||||||
|
services = yield self.store.get_app_services()
|
||||||
|
interested_list = [
|
||||||
|
s for s in services if s.is_interested_in_protocol(protocol)
|
||||||
|
]
|
||||||
|
defer.returnValue(interested_list)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _is_unknown_user(self, user_id):
|
def _is_unknown_user(self, user_id):
|
||||||
if not self.is_mine_id(user_id):
|
if not self.is_mine_id(user_id):
|
||||||
|
@ -70,11 +70,11 @@ class AuthHandler(BaseHandler):
|
|||||||
self.ldap_uri = hs.config.ldap_uri
|
self.ldap_uri = hs.config.ldap_uri
|
||||||
self.ldap_start_tls = hs.config.ldap_start_tls
|
self.ldap_start_tls = hs.config.ldap_start_tls
|
||||||
self.ldap_base = hs.config.ldap_base
|
self.ldap_base = hs.config.ldap_base
|
||||||
self.ldap_filter = hs.config.ldap_filter
|
|
||||||
self.ldap_attributes = hs.config.ldap_attributes
|
self.ldap_attributes = hs.config.ldap_attributes
|
||||||
if self.ldap_mode == LDAPMode.SEARCH:
|
if self.ldap_mode == LDAPMode.SEARCH:
|
||||||
self.ldap_bind_dn = hs.config.ldap_bind_dn
|
self.ldap_bind_dn = hs.config.ldap_bind_dn
|
||||||
self.ldap_bind_password = hs.config.ldap_bind_password
|
self.ldap_bind_password = hs.config.ldap_bind_password
|
||||||
|
self.ldap_filter = hs.config.ldap_filter
|
||||||
|
|
||||||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||||
self.device_handler = hs.get_device_handler()
|
self.device_handler = hs.get_device_handler()
|
||||||
@ -660,7 +660,7 @@ class AuthHandler(BaseHandler):
|
|||||||
else:
|
else:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"ldap registration failed: unexpected (%d!=1) amount of results",
|
"ldap registration failed: unexpected (%d!=1) amount of results",
|
||||||
len(result)
|
len(conn.response)
|
||||||
)
|
)
|
||||||
defer.returnValue(False)
|
defer.returnValue(False)
|
||||||
|
|
||||||
@ -719,13 +719,14 @@ class AuthHandler(BaseHandler):
|
|||||||
return macaroon.serialize()
|
return macaroon.serialize()
|
||||||
|
|
||||||
def validate_short_term_login_token_and_get_user_id(self, login_token):
|
def validate_short_term_login_token_and_get_user_id(self, login_token):
|
||||||
|
auth_api = self.hs.get_auth()
|
||||||
try:
|
try:
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(login_token)
|
macaroon = pymacaroons.Macaroon.deserialize(login_token)
|
||||||
auth_api = self.hs.get_auth()
|
user_id = auth_api.get_user_id_from_macaroon(macaroon)
|
||||||
auth_api.validate_macaroon(macaroon, "login", True)
|
auth_api.validate_macaroon(macaroon, "login", True, user_id)
|
||||||
return self.get_user_from_macaroon(macaroon)
|
return user_id
|
||||||
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
|
except Exception:
|
||||||
raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN)
|
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
|
||||||
|
|
||||||
def _generate_base_macaroon(self, user_id):
|
def _generate_base_macaroon(self, user_id):
|
||||||
macaroon = pymacaroons.Macaroon(
|
macaroon = pymacaroons.Macaroon(
|
||||||
@ -736,21 +737,11 @@ class AuthHandler(BaseHandler):
|
|||||||
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
||||||
return macaroon
|
return macaroon
|
||||||
|
|
||||||
def get_user_from_macaroon(self, macaroon):
|
|
||||||
user_prefix = "user_id = "
|
|
||||||
for caveat in macaroon.caveats:
|
|
||||||
if caveat.caveat_id.startswith(user_prefix):
|
|
||||||
return caveat.caveat_id[len(user_prefix):]
|
|
||||||
raise AuthError(
|
|
||||||
self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token",
|
|
||||||
errcode=Codes.UNKNOWN_TOKEN
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def set_password(self, user_id, newpassword, requester=None):
|
def set_password(self, user_id, newpassword, requester=None):
|
||||||
password_hash = self.hash(newpassword)
|
password_hash = self.hash(newpassword)
|
||||||
|
|
||||||
except_access_token_ids = [requester.access_token_id] if requester else []
|
except_access_token_id = requester.access_token_id if requester else None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield self.store.user_set_password_hash(user_id, password_hash)
|
yield self.store.user_set_password_hash(user_id, password_hash)
|
||||||
@ -759,10 +750,10 @@ class AuthHandler(BaseHandler):
|
|||||||
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
|
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
|
||||||
raise e
|
raise e
|
||||||
yield self.store.user_delete_access_tokens(
|
yield self.store.user_delete_access_tokens(
|
||||||
user_id, except_access_token_ids
|
user_id, except_access_token_id
|
||||||
)
|
)
|
||||||
yield self.hs.get_pusherpool().remove_pushers_by_user(
|
yield self.hs.get_pusherpool().remove_pushers_by_user(
|
||||||
user_id, except_access_token_ids
|
user_id, except_access_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -26,7 +26,9 @@ from synapse.api.errors import (
|
|||||||
from synapse.api.constants import EventTypes, Membership, RejectedReason
|
from synapse.api.constants import EventTypes, Membership, RejectedReason
|
||||||
from synapse.events.validator import EventValidator
|
from synapse.events.validator import EventValidator
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError
|
||||||
from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
|
from synapse.util.logcontext import (
|
||||||
|
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred
|
||||||
|
)
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.util.frozenutils import unfreeze
|
from synapse.util.frozenutils import unfreeze
|
||||||
@ -249,7 +251,7 @@ class FederationHandler(BaseHandler):
|
|||||||
if ev.type != EventTypes.Member:
|
if ev.type != EventTypes.Member:
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
domain = UserID.from_string(ev.state_key).domain
|
domain = get_domain_from_id(ev.state_key)
|
||||||
except:
|
except:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -274,7 +276,7 @@ class FederationHandler(BaseHandler):
|
|||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def backfill(self, dest, room_id, limit, extremities=[]):
|
def backfill(self, dest, room_id, limit, extremities):
|
||||||
""" Trigger a backfill request to `dest` for the given `room_id`
|
""" Trigger a backfill request to `dest` for the given `room_id`
|
||||||
|
|
||||||
This will attempt to get more events from the remote. This may return
|
This will attempt to get more events from the remote. This may return
|
||||||
@ -284,9 +286,6 @@ class FederationHandler(BaseHandler):
|
|||||||
if dest == self.server_name:
|
if dest == self.server_name:
|
||||||
raise SynapseError(400, "Can't backfill from self.")
|
raise SynapseError(400, "Can't backfill from self.")
|
||||||
|
|
||||||
if not extremities:
|
|
||||||
extremities = yield self.store.get_oldest_events_in_room(room_id)
|
|
||||||
|
|
||||||
events = yield self.replication_layer.backfill(
|
events = yield self.replication_layer.backfill(
|
||||||
dest,
|
dest,
|
||||||
room_id,
|
room_id,
|
||||||
@ -364,9 +363,9 @@ class FederationHandler(BaseHandler):
|
|||||||
missing_auth - failed_to_fetch
|
missing_auth - failed_to_fetch
|
||||||
)
|
)
|
||||||
|
|
||||||
results = yield defer.gatherResults(
|
results = yield preserve_context_over_deferred(defer.gatherResults(
|
||||||
[
|
[
|
||||||
self.replication_layer.get_pdu(
|
preserve_fn(self.replication_layer.get_pdu)(
|
||||||
[dest],
|
[dest],
|
||||||
event_id,
|
event_id,
|
||||||
outlier=True,
|
outlier=True,
|
||||||
@ -375,10 +374,10 @@ class FederationHandler(BaseHandler):
|
|||||||
for event_id in missing_auth - failed_to_fetch
|
for event_id in missing_auth - failed_to_fetch
|
||||||
],
|
],
|
||||||
consumeErrors=True
|
consumeErrors=True
|
||||||
).addErrback(unwrapFirstError)
|
)).addErrback(unwrapFirstError)
|
||||||
auth_events.update({a.event_id: a for a in results})
|
auth_events.update({a.event_id: a for a in results if a})
|
||||||
required_auth.update(
|
required_auth.update(
|
||||||
a_id for event in results for a_id, _ in event.auth_events
|
a_id for event in results for a_id, _ in event.auth_events if event
|
||||||
)
|
)
|
||||||
missing_auth = required_auth - set(auth_events)
|
missing_auth = required_auth - set(auth_events)
|
||||||
|
|
||||||
@ -455,6 +454,10 @@ class FederationHandler(BaseHandler):
|
|||||||
)
|
)
|
||||||
max_depth = sorted_extremeties_tuple[0][1]
|
max_depth = sorted_extremeties_tuple[0][1]
|
||||||
|
|
||||||
|
# We don't want to specify too many extremities as it causes the backfill
|
||||||
|
# request URI to be too long.
|
||||||
|
extremities = dict(sorted_extremeties_tuple[:5])
|
||||||
|
|
||||||
if current_depth > max_depth:
|
if current_depth > max_depth:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Not backfilling as we don't need to. %d < %d",
|
"Not backfilling as we don't need to. %d < %d",
|
||||||
@ -551,10 +554,10 @@ class FederationHandler(BaseHandler):
|
|||||||
|
|
||||||
event_ids = list(extremities.keys())
|
event_ids = list(extremities.keys())
|
||||||
|
|
||||||
states = yield defer.gatherResults([
|
states = yield preserve_context_over_deferred(defer.gatherResults([
|
||||||
self.state_handler.resolve_state_groups(room_id, [e])
|
preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e])
|
||||||
for e in event_ids
|
for e in event_ids
|
||||||
])
|
]))
|
||||||
states = dict(zip(event_ids, [s[1] for s in states]))
|
states = dict(zip(event_ids, [s[1] for s in states]))
|
||||||
|
|
||||||
for e_id, _ in sorted_extremeties_tuple:
|
for e_id, _ in sorted_extremeties_tuple:
|
||||||
@ -1093,16 +1096,17 @@ class FederationHandler(BaseHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if event:
|
if event:
|
||||||
# FIXME: This is a temporary work around where we occasionally
|
if self.hs.is_mine_id(event.event_id):
|
||||||
# return events slightly differently than when they were
|
# FIXME: This is a temporary work around where we occasionally
|
||||||
# originally signed
|
# return events slightly differently than when they were
|
||||||
event.signatures.update(
|
# originally signed
|
||||||
compute_event_signature(
|
event.signatures.update(
|
||||||
event,
|
compute_event_signature(
|
||||||
self.hs.hostname,
|
event,
|
||||||
self.hs.config.signing_key[0]
|
self.hs.hostname,
|
||||||
|
self.hs.config.signing_key[0]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if do_auth:
|
if do_auth:
|
||||||
in_room = yield self.auth.check_host_in_room(
|
in_room = yield self.auth.check_host_in_room(
|
||||||
@ -1112,6 +1116,12 @@ class FederationHandler(BaseHandler):
|
|||||||
if not in_room:
|
if not in_room:
|
||||||
raise AuthError(403, "Host not in room.")
|
raise AuthError(403, "Host not in room.")
|
||||||
|
|
||||||
|
events = yield self._filter_events_for_server(
|
||||||
|
origin, event.room_id, [event]
|
||||||
|
)
|
||||||
|
|
||||||
|
event = events[0]
|
||||||
|
|
||||||
defer.returnValue(event)
|
defer.returnValue(event)
|
||||||
else:
|
else:
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
@ -1158,9 +1168,9 @@ class FederationHandler(BaseHandler):
|
|||||||
a bunch of outliers, but not a chunk of individual events that depend
|
a bunch of outliers, but not a chunk of individual events that depend
|
||||||
on each other for state calculations.
|
on each other for state calculations.
|
||||||
"""
|
"""
|
||||||
contexts = yield defer.gatherResults(
|
contexts = yield preserve_context_over_deferred(defer.gatherResults(
|
||||||
[
|
[
|
||||||
self._prep_event(
|
preserve_fn(self._prep_event)(
|
||||||
origin,
|
origin,
|
||||||
ev_info["event"],
|
ev_info["event"],
|
||||||
state=ev_info.get("state"),
|
state=ev_info.get("state"),
|
||||||
@ -1168,7 +1178,7 @@ class FederationHandler(BaseHandler):
|
|||||||
)
|
)
|
||||||
for ev_info in event_infos
|
for ev_info in event_infos
|
||||||
]
|
]
|
||||||
)
|
))
|
||||||
|
|
||||||
yield self.store.persist_events(
|
yield self.store.persist_events(
|
||||||
[
|
[
|
||||||
@ -1452,9 +1462,9 @@ class FederationHandler(BaseHandler):
|
|||||||
# Do auth conflict res.
|
# Do auth conflict res.
|
||||||
logger.info("Different auth: %s", different_auth)
|
logger.info("Different auth: %s", different_auth)
|
||||||
|
|
||||||
different_events = yield defer.gatherResults(
|
different_events = yield preserve_context_over_deferred(defer.gatherResults(
|
||||||
[
|
[
|
||||||
self.store.get_event(
|
preserve_fn(self.store.get_event)(
|
||||||
d,
|
d,
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
allow_rejected=False,
|
allow_rejected=False,
|
||||||
@ -1463,7 +1473,7 @@ class FederationHandler(BaseHandler):
|
|||||||
if d in have_events and not have_events[d]
|
if d in have_events and not have_events[d]
|
||||||
],
|
],
|
||||||
consumeErrors=True
|
consumeErrors=True
|
||||||
).addErrback(unwrapFirstError)
|
)).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
if different_events:
|
if different_events:
|
||||||
local_view = dict(auth_events)
|
local_view = dict(auth_events)
|
||||||
|
@ -28,7 +28,8 @@ from synapse.types import (
|
|||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError
|
||||||
from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock
|
from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock
|
||||||
from synapse.util.caches.snapshot_cache import SnapshotCache
|
from synapse.util.caches.snapshot_cache import SnapshotCache
|
||||||
from synapse.util.logcontext import preserve_fn
|
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||||
|
from synapse.util.metrics import measure_func
|
||||||
from synapse.visibility import filter_events_for_client
|
from synapse.visibility import filter_events_for_client
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
@ -502,15 +503,17 @@ class MessageHandler(BaseHandler):
|
|||||||
lambda states: states[event.event_id]
|
lambda states: states[event.event_id]
|
||||||
)
|
)
|
||||||
|
|
||||||
(messages, token), current_state = yield defer.gatherResults(
|
(messages, token), current_state = yield preserve_context_over_deferred(
|
||||||
[
|
defer.gatherResults(
|
||||||
self.store.get_recent_events_for_room(
|
[
|
||||||
event.room_id,
|
preserve_fn(self.store.get_recent_events_for_room)(
|
||||||
limit=limit,
|
event.room_id,
|
||||||
end_token=room_end_token,
|
limit=limit,
|
||||||
),
|
end_token=room_end_token,
|
||||||
deferred_room_state,
|
),
|
||||||
]
|
deferred_room_state,
|
||||||
|
]
|
||||||
|
)
|
||||||
).addErrback(unwrapFirstError)
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
messages = yield filter_events_for_client(
|
messages = yield filter_events_for_client(
|
||||||
@ -719,9 +722,9 @@ class MessageHandler(BaseHandler):
|
|||||||
|
|
||||||
presence, receipts, (messages, token) = yield defer.gatherResults(
|
presence, receipts, (messages, token) = yield defer.gatherResults(
|
||||||
[
|
[
|
||||||
get_presence(),
|
preserve_fn(get_presence)(),
|
||||||
get_receipts(),
|
preserve_fn(get_receipts)(),
|
||||||
self.store.get_recent_events_for_room(
|
preserve_fn(self.store.get_recent_events_for_room)(
|
||||||
room_id,
|
room_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
end_token=now_token.room_key,
|
end_token=now_token.room_key,
|
||||||
@ -755,6 +758,7 @@ class MessageHandler(BaseHandler):
|
|||||||
|
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
|
|
||||||
|
@measure_func("_create_new_client_event")
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _create_new_client_event(self, builder, prev_event_ids=None):
|
def _create_new_client_event(self, builder, prev_event_ids=None):
|
||||||
if prev_event_ids:
|
if prev_event_ids:
|
||||||
@ -806,6 +810,7 @@ class MessageHandler(BaseHandler):
|
|||||||
(event, context,)
|
(event, context,)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@measure_func("handle_new_client_event")
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def handle_new_client_event(
|
def handle_new_client_event(
|
||||||
self,
|
self,
|
||||||
@ -934,7 +939,7 @@ class MessageHandler(BaseHandler):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _notify():
|
def _notify():
|
||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
self.notifier.on_new_room_event(
|
yield self.notifier.on_new_room_event(
|
||||||
event, event_stream_id, max_stream_id,
|
event, event_stream_id, max_stream_id,
|
||||||
extra_users=extra_users
|
extra_users=extra_users
|
||||||
)
|
)
|
||||||
@ -944,6 +949,6 @@ class MessageHandler(BaseHandler):
|
|||||||
# If invite, remove room_state from unsigned before sending.
|
# If invite, remove room_state from unsigned before sending.
|
||||||
event.unsigned.pop("invite_room_state", None)
|
event.unsigned.pop("invite_room_state", None)
|
||||||
|
|
||||||
federation_handler.handle_new_event(
|
preserve_fn(federation_handler.handle_new_event)(
|
||||||
event, destinations=destinations,
|
event, destinations=destinations,
|
||||||
)
|
)
|
||||||
|
@ -503,7 +503,7 @@ class PresenceHandler(object):
|
|||||||
defer.returnValue(states)
|
defer.returnValue(states)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _get_interested_parties(self, states):
|
def _get_interested_parties(self, states, calculate_remote_hosts=True):
|
||||||
"""Given a list of states return which entities (rooms, users, servers)
|
"""Given a list of states return which entities (rooms, users, servers)
|
||||||
are interested in the given states.
|
are interested in the given states.
|
||||||
|
|
||||||
@ -526,14 +526,15 @@ class PresenceHandler(object):
|
|||||||
users_to_states.setdefault(state.user_id, []).append(state)
|
users_to_states.setdefault(state.user_id, []).append(state)
|
||||||
|
|
||||||
hosts_to_states = {}
|
hosts_to_states = {}
|
||||||
for room_id, states in room_ids_to_states.items():
|
if calculate_remote_hosts:
|
||||||
local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
|
for room_id, states in room_ids_to_states.items():
|
||||||
if not local_states:
|
local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
|
||||||
continue
|
if not local_states:
|
||||||
|
continue
|
||||||
|
|
||||||
hosts = yield self.store.get_joined_hosts_for_room(room_id)
|
hosts = yield self.store.get_joined_hosts_for_room(room_id)
|
||||||
for host in hosts:
|
for host in hosts:
|
||||||
hosts_to_states.setdefault(host, []).extend(local_states)
|
hosts_to_states.setdefault(host, []).extend(local_states)
|
||||||
|
|
||||||
for user_id, states in users_to_states.items():
|
for user_id, states in users_to_states.items():
|
||||||
local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
|
local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
|
||||||
@ -565,6 +566,16 @@ class PresenceHandler(object):
|
|||||||
|
|
||||||
self._push_to_remotes(hosts_to_states)
|
self._push_to_remotes(hosts_to_states)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def notify_for_states(self, state, stream_id):
|
||||||
|
parties = yield self._get_interested_parties([state])
|
||||||
|
room_ids_to_states, users_to_states, hosts_to_states = parties
|
||||||
|
|
||||||
|
self.notifier.on_new_event(
|
||||||
|
"presence_key", stream_id, rooms=room_ids_to_states.keys(),
|
||||||
|
users=[UserID.from_string(u) for u in users_to_states.keys()]
|
||||||
|
)
|
||||||
|
|
||||||
def _push_to_remotes(self, hosts_to_states):
|
def _push_to_remotes(self, hosts_to_states):
|
||||||
"""Sends state updates to remote servers.
|
"""Sends state updates to remote servers.
|
||||||
|
|
||||||
@ -672,7 +683,7 @@ class PresenceHandler(object):
|
|||||||
])
|
])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def set_state(self, target_user, state):
|
def set_state(self, target_user, state, ignore_status_msg=False):
|
||||||
"""Set the presence state of the user.
|
"""Set the presence state of the user.
|
||||||
"""
|
"""
|
||||||
status_msg = state.get("status_msg", None)
|
status_msg = state.get("status_msg", None)
|
||||||
@ -689,10 +700,13 @@ class PresenceHandler(object):
|
|||||||
prev_state = yield self.current_state_for_user(user_id)
|
prev_state = yield self.current_state_for_user(user_id)
|
||||||
|
|
||||||
new_fields = {
|
new_fields = {
|
||||||
"state": presence,
|
"state": presence
|
||||||
"status_msg": status_msg if presence != PresenceState.OFFLINE else None
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if not ignore_status_msg:
|
||||||
|
msg = status_msg if presence != PresenceState.OFFLINE else None
|
||||||
|
new_fields["status_msg"] = msg
|
||||||
|
|
||||||
if presence == PresenceState.ONLINE:
|
if presence == PresenceState.ONLINE:
|
||||||
new_fields["last_active_ts"] = self.clock.time_msec()
|
new_fields["last_active_ts"] = self.clock.time_msec()
|
||||||
|
|
||||||
|
@ -59,10 +59,13 @@ class RoomMemberHandler(BaseHandler):
|
|||||||
prev_event_ids,
|
prev_event_ids,
|
||||||
txn_id=None,
|
txn_id=None,
|
||||||
ratelimit=True,
|
ratelimit=True,
|
||||||
|
content=None,
|
||||||
):
|
):
|
||||||
|
if content is None:
|
||||||
|
content = {}
|
||||||
msg_handler = self.hs.get_handlers().message_handler
|
msg_handler = self.hs.get_handlers().message_handler
|
||||||
|
|
||||||
content = {"membership": membership}
|
content["membership"] = membership
|
||||||
if requester.is_guest:
|
if requester.is_guest:
|
||||||
content["kind"] = "guest"
|
content["kind"] = "guest"
|
||||||
|
|
||||||
@ -140,8 +143,9 @@ class RoomMemberHandler(BaseHandler):
|
|||||||
remote_room_hosts=None,
|
remote_room_hosts=None,
|
||||||
third_party_signed=None,
|
third_party_signed=None,
|
||||||
ratelimit=True,
|
ratelimit=True,
|
||||||
|
content=None,
|
||||||
):
|
):
|
||||||
key = (target, room_id,)
|
key = (room_id,)
|
||||||
|
|
||||||
with (yield self.member_linearizer.queue(key)):
|
with (yield self.member_linearizer.queue(key)):
|
||||||
result = yield self._update_membership(
|
result = yield self._update_membership(
|
||||||
@ -153,6 +157,7 @@ class RoomMemberHandler(BaseHandler):
|
|||||||
remote_room_hosts=remote_room_hosts,
|
remote_room_hosts=remote_room_hosts,
|
||||||
third_party_signed=third_party_signed,
|
third_party_signed=third_party_signed,
|
||||||
ratelimit=ratelimit,
|
ratelimit=ratelimit,
|
||||||
|
content=content,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
@ -168,7 +173,11 @@ class RoomMemberHandler(BaseHandler):
|
|||||||
remote_room_hosts=None,
|
remote_room_hosts=None,
|
||||||
third_party_signed=None,
|
third_party_signed=None,
|
||||||
ratelimit=True,
|
ratelimit=True,
|
||||||
|
content=None,
|
||||||
):
|
):
|
||||||
|
if content is None:
|
||||||
|
content = {}
|
||||||
|
|
||||||
effective_membership_state = action
|
effective_membership_state = action
|
||||||
if action in ["kick", "unban"]:
|
if action in ["kick", "unban"]:
|
||||||
effective_membership_state = "leave"
|
effective_membership_state = "leave"
|
||||||
@ -218,7 +227,7 @@ class RoomMemberHandler(BaseHandler):
|
|||||||
if inviter and not self.hs.is_mine(inviter):
|
if inviter and not self.hs.is_mine(inviter):
|
||||||
remote_room_hosts.append(inviter.domain)
|
remote_room_hosts.append(inviter.domain)
|
||||||
|
|
||||||
content = {"membership": Membership.JOIN}
|
content["membership"] = Membership.JOIN
|
||||||
|
|
||||||
profile = self.hs.get_handlers().profile_handler
|
profile = self.hs.get_handlers().profile_handler
|
||||||
content["displayname"] = yield profile.get_displayname(target)
|
content["displayname"] = yield profile.get_displayname(target)
|
||||||
@ -272,6 +281,7 @@ class RoomMemberHandler(BaseHandler):
|
|||||||
txn_id=txn_id,
|
txn_id=txn_id,
|
||||||
ratelimit=ratelimit,
|
ratelimit=ratelimit,
|
||||||
prev_event_ids=latest_event_ids,
|
prev_event_ids=latest_event_ids,
|
||||||
|
content=content,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -464,10 +464,10 @@ class SyncHandler(object):
|
|||||||
else:
|
else:
|
||||||
state = {}
|
state = {}
|
||||||
|
|
||||||
defer.returnValue({
|
defer.returnValue({
|
||||||
(e.type, e.state_key): e
|
(e.type, e.state_key): e
|
||||||
for e in sync_config.filter_collection.filter_room_state(state.values())
|
for e in sync_config.filter_collection.filter_room_state(state.values())
|
||||||
})
|
})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def unread_notifs_for_room_id(self, room_id, sync_config):
|
def unread_notifs_for_room_id(self, room_id, sync_config):
|
||||||
@ -485,9 +485,9 @@ class SyncHandler(object):
|
|||||||
)
|
)
|
||||||
defer.returnValue(notifs)
|
defer.returnValue(notifs)
|
||||||
|
|
||||||
# There is no new information in this period, so your notification
|
# There is no new information in this period, so your notification
|
||||||
# count is whatever it was last time.
|
# count is whatever it was last time.
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def generate_sync_result(self, sync_config, since_token=None, full_state=False):
|
def generate_sync_result(self, sync_config, since_token=None, full_state=False):
|
||||||
|
@ -16,7 +16,9 @@
|
|||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError, AuthError
|
from synapse.api.errors import SynapseError, AuthError
|
||||||
from synapse.util.logcontext import PreserveLoggingContext
|
from synapse.util.logcontext import (
|
||||||
|
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
|
||||||
|
)
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
@ -169,13 +171,13 @@ class TypingHandler(object):
|
|||||||
deferreds = []
|
deferreds = []
|
||||||
for domain in domains:
|
for domain in domains:
|
||||||
if domain == self.server_name:
|
if domain == self.server_name:
|
||||||
self._push_update_local(
|
preserve_fn(self._push_update_local)(
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
typing=typing
|
typing=typing
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
deferreds.append(self.federation.send_edu(
|
deferreds.append(preserve_fn(self.federation.send_edu)(
|
||||||
destination=domain,
|
destination=domain,
|
||||||
edu_type="m.typing",
|
edu_type="m.typing",
|
||||||
content={
|
content={
|
||||||
@ -185,7 +187,9 @@ class TypingHandler(object):
|
|||||||
},
|
},
|
||||||
))
|
))
|
||||||
|
|
||||||
yield defer.DeferredList(deferreds, consumeErrors=True)
|
yield preserve_context_over_deferred(
|
||||||
|
defer.DeferredList(deferreds, consumeErrors=True)
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _recv_edu(self, origin, content):
|
def _recv_edu(self, origin, content):
|
||||||
|
@ -155,9 +155,7 @@ class MatrixFederationHttpClient(object):
|
|||||||
time_out=timeout / 1000. if timeout else 60,
|
time_out=timeout / 1000. if timeout else 60,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = yield preserve_context_over_fn(
|
response = yield preserve_context_over_fn(send_request)
|
||||||
send_request,
|
|
||||||
)
|
|
||||||
|
|
||||||
log_result = "%d %s" % (response.code, response.phrase,)
|
log_result = "%d %s" % (response.code, response.phrase,)
|
||||||
break
|
break
|
||||||
|
@ -19,6 +19,7 @@ from synapse.api.errors import (
|
|||||||
)
|
)
|
||||||
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
||||||
from synapse.util.caches import intern_dict
|
from synapse.util.caches import intern_dict
|
||||||
|
from synapse.util.metrics import Measure
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
import synapse.events
|
import synapse.events
|
||||||
|
|
||||||
@ -74,12 +75,12 @@ response_db_txn_duration = metrics.register_distribution(
|
|||||||
_next_request_id = 0
|
_next_request_id = 0
|
||||||
|
|
||||||
|
|
||||||
def request_handler(report_metrics=True):
|
def request_handler(include_metrics=False):
|
||||||
"""Decorator for ``wrap_request_handler``"""
|
"""Decorator for ``wrap_request_handler``"""
|
||||||
return lambda request_handler: wrap_request_handler(request_handler, report_metrics)
|
return lambda request_handler: wrap_request_handler(request_handler, include_metrics)
|
||||||
|
|
||||||
|
|
||||||
def wrap_request_handler(request_handler, report_metrics):
|
def wrap_request_handler(request_handler, include_metrics=False):
|
||||||
"""Wraps a method that acts as a request handler with the necessary logging
|
"""Wraps a method that acts as a request handler with the necessary logging
|
||||||
and exception handling.
|
and exception handling.
|
||||||
|
|
||||||
@ -103,54 +104,56 @@ def wrap_request_handler(request_handler, report_metrics):
|
|||||||
_next_request_id += 1
|
_next_request_id += 1
|
||||||
|
|
||||||
with LoggingContext(request_id) as request_context:
|
with LoggingContext(request_id) as request_context:
|
||||||
if report_metrics:
|
with Measure(self.clock, "wrapped_request_handler"):
|
||||||
request_metrics = RequestMetrics()
|
request_metrics = RequestMetrics()
|
||||||
request_metrics.start(self.clock)
|
request_metrics.start(self.clock, name=self.__class__.__name__)
|
||||||
|
|
||||||
request_context.request = request_id
|
request_context.request = request_id
|
||||||
with request.processing():
|
with request.processing():
|
||||||
try:
|
|
||||||
with PreserveLoggingContext(request_context):
|
|
||||||
yield request_handler(self, request)
|
|
||||||
except CodeMessageException as e:
|
|
||||||
code = e.code
|
|
||||||
if isinstance(e, SynapseError):
|
|
||||||
logger.info(
|
|
||||||
"%s SynapseError: %s - %s", request, code, e.msg
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.exception(e)
|
|
||||||
outgoing_responses_counter.inc(request.method, str(code))
|
|
||||||
respond_with_json(
|
|
||||||
request, code, cs_exception(e), send_cors=True,
|
|
||||||
pretty_print=_request_user_agent_is_curl(request),
|
|
||||||
version_string=self.version_string,
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
logger.exception(
|
|
||||||
"Failed handle request %s.%s on %r: %r",
|
|
||||||
request_handler.__module__,
|
|
||||||
request_handler.__name__,
|
|
||||||
self,
|
|
||||||
request
|
|
||||||
)
|
|
||||||
respond_with_json(
|
|
||||||
request,
|
|
||||||
500,
|
|
||||||
{
|
|
||||||
"error": "Internal server error",
|
|
||||||
"errcode": Codes.UNKNOWN,
|
|
||||||
},
|
|
||||||
send_cors=True
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
try:
|
try:
|
||||||
if report_metrics:
|
with PreserveLoggingContext(request_context):
|
||||||
request_metrics.stop(
|
if include_metrics:
|
||||||
self.clock, request, self.__class__.__name__
|
yield request_handler(self, request, request_metrics)
|
||||||
|
else:
|
||||||
|
yield request_handler(self, request)
|
||||||
|
except CodeMessageException as e:
|
||||||
|
code = e.code
|
||||||
|
if isinstance(e, SynapseError):
|
||||||
|
logger.info(
|
||||||
|
"%s SynapseError: %s - %s", request, code, e.msg
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
logger.exception(e)
|
||||||
|
outgoing_responses_counter.inc(request.method, str(code))
|
||||||
|
respond_with_json(
|
||||||
|
request, code, cs_exception(e), send_cors=True,
|
||||||
|
pretty_print=_request_user_agent_is_curl(request),
|
||||||
|
version_string=self.version_string,
|
||||||
|
)
|
||||||
except:
|
except:
|
||||||
pass
|
logger.exception(
|
||||||
|
"Failed handle request %s.%s on %r: %r",
|
||||||
|
request_handler.__module__,
|
||||||
|
request_handler.__name__,
|
||||||
|
self,
|
||||||
|
request
|
||||||
|
)
|
||||||
|
respond_with_json(
|
||||||
|
request,
|
||||||
|
500,
|
||||||
|
{
|
||||||
|
"error": "Internal server error",
|
||||||
|
"errcode": Codes.UNKNOWN,
|
||||||
|
},
|
||||||
|
send_cors=True
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
request_metrics.stop(
|
||||||
|
self.clock, request
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warn("Failed to stop metrics: %r", e)
|
||||||
return wrapped_request_handler
|
return wrapped_request_handler
|
||||||
|
|
||||||
|
|
||||||
@ -220,9 +223,9 @@ class JsonResource(HttpServer, resource.Resource):
|
|||||||
# It does its own metric reporting because _async_render dispatches to
|
# It does its own metric reporting because _async_render dispatches to
|
||||||
# a callback and it's the class name of that callback we want to report
|
# a callback and it's the class name of that callback we want to report
|
||||||
# against rather than the JsonResource itself.
|
# against rather than the JsonResource itself.
|
||||||
@request_handler(report_metrics=False)
|
@request_handler(include_metrics=True)
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _async_render(self, request):
|
def _async_render(self, request, request_metrics):
|
||||||
""" This gets called from render() every time someone sends us a request.
|
""" This gets called from render() every time someone sends us a request.
|
||||||
This checks if anyone has registered a callback for that method and
|
This checks if anyone has registered a callback for that method and
|
||||||
path.
|
path.
|
||||||
@ -231,9 +234,6 @@ class JsonResource(HttpServer, resource.Resource):
|
|||||||
self._send_response(request, 200, {})
|
self._send_response(request, 200, {})
|
||||||
return
|
return
|
||||||
|
|
||||||
request_metrics = RequestMetrics()
|
|
||||||
request_metrics.start(self.clock)
|
|
||||||
|
|
||||||
# Loop through all the registered callbacks to check if the method
|
# Loop through all the registered callbacks to check if the method
|
||||||
# and path regex match
|
# and path regex match
|
||||||
for path_entry in self.path_regexs.get(request.method, []):
|
for path_entry in self.path_regexs.get(request.method, []):
|
||||||
@ -247,12 +247,6 @@ class JsonResource(HttpServer, resource.Resource):
|
|||||||
|
|
||||||
callback = path_entry.callback
|
callback = path_entry.callback
|
||||||
|
|
||||||
servlet_instance = getattr(callback, "__self__", None)
|
|
||||||
if servlet_instance is not None:
|
|
||||||
servlet_classname = servlet_instance.__class__.__name__
|
|
||||||
else:
|
|
||||||
servlet_classname = "%r" % callback
|
|
||||||
|
|
||||||
kwargs = intern_dict({
|
kwargs = intern_dict({
|
||||||
name: urllib.unquote(value).decode("UTF-8") if value else value
|
name: urllib.unquote(value).decode("UTF-8") if value else value
|
||||||
for name, value in m.groupdict().items()
|
for name, value in m.groupdict().items()
|
||||||
@ -263,10 +257,13 @@ class JsonResource(HttpServer, resource.Resource):
|
|||||||
code, response = callback_return
|
code, response = callback_return
|
||||||
self._send_response(request, code, response)
|
self._send_response(request, code, response)
|
||||||
|
|
||||||
try:
|
servlet_instance = getattr(callback, "__self__", None)
|
||||||
request_metrics.stop(self.clock, request, servlet_classname)
|
if servlet_instance is not None:
|
||||||
except:
|
servlet_classname = servlet_instance.__class__.__name__
|
||||||
pass
|
else:
|
||||||
|
servlet_classname = "%r" % callback
|
||||||
|
|
||||||
|
request_metrics.name = servlet_classname
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -298,11 +295,12 @@ class JsonResource(HttpServer, resource.Resource):
|
|||||||
|
|
||||||
|
|
||||||
class RequestMetrics(object):
|
class RequestMetrics(object):
|
||||||
def start(self, clock):
|
def start(self, clock, name):
|
||||||
self.start = clock.time_msec()
|
self.start = clock.time_msec()
|
||||||
self.start_context = LoggingContext.current_context()
|
self.start_context = LoggingContext.current_context()
|
||||||
|
self.name = name
|
||||||
|
|
||||||
def stop(self, clock, request, servlet_classname):
|
def stop(self, clock, request):
|
||||||
context = LoggingContext.current_context()
|
context = LoggingContext.current_context()
|
||||||
|
|
||||||
tag = ""
|
tag = ""
|
||||||
@ -316,26 +314,26 @@ class RequestMetrics(object):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
incoming_requests_counter.inc(request.method, servlet_classname, tag)
|
incoming_requests_counter.inc(request.method, self.name, tag)
|
||||||
|
|
||||||
response_timer.inc_by(
|
response_timer.inc_by(
|
||||||
clock.time_msec() - self.start, request.method,
|
clock.time_msec() - self.start, request.method,
|
||||||
servlet_classname, tag
|
self.name, tag
|
||||||
)
|
)
|
||||||
|
|
||||||
ru_utime, ru_stime = context.get_resource_usage()
|
ru_utime, ru_stime = context.get_resource_usage()
|
||||||
|
|
||||||
response_ru_utime.inc_by(
|
response_ru_utime.inc_by(
|
||||||
ru_utime, request.method, servlet_classname, tag
|
ru_utime, request.method, self.name, tag
|
||||||
)
|
)
|
||||||
response_ru_stime.inc_by(
|
response_ru_stime.inc_by(
|
||||||
ru_stime, request.method, servlet_classname, tag
|
ru_stime, request.method, self.name, tag
|
||||||
)
|
)
|
||||||
response_db_txn_count.inc_by(
|
response_db_txn_count.inc_by(
|
||||||
context.db_txn_count, request.method, servlet_classname, tag
|
context.db_txn_count, request.method, self.name, tag
|
||||||
)
|
)
|
||||||
response_db_txn_duration.inc_by(
|
response_db_txn_duration.inc_by(
|
||||||
context.db_txn_duration, request.method, servlet_classname, tag
|
context.db_txn_duration, request.method, self.name, tag
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,7 +19,8 @@ from synapse.api.errors import AuthError
|
|||||||
|
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.async import ObservableDeferred
|
from synapse.util.async import ObservableDeferred
|
||||||
from synapse.util.logcontext import PreserveLoggingContext
|
from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
|
||||||
|
from synapse.util.metrics import Measure
|
||||||
from synapse.types import StreamToken
|
from synapse.types import StreamToken
|
||||||
from synapse.visibility import filter_events_for_client
|
from synapse.visibility import filter_events_for_client
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
@ -67,10 +68,8 @@ class _NotifierUserStream(object):
|
|||||||
so that it can remove itself from the indexes in the Notifier class.
|
so that it can remove itself from the indexes in the Notifier class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, user_id, rooms, current_token, time_now_ms,
|
def __init__(self, user_id, rooms, current_token, time_now_ms):
|
||||||
appservice=None):
|
|
||||||
self.user_id = user_id
|
self.user_id = user_id
|
||||||
self.appservice = appservice
|
|
||||||
self.rooms = set(rooms)
|
self.rooms = set(rooms)
|
||||||
self.current_token = current_token
|
self.current_token = current_token
|
||||||
self.last_notified_ms = time_now_ms
|
self.last_notified_ms = time_now_ms
|
||||||
@ -107,11 +106,6 @@ class _NotifierUserStream(object):
|
|||||||
|
|
||||||
notifier.user_to_user_stream.pop(self.user_id)
|
notifier.user_to_user_stream.pop(self.user_id)
|
||||||
|
|
||||||
if self.appservice:
|
|
||||||
notifier.appservice_to_user_streams.get(
|
|
||||||
self.appservice, set()
|
|
||||||
).discard(self)
|
|
||||||
|
|
||||||
def count_listeners(self):
|
def count_listeners(self):
|
||||||
return len(self.notify_deferred.observers())
|
return len(self.notify_deferred.observers())
|
||||||
|
|
||||||
@ -142,7 +136,6 @@ class Notifier(object):
|
|||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.user_to_user_stream = {}
|
self.user_to_user_stream = {}
|
||||||
self.room_to_user_streams = {}
|
self.room_to_user_streams = {}
|
||||||
self.appservice_to_user_streams = {}
|
|
||||||
|
|
||||||
self.event_sources = hs.get_event_sources()
|
self.event_sources = hs.get_event_sources()
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
@ -168,8 +161,6 @@ class Notifier(object):
|
|||||||
all_user_streams |= x
|
all_user_streams |= x
|
||||||
for x in self.user_to_user_stream.values():
|
for x in self.user_to_user_stream.values():
|
||||||
all_user_streams.add(x)
|
all_user_streams.add(x)
|
||||||
for x in self.appservice_to_user_streams.values():
|
|
||||||
all_user_streams |= x
|
|
||||||
|
|
||||||
return sum(stream.count_listeners() for stream in all_user_streams)
|
return sum(stream.count_listeners() for stream in all_user_streams)
|
||||||
metrics.register_callback("listeners", count_listeners)
|
metrics.register_callback("listeners", count_listeners)
|
||||||
@ -182,11 +173,8 @@ class Notifier(object):
|
|||||||
"users",
|
"users",
|
||||||
lambda: len(self.user_to_user_stream),
|
lambda: len(self.user_to_user_stream),
|
||||||
)
|
)
|
||||||
metrics.register_callback(
|
|
||||||
"appservices",
|
|
||||||
lambda: count(bool, self.appservice_to_user_streams.values()),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
@preserve_fn
|
||||||
def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
|
def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
|
||||||
extra_users=[]):
|
extra_users=[]):
|
||||||
""" Used by handlers to inform the notifier something has happened
|
""" Used by handlers to inform the notifier something has happened
|
||||||
@ -208,6 +196,7 @@ class Notifier(object):
|
|||||||
|
|
||||||
self.notify_replication()
|
self.notify_replication()
|
||||||
|
|
||||||
|
@preserve_fn
|
||||||
def _notify_pending_new_room_events(self, max_room_stream_id):
|
def _notify_pending_new_room_events(self, max_room_stream_id):
|
||||||
"""Notify for the room events that were queued waiting for a previous
|
"""Notify for the room events that were queued waiting for a previous
|
||||||
event to be persisted.
|
event to be persisted.
|
||||||
@ -225,24 +214,11 @@ class Notifier(object):
|
|||||||
else:
|
else:
|
||||||
self._on_new_room_event(event, room_stream_id, extra_users)
|
self._on_new_room_event(event, room_stream_id, extra_users)
|
||||||
|
|
||||||
|
@preserve_fn
|
||||||
def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
|
def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
|
||||||
"""Notify any user streams that are interested in this room event"""
|
"""Notify any user streams that are interested in this room event"""
|
||||||
# poke any interested application service.
|
# poke any interested application service.
|
||||||
self.appservice_handler.notify_interested_services(event)
|
self.appservice_handler.notify_interested_services(room_stream_id)
|
||||||
|
|
||||||
app_streams = set()
|
|
||||||
|
|
||||||
for appservice in self.appservice_to_user_streams:
|
|
||||||
# TODO (kegan): Redundant appservice listener checks?
|
|
||||||
# App services will already be in the room_to_user_streams set, but
|
|
||||||
# that isn't enough. They need to be checked here in order to
|
|
||||||
# receive *invites* for users they are interested in. Does this
|
|
||||||
# make the room_to_user_streams check somewhat obselete?
|
|
||||||
if appservice.is_interested(event):
|
|
||||||
app_user_streams = self.appservice_to_user_streams.get(
|
|
||||||
appservice, set()
|
|
||||||
)
|
|
||||||
app_streams |= app_user_streams
|
|
||||||
|
|
||||||
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
|
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
|
||||||
self._user_joined_room(event.state_key, event.room_id)
|
self._user_joined_room(event.state_key, event.room_id)
|
||||||
@ -251,35 +227,36 @@ class Notifier(object):
|
|||||||
"room_key", room_stream_id,
|
"room_key", room_stream_id,
|
||||||
users=extra_users,
|
users=extra_users,
|
||||||
rooms=[event.room_id],
|
rooms=[event.room_id],
|
||||||
extra_streams=app_streams,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_new_event(self, stream_key, new_token, users=[], rooms=[],
|
@preserve_fn
|
||||||
extra_streams=set()):
|
def on_new_event(self, stream_key, new_token, users=[], rooms=[]):
|
||||||
""" Used to inform listeners that something has happend event wise.
|
""" Used to inform listeners that something has happend event wise.
|
||||||
|
|
||||||
Will wake up all listeners for the given users and rooms.
|
Will wake up all listeners for the given users and rooms.
|
||||||
"""
|
"""
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
user_streams = set()
|
with Measure(self.clock, "on_new_event"):
|
||||||
|
user_streams = set()
|
||||||
|
|
||||||
for user in users:
|
for user in users:
|
||||||
user_stream = self.user_to_user_stream.get(str(user))
|
user_stream = self.user_to_user_stream.get(str(user))
|
||||||
if user_stream is not None:
|
if user_stream is not None:
|
||||||
user_streams.add(user_stream)
|
user_streams.add(user_stream)
|
||||||
|
|
||||||
for room in rooms:
|
for room in rooms:
|
||||||
user_streams |= self.room_to_user_streams.get(room, set())
|
user_streams |= self.room_to_user_streams.get(room, set())
|
||||||
|
|
||||||
time_now_ms = self.clock.time_msec()
|
time_now_ms = self.clock.time_msec()
|
||||||
for user_stream in user_streams:
|
for user_stream in user_streams:
|
||||||
try:
|
try:
|
||||||
user_stream.notify(stream_key, new_token, time_now_ms)
|
user_stream.notify(stream_key, new_token, time_now_ms)
|
||||||
except:
|
except:
|
||||||
logger.exception("Failed to notify listener")
|
logger.exception("Failed to notify listener")
|
||||||
|
|
||||||
self.notify_replication()
|
self.notify_replication()
|
||||||
|
|
||||||
|
@preserve_fn
|
||||||
def on_new_replication_data(self):
|
def on_new_replication_data(self):
|
||||||
"""Used to inform replication listeners that something has happend
|
"""Used to inform replication listeners that something has happend
|
||||||
without waking up any of the normal user event streams"""
|
without waking up any of the normal user event streams"""
|
||||||
@ -294,7 +271,6 @@ class Notifier(object):
|
|||||||
"""
|
"""
|
||||||
user_stream = self.user_to_user_stream.get(user_id)
|
user_stream = self.user_to_user_stream.get(user_id)
|
||||||
if user_stream is None:
|
if user_stream is None:
|
||||||
appservice = yield self.store.get_app_service_by_user_id(user_id)
|
|
||||||
current_token = yield self.event_sources.get_current_token()
|
current_token = yield self.event_sources.get_current_token()
|
||||||
if room_ids is None:
|
if room_ids is None:
|
||||||
rooms = yield self.store.get_rooms_for_user(user_id)
|
rooms = yield self.store.get_rooms_for_user(user_id)
|
||||||
@ -302,7 +278,6 @@ class Notifier(object):
|
|||||||
user_stream = _NotifierUserStream(
|
user_stream = _NotifierUserStream(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
rooms=room_ids,
|
rooms=room_ids,
|
||||||
appservice=appservice,
|
|
||||||
current_token=current_token,
|
current_token=current_token,
|
||||||
time_now_ms=self.clock.time_msec(),
|
time_now_ms=self.clock.time_msec(),
|
||||||
)
|
)
|
||||||
@ -477,11 +452,6 @@ class Notifier(object):
|
|||||||
s = self.room_to_user_streams.setdefault(room, set())
|
s = self.room_to_user_streams.setdefault(room, set())
|
||||||
s.add(user_stream)
|
s.add(user_stream)
|
||||||
|
|
||||||
if user_stream.appservice:
|
|
||||||
self.appservice_to_user_stream.setdefault(
|
|
||||||
user_stream.appservice, set()
|
|
||||||
).add(user_stream)
|
|
||||||
|
|
||||||
def _user_joined_room(self, user_id, room_id):
|
def _user_joined_room(self, user_id, room_id):
|
||||||
new_user_stream = self.user_to_user_stream.get(user_id)
|
new_user_stream = self.user_to_user_stream.get(user_id)
|
||||||
if new_user_stream is not None:
|
if new_user_stream is not None:
|
||||||
|
@ -38,15 +38,16 @@ class ActionGenerator:
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def handle_push_actions_for_event(self, event, context):
|
def handle_push_actions_for_event(self, event, context):
|
||||||
with Measure(self.clock, "handle_push_actions_for_event"):
|
with Measure(self.clock, "evaluator_for_event"):
|
||||||
bulk_evaluator = yield evaluator_for_event(
|
bulk_evaluator = yield evaluator_for_event(
|
||||||
event, self.hs, self.store, context.current_state
|
event, self.hs, self.store, context.state_group, context.current_state
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with Measure(self.clock, "action_for_event_by_user"):
|
||||||
actions_by_user = yield bulk_evaluator.action_for_event_by_user(
|
actions_by_user = yield bulk_evaluator.action_for_event_by_user(
|
||||||
event, context.current_state
|
event, context.current_state
|
||||||
)
|
)
|
||||||
|
|
||||||
context.push_actions = [
|
context.push_actions = [
|
||||||
(uid, actions) for uid, actions in actions_by_user.items()
|
(uid, actions) for uid, actions in actions_by_user.items()
|
||||||
]
|
]
|
||||||
|
@ -217,6 +217,27 @@ BASE_APPEND_OVERRIDE_RULES = [
|
|||||||
'dont_notify'
|
'dont_notify'
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
# This was changed from underride to override so it's closer in priority
|
||||||
|
# to the content rules where the user name highlight rule lives. This
|
||||||
|
# way a room rule is lower priority than both but a custom override rule
|
||||||
|
# is higher priority than both.
|
||||||
|
{
|
||||||
|
'rule_id': 'global/override/.m.rule.contains_display_name',
|
||||||
|
'conditions': [
|
||||||
|
{
|
||||||
|
'kind': 'contains_display_name'
|
||||||
|
}
|
||||||
|
],
|
||||||
|
'actions': [
|
||||||
|
'notify',
|
||||||
|
{
|
||||||
|
'set_tweak': 'sound',
|
||||||
|
'value': 'default'
|
||||||
|
}, {
|
||||||
|
'set_tweak': 'highlight'
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -242,23 +263,6 @@ BASE_APPEND_UNDERRIDE_RULES = [
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
'rule_id': 'global/underride/.m.rule.contains_display_name',
|
|
||||||
'conditions': [
|
|
||||||
{
|
|
||||||
'kind': 'contains_display_name'
|
|
||||||
}
|
|
||||||
],
|
|
||||||
'actions': [
|
|
||||||
'notify',
|
|
||||||
{
|
|
||||||
'set_tweak': 'sound',
|
|
||||||
'value': 'default'
|
|
||||||
}, {
|
|
||||||
'set_tweak': 'highlight'
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
'rule_id': 'global/underride/.m.rule.room_one_to_one',
|
'rule_id': 'global/underride/.m.rule.room_one_to_one',
|
||||||
'conditions': [
|
'conditions': [
|
||||||
|
@ -36,35 +36,11 @@ def _get_rules(room_id, user_ids, store):
|
|||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def evaluator_for_event(event, hs, store, current_state):
|
def evaluator_for_event(event, hs, store, state_group, current_state):
|
||||||
room_id = event.room_id
|
rules_by_user = yield store.bulk_get_push_rules_for_room(
|
||||||
# We also will want to generate notifs for other people in the room so
|
event.room_id, state_group, current_state
|
||||||
# their unread countss are correct in the event stream, but to avoid
|
|
||||||
# generating them for bot / AS users etc, we only do so for people who've
|
|
||||||
# sent a read receipt into the room.
|
|
||||||
|
|
||||||
local_users_in_room = set(
|
|
||||||
e.state_key for e in current_state.values()
|
|
||||||
if e.type == EventTypes.Member and e.membership == Membership.JOIN
|
|
||||||
and hs.is_mine_id(e.state_key)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# users in the room who have pushers need to get push rules run because
|
|
||||||
# that's how their pushers work
|
|
||||||
if_users_with_pushers = yield store.get_if_users_have_pushers(
|
|
||||||
local_users_in_room
|
|
||||||
)
|
|
||||||
user_ids = set(
|
|
||||||
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
|
|
||||||
)
|
|
||||||
|
|
||||||
users_with_receipts = yield store.get_users_with_read_receipts_in_room(room_id)
|
|
||||||
|
|
||||||
# any users with pushers must be ours: they have pushers
|
|
||||||
for uid in users_with_receipts:
|
|
||||||
if uid in local_users_in_room:
|
|
||||||
user_ids.add(uid)
|
|
||||||
|
|
||||||
# if this event is an invite event, we may need to run rules for the user
|
# if this event is an invite event, we may need to run rules for the user
|
||||||
# who's been invited, otherwise they won't get told they've been invited
|
# who's been invited, otherwise they won't get told they've been invited
|
||||||
if event.type == 'm.room.member' and event.content['membership'] == 'invite':
|
if event.type == 'm.room.member' and event.content['membership'] == 'invite':
|
||||||
@ -72,12 +48,12 @@ def evaluator_for_event(event, hs, store, current_state):
|
|||||||
if invited_user and hs.is_mine_id(invited_user):
|
if invited_user and hs.is_mine_id(invited_user):
|
||||||
has_pusher = yield store.user_has_pusher(invited_user)
|
has_pusher = yield store.user_has_pusher(invited_user)
|
||||||
if has_pusher:
|
if has_pusher:
|
||||||
user_ids.add(invited_user)
|
rules_by_user[invited_user] = yield store.get_push_rules_for_user(
|
||||||
|
invited_user
|
||||||
rules_by_user = yield _get_rules(room_id, user_ids, store)
|
)
|
||||||
|
|
||||||
defer.returnValue(BulkPushRuleEvaluator(
|
defer.returnValue(BulkPushRuleEvaluator(
|
||||||
room_id, rules_by_user, user_ids, store
|
event.room_id, rules_by_user, store
|
||||||
))
|
))
|
||||||
|
|
||||||
|
|
||||||
@ -90,10 +66,9 @@ class BulkPushRuleEvaluator:
|
|||||||
the same logic to run the actual rules, but could be optimised further
|
the same logic to run the actual rules, but could be optimised further
|
||||||
(see https://matrix.org/jira/browse/SYN-562)
|
(see https://matrix.org/jira/browse/SYN-562)
|
||||||
"""
|
"""
|
||||||
def __init__(self, room_id, rules_by_user, users_in_room, store):
|
def __init__(self, room_id, rules_by_user, store):
|
||||||
self.room_id = room_id
|
self.room_id = room_id
|
||||||
self.rules_by_user = rules_by_user
|
self.rules_by_user = rules_by_user
|
||||||
self.users_in_room = users_in_room
|
|
||||||
self.store = store
|
self.store = store
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -17,14 +17,15 @@ from twisted.internet import defer
|
|||||||
from synapse.util.presentable_names import (
|
from synapse.util.presentable_names import (
|
||||||
calculate_room_name, name_from_member_event
|
calculate_room_name, name_from_member_event
|
||||||
)
|
)
|
||||||
|
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_badge_count(store, user_id):
|
def get_badge_count(store, user_id):
|
||||||
invites, joins = yield defer.gatherResults([
|
invites, joins = yield preserve_context_over_deferred(defer.gatherResults([
|
||||||
store.get_invited_rooms_for_user(user_id),
|
preserve_fn(store.get_invited_rooms_for_user)(user_id),
|
||||||
store.get_rooms_for_user(user_id),
|
preserve_fn(store.get_rooms_for_user)(user_id),
|
||||||
], consumeErrors=True)
|
], consumeErrors=True))
|
||||||
|
|
||||||
my_receipts_by_room = yield store.get_receipts_for_user(
|
my_receipts_by_room = yield store.get_receipts_for_user(
|
||||||
user_id, "m.read",
|
user_id, "m.read",
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
import pusher
|
import pusher
|
||||||
from synapse.util.logcontext import preserve_fn
|
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@ -102,14 +102,14 @@ class PusherPool:
|
|||||||
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
|
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def remove_pushers_by_user(self, user_id, except_token_ids=[]):
|
def remove_pushers_by_user(self, user_id, except_access_token_id=None):
|
||||||
all = yield self.store.get_all_pushers()
|
all = yield self.store.get_all_pushers()
|
||||||
logger.info(
|
logger.info(
|
||||||
"Removing all pushers for user %s except access tokens ids %r",
|
"Removing all pushers for user %s except access tokens id %r",
|
||||||
user_id, except_token_ids
|
user_id, except_access_token_id
|
||||||
)
|
)
|
||||||
for p in all:
|
for p in all:
|
||||||
if p['user_name'] == user_id and p['access_token'] not in except_token_ids:
|
if p['user_name'] == user_id and p['access_token'] != except_access_token_id:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Removing pusher for app id %s, pushkey %s, user %s",
|
"Removing pusher for app id %s, pushkey %s, user %s",
|
||||||
p['app_id'], p['pushkey'], p['user_name']
|
p['app_id'], p['pushkey'], p['user_name']
|
||||||
@ -130,10 +130,12 @@ class PusherPool:
|
|||||||
if u in self.pushers:
|
if u in self.pushers:
|
||||||
for p in self.pushers[u].values():
|
for p in self.pushers[u].values():
|
||||||
deferreds.append(
|
deferreds.append(
|
||||||
p.on_new_notifications(min_stream_id, max_stream_id)
|
preserve_fn(p.on_new_notifications)(
|
||||||
|
min_stream_id, max_stream_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield defer.gatherResults(deferreds)
|
yield preserve_context_over_deferred(defer.gatherResults(deferreds))
|
||||||
except:
|
except:
|
||||||
logger.exception("Exception in pusher on_new_notifications")
|
logger.exception("Exception in pusher on_new_notifications")
|
||||||
|
|
||||||
@ -155,10 +157,10 @@ class PusherPool:
|
|||||||
if u in self.pushers:
|
if u in self.pushers:
|
||||||
for p in self.pushers[u].values():
|
for p in self.pushers[u].values():
|
||||||
deferreds.append(
|
deferreds.append(
|
||||||
p.on_new_receipts(min_stream_id, max_stream_id)
|
preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield defer.gatherResults(deferreds)
|
yield preserve_context_over_deferred(defer.gatherResults(deferreds))
|
||||||
except:
|
except:
|
||||||
logger.exception("Exception in pusher on_new_receipts")
|
logger.exception("Exception in pusher on_new_receipts")
|
||||||
|
|
||||||
|
@ -41,6 +41,7 @@ STREAM_NAMES = (
|
|||||||
("push_rules",),
|
("push_rules",),
|
||||||
("pushers",),
|
("pushers",),
|
||||||
("state",),
|
("state",),
|
||||||
|
("caches",),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -70,6 +71,7 @@ class ReplicationResource(Resource):
|
|||||||
* "backfill": Old events that have been backfilled from other servers.
|
* "backfill": Old events that have been backfilled from other servers.
|
||||||
* "push_rules": Per user changes to push rules.
|
* "push_rules": Per user changes to push rules.
|
||||||
* "pushers": Per user changes to their pushers.
|
* "pushers": Per user changes to their pushers.
|
||||||
|
* "caches": Cache invalidations.
|
||||||
|
|
||||||
The API takes two additional query parameters:
|
The API takes two additional query parameters:
|
||||||
|
|
||||||
@ -129,6 +131,7 @@ class ReplicationResource(Resource):
|
|||||||
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
|
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
|
||||||
pushers_token = self.store.get_pushers_stream_token()
|
pushers_token = self.store.get_pushers_stream_token()
|
||||||
state_token = self.store.get_state_stream_token()
|
state_token = self.store.get_state_stream_token()
|
||||||
|
caches_token = self.store.get_cache_stream_token()
|
||||||
|
|
||||||
defer.returnValue(_ReplicationToken(
|
defer.returnValue(_ReplicationToken(
|
||||||
room_stream_token,
|
room_stream_token,
|
||||||
@ -140,6 +143,7 @@ class ReplicationResource(Resource):
|
|||||||
push_rules_token,
|
push_rules_token,
|
||||||
pushers_token,
|
pushers_token,
|
||||||
state_token,
|
state_token,
|
||||||
|
caches_token,
|
||||||
))
|
))
|
||||||
|
|
||||||
@request_handler()
|
@request_handler()
|
||||||
@ -188,6 +192,7 @@ class ReplicationResource(Resource):
|
|||||||
yield self.push_rules(writer, current_token, limit, request_streams)
|
yield self.push_rules(writer, current_token, limit, request_streams)
|
||||||
yield self.pushers(writer, current_token, limit, request_streams)
|
yield self.pushers(writer, current_token, limit, request_streams)
|
||||||
yield self.state(writer, current_token, limit, request_streams)
|
yield self.state(writer, current_token, limit, request_streams)
|
||||||
|
yield self.caches(writer, current_token, limit, request_streams)
|
||||||
self.streams(writer, current_token, request_streams)
|
self.streams(writer, current_token, request_streams)
|
||||||
|
|
||||||
logger.info("Replicated %d rows", writer.total)
|
logger.info("Replicated %d rows", writer.total)
|
||||||
@ -379,6 +384,20 @@ class ReplicationResource(Resource):
|
|||||||
"position", "type", "state_key", "event_id"
|
"position", "type", "state_key", "event_id"
|
||||||
))
|
))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def caches(self, writer, current_token, limit, request_streams):
|
||||||
|
current_position = current_token.caches
|
||||||
|
|
||||||
|
caches = request_streams.get("caches")
|
||||||
|
|
||||||
|
if caches is not None:
|
||||||
|
updated_caches = yield self.store.get_all_updated_caches(
|
||||||
|
caches, current_position, limit
|
||||||
|
)
|
||||||
|
writer.write_header_and_rows("caches", updated_caches, (
|
||||||
|
"position", "cache_func", "keys", "invalidation_ts"
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
class _Writer(object):
|
class _Writer(object):
|
||||||
"""Writes the streams as a JSON object as the response to the request"""
|
"""Writes the streams as a JSON object as the response to the request"""
|
||||||
@ -407,7 +426,7 @@ class _Writer(object):
|
|||||||
|
|
||||||
class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
|
class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
|
||||||
"events", "presence", "typing", "receipts", "account_data", "backfill",
|
"events", "presence", "typing", "receipts", "account_data", "backfill",
|
||||||
"push_rules", "pushers", "state"
|
"push_rules", "pushers", "state", "caches",
|
||||||
))):
|
))):
|
||||||
__slots__ = []
|
__slots__ = []
|
||||||
|
|
||||||
|
@ -14,15 +14,43 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
|
from synapse.storage.engines import PostgresEngine
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from ._slaved_id_tracker import SlavedIdTracker
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BaseSlavedStore(SQLBaseStore):
|
class BaseSlavedStore(SQLBaseStore):
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(BaseSlavedStore, self).__init__(hs)
|
super(BaseSlavedStore, self).__init__(hs)
|
||||||
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
|
self._cache_id_gen = SlavedIdTracker(
|
||||||
|
db_conn, "cache_invalidation_stream", "stream_id",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._cache_id_gen = None
|
||||||
|
|
||||||
def stream_positions(self):
|
def stream_positions(self):
|
||||||
return {}
|
pos = {}
|
||||||
|
if self._cache_id_gen:
|
||||||
|
pos["caches"] = self._cache_id_gen.get_current_token()
|
||||||
|
return pos
|
||||||
|
|
||||||
def process_replication(self, result):
|
def process_replication(self, result):
|
||||||
|
stream = result.get("caches")
|
||||||
|
if stream:
|
||||||
|
for row in stream["rows"]:
|
||||||
|
(
|
||||||
|
position, cache_func, keys, invalidation_ts,
|
||||||
|
) = row
|
||||||
|
|
||||||
|
try:
|
||||||
|
getattr(self, cache_func).invalidate(tuple(keys))
|
||||||
|
except AttributeError:
|
||||||
|
logger.info("Got unexpected cache_func: %r", cache_func)
|
||||||
|
self._cache_id_gen.advance(int(stream["position"]))
|
||||||
return defer.succeed(None)
|
return defer.succeed(None)
|
||||||
|
@ -28,3 +28,13 @@ class SlavedApplicationServiceStore(BaseSlavedStore):
|
|||||||
|
|
||||||
get_app_service_by_token = DataStore.get_app_service_by_token.__func__
|
get_app_service_by_token = DataStore.get_app_service_by_token.__func__
|
||||||
get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__
|
get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__
|
||||||
|
get_app_services = DataStore.get_app_services.__func__
|
||||||
|
get_new_events_for_appservice = DataStore.get_new_events_for_appservice.__func__
|
||||||
|
create_appservice_txn = DataStore.create_appservice_txn.__func__
|
||||||
|
get_appservices_by_state = DataStore.get_appservices_by_state.__func__
|
||||||
|
get_oldest_unsent_txn = DataStore.get_oldest_unsent_txn.__func__
|
||||||
|
_get_last_txn = DataStore._get_last_txn.__func__
|
||||||
|
complete_appservice_txn = DataStore.complete_appservice_txn.__func__
|
||||||
|
get_appservice_state = DataStore.get_appservice_state.__func__
|
||||||
|
set_appservice_last_pos = DataStore.set_appservice_last_pos.__func__
|
||||||
|
set_appservice_state = DataStore.set_appservice_state.__func__
|
||||||
|
@ -20,4 +20,4 @@ from synapse.storage.directory import DirectoryStore
|
|||||||
class DirectoryStore(BaseSlavedStore):
|
class DirectoryStore(BaseSlavedStore):
|
||||||
get_aliases_for_room = DirectoryStore.__dict__[
|
get_aliases_for_room = DirectoryStore.__dict__[
|
||||||
"get_aliases_for_room"
|
"get_aliases_for_room"
|
||||||
].orig
|
]
|
||||||
|
@ -25,6 +25,9 @@ class SlavedRegistrationStore(BaseSlavedStore):
|
|||||||
# TODO: use the cached version and invalidate deleted tokens
|
# TODO: use the cached version and invalidate deleted tokens
|
||||||
get_user_by_access_token = RegistrationStore.__dict__[
|
get_user_by_access_token = RegistrationStore.__dict__[
|
||||||
"get_user_by_access_token"
|
"get_user_by_access_token"
|
||||||
].orig
|
]
|
||||||
|
|
||||||
_query_for_auth = DataStore._query_for_auth.__func__
|
_query_for_auth = DataStore._query_for_auth.__func__
|
||||||
|
get_user_by_id = RegistrationStore.__dict__[
|
||||||
|
"get_user_by_id"
|
||||||
|
]
|
||||||
|
@ -46,7 +46,9 @@ from synapse.rest.client.v2_alpha import (
|
|||||||
account_data,
|
account_data,
|
||||||
report_event,
|
report_event,
|
||||||
openid,
|
openid,
|
||||||
|
notifications,
|
||||||
devices,
|
devices,
|
||||||
|
thirdparty,
|
||||||
)
|
)
|
||||||
|
|
||||||
from synapse.http.server import JsonResource
|
from synapse.http.server import JsonResource
|
||||||
@ -91,4 +93,6 @@ class ClientRestResource(JsonResource):
|
|||||||
account_data.register_servlets(hs, client_resource)
|
account_data.register_servlets(hs, client_resource)
|
||||||
report_event.register_servlets(hs, client_resource)
|
report_event.register_servlets(hs, client_resource)
|
||||||
openid.register_servlets(hs, client_resource)
|
openid.register_servlets(hs, client_resource)
|
||||||
|
notifications.register_servlets(hs, client_resource)
|
||||||
devices.register_servlets(hs, client_resource)
|
devices.register_servlets(hs, client_resource)
|
||||||
|
thirdparty.register_servlets(hs, client_resource)
|
||||||
|
@ -28,6 +28,10 @@ logger = logging.getLogger(__name__)
|
|||||||
class WhoisRestServlet(ClientV1RestServlet):
|
class WhoisRestServlet(ClientV1RestServlet):
|
||||||
PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)")
|
PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(WhoisRestServlet, self).__init__(hs)
|
||||||
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, user_id):
|
def on_GET(self, request, user_id):
|
||||||
target_user = UserID.from_string(user_id)
|
target_user = UserID.from_string(user_id)
|
||||||
@ -82,6 +86,10 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
|
|||||||
"/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
|
"/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(PurgeHistoryRestServlet, self).__init__(hs)
|
||||||
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, room_id, event_id):
|
def on_POST(self, request, room_id, event_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
@ -57,7 +57,6 @@ class ClientV1RestServlet(RestServlet):
|
|||||||
hs (synapse.server.HomeServer):
|
hs (synapse.server.HomeServer):
|
||||||
"""
|
"""
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.handlers = hs.get_handlers()
|
|
||||||
self.builder_factory = hs.get_event_builder_factory()
|
self.builder_factory = hs.get_event_builder_factory()
|
||||||
self.auth = hs.get_v1auth()
|
self.auth = hs.get_v1auth()
|
||||||
self.txns = HttpTransactionStore()
|
self.txns = HttpTransactionStore()
|
||||||
|
@ -36,6 +36,10 @@ def register_servlets(hs, http_server):
|
|||||||
class ClientDirectoryServer(ClientV1RestServlet):
|
class ClientDirectoryServer(ClientV1RestServlet):
|
||||||
PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$")
|
PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(ClientDirectoryServer, self).__init__(hs)
|
||||||
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, room_alias):
|
def on_GET(self, request, room_alias):
|
||||||
room_alias = RoomAlias.from_string(room_alias)
|
room_alias = RoomAlias.from_string(room_alias)
|
||||||
@ -146,6 +150,7 @@ class ClientDirectoryListServer(ClientV1RestServlet):
|
|||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(ClientDirectoryListServer, self).__init__(hs)
|
super(ClientDirectoryListServer, self).__init__(hs)
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, room_id):
|
def on_GET(self, request, room_id):
|
||||||
|
@ -32,6 +32,10 @@ class EventStreamRestServlet(ClientV1RestServlet):
|
|||||||
|
|
||||||
DEFAULT_LONGPOLL_TIME_MS = 30000
|
DEFAULT_LONGPOLL_TIME_MS = 30000
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(EventStreamRestServlet, self).__init__(hs)
|
||||||
|
self.event_stream_handler = hs.get_event_stream_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
requester = yield self.auth.get_user_by_req(
|
requester = yield self.auth.get_user_by_req(
|
||||||
@ -46,7 +50,6 @@ class EventStreamRestServlet(ClientV1RestServlet):
|
|||||||
if "room_id" in request.args:
|
if "room_id" in request.args:
|
||||||
room_id = request.args["room_id"][0]
|
room_id = request.args["room_id"][0]
|
||||||
|
|
||||||
handler = self.handlers.event_stream_handler
|
|
||||||
pagin_config = PaginationConfig.from_request(request)
|
pagin_config = PaginationConfig.from_request(request)
|
||||||
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
|
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
|
||||||
if "timeout" in request.args:
|
if "timeout" in request.args:
|
||||||
@ -57,7 +60,7 @@ class EventStreamRestServlet(ClientV1RestServlet):
|
|||||||
|
|
||||||
as_client_event = "raw" not in request.args
|
as_client_event = "raw" not in request.args
|
||||||
|
|
||||||
chunk = yield handler.get_stream(
|
chunk = yield self.event_stream_handler.get_stream(
|
||||||
requester.user.to_string(),
|
requester.user.to_string(),
|
||||||
pagin_config,
|
pagin_config,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
@ -80,12 +83,12 @@ class EventRestServlet(ClientV1RestServlet):
|
|||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(EventRestServlet, self).__init__(hs)
|
super(EventRestServlet, self).__init__(hs)
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
self.event_handler = hs.get_event_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, event_id):
|
def on_GET(self, request, event_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
handler = self.handlers.event_handler
|
event = yield self.event_handler.get_event(requester.user, event_id)
|
||||||
event = yield handler.get_event(requester.user, event_id)
|
|
||||||
|
|
||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
if event:
|
if event:
|
||||||
|
@ -23,6 +23,10 @@ from .base import ClientV1RestServlet, client_path_patterns
|
|||||||
class InitialSyncRestServlet(ClientV1RestServlet):
|
class InitialSyncRestServlet(ClientV1RestServlet):
|
||||||
PATTERNS = client_path_patterns("/initialSync$")
|
PATTERNS = client_path_patterns("/initialSync$")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(InitialSyncRestServlet, self).__init__(hs)
|
||||||
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
@ -54,12 +54,9 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||||||
self.jwt_secret = hs.config.jwt_secret
|
self.jwt_secret = hs.config.jwt_secret
|
||||||
self.jwt_algorithm = hs.config.jwt_algorithm
|
self.jwt_algorithm = hs.config.jwt_algorithm
|
||||||
self.cas_enabled = hs.config.cas_enabled
|
self.cas_enabled = hs.config.cas_enabled
|
||||||
self.cas_server_url = hs.config.cas_server_url
|
|
||||||
self.cas_required_attributes = hs.config.cas_required_attributes
|
|
||||||
self.servername = hs.config.server_name
|
|
||||||
self.http_client = hs.get_simple_http_client()
|
|
||||||
self.auth_handler = self.hs.get_auth_handler()
|
self.auth_handler = self.hs.get_auth_handler()
|
||||||
self.device_handler = self.hs.get_device_handler()
|
self.device_handler = self.hs.get_device_handler()
|
||||||
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
flows = []
|
flows = []
|
||||||
@ -110,17 +107,6 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||||||
LoginRestServlet.JWT_TYPE):
|
LoginRestServlet.JWT_TYPE):
|
||||||
result = yield self.do_jwt_login(login_submission)
|
result = yield self.do_jwt_login(login_submission)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
# TODO Delete this after all CAS clients switch to token login instead
|
|
||||||
elif self.cas_enabled and (login_submission["type"] ==
|
|
||||||
LoginRestServlet.CAS_TYPE):
|
|
||||||
uri = "%s/proxyValidate" % (self.cas_server_url,)
|
|
||||||
args = {
|
|
||||||
"ticket": login_submission["ticket"],
|
|
||||||
"service": login_submission["service"]
|
|
||||||
}
|
|
||||||
body = yield self.http_client.get_raw(uri, args)
|
|
||||||
result = yield self.do_cas_login(body)
|
|
||||||
defer.returnValue(result)
|
|
||||||
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
||||||
result = yield self.do_token_login(login_submission)
|
result = yield self.do_token_login(login_submission)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
@ -191,51 +177,6 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||||||
|
|
||||||
defer.returnValue((200, result))
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
# TODO Delete this after all CAS clients switch to token login instead
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def do_cas_login(self, cas_response_body):
|
|
||||||
user, attributes = self.parse_cas_response(cas_response_body)
|
|
||||||
|
|
||||||
for required_attribute, required_value in self.cas_required_attributes.items():
|
|
||||||
# If required attribute was not in CAS Response - Forbidden
|
|
||||||
if required_attribute not in attributes:
|
|
||||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
|
||||||
|
|
||||||
# Also need to check value
|
|
||||||
if required_value is not None:
|
|
||||||
actual_value = attributes[required_attribute]
|
|
||||||
# If required attribute value does not match expected - Forbidden
|
|
||||||
if required_value != actual_value:
|
|
||||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
|
||||||
|
|
||||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
|
||||||
auth_handler = self.auth_handler
|
|
||||||
registered_user_id = yield auth_handler.check_user_exists(user_id)
|
|
||||||
if registered_user_id:
|
|
||||||
access_token, refresh_token = (
|
|
||||||
yield auth_handler.get_login_tuple_for_user_id(
|
|
||||||
registered_user_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
result = {
|
|
||||||
"user_id": registered_user_id, # may have changed
|
|
||||||
"access_token": access_token,
|
|
||||||
"refresh_token": refresh_token,
|
|
||||||
"home_server": self.hs.hostname,
|
|
||||||
}
|
|
||||||
|
|
||||||
else:
|
|
||||||
user_id, access_token = (
|
|
||||||
yield self.handlers.registration_handler.register(localpart=user)
|
|
||||||
)
|
|
||||||
result = {
|
|
||||||
"user_id": user_id, # may have changed
|
|
||||||
"access_token": access_token,
|
|
||||||
"home_server": self.hs.hostname,
|
|
||||||
}
|
|
||||||
|
|
||||||
defer.returnValue((200, result))
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def do_jwt_login(self, login_submission):
|
def do_jwt_login(self, login_submission):
|
||||||
token = login_submission.get("token", None)
|
token = login_submission.get("token", None)
|
||||||
@ -293,33 +234,6 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||||||
|
|
||||||
defer.returnValue((200, result))
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
# TODO Delete this after all CAS clients switch to token login instead
|
|
||||||
def parse_cas_response(self, cas_response_body):
|
|
||||||
root = ET.fromstring(cas_response_body)
|
|
||||||
if not root.tag.endswith("serviceResponse"):
|
|
||||||
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
|
||||||
if not root[0].tag.endswith("authenticationSuccess"):
|
|
||||||
raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
|
|
||||||
for child in root[0]:
|
|
||||||
if child.tag.endswith("user"):
|
|
||||||
user = child.text
|
|
||||||
if child.tag.endswith("attributes"):
|
|
||||||
attributes = {}
|
|
||||||
for attribute in child:
|
|
||||||
# ElementTree library expands the namespace in attribute tags
|
|
||||||
# to the full URL of the namespace.
|
|
||||||
# See (https://docs.python.org/2/library/xml.etree.elementtree.html)
|
|
||||||
# We don't care about namespace here and it will always be encased in
|
|
||||||
# curly braces, so we remove them.
|
|
||||||
if "}" in attribute.tag:
|
|
||||||
attributes[attribute.tag.split("}")[1]] = attribute.text
|
|
||||||
else:
|
|
||||||
attributes[attribute.tag] = attribute.text
|
|
||||||
if user is None or attributes is None:
|
|
||||||
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
|
||||||
|
|
||||||
return (user, attributes)
|
|
||||||
|
|
||||||
def _register_device(self, user_id, login_submission):
|
def _register_device(self, user_id, login_submission):
|
||||||
"""Register a device for a user.
|
"""Register a device for a user.
|
||||||
|
|
||||||
@ -347,6 +261,7 @@ class SAML2RestServlet(ClientV1RestServlet):
|
|||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(SAML2RestServlet, self).__init__(hs)
|
super(SAML2RestServlet, self).__init__(hs)
|
||||||
self.sp_config = hs.config.saml2_config_path
|
self.sp_config = hs.config.saml2_config_path
|
||||||
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
@ -384,18 +299,6 @@ class SAML2RestServlet(ClientV1RestServlet):
|
|||||||
defer.returnValue((200, {"status": "not_authenticated"}))
|
defer.returnValue((200, {"status": "not_authenticated"}))
|
||||||
|
|
||||||
|
|
||||||
# TODO Delete this after all CAS clients switch to token login instead
|
|
||||||
class CasRestServlet(ClientV1RestServlet):
|
|
||||||
PATTERNS = client_path_patterns("/login/cas", releases=())
|
|
||||||
|
|
||||||
def __init__(self, hs):
|
|
||||||
super(CasRestServlet, self).__init__(hs)
|
|
||||||
self.cas_server_url = hs.config.cas_server_url
|
|
||||||
|
|
||||||
def on_GET(self, request):
|
|
||||||
return (200, {"serverUrl": self.cas_server_url})
|
|
||||||
|
|
||||||
|
|
||||||
class CasRedirectServlet(ClientV1RestServlet):
|
class CasRedirectServlet(ClientV1RestServlet):
|
||||||
PATTERNS = client_path_patterns("/login/cas/redirect", releases=())
|
PATTERNS = client_path_patterns("/login/cas/redirect", releases=())
|
||||||
|
|
||||||
@ -427,6 +330,8 @@ class CasTicketServlet(ClientV1RestServlet):
|
|||||||
self.cas_server_url = hs.config.cas_server_url
|
self.cas_server_url = hs.config.cas_server_url
|
||||||
self.cas_service_url = hs.config.cas_service_url
|
self.cas_service_url = hs.config.cas_service_url
|
||||||
self.cas_required_attributes = hs.config.cas_required_attributes
|
self.cas_required_attributes = hs.config.cas_required_attributes
|
||||||
|
self.auth_handler = hs.get_auth_handler()
|
||||||
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
@ -479,30 +384,39 @@ class CasTicketServlet(ClientV1RestServlet):
|
|||||||
return urlparse.urlunparse(url_parts)
|
return urlparse.urlunparse(url_parts)
|
||||||
|
|
||||||
def parse_cas_response(self, cas_response_body):
|
def parse_cas_response(self, cas_response_body):
|
||||||
root = ET.fromstring(cas_response_body)
|
user = None
|
||||||
if not root.tag.endswith("serviceResponse"):
|
attributes = None
|
||||||
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
try:
|
||||||
if not root[0].tag.endswith("authenticationSuccess"):
|
root = ET.fromstring(cas_response_body)
|
||||||
raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
|
if not root.tag.endswith("serviceResponse"):
|
||||||
for child in root[0]:
|
raise Exception("root of CAS response is not serviceResponse")
|
||||||
if child.tag.endswith("user"):
|
success = (root[0].tag.endswith("authenticationSuccess"))
|
||||||
user = child.text
|
for child in root[0]:
|
||||||
if child.tag.endswith("attributes"):
|
if child.tag.endswith("user"):
|
||||||
attributes = {}
|
user = child.text
|
||||||
for attribute in child:
|
if child.tag.endswith("attributes"):
|
||||||
# ElementTree library expands the namespace in attribute tags
|
attributes = {}
|
||||||
# to the full URL of the namespace.
|
for attribute in child:
|
||||||
# See (https://docs.python.org/2/library/xml.etree.elementtree.html)
|
# ElementTree library expands the namespace in
|
||||||
# We don't care about namespace here and it will always be encased in
|
# attribute tags to the full URL of the namespace.
|
||||||
# curly braces, so we remove them.
|
# We don't care about namespace here and it will always
|
||||||
if "}" in attribute.tag:
|
# be encased in curly braces, so we remove them.
|
||||||
attributes[attribute.tag.split("}")[1]] = attribute.text
|
tag = attribute.tag
|
||||||
else:
|
if "}" in tag:
|
||||||
attributes[attribute.tag] = attribute.text
|
tag = tag.split("}")[1]
|
||||||
if user is None or attributes is None:
|
attributes[tag] = attribute.text
|
||||||
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
if user is None:
|
||||||
|
raise Exception("CAS response does not contain user")
|
||||||
return (user, attributes)
|
if attributes is None:
|
||||||
|
raise Exception("CAS response does not contain attributes")
|
||||||
|
except Exception:
|
||||||
|
logger.error("Error parsing CAS response", exc_info=1)
|
||||||
|
raise LoginError(401, "Invalid CAS response",
|
||||||
|
errcode=Codes.UNAUTHORIZED)
|
||||||
|
if not success:
|
||||||
|
raise LoginError(401, "Unsuccessful CAS response",
|
||||||
|
errcode=Codes.UNAUTHORIZED)
|
||||||
|
return user, attributes
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
@ -512,5 +426,3 @@ def register_servlets(hs, http_server):
|
|||||||
if hs.config.cas_enabled:
|
if hs.config.cas_enabled:
|
||||||
CasRedirectServlet(hs).register(http_server)
|
CasRedirectServlet(hs).register(http_server)
|
||||||
CasTicketServlet(hs).register(http_server)
|
CasTicketServlet(hs).register(http_server)
|
||||||
CasRestServlet(hs).register(http_server)
|
|
||||||
# TODO PasswordResetRestServlet(hs).register(http_server)
|
|
||||||
|
@ -24,6 +24,10 @@ from synapse.http.servlet import parse_json_object_from_request
|
|||||||
class ProfileDisplaynameRestServlet(ClientV1RestServlet):
|
class ProfileDisplaynameRestServlet(ClientV1RestServlet):
|
||||||
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname")
|
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(ProfileDisplaynameRestServlet, self).__init__(hs)
|
||||||
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, user_id):
|
def on_GET(self, request, user_id):
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
@ -62,6 +66,10 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
|
|||||||
class ProfileAvatarURLRestServlet(ClientV1RestServlet):
|
class ProfileAvatarURLRestServlet(ClientV1RestServlet):
|
||||||
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url")
|
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(ProfileAvatarURLRestServlet, self).__init__(hs)
|
||||||
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, user_id):
|
def on_GET(self, request, user_id):
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
@ -99,6 +107,10 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
|
|||||||
class ProfileRestServlet(ClientV1RestServlet):
|
class ProfileRestServlet(ClientV1RestServlet):
|
||||||
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)")
|
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(ProfileRestServlet, self).__init__(hs)
|
||||||
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, user_id):
|
def on_GET(self, request, user_id):
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
|
@ -65,6 +65,7 @@ class RegisterRestServlet(ClientV1RestServlet):
|
|||||||
self.sessions = {}
|
self.sessions = {}
|
||||||
self.enable_registration = hs.config.enable_registration
|
self.enable_registration = hs.config.enable_registration
|
||||||
self.auth_handler = hs.get_auth_handler()
|
self.auth_handler = hs.get_auth_handler()
|
||||||
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
if self.hs.config.enable_registration_captcha:
|
if self.hs.config.enable_registration_captcha:
|
||||||
@ -383,6 +384,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
|
|||||||
super(CreateUserRestServlet, self).__init__(hs)
|
super(CreateUserRestServlet, self).__init__(hs)
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
|
self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
|
||||||
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
|
@ -35,6 +35,10 @@ logger = logging.getLogger(__name__)
|
|||||||
class RoomCreateRestServlet(ClientV1RestServlet):
|
class RoomCreateRestServlet(ClientV1RestServlet):
|
||||||
# No PATTERN; we have custom dispatch rules here
|
# No PATTERN; we have custom dispatch rules here
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(RoomCreateRestServlet, self).__init__(hs)
|
||||||
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
def register(self, http_server):
|
def register(self, http_server):
|
||||||
PATTERNS = "/createRoom"
|
PATTERNS = "/createRoom"
|
||||||
register_txn_path(self, PATTERNS, http_server)
|
register_txn_path(self, PATTERNS, http_server)
|
||||||
@ -82,6 +86,10 @@ class RoomCreateRestServlet(ClientV1RestServlet):
|
|||||||
|
|
||||||
# TODO: Needs unit testing for generic events
|
# TODO: Needs unit testing for generic events
|
||||||
class RoomStateEventRestServlet(ClientV1RestServlet):
|
class RoomStateEventRestServlet(ClientV1RestServlet):
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(RoomStateEventRestServlet, self).__init__(hs)
|
||||||
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
def register(self, http_server):
|
def register(self, http_server):
|
||||||
# /room/$roomid/state/$eventtype
|
# /room/$roomid/state/$eventtype
|
||||||
no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$"
|
no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$"
|
||||||
@ -166,6 +174,10 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
|
|||||||
# TODO: Needs unit testing for generic events + feedback
|
# TODO: Needs unit testing for generic events + feedback
|
||||||
class RoomSendEventRestServlet(ClientV1RestServlet):
|
class RoomSendEventRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(RoomSendEventRestServlet, self).__init__(hs)
|
||||||
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
def register(self, http_server):
|
def register(self, http_server):
|
||||||
# /rooms/$roomid/send/$event_type[/$txn_id]
|
# /rooms/$roomid/send/$event_type[/$txn_id]
|
||||||
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)")
|
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)")
|
||||||
@ -210,6 +222,9 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
|
|||||||
|
|
||||||
# TODO: Needs unit testing for room ID + alias joins
|
# TODO: Needs unit testing for room ID + alias joins
|
||||||
class JoinRoomAliasServlet(ClientV1RestServlet):
|
class JoinRoomAliasServlet(ClientV1RestServlet):
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(JoinRoomAliasServlet, self).__init__(hs)
|
||||||
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
def register(self, http_server):
|
def register(self, http_server):
|
||||||
# /join/$room_identifier[/$txn_id]
|
# /join/$room_identifier[/$txn_id]
|
||||||
@ -253,6 +268,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
|
|||||||
action="join",
|
action="join",
|
||||||
txn_id=txn_id,
|
txn_id=txn_id,
|
||||||
remote_room_hosts=remote_room_hosts,
|
remote_room_hosts=remote_room_hosts,
|
||||||
|
content=content,
|
||||||
third_party_signed=content.get("third_party_signed", None),
|
third_party_signed=content.get("third_party_signed", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -296,6 +312,10 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
|
|||||||
class RoomMemberListRestServlet(ClientV1RestServlet):
|
class RoomMemberListRestServlet(ClientV1RestServlet):
|
||||||
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$")
|
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(RoomMemberListRestServlet, self).__init__(hs)
|
||||||
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, room_id):
|
def on_GET(self, request, room_id):
|
||||||
# TODO support Pagination stream API (limit/tokens)
|
# TODO support Pagination stream API (limit/tokens)
|
||||||
@ -322,6 +342,10 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
|
|||||||
class RoomMessageListRestServlet(ClientV1RestServlet):
|
class RoomMessageListRestServlet(ClientV1RestServlet):
|
||||||
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$")
|
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(RoomMessageListRestServlet, self).__init__(hs)
|
||||||
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, room_id):
|
def on_GET(self, request, room_id):
|
||||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
@ -351,6 +375,10 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
|
|||||||
class RoomStateRestServlet(ClientV1RestServlet):
|
class RoomStateRestServlet(ClientV1RestServlet):
|
||||||
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$")
|
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(RoomStateRestServlet, self).__init__(hs)
|
||||||
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, room_id):
|
def on_GET(self, request, room_id):
|
||||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
@ -368,6 +396,10 @@ class RoomStateRestServlet(ClientV1RestServlet):
|
|||||||
class RoomInitialSyncRestServlet(ClientV1RestServlet):
|
class RoomInitialSyncRestServlet(ClientV1RestServlet):
|
||||||
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$")
|
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(RoomInitialSyncRestServlet, self).__init__(hs)
|
||||||
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, room_id):
|
def on_GET(self, request, room_id):
|
||||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
@ -388,6 +420,7 @@ class RoomEventContext(ClientV1RestServlet):
|
|||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(RoomEventContext, self).__init__(hs)
|
super(RoomEventContext, self).__init__(hs)
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, room_id, event_id):
|
def on_GET(self, request, room_id, event_id):
|
||||||
@ -424,6 +457,10 @@ class RoomEventContext(ClientV1RestServlet):
|
|||||||
|
|
||||||
|
|
||||||
class RoomForgetRestServlet(ClientV1RestServlet):
|
class RoomForgetRestServlet(ClientV1RestServlet):
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(RoomForgetRestServlet, self).__init__(hs)
|
||||||
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
def register(self, http_server):
|
def register(self, http_server):
|
||||||
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
|
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
|
||||||
register_txn_path(self, PATTERNS, http_server)
|
register_txn_path(self, PATTERNS, http_server)
|
||||||
@ -462,6 +499,10 @@ class RoomForgetRestServlet(ClientV1RestServlet):
|
|||||||
# TODO: Needs unit testing
|
# TODO: Needs unit testing
|
||||||
class RoomMembershipRestServlet(ClientV1RestServlet):
|
class RoomMembershipRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(RoomMembershipRestServlet, self).__init__(hs)
|
||||||
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
def register(self, http_server):
|
def register(self, http_server):
|
||||||
# /rooms/$roomid/[invite|join|leave]
|
# /rooms/$roomid/[invite|join|leave]
|
||||||
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
|
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
|
||||||
@ -542,6 +583,10 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
|
|||||||
|
|
||||||
|
|
||||||
class RoomRedactEventRestServlet(ClientV1RestServlet):
|
class RoomRedactEventRestServlet(ClientV1RestServlet):
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(RoomRedactEventRestServlet, self).__init__(hs)
|
||||||
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
def register(self, http_server):
|
def register(self, http_server):
|
||||||
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
|
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
|
||||||
register_txn_path(self, PATTERNS, http_server)
|
register_txn_path(self, PATTERNS, http_server)
|
||||||
@ -624,6 +669,10 @@ class SearchRestServlet(ClientV1RestServlet):
|
|||||||
"/search$"
|
"/search$"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(SearchRestServlet, self).__init__(hs)
|
||||||
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
99
synapse/rest/client/v2_alpha/notifications.py
Normal file
99
synapse/rest/client/v2_alpha/notifications.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.http.servlet import (
|
||||||
|
RestServlet, parse_string, parse_integer
|
||||||
|
)
|
||||||
|
from synapse.events.utils import (
|
||||||
|
serialize_event, format_event_for_client_v2_without_room_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ._base import client_v2_patterns
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class NotificationsServlet(RestServlet):
|
||||||
|
PATTERNS = client_v2_patterns("/notifications$", releases=())
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(NotificationsServlet, self).__init__()
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, request):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
from_token = parse_string(request, "from", required=False)
|
||||||
|
limit = parse_integer(request, "limit", default=50)
|
||||||
|
|
||||||
|
limit = min(limit, 500)
|
||||||
|
|
||||||
|
push_actions = yield self.store.get_push_actions_for_user(
|
||||||
|
user_id, from_token, limit
|
||||||
|
)
|
||||||
|
|
||||||
|
receipts_by_room = yield self.store.get_receipts_for_user_with_orderings(
|
||||||
|
user_id, 'm.read'
|
||||||
|
)
|
||||||
|
|
||||||
|
notif_event_ids = [pa["event_id"] for pa in push_actions]
|
||||||
|
notif_events = yield self.store.get_events(notif_event_ids)
|
||||||
|
|
||||||
|
returned_push_actions = []
|
||||||
|
|
||||||
|
next_token = None
|
||||||
|
|
||||||
|
for pa in push_actions:
|
||||||
|
returned_pa = {
|
||||||
|
"room_id": pa["room_id"],
|
||||||
|
"profile_tag": pa["profile_tag"],
|
||||||
|
"actions": pa["actions"],
|
||||||
|
"ts": pa["received_ts"],
|
||||||
|
"event": serialize_event(
|
||||||
|
notif_events[pa["event_id"]],
|
||||||
|
self.clock.time_msec(),
|
||||||
|
event_format=format_event_for_client_v2_without_room_id,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
if pa["room_id"] not in receipts_by_room:
|
||||||
|
returned_pa["read"] = False
|
||||||
|
else:
|
||||||
|
receipt = receipts_by_room[pa["room_id"]]
|
||||||
|
|
||||||
|
returned_pa["read"] = (
|
||||||
|
receipt["topological_ordering"], receipt["stream_ordering"]
|
||||||
|
) >= (
|
||||||
|
pa["topological_ordering"], pa["stream_ordering"]
|
||||||
|
)
|
||||||
|
returned_push_actions.append(returned_pa)
|
||||||
|
next_token = pa["stream_ordering"]
|
||||||
|
|
||||||
|
defer.returnValue((200, {
|
||||||
|
"notifications": returned_push_actions,
|
||||||
|
"next_token": next_token,
|
||||||
|
}))
|
||||||
|
|
||||||
|
|
||||||
|
def register_servlets(hs, http_server):
|
||||||
|
NotificationsServlet(hs).register(http_server)
|
@ -403,10 +403,9 @@ class RegisterRestServlet(RestServlet):
|
|||||||
# register the user's device
|
# register the user's device
|
||||||
device_id = params.get("device_id")
|
device_id = params.get("device_id")
|
||||||
initial_display_name = params.get("initial_device_display_name")
|
initial_display_name = params.get("initial_device_display_name")
|
||||||
device_id = self.device_handler.check_device_registered(
|
return self.device_handler.check_device_registered(
|
||||||
user_id, device_id, initial_display_name
|
user_id, device_id, initial_display_name
|
||||||
)
|
)
|
||||||
return device_id
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _do_guest_registration(self):
|
def _do_guest_registration(self):
|
||||||
|
@ -146,7 +146,7 @@ class SyncRestServlet(RestServlet):
|
|||||||
affect_presence = set_presence != PresenceState.OFFLINE
|
affect_presence = set_presence != PresenceState.OFFLINE
|
||||||
|
|
||||||
if affect_presence:
|
if affect_presence:
|
||||||
yield self.presence_handler.set_state(user, {"presence": set_presence})
|
yield self.presence_handler.set_state(user, {"presence": set_presence}, True)
|
||||||
|
|
||||||
context = yield self.presence_handler.user_syncing(
|
context = yield self.presence_handler.user_syncing(
|
||||||
user.to_string(), affect_presence=affect_presence,
|
user.to_string(), affect_presence=affect_presence,
|
||||||
|
78
synapse/rest/client/v2_alpha/thirdparty.py
Normal file
78
synapse/rest/client/v2_alpha/thirdparty.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2015, 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.http.servlet import RestServlet
|
||||||
|
from synapse.types import ThirdPartyEntityKind
|
||||||
|
from ._base import client_v2_patterns
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ThirdPartyUserServlet(RestServlet):
|
||||||
|
PATTERNS = client_v2_patterns("/3pu(/(?P<protocol>[^/]+))?$",
|
||||||
|
releases=())
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(ThirdPartyUserServlet, self).__init__()
|
||||||
|
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.appservice_handler = hs.get_application_service_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, request, protocol):
|
||||||
|
yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
|
fields = request.args
|
||||||
|
del fields["access_token"]
|
||||||
|
|
||||||
|
results = yield self.appservice_handler.query_3pe(
|
||||||
|
ThirdPartyEntityKind.USER, protocol, fields
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, results))
|
||||||
|
|
||||||
|
|
||||||
|
class ThirdPartyLocationServlet(RestServlet):
|
||||||
|
PATTERNS = client_v2_patterns("/3pl(/(?P<protocol>[^/]+))?$",
|
||||||
|
releases=())
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(ThirdPartyLocationServlet, self).__init__()
|
||||||
|
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.appservice_handler = hs.get_application_service_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, request, protocol):
|
||||||
|
yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
|
fields = request.args
|
||||||
|
del fields["access_token"]
|
||||||
|
|
||||||
|
results = yield self.appservice_handler.query_3pe(
|
||||||
|
ThirdPartyEntityKind.LOCATION, protocol, fields
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, results))
|
||||||
|
|
||||||
|
|
||||||
|
def register_servlets(hs, http_server):
|
||||||
|
ThirdPartyUserServlet(hs).register(http_server)
|
||||||
|
ThirdPartyLocationServlet(hs).register(http_server)
|
@ -15,6 +15,7 @@
|
|||||||
from synapse.http.server import request_handler, respond_with_json_bytes
|
from synapse.http.server import request_handler, respond_with_json_bytes
|
||||||
from synapse.http.servlet import parse_integer, parse_json_object_from_request
|
from synapse.http.servlet import parse_integer, parse_json_object_from_request
|
||||||
from synapse.api.errors import SynapseError, Codes
|
from synapse.api.errors import SynapseError, Codes
|
||||||
|
from synapse.crypto.keyring import KeyLookupError
|
||||||
|
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
from twisted.web.server import NOT_DONE_YET
|
from twisted.web.server import NOT_DONE_YET
|
||||||
@ -210,9 +211,10 @@ class RemoteKey(Resource):
|
|||||||
yield self.keyring.get_server_verify_key_v2_direct(
|
yield self.keyring.get_server_verify_key_v2_direct(
|
||||||
server_name, key_ids
|
server_name, key_ids
|
||||||
)
|
)
|
||||||
|
except KeyLookupError as e:
|
||||||
|
logger.info("Failed to fetch key: %s", e)
|
||||||
except:
|
except:
|
||||||
logger.exception("Failed to get key for %r", server_name)
|
logger.exception("Failed to get key for %r", server_name)
|
||||||
pass
|
|
||||||
yield self.query_keys(
|
yield self.query_keys(
|
||||||
request, query, query_remote_on_cache_miss=False
|
request, query, query_remote_on_cache_miss=False
|
||||||
)
|
)
|
||||||
|
@ -45,6 +45,7 @@ class DownloadResource(Resource):
|
|||||||
@request_handler()
|
@request_handler()
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _async_render_GET(self, request):
|
def _async_render_GET(self, request):
|
||||||
|
request.setHeader("Content-Security-Policy", "sandbox")
|
||||||
server_name, media_id, name = parse_media_id(request)
|
server_name, media_id, name = parse_media_id(request)
|
||||||
if server_name == self.server_name:
|
if server_name == self.server_name:
|
||||||
yield self._respond_local_file(request, media_id, name)
|
yield self._respond_local_file(request, media_id, name)
|
||||||
|
@ -29,14 +29,13 @@ from synapse.http.server import (
|
|||||||
from synapse.util.async import ObservableDeferred
|
from synapse.util.async import ObservableDeferred
|
||||||
from synapse.util.stringutils import is_ascii
|
from synapse.util.stringutils import is_ascii
|
||||||
|
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import fnmatch
|
import fnmatch
|
||||||
import cgi
|
import cgi
|
||||||
import ujson as json
|
import ujson as json
|
||||||
import urlparse
|
import urlparse
|
||||||
|
import itertools
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -163,7 +162,7 @@ class PreviewUrlResource(Resource):
|
|||||||
|
|
||||||
logger.debug("got media_info of '%s'" % media_info)
|
logger.debug("got media_info of '%s'" % media_info)
|
||||||
|
|
||||||
if self._is_media(media_info['media_type']):
|
if _is_media(media_info['media_type']):
|
||||||
dims = yield self.media_repo._generate_local_thumbnails(
|
dims = yield self.media_repo._generate_local_thumbnails(
|
||||||
media_info['filesystem_id'], media_info
|
media_info['filesystem_id'], media_info
|
||||||
)
|
)
|
||||||
@ -184,11 +183,9 @@ class PreviewUrlResource(Resource):
|
|||||||
logger.warn("Couldn't get dims for %s" % url)
|
logger.warn("Couldn't get dims for %s" % url)
|
||||||
|
|
||||||
# define our OG response for this media
|
# define our OG response for this media
|
||||||
elif self._is_html(media_info['media_type']):
|
elif _is_html(media_info['media_type']):
|
||||||
# TODO: somehow stop a big HTML tree from exploding synapse's RAM
|
# TODO: somehow stop a big HTML tree from exploding synapse's RAM
|
||||||
|
|
||||||
from lxml import etree
|
|
||||||
|
|
||||||
file = open(media_info['filename'])
|
file = open(media_info['filename'])
|
||||||
body = file.read()
|
body = file.read()
|
||||||
file.close()
|
file.close()
|
||||||
@ -199,17 +196,35 @@ class PreviewUrlResource(Resource):
|
|||||||
match = re.match(r'.*; *charset=(.*?)(;|$)', media_info['media_type'], re.I)
|
match = re.match(r'.*; *charset=(.*?)(;|$)', media_info['media_type'], re.I)
|
||||||
encoding = match.group(1) if match else "utf-8"
|
encoding = match.group(1) if match else "utf-8"
|
||||||
|
|
||||||
try:
|
og = decode_and_calc_og(body, media_info['uri'], encoding)
|
||||||
parser = etree.HTMLParser(recover=True, encoding=encoding)
|
|
||||||
tree = etree.fromstring(body, parser)
|
|
||||||
og = yield self._calc_og(tree, media_info, requester)
|
|
||||||
except UnicodeDecodeError:
|
|
||||||
# blindly try decoding the body as utf-8, which seems to fix
|
|
||||||
# the charset mismatches on https://google.com
|
|
||||||
parser = etree.HTMLParser(recover=True, encoding=encoding)
|
|
||||||
tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser)
|
|
||||||
og = yield self._calc_og(tree, media_info, requester)
|
|
||||||
|
|
||||||
|
# pre-cache the image for posterity
|
||||||
|
# FIXME: it might be cleaner to use the same flow as the main /preview_url
|
||||||
|
# request itself and benefit from the same caching etc. But for now we
|
||||||
|
# just rely on the caching on the master request to speed things up.
|
||||||
|
if 'og:image' in og and og['og:image']:
|
||||||
|
image_info = yield self._download_url(
|
||||||
|
_rebase_url(og['og:image'], media_info['uri']), requester.user
|
||||||
|
)
|
||||||
|
|
||||||
|
if _is_media(image_info['media_type']):
|
||||||
|
# TODO: make sure we don't choke on white-on-transparent images
|
||||||
|
dims = yield self.media_repo._generate_local_thumbnails(
|
||||||
|
image_info['filesystem_id'], image_info
|
||||||
|
)
|
||||||
|
if dims:
|
||||||
|
og["og:image:width"] = dims['width']
|
||||||
|
og["og:image:height"] = dims['height']
|
||||||
|
else:
|
||||||
|
logger.warn("Couldn't get dims for %s" % og["og:image"])
|
||||||
|
|
||||||
|
og["og:image"] = "mxc://%s/%s" % (
|
||||||
|
self.server_name, image_info['filesystem_id']
|
||||||
|
)
|
||||||
|
og["og:image:type"] = image_info['media_type']
|
||||||
|
og["matrix:image:size"] = image_info['media_length']
|
||||||
|
else:
|
||||||
|
del og["og:image"]
|
||||||
else:
|
else:
|
||||||
logger.warn("Failed to find any OG data in %s", url)
|
logger.warn("Failed to find any OG data in %s", url)
|
||||||
og = {}
|
og = {}
|
||||||
@ -232,139 +247,6 @@ class PreviewUrlResource(Resource):
|
|||||||
|
|
||||||
respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
|
respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _calc_og(self, tree, media_info, requester):
|
|
||||||
# suck our tree into lxml and define our OG response.
|
|
||||||
|
|
||||||
# if we see any image URLs in the OG response, then spider them
|
|
||||||
# (although the client could choose to do this by asking for previews of those
|
|
||||||
# URLs to avoid DoSing the server)
|
|
||||||
|
|
||||||
# "og:type" : "video",
|
|
||||||
# "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
|
|
||||||
# "og:site_name" : "YouTube",
|
|
||||||
# "og:video:type" : "application/x-shockwave-flash",
|
|
||||||
# "og:description" : "Fun stuff happening here",
|
|
||||||
# "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
|
|
||||||
# "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
|
|
||||||
# "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
|
|
||||||
# "og:video:width" : "1280"
|
|
||||||
# "og:video:height" : "720",
|
|
||||||
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
|
|
||||||
|
|
||||||
og = {}
|
|
||||||
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
|
|
||||||
if 'content' in tag.attrib:
|
|
||||||
og[tag.attrib['property']] = tag.attrib['content']
|
|
||||||
|
|
||||||
# TODO: grab article: meta tags too, e.g.:
|
|
||||||
|
|
||||||
# "article:publisher" : "https://www.facebook.com/thethudonline" />
|
|
||||||
# "article:author" content="https://www.facebook.com/thethudonline" />
|
|
||||||
# "article:tag" content="baby" />
|
|
||||||
# "article:section" content="Breaking News" />
|
|
||||||
# "article:published_time" content="2016-03-31T19:58:24+00:00" />
|
|
||||||
# "article:modified_time" content="2016-04-01T18:31:53+00:00" />
|
|
||||||
|
|
||||||
if 'og:title' not in og:
|
|
||||||
# do some basic spidering of the HTML
|
|
||||||
title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
|
|
||||||
og['og:title'] = title[0].text.strip() if title else None
|
|
||||||
|
|
||||||
if 'og:image' not in og:
|
|
||||||
# TODO: extract a favicon failing all else
|
|
||||||
meta_image = tree.xpath(
|
|
||||||
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
|
|
||||||
)
|
|
||||||
if meta_image:
|
|
||||||
og['og:image'] = self._rebase_url(meta_image[0], media_info['uri'])
|
|
||||||
else:
|
|
||||||
# TODO: consider inlined CSS styles as well as width & height attribs
|
|
||||||
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
|
|
||||||
images = sorted(images, key=lambda i: (
|
|
||||||
-1 * float(i.attrib['width']) * float(i.attrib['height'])
|
|
||||||
))
|
|
||||||
if not images:
|
|
||||||
images = tree.xpath("//img[@src]")
|
|
||||||
if images:
|
|
||||||
og['og:image'] = images[0].attrib['src']
|
|
||||||
|
|
||||||
# pre-cache the image for posterity
|
|
||||||
# FIXME: it might be cleaner to use the same flow as the main /preview_url
|
|
||||||
# request itself and benefit from the same caching etc. But for now we
|
|
||||||
# just rely on the caching on the master request to speed things up.
|
|
||||||
if 'og:image' in og and og['og:image']:
|
|
||||||
image_info = yield self._download_url(
|
|
||||||
self._rebase_url(og['og:image'], media_info['uri']), requester.user
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._is_media(image_info['media_type']):
|
|
||||||
# TODO: make sure we don't choke on white-on-transparent images
|
|
||||||
dims = yield self.media_repo._generate_local_thumbnails(
|
|
||||||
image_info['filesystem_id'], image_info
|
|
||||||
)
|
|
||||||
if dims:
|
|
||||||
og["og:image:width"] = dims['width']
|
|
||||||
og["og:image:height"] = dims['height']
|
|
||||||
else:
|
|
||||||
logger.warn("Couldn't get dims for %s" % og["og:image"])
|
|
||||||
|
|
||||||
og["og:image"] = "mxc://%s/%s" % (
|
|
||||||
self.server_name, image_info['filesystem_id']
|
|
||||||
)
|
|
||||||
og["og:image:type"] = image_info['media_type']
|
|
||||||
og["matrix:image:size"] = image_info['media_length']
|
|
||||||
else:
|
|
||||||
del og["og:image"]
|
|
||||||
|
|
||||||
if 'og:description' not in og:
|
|
||||||
meta_description = tree.xpath(
|
|
||||||
"//*/meta"
|
|
||||||
"[translate(@name, 'DESCRIPTION', 'description')='description']"
|
|
||||||
"/@content")
|
|
||||||
if meta_description:
|
|
||||||
og['og:description'] = meta_description[0]
|
|
||||||
else:
|
|
||||||
# grab any text nodes which are inside the <body/> tag...
|
|
||||||
# unless they are within an HTML5 semantic markup tag...
|
|
||||||
# <header/>, <nav/>, <aside/>, <footer/>
|
|
||||||
# ...or if they are within a <script/> or <style/> tag.
|
|
||||||
# This is a very very very coarse approximation to a plain text
|
|
||||||
# render of the page.
|
|
||||||
|
|
||||||
# We don't just use XPATH here as that is slow on some machines.
|
|
||||||
|
|
||||||
# We clone `tree` as we modify it.
|
|
||||||
cloned_tree = deepcopy(tree.find("body"))
|
|
||||||
|
|
||||||
TAGS_TO_REMOVE = ("header", "nav", "aside", "footer", "script", "style",)
|
|
||||||
for el in cloned_tree.iter(TAGS_TO_REMOVE):
|
|
||||||
el.getparent().remove(el)
|
|
||||||
|
|
||||||
# Split all the text nodes into paragraphs (by splitting on new
|
|
||||||
# lines)
|
|
||||||
text_nodes = (
|
|
||||||
re.sub(r'\s+', '\n', el.text).strip()
|
|
||||||
for el in cloned_tree.iter()
|
|
||||||
if el.text and isinstance(el.tag, basestring) # Removes comments
|
|
||||||
)
|
|
||||||
og['og:description'] = summarize_paragraphs(text_nodes)
|
|
||||||
|
|
||||||
# TODO: delete the url downloads to stop diskfilling,
|
|
||||||
# as we only ever cared about its OG
|
|
||||||
defer.returnValue(og)
|
|
||||||
|
|
||||||
def _rebase_url(self, url, base):
|
|
||||||
base = list(urlparse.urlparse(base))
|
|
||||||
url = list(urlparse.urlparse(url))
|
|
||||||
if not url[0]: # fix up schema
|
|
||||||
url[0] = base[0] or "http"
|
|
||||||
if not url[1]: # fix up hostname
|
|
||||||
url[1] = base[1]
|
|
||||||
if not url[2].startswith('/'):
|
|
||||||
url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2]
|
|
||||||
return urlparse.urlunparse(url)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _download_url(self, url, user):
|
def _download_url(self, url, user):
|
||||||
# TODO: we should probably honour robots.txt... except in practice
|
# TODO: we should probably honour robots.txt... except in practice
|
||||||
@ -445,17 +327,171 @@ class PreviewUrlResource(Resource):
|
|||||||
"etag": headers["ETag"][0] if "ETag" in headers else None,
|
"etag": headers["ETag"][0] if "ETag" in headers else None,
|
||||||
})
|
})
|
||||||
|
|
||||||
def _is_media(self, content_type):
|
|
||||||
if content_type.lower().startswith("image/"):
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _is_html(self, content_type):
|
def decode_and_calc_og(body, media_uri, request_encoding=None):
|
||||||
content_type = content_type.lower()
|
from lxml import etree
|
||||||
if (
|
|
||||||
content_type.startswith("text/html") or
|
try:
|
||||||
content_type.startswith("application/xhtml")
|
parser = etree.HTMLParser(recover=True, encoding=request_encoding)
|
||||||
):
|
tree = etree.fromstring(body, parser)
|
||||||
return True
|
og = _calc_og(tree, media_uri)
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
# blindly try decoding the body as utf-8, which seems to fix
|
||||||
|
# the charset mismatches on https://google.com
|
||||||
|
parser = etree.HTMLParser(recover=True, encoding=request_encoding)
|
||||||
|
tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser)
|
||||||
|
og = _calc_og(tree, media_uri)
|
||||||
|
|
||||||
|
return og
|
||||||
|
|
||||||
|
|
||||||
|
def _calc_og(tree, media_uri):
|
||||||
|
# suck our tree into lxml and define our OG response.
|
||||||
|
|
||||||
|
# if we see any image URLs in the OG response, then spider them
|
||||||
|
# (although the client could choose to do this by asking for previews of those
|
||||||
|
# URLs to avoid DoSing the server)
|
||||||
|
|
||||||
|
# "og:type" : "video",
|
||||||
|
# "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
|
||||||
|
# "og:site_name" : "YouTube",
|
||||||
|
# "og:video:type" : "application/x-shockwave-flash",
|
||||||
|
# "og:description" : "Fun stuff happening here",
|
||||||
|
# "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
|
||||||
|
# "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
|
||||||
|
# "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
|
||||||
|
# "og:video:width" : "1280"
|
||||||
|
# "og:video:height" : "720",
|
||||||
|
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
|
||||||
|
|
||||||
|
og = {}
|
||||||
|
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
|
||||||
|
if 'content' in tag.attrib:
|
||||||
|
og[tag.attrib['property']] = tag.attrib['content']
|
||||||
|
|
||||||
|
# TODO: grab article: meta tags too, e.g.:
|
||||||
|
|
||||||
|
# "article:publisher" : "https://www.facebook.com/thethudonline" />
|
||||||
|
# "article:author" content="https://www.facebook.com/thethudonline" />
|
||||||
|
# "article:tag" content="baby" />
|
||||||
|
# "article:section" content="Breaking News" />
|
||||||
|
# "article:published_time" content="2016-03-31T19:58:24+00:00" />
|
||||||
|
# "article:modified_time" content="2016-04-01T18:31:53+00:00" />
|
||||||
|
|
||||||
|
if 'og:title' not in og:
|
||||||
|
# do some basic spidering of the HTML
|
||||||
|
title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
|
||||||
|
og['og:title'] = title[0].text.strip() if title else None
|
||||||
|
|
||||||
|
if 'og:image' not in og:
|
||||||
|
# TODO: extract a favicon failing all else
|
||||||
|
meta_image = tree.xpath(
|
||||||
|
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
|
||||||
|
)
|
||||||
|
if meta_image:
|
||||||
|
og['og:image'] = _rebase_url(meta_image[0], media_uri)
|
||||||
|
else:
|
||||||
|
# TODO: consider inlined CSS styles as well as width & height attribs
|
||||||
|
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
|
||||||
|
images = sorted(images, key=lambda i: (
|
||||||
|
-1 * float(i.attrib['width']) * float(i.attrib['height'])
|
||||||
|
))
|
||||||
|
if not images:
|
||||||
|
images = tree.xpath("//img[@src]")
|
||||||
|
if images:
|
||||||
|
og['og:image'] = images[0].attrib['src']
|
||||||
|
|
||||||
|
if 'og:description' not in og:
|
||||||
|
meta_description = tree.xpath(
|
||||||
|
"//*/meta"
|
||||||
|
"[translate(@name, 'DESCRIPTION', 'description')='description']"
|
||||||
|
"/@content")
|
||||||
|
if meta_description:
|
||||||
|
og['og:description'] = meta_description[0]
|
||||||
|
else:
|
||||||
|
# grab any text nodes which are inside the <body/> tag...
|
||||||
|
# unless they are within an HTML5 semantic markup tag...
|
||||||
|
# <header/>, <nav/>, <aside/>, <footer/>
|
||||||
|
# ...or if they are within a <script/> or <style/> tag.
|
||||||
|
# This is a very very very coarse approximation to a plain text
|
||||||
|
# render of the page.
|
||||||
|
|
||||||
|
# We don't just use XPATH here as that is slow on some machines.
|
||||||
|
|
||||||
|
from lxml import etree
|
||||||
|
|
||||||
|
TAGS_TO_REMOVE = (
|
||||||
|
"header", "nav", "aside", "footer", "script", "style", etree.Comment
|
||||||
|
)
|
||||||
|
|
||||||
|
# Split all the text nodes into paragraphs (by splitting on new
|
||||||
|
# lines)
|
||||||
|
text_nodes = (
|
||||||
|
re.sub(r'\s+', '\n', el).strip()
|
||||||
|
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
|
||||||
|
)
|
||||||
|
og['og:description'] = summarize_paragraphs(text_nodes)
|
||||||
|
|
||||||
|
# TODO: delete the url downloads to stop diskfilling,
|
||||||
|
# as we only ever cared about its OG
|
||||||
|
return og
|
||||||
|
|
||||||
|
|
||||||
|
def _iterate_over_text(tree, *tags_to_ignore):
|
||||||
|
"""Iterate over the tree returning text nodes in a depth first fashion,
|
||||||
|
skipping text nodes inside certain tags.
|
||||||
|
"""
|
||||||
|
# This is basically a stack that we extend using itertools.chain.
|
||||||
|
# This will either consist of an element to iterate over *or* a string
|
||||||
|
# to be returned.
|
||||||
|
elements = iter([tree])
|
||||||
|
while True:
|
||||||
|
el = elements.next()
|
||||||
|
if isinstance(el, basestring):
|
||||||
|
yield el
|
||||||
|
elif el is not None and el.tag not in tags_to_ignore:
|
||||||
|
# el.text is the text before the first child, so we can immediately
|
||||||
|
# return it if the text exists.
|
||||||
|
if el.text:
|
||||||
|
yield el.text
|
||||||
|
|
||||||
|
# We add to the stack all the elements children, interspersed with
|
||||||
|
# each child's tail text (if it exists). The tail text of a node
|
||||||
|
# is text that comes *after* the node, so we always include it even
|
||||||
|
# if we ignore the child node.
|
||||||
|
elements = itertools.chain(
|
||||||
|
itertools.chain.from_iterable( # Basically a flatmap
|
||||||
|
[child, child.tail] if child.tail else [child]
|
||||||
|
for child in el.iterchildren()
|
||||||
|
),
|
||||||
|
elements
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _rebase_url(url, base):
|
||||||
|
base = list(urlparse.urlparse(base))
|
||||||
|
url = list(urlparse.urlparse(url))
|
||||||
|
if not url[0]: # fix up schema
|
||||||
|
url[0] = base[0] or "http"
|
||||||
|
if not url[1]: # fix up hostname
|
||||||
|
url[1] = base[1]
|
||||||
|
if not url[2].startswith('/'):
|
||||||
|
url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2]
|
||||||
|
return urlparse.urlunparse(url)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_media(content_type):
|
||||||
|
if content_type.lower().startswith("image/"):
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _is_html(content_type):
|
||||||
|
content_type = content_type.lower()
|
||||||
|
if (
|
||||||
|
content_type.startswith("text/html") or
|
||||||
|
content_type.startswith("application/xhtml")
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
|
def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
|
||||||
|
@ -41,6 +41,7 @@ from synapse.handlers.presence import PresenceHandler
|
|||||||
from synapse.handlers.room import RoomListHandler
|
from synapse.handlers.room import RoomListHandler
|
||||||
from synapse.handlers.sync import SyncHandler
|
from synapse.handlers.sync import SyncHandler
|
||||||
from synapse.handlers.typing import TypingHandler
|
from synapse.handlers.typing import TypingHandler
|
||||||
|
from synapse.handlers.events import EventHandler, EventStreamHandler
|
||||||
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
|
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
|
||||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||||
from synapse.notifier import Notifier
|
from synapse.notifier import Notifier
|
||||||
@ -94,6 +95,8 @@ class HomeServer(object):
|
|||||||
'auth_handler',
|
'auth_handler',
|
||||||
'device_handler',
|
'device_handler',
|
||||||
'e2e_keys_handler',
|
'e2e_keys_handler',
|
||||||
|
'event_handler',
|
||||||
|
'event_stream_handler',
|
||||||
'application_service_api',
|
'application_service_api',
|
||||||
'application_service_scheduler',
|
'application_service_scheduler',
|
||||||
'application_service_handler',
|
'application_service_handler',
|
||||||
@ -214,6 +217,12 @@ class HomeServer(object):
|
|||||||
def build_application_service_handler(self):
|
def build_application_service_handler(self):
|
||||||
return ApplicationServicesHandler(self)
|
return ApplicationServicesHandler(self)
|
||||||
|
|
||||||
|
def build_event_handler(self):
|
||||||
|
return EventHandler(self)
|
||||||
|
|
||||||
|
def build_event_stream_handler(self):
|
||||||
|
return EventStreamHandler(self)
|
||||||
|
|
||||||
def build_event_sources(self):
|
def build_event_sources(self):
|
||||||
return EventSources(self)
|
return EventSources(self)
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import synapse.api.auth
|
||||||
import synapse.handlers
|
import synapse.handlers
|
||||||
import synapse.handlers.auth
|
import synapse.handlers.auth
|
||||||
import synapse.handlers.device
|
import synapse.handlers.device
|
||||||
@ -6,6 +7,9 @@ import synapse.storage
|
|||||||
import synapse.state
|
import synapse.state
|
||||||
|
|
||||||
class HomeServer(object):
|
class HomeServer(object):
|
||||||
|
def get_auth(self) -> synapse.api.auth.Auth:
|
||||||
|
pass
|
||||||
|
|
||||||
def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler:
|
def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -50,6 +50,7 @@ from .openid import OpenIdStore
|
|||||||
from .client_ips import ClientIpStore
|
from .client_ips import ClientIpStore
|
||||||
|
|
||||||
from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
|
from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
|
||||||
|
from .engines import PostgresEngine
|
||||||
|
|
||||||
from synapse.api.constants import PresenceState
|
from synapse.api.constants import PresenceState
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
@ -123,6 +124,13 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||||||
extra_tables=[("deleted_pushers", "stream_id")],
|
extra_tables=[("deleted_pushers", "stream_id")],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
|
self._cache_id_gen = StreamIdGenerator(
|
||||||
|
db_conn, "cache_invalidation_stream", "stream_id",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._cache_id_gen = None
|
||||||
|
|
||||||
events_max = self._stream_id_gen.get_current_token()
|
events_max = self._stream_id_gen.get_current_token()
|
||||||
event_cache_prefill, min_event_val = self._get_cache_dict(
|
event_cache_prefill, min_event_val = self._get_cache_dict(
|
||||||
db_conn, "events",
|
db_conn, "events",
|
||||||
|
@ -19,6 +19,7 @@ from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
|||||||
from synapse.util.caches.dictionary_cache import DictionaryCache
|
from synapse.util.caches.dictionary_cache import DictionaryCache
|
||||||
from synapse.util.caches.descriptors import Cache
|
from synapse.util.caches.descriptors import Cache
|
||||||
from synapse.util.caches import intern_dict
|
from synapse.util.caches import intern_dict
|
||||||
|
from synapse.storage.engines import PostgresEngine
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
|
|
||||||
@ -165,7 +166,7 @@ class SQLBaseStore(object):
|
|||||||
self._txn_perf_counters = PerformanceCounters()
|
self._txn_perf_counters = PerformanceCounters()
|
||||||
self._get_event_counters = PerformanceCounters()
|
self._get_event_counters = PerformanceCounters()
|
||||||
|
|
||||||
self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
|
self._get_event_cache = Cache("*getEvent*", keylen=3,
|
||||||
max_entries=hs.config.event_cache_size)
|
max_entries=hs.config.event_cache_size)
|
||||||
|
|
||||||
self._state_group_cache = DictionaryCache(
|
self._state_group_cache = DictionaryCache(
|
||||||
@ -305,13 +306,14 @@ class SQLBaseStore(object):
|
|||||||
func, *args, **kwargs
|
func, *args, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
try:
|
||||||
result = yield self._db_pool.runWithConnection(
|
with PreserveLoggingContext():
|
||||||
inner_func, *args, **kwargs
|
result = yield self._db_pool.runWithConnection(
|
||||||
)
|
inner_func, *args, **kwargs
|
||||||
|
)
|
||||||
for after_callback, after_args in after_callbacks:
|
finally:
|
||||||
after_callback(*after_args)
|
for after_callback, after_args in after_callbacks:
|
||||||
|
after_callback(*after_args)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -860,6 +862,62 @@ class SQLBaseStore(object):
|
|||||||
|
|
||||||
return cache, min_val
|
return cache, min_val
|
||||||
|
|
||||||
|
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
|
||||||
|
"""Invalidates the cache and adds it to the cache stream so slaves
|
||||||
|
will know to invalidate their caches.
|
||||||
|
|
||||||
|
This should only be used to invalidate caches where slaves won't
|
||||||
|
otherwise know from other replication streams that the cache should
|
||||||
|
be invalidated.
|
||||||
|
"""
|
||||||
|
txn.call_after(cache_func.invalidate, keys)
|
||||||
|
|
||||||
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
|
# get_next() returns a context manager which is designed to wrap
|
||||||
|
# the transaction. However, we want to only get an ID when we want
|
||||||
|
# to use it, here, so we need to call __enter__ manually, and have
|
||||||
|
# __exit__ called after the transaction finishes.
|
||||||
|
ctx = self._cache_id_gen.get_next()
|
||||||
|
stream_id = ctx.__enter__()
|
||||||
|
txn.call_after(ctx.__exit__, None, None, None)
|
||||||
|
txn.call_after(self.hs.get_notifier().on_new_replication_data)
|
||||||
|
|
||||||
|
self._simple_insert_txn(
|
||||||
|
txn,
|
||||||
|
table="cache_invalidation_stream",
|
||||||
|
values={
|
||||||
|
"stream_id": stream_id,
|
||||||
|
"cache_func": cache_func.__name__,
|
||||||
|
"keys": list(keys),
|
||||||
|
"invalidation_ts": self.clock.time_msec(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_all_updated_caches(self, last_id, current_id, limit):
|
||||||
|
if last_id == current_id:
|
||||||
|
return defer.succeed([])
|
||||||
|
|
||||||
|
def get_all_updated_caches_txn(txn):
|
||||||
|
# We purposefully don't bound by the current token, as we want to
|
||||||
|
# send across cache invalidations as quickly as possible. Cache
|
||||||
|
# invalidations are idempotent, so duplicates are fine.
|
||||||
|
sql = (
|
||||||
|
"SELECT stream_id, cache_func, keys, invalidation_ts"
|
||||||
|
" FROM cache_invalidation_stream"
|
||||||
|
" WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
|
||||||
|
)
|
||||||
|
txn.execute(sql, (last_id, limit,))
|
||||||
|
return txn.fetchall()
|
||||||
|
return self.runInteraction(
|
||||||
|
"get_all_updated_caches", get_all_updated_caches_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_cache_stream_token(self):
|
||||||
|
if self._cache_id_gen:
|
||||||
|
return self._cache_id_gen.get_current_token()
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
class _RollbackButIsFineException(Exception):
|
class _RollbackButIsFineException(Exception):
|
||||||
""" This exception is used to rollback a transaction without implying
|
""" This exception is used to rollback a transaction without implying
|
||||||
|
@ -218,38 +218,37 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
|
|||||||
Returns:
|
Returns:
|
||||||
AppServiceTransaction: A new transaction.
|
AppServiceTransaction: A new transaction.
|
||||||
"""
|
"""
|
||||||
|
def _create_appservice_txn(txn):
|
||||||
|
# work out new txn id (highest txn id for this service += 1)
|
||||||
|
# The highest id may be the last one sent (in which case it is last_txn)
|
||||||
|
# or it may be the highest in the txns list (which are waiting to be/are
|
||||||
|
# being sent)
|
||||||
|
last_txn_id = self._get_last_txn(txn, service.id)
|
||||||
|
|
||||||
|
txn.execute(
|
||||||
|
"SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?",
|
||||||
|
(service.id,)
|
||||||
|
)
|
||||||
|
highest_txn_id = txn.fetchone()[0]
|
||||||
|
if highest_txn_id is None:
|
||||||
|
highest_txn_id = 0
|
||||||
|
|
||||||
|
new_txn_id = max(highest_txn_id, last_txn_id) + 1
|
||||||
|
|
||||||
|
# Insert new txn into txn table
|
||||||
|
event_ids = json.dumps([e.event_id for e in events])
|
||||||
|
txn.execute(
|
||||||
|
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
|
||||||
|
"VALUES(?,?,?)",
|
||||||
|
(service.id, new_txn_id, event_ids)
|
||||||
|
)
|
||||||
|
return AppServiceTransaction(
|
||||||
|
service=service, id=new_txn_id, events=events
|
||||||
|
)
|
||||||
|
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"create_appservice_txn",
|
"create_appservice_txn",
|
||||||
self._create_appservice_txn,
|
_create_appservice_txn,
|
||||||
service, events
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_appservice_txn(self, txn, service, events):
|
|
||||||
# work out new txn id (highest txn id for this service += 1)
|
|
||||||
# The highest id may be the last one sent (in which case it is last_txn)
|
|
||||||
# or it may be the highest in the txns list (which are waiting to be/are
|
|
||||||
# being sent)
|
|
||||||
last_txn_id = self._get_last_txn(txn, service.id)
|
|
||||||
|
|
||||||
txn.execute(
|
|
||||||
"SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?",
|
|
||||||
(service.id,)
|
|
||||||
)
|
|
||||||
highest_txn_id = txn.fetchone()[0]
|
|
||||||
if highest_txn_id is None:
|
|
||||||
highest_txn_id = 0
|
|
||||||
|
|
||||||
new_txn_id = max(highest_txn_id, last_txn_id) + 1
|
|
||||||
|
|
||||||
# Insert new txn into txn table
|
|
||||||
event_ids = json.dumps([e.event_id for e in events])
|
|
||||||
txn.execute(
|
|
||||||
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
|
|
||||||
"VALUES(?,?,?)",
|
|
||||||
(service.id, new_txn_id, event_ids)
|
|
||||||
)
|
|
||||||
return AppServiceTransaction(
|
|
||||||
service=service, id=new_txn_id, events=events
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def complete_appservice_txn(self, txn_id, service):
|
def complete_appservice_txn(self, txn_id, service):
|
||||||
@ -263,39 +262,38 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
|
|||||||
A Deferred which resolves if this transaction was stored
|
A Deferred which resolves if this transaction was stored
|
||||||
successfully.
|
successfully.
|
||||||
"""
|
"""
|
||||||
return self.runInteraction(
|
|
||||||
"complete_appservice_txn",
|
|
||||||
self._complete_appservice_txn,
|
|
||||||
txn_id, service
|
|
||||||
)
|
|
||||||
|
|
||||||
def _complete_appservice_txn(self, txn, txn_id, service):
|
|
||||||
txn_id = int(txn_id)
|
txn_id = int(txn_id)
|
||||||
|
|
||||||
# Debugging query: Make sure the txn being completed is EXACTLY +1 from
|
def _complete_appservice_txn(txn):
|
||||||
# what was there before. If it isn't, we've got problems (e.g. the AS
|
# Debugging query: Make sure the txn being completed is EXACTLY +1 from
|
||||||
# has probably missed some events), so whine loudly but still continue,
|
# what was there before. If it isn't, we've got problems (e.g. the AS
|
||||||
# since it shouldn't fail completion of the transaction.
|
# has probably missed some events), so whine loudly but still continue,
|
||||||
last_txn_id = self._get_last_txn(txn, service.id)
|
# since it shouldn't fail completion of the transaction.
|
||||||
if (last_txn_id + 1) != txn_id:
|
last_txn_id = self._get_last_txn(txn, service.id)
|
||||||
logger.error(
|
if (last_txn_id + 1) != txn_id:
|
||||||
"appservice: Completing a transaction which has an ID > 1 from "
|
logger.error(
|
||||||
"the last ID sent to this AS. We've either dropped events or "
|
"appservice: Completing a transaction which has an ID > 1 from "
|
||||||
"sent it to the AS out of order. FIX ME. last_txn=%s "
|
"the last ID sent to this AS. We've either dropped events or "
|
||||||
"completing_txn=%s service_id=%s", last_txn_id, txn_id,
|
"sent it to the AS out of order. FIX ME. last_txn=%s "
|
||||||
service.id
|
"completing_txn=%s service_id=%s", last_txn_id, txn_id,
|
||||||
|
service.id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set current txn_id for AS to 'txn_id'
|
||||||
|
self._simple_upsert_txn(
|
||||||
|
txn, "application_services_state", dict(as_id=service.id),
|
||||||
|
dict(last_txn=txn_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set current txn_id for AS to 'txn_id'
|
# Delete txn
|
||||||
self._simple_upsert_txn(
|
self._simple_delete_txn(
|
||||||
txn, "application_services_state", dict(as_id=service.id),
|
txn, "application_services_txns",
|
||||||
dict(last_txn=txn_id)
|
dict(txn_id=txn_id, as_id=service.id)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Delete txn
|
return self.runInteraction(
|
||||||
self._simple_delete_txn(
|
"complete_appservice_txn",
|
||||||
txn, "application_services_txns",
|
_complete_appservice_txn,
|
||||||
dict(txn_id=txn_id, as_id=service.id)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -309,10 +307,25 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
|
|||||||
A Deferred which resolves to an AppServiceTransaction or
|
A Deferred which resolves to an AppServiceTransaction or
|
||||||
None.
|
None.
|
||||||
"""
|
"""
|
||||||
|
def _get_oldest_unsent_txn(txn):
|
||||||
|
# Monotonically increasing txn ids, so just select the smallest
|
||||||
|
# one in the txns table (we delete them when they are sent)
|
||||||
|
txn.execute(
|
||||||
|
"SELECT * FROM application_services_txns WHERE as_id=?"
|
||||||
|
" ORDER BY txn_id ASC LIMIT 1",
|
||||||
|
(service.id,)
|
||||||
|
)
|
||||||
|
rows = self.cursor_to_dict(txn)
|
||||||
|
if not rows:
|
||||||
|
return None
|
||||||
|
|
||||||
|
entry = rows[0]
|
||||||
|
|
||||||
|
return entry
|
||||||
|
|
||||||
entry = yield self.runInteraction(
|
entry = yield self.runInteraction(
|
||||||
"get_oldest_unsent_appservice_txn",
|
"get_oldest_unsent_appservice_txn",
|
||||||
self._get_oldest_unsent_txn,
|
_get_oldest_unsent_txn,
|
||||||
service
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not entry:
|
if not entry:
|
||||||
@ -326,22 +339,6 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
|
|||||||
service=service, id=entry["txn_id"], events=events
|
service=service, id=entry["txn_id"], events=events
|
||||||
))
|
))
|
||||||
|
|
||||||
def _get_oldest_unsent_txn(self, txn, service):
|
|
||||||
# Monotonically increasing txn ids, so just select the smallest
|
|
||||||
# one in the txns table (we delete them when they are sent)
|
|
||||||
txn.execute(
|
|
||||||
"SELECT * FROM application_services_txns WHERE as_id=?"
|
|
||||||
" ORDER BY txn_id ASC LIMIT 1",
|
|
||||||
(service.id,)
|
|
||||||
)
|
|
||||||
rows = self.cursor_to_dict(txn)
|
|
||||||
if not rows:
|
|
||||||
return None
|
|
||||||
|
|
||||||
entry = rows[0]
|
|
||||||
|
|
||||||
return entry
|
|
||||||
|
|
||||||
def _get_last_txn(self, txn, service_id):
|
def _get_last_txn(self, txn, service_id):
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"SELECT last_txn FROM application_services_state WHERE as_id=?",
|
"SELECT last_txn FROM application_services_state WHERE as_id=?",
|
||||||
@ -352,3 +349,45 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
|
|||||||
return 0
|
return 0
|
||||||
else:
|
else:
|
||||||
return int(last_txn_id[0]) # select 'last_txn' col
|
return int(last_txn_id[0]) # select 'last_txn' col
|
||||||
|
|
||||||
|
def set_appservice_last_pos(self, pos):
|
||||||
|
def set_appservice_last_pos_txn(txn):
|
||||||
|
txn.execute(
|
||||||
|
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
|
||||||
|
)
|
||||||
|
return self.runInteraction(
|
||||||
|
"set_appservice_last_pos", set_appservice_last_pos_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_new_events_for_appservice(self, current_id, limit):
|
||||||
|
"""Get all new evnets"""
|
||||||
|
|
||||||
|
def get_new_events_for_appservice_txn(txn):
|
||||||
|
sql = (
|
||||||
|
"SELECT e.stream_ordering, e.event_id"
|
||||||
|
" FROM events AS e"
|
||||||
|
" WHERE"
|
||||||
|
" (SELECT stream_ordering FROM appservice_stream_position)"
|
||||||
|
" < e.stream_ordering"
|
||||||
|
" AND e.stream_ordering <= ?"
|
||||||
|
" ORDER BY e.stream_ordering ASC"
|
||||||
|
" LIMIT ?"
|
||||||
|
)
|
||||||
|
|
||||||
|
txn.execute(sql, (current_id, limit))
|
||||||
|
rows = txn.fetchall()
|
||||||
|
|
||||||
|
upper_bound = current_id
|
||||||
|
if len(rows) == limit:
|
||||||
|
upper_bound = rows[-1][0]
|
||||||
|
|
||||||
|
return upper_bound, [row[1] for row in rows]
|
||||||
|
|
||||||
|
upper_bound, event_ids = yield self.runInteraction(
|
||||||
|
"get_new_events_for_appservice", get_new_events_for_appservice_txn,
|
||||||
|
)
|
||||||
|
|
||||||
|
events = yield self._get_events(event_ids)
|
||||||
|
|
||||||
|
defer.returnValue((upper_bound, events))
|
||||||
|
@ -82,32 +82,39 @@ class DirectoryStore(SQLBaseStore):
|
|||||||
Returns:
|
Returns:
|
||||||
Deferred
|
Deferred
|
||||||
"""
|
"""
|
||||||
try:
|
def alias_txn(txn):
|
||||||
yield self._simple_insert(
|
self._simple_insert_txn(
|
||||||
|
txn,
|
||||||
"room_aliases",
|
"room_aliases",
|
||||||
{
|
{
|
||||||
"room_alias": room_alias.to_string(),
|
"room_alias": room_alias.to_string(),
|
||||||
"room_id": room_id,
|
"room_id": room_id,
|
||||||
"creator": creator,
|
"creator": creator,
|
||||||
},
|
},
|
||||||
desc="create_room_alias_association",
|
)
|
||||||
|
|
||||||
|
self._simple_insert_many_txn(
|
||||||
|
txn,
|
||||||
|
table="room_alias_servers",
|
||||||
|
values=[{
|
||||||
|
"room_alias": room_alias.to_string(),
|
||||||
|
"server": server,
|
||||||
|
} for server in servers],
|
||||||
|
)
|
||||||
|
|
||||||
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.get_aliases_for_room, (room_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
ret = yield self.runInteraction(
|
||||||
|
"create_room_alias_association", alias_txn
|
||||||
)
|
)
|
||||||
except self.database_engine.module.IntegrityError:
|
except self.database_engine.module.IntegrityError:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
409, "Room alias %s already exists" % room_alias.to_string()
|
409, "Room alias %s already exists" % room_alias.to_string()
|
||||||
)
|
)
|
||||||
|
defer.returnValue(ret)
|
||||||
for server in servers:
|
|
||||||
# TODO(erikj): Fix this to bulk insert
|
|
||||||
yield self._simple_insert(
|
|
||||||
"room_alias_servers",
|
|
||||||
{
|
|
||||||
"room_alias": room_alias.to_string(),
|
|
||||||
"server": server,
|
|
||||||
},
|
|
||||||
desc="create_room_alias_association",
|
|
||||||
)
|
|
||||||
self.get_aliases_for_room.invalidate((room_id,))
|
|
||||||
|
|
||||||
def get_room_alias_creator(self, room_alias):
|
def get_room_alias_creator(self, room_alias):
|
||||||
return self._simple_select_one_onecol(
|
return self._simple_select_one_onecol(
|
||||||
|
@ -56,7 +56,7 @@ class EventPushActionsStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
self._simple_insert_many_txn(txn, "event_push_actions", values)
|
self._simple_insert_many_txn(txn, "event_push_actions", values)
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=3, lru=True, tree=True, max_entries=5000)
|
@cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
|
||||||
def get_unread_event_push_actions_by_room_for_user(
|
def get_unread_event_push_actions_by_room_for_user(
|
||||||
self, room_id, user_id, last_read_event_id
|
self, room_id, user_id, last_read_event_id
|
||||||
):
|
):
|
||||||
@ -337,6 +337,36 @@ class EventPushActionsStore(SQLBaseStore):
|
|||||||
# Now return the first `limit`
|
# Now return the first `limit`
|
||||||
defer.returnValue(notifs[:limit])
|
defer.returnValue(notifs[:limit])
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_push_actions_for_user(self, user_id, before=None, limit=50):
|
||||||
|
def f(txn):
|
||||||
|
before_clause = ""
|
||||||
|
if before:
|
||||||
|
before_clause = "AND stream_ordering < ?"
|
||||||
|
args = [user_id, before, limit]
|
||||||
|
else:
|
||||||
|
args = [user_id, limit]
|
||||||
|
sql = (
|
||||||
|
"SELECT epa.event_id, epa.room_id,"
|
||||||
|
" epa.stream_ordering, epa.topological_ordering,"
|
||||||
|
" epa.actions, epa.profile_tag, e.received_ts"
|
||||||
|
" FROM event_push_actions epa, events e"
|
||||||
|
" WHERE epa.room_id = e.room_id AND epa.event_id = e.event_id"
|
||||||
|
" AND epa.user_id = ? %s"
|
||||||
|
" ORDER BY epa.stream_ordering DESC"
|
||||||
|
" LIMIT ?"
|
||||||
|
% (before_clause,)
|
||||||
|
)
|
||||||
|
txn.execute(sql, args)
|
||||||
|
return self.cursor_to_dict(txn)
|
||||||
|
|
||||||
|
push_actions = yield self.runInteraction(
|
||||||
|
"get_push_actions_for_user", f
|
||||||
|
)
|
||||||
|
for pa in push_actions:
|
||||||
|
pa["actions"] = json.loads(pa["actions"])
|
||||||
|
defer.returnValue(push_actions)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_time_of_last_push_action_before(self, stream_ordering):
|
def get_time_of_last_push_action_before(self, stream_ordering):
|
||||||
def f(txn):
|
def f(txn):
|
||||||
|
@ -20,8 +20,11 @@ from synapse.events import FrozenEvent, USE_FROZEN_DICTS
|
|||||||
from synapse.events.utils import prune_event
|
from synapse.events.utils import prune_event
|
||||||
|
|
||||||
from synapse.util.async import ObservableDeferred
|
from synapse.util.async import ObservableDeferred
|
||||||
from synapse.util.logcontext import preserve_fn, PreserveLoggingContext
|
from synapse.util.logcontext import (
|
||||||
|
preserve_fn, PreserveLoggingContext, preserve_context_over_deferred
|
||||||
|
)
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
|
from synapse.util.metrics import Measure
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
|
|
||||||
@ -201,7 +204,7 @@ class EventsStore(SQLBaseStore):
|
|||||||
|
|
||||||
deferreds = []
|
deferreds = []
|
||||||
for room_id, evs_ctxs in partitioned.items():
|
for room_id, evs_ctxs in partitioned.items():
|
||||||
d = self._event_persist_queue.add_to_queue(
|
d = preserve_fn(self._event_persist_queue.add_to_queue)(
|
||||||
room_id, evs_ctxs,
|
room_id, evs_ctxs,
|
||||||
backfilled=backfilled,
|
backfilled=backfilled,
|
||||||
current_state=None,
|
current_state=None,
|
||||||
@ -211,7 +214,9 @@ class EventsStore(SQLBaseStore):
|
|||||||
for room_id in partitioned.keys():
|
for room_id in partitioned.keys():
|
||||||
self._maybe_start_persisting(room_id)
|
self._maybe_start_persisting(room_id)
|
||||||
|
|
||||||
return defer.gatherResults(deferreds, consumeErrors=True)
|
return preserve_context_over_deferred(
|
||||||
|
defer.gatherResults(deferreds, consumeErrors=True)
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
@ -224,7 +229,7 @@ class EventsStore(SQLBaseStore):
|
|||||||
|
|
||||||
self._maybe_start_persisting(event.room_id)
|
self._maybe_start_persisting(event.room_id)
|
||||||
|
|
||||||
yield deferred
|
yield preserve_context_over_deferred(deferred)
|
||||||
|
|
||||||
max_persisted_id = yield self._stream_id_gen.get_current_token()
|
max_persisted_id = yield self._stream_id_gen.get_current_token()
|
||||||
defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id))
|
defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id))
|
||||||
@ -600,7 +605,8 @@ class EventsStore(SQLBaseStore):
|
|||||||
"rejections",
|
"rejections",
|
||||||
"redactions",
|
"redactions",
|
||||||
"room_memberships",
|
"room_memberships",
|
||||||
"state_events"
|
"state_events",
|
||||||
|
"topics"
|
||||||
):
|
):
|
||||||
txn.executemany(
|
txn.executemany(
|
||||||
"DELETE FROM %s WHERE event_id = ?" % (table,),
|
"DELETE FROM %s WHERE event_id = ?" % (table,),
|
||||||
@ -1086,7 +1092,7 @@ class EventsStore(SQLBaseStore):
|
|||||||
if not allow_rejected:
|
if not allow_rejected:
|
||||||
rows[:] = [r for r in rows if not r["rejects"]]
|
rows[:] = [r for r in rows if not r["rejects"]]
|
||||||
|
|
||||||
res = yield defer.gatherResults(
|
res = yield preserve_context_over_deferred(defer.gatherResults(
|
||||||
[
|
[
|
||||||
preserve_fn(self._get_event_from_row)(
|
preserve_fn(self._get_event_from_row)(
|
||||||
row["internal_metadata"], row["json"], row["redacts"],
|
row["internal_metadata"], row["json"], row["redacts"],
|
||||||
@ -1095,7 +1101,7 @@ class EventsStore(SQLBaseStore):
|
|||||||
for row in rows
|
for row in rows
|
||||||
],
|
],
|
||||||
consumeErrors=True
|
consumeErrors=True
|
||||||
)
|
))
|
||||||
|
|
||||||
defer.returnValue({
|
defer.returnValue({
|
||||||
e.event.event_id: e
|
e.event.event_id: e
|
||||||
@ -1131,54 +1137,55 @@ class EventsStore(SQLBaseStore):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _get_event_from_row(self, internal_metadata, js, redacted,
|
def _get_event_from_row(self, internal_metadata, js, redacted,
|
||||||
rejected_reason=None):
|
rejected_reason=None):
|
||||||
d = json.loads(js)
|
with Measure(self._clock, "_get_event_from_row"):
|
||||||
internal_metadata = json.loads(internal_metadata)
|
d = json.loads(js)
|
||||||
|
internal_metadata = json.loads(internal_metadata)
|
||||||
|
|
||||||
if rejected_reason:
|
if rejected_reason:
|
||||||
rejected_reason = yield self._simple_select_one_onecol(
|
rejected_reason = yield self._simple_select_one_onecol(
|
||||||
table="rejections",
|
table="rejections",
|
||||||
keyvalues={"event_id": rejected_reason},
|
keyvalues={"event_id": rejected_reason},
|
||||||
retcol="reason",
|
retcol="reason",
|
||||||
desc="_get_event_from_row_rejected_reason",
|
desc="_get_event_from_row_rejected_reason",
|
||||||
|
)
|
||||||
|
|
||||||
|
original_ev = FrozenEvent(
|
||||||
|
d,
|
||||||
|
internal_metadata_dict=internal_metadata,
|
||||||
|
rejected_reason=rejected_reason,
|
||||||
)
|
)
|
||||||
|
|
||||||
original_ev = FrozenEvent(
|
redacted_event = None
|
||||||
d,
|
if redacted:
|
||||||
internal_metadata_dict=internal_metadata,
|
redacted_event = prune_event(original_ev)
|
||||||
rejected_reason=rejected_reason,
|
|
||||||
)
|
|
||||||
|
|
||||||
redacted_event = None
|
redaction_id = yield self._simple_select_one_onecol(
|
||||||
if redacted:
|
table="redactions",
|
||||||
redacted_event = prune_event(original_ev)
|
keyvalues={"redacts": redacted_event.event_id},
|
||||||
|
retcol="event_id",
|
||||||
|
desc="_get_event_from_row_redactions",
|
||||||
|
)
|
||||||
|
|
||||||
redaction_id = yield self._simple_select_one_onecol(
|
redacted_event.unsigned["redacted_by"] = redaction_id
|
||||||
table="redactions",
|
# Get the redaction event.
|
||||||
keyvalues={"redacts": redacted_event.event_id},
|
|
||||||
retcol="event_id",
|
because = yield self.get_event(
|
||||||
desc="_get_event_from_row_redactions",
|
redaction_id,
|
||||||
|
check_redacted=False,
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if because:
|
||||||
|
# It's fine to do add the event directly, since get_pdu_json
|
||||||
|
# will serialise this field correctly
|
||||||
|
redacted_event.unsigned["redacted_because"] = because
|
||||||
|
|
||||||
|
cache_entry = _EventCacheEntry(
|
||||||
|
event=original_ev,
|
||||||
|
redacted_event=redacted_event,
|
||||||
)
|
)
|
||||||
|
|
||||||
redacted_event.unsigned["redacted_by"] = redaction_id
|
self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
|
||||||
# Get the redaction event.
|
|
||||||
|
|
||||||
because = yield self.get_event(
|
|
||||||
redaction_id,
|
|
||||||
check_redacted=False,
|
|
||||||
allow_none=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if because:
|
|
||||||
# It's fine to do add the event directly, since get_pdu_json
|
|
||||||
# will serialise this field correctly
|
|
||||||
redacted_event.unsigned["redacted_because"] = because
|
|
||||||
|
|
||||||
cache_entry = _EventCacheEntry(
|
|
||||||
event=original_ev,
|
|
||||||
redacted_event=redacted_event,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
|
|
||||||
|
|
||||||
defer.returnValue(cache_entry)
|
defer.returnValue(cache_entry)
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# Remember to update this number every time a change is made to database
|
# Remember to update this number every time a change is made to database
|
||||||
# schema files, so the users will be informed on server restarts.
|
# schema files, so the users will be informed on server restarts.
|
||||||
SCHEMA_VERSION = 33
|
SCHEMA_VERSION = 34
|
||||||
|
|
||||||
dir_path = os.path.abspath(os.path.dirname(__file__))
|
dir_path = os.path.abspath(os.path.dirname(__file__))
|
||||||
|
|
||||||
|
@ -189,18 +189,30 @@ class PresenceStore(SQLBaseStore):
|
|||||||
desc="add_presence_list_pending",
|
desc="add_presence_list_pending",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def set_presence_list_accepted(self, observer_localpart, observed_userid):
|
def set_presence_list_accepted(self, observer_localpart, observed_userid):
|
||||||
result = yield self._simple_update_one(
|
def update_presence_list_txn(txn):
|
||||||
table="presence_list",
|
result = self._simple_update_one_txn(
|
||||||
keyvalues={"user_id": observer_localpart,
|
txn,
|
||||||
"observed_user_id": observed_userid},
|
table="presence_list",
|
||||||
updatevalues={"accepted": True},
|
keyvalues={
|
||||||
desc="set_presence_list_accepted",
|
"user_id": observer_localpart,
|
||||||
|
"observed_user_id": observed_userid
|
||||||
|
},
|
||||||
|
updatevalues={"accepted": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.get_presence_list_accepted, (observer_localpart,)
|
||||||
|
)
|
||||||
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.get_presence_list_observers_accepted, (observed_userid,)
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
return self.runInteraction(
|
||||||
|
"set_presence_list_accepted", update_presence_list_txn,
|
||||||
)
|
)
|
||||||
self.get_presence_list_accepted.invalidate((observer_localpart,))
|
|
||||||
self.get_presence_list_observers_accepted.invalidate((observed_userid,))
|
|
||||||
defer.returnValue(result)
|
|
||||||
|
|
||||||
def get_presence_list(self, observer_localpart, accepted=None):
|
def get_presence_list(self, observer_localpart, accepted=None):
|
||||||
if accepted:
|
if accepted:
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
|
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
|
||||||
from synapse.push.baserules import list_with_base_rules
|
from synapse.push.baserules import list_with_base_rules
|
||||||
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@ -48,7 +49,7 @@ def _load_rules(rawrules, enabled_map):
|
|||||||
|
|
||||||
|
|
||||||
class PushRuleStore(SQLBaseStore):
|
class PushRuleStore(SQLBaseStore):
|
||||||
@cachedInlineCallbacks(lru=True)
|
@cachedInlineCallbacks()
|
||||||
def get_push_rules_for_user(self, user_id):
|
def get_push_rules_for_user(self, user_id):
|
||||||
rows = yield self._simple_select_list(
|
rows = yield self._simple_select_list(
|
||||||
table="push_rules",
|
table="push_rules",
|
||||||
@ -72,7 +73,7 @@ class PushRuleStore(SQLBaseStore):
|
|||||||
|
|
||||||
defer.returnValue(rules)
|
defer.returnValue(rules)
|
||||||
|
|
||||||
@cachedInlineCallbacks(lru=True)
|
@cachedInlineCallbacks()
|
||||||
def get_push_rules_enabled_for_user(self, user_id):
|
def get_push_rules_enabled_for_user(self, user_id):
|
||||||
results = yield self._simple_select_list(
|
results = yield self._simple_select_list(
|
||||||
table="push_rules_enable",
|
table="push_rules_enable",
|
||||||
@ -123,6 +124,61 @@ class PushRuleStore(SQLBaseStore):
|
|||||||
|
|
||||||
defer.returnValue(results)
|
defer.returnValue(results)
|
||||||
|
|
||||||
|
def bulk_get_push_rules_for_room(self, room_id, state_group, current_state):
|
||||||
|
if not state_group:
|
||||||
|
# If state_group is None it means it has yet to be assigned a
|
||||||
|
# state group, i.e. we need to make sure that calls with a state_group
|
||||||
|
# of None don't hit previous cached calls with a None state_group.
|
||||||
|
# To do this we set the state_group to a new object as object() != object()
|
||||||
|
state_group = object()
|
||||||
|
|
||||||
|
return self._bulk_get_push_rules_for_room(room_id, state_group, current_state)
|
||||||
|
|
||||||
|
@cachedInlineCallbacks(num_args=2, cache_context=True)
|
||||||
|
def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state,
|
||||||
|
cache_context):
|
||||||
|
# We don't use `state_group`, its there so that we can cache based
|
||||||
|
# on it. However, its important that its never None, since two current_state's
|
||||||
|
# with a state_group of None are likely to be different.
|
||||||
|
# See bulk_get_push_rules_for_room for how we work around this.
|
||||||
|
assert state_group is not None
|
||||||
|
|
||||||
|
# We also will want to generate notifs for other people in the room so
|
||||||
|
# their unread countss are correct in the event stream, but to avoid
|
||||||
|
# generating them for bot / AS users etc, we only do so for people who've
|
||||||
|
# sent a read receipt into the room.
|
||||||
|
local_users_in_room = set(
|
||||||
|
e.state_key for e in current_state.values()
|
||||||
|
if e.type == EventTypes.Member and e.membership == Membership.JOIN
|
||||||
|
and self.hs.is_mine_id(e.state_key)
|
||||||
|
)
|
||||||
|
|
||||||
|
# users in the room who have pushers need to get push rules run because
|
||||||
|
# that's how their pushers work
|
||||||
|
if_users_with_pushers = yield self.get_if_users_have_pushers(
|
||||||
|
local_users_in_room, on_invalidate=cache_context.invalidate,
|
||||||
|
)
|
||||||
|
user_ids = set(
|
||||||
|
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
|
||||||
|
)
|
||||||
|
|
||||||
|
users_with_receipts = yield self.get_users_with_read_receipts_in_room(
|
||||||
|
room_id, on_invalidate=cache_context.invalidate,
|
||||||
|
)
|
||||||
|
|
||||||
|
# any users with pushers must be ours: they have pushers
|
||||||
|
for uid in users_with_receipts:
|
||||||
|
if uid in local_users_in_room:
|
||||||
|
user_ids.add(uid)
|
||||||
|
|
||||||
|
rules_by_user = yield self.bulk_get_push_rules(
|
||||||
|
user_ids, on_invalidate=cache_context.invalidate,
|
||||||
|
)
|
||||||
|
|
||||||
|
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
|
||||||
|
|
||||||
|
defer.returnValue(rules_by_user)
|
||||||
|
|
||||||
@cachedList(cached_method_name="get_push_rules_enabled_for_user",
|
@cachedList(cached_method_name="get_push_rules_enabled_for_user",
|
||||||
list_name="user_ids", num_args=1, inlineCallbacks=True)
|
list_name="user_ids", num_args=1, inlineCallbacks=True)
|
||||||
def bulk_get_push_rules_enabled(self, user_ids):
|
def bulk_get_push_rules_enabled(self, user_ids):
|
||||||
|
@ -135,7 +135,7 @@ class PusherStore(SQLBaseStore):
|
|||||||
"get_all_updated_pushers", get_all_updated_pushers_txn
|
"get_all_updated_pushers", get_all_updated_pushers_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
@cachedInlineCallbacks(lru=True, num_args=1, max_entries=15000)
|
@cachedInlineCallbacks(num_args=1, max_entries=15000)
|
||||||
def get_if_user_has_pusher(self, user_id):
|
def get_if_user_has_pusher(self, user_id):
|
||||||
result = yield self._simple_select_many_batch(
|
result = yield self._simple_select_many_batch(
|
||||||
table='pushers',
|
table='pushers',
|
||||||
|
@ -94,6 +94,31 @@ class ReceiptsStore(SQLBaseStore):
|
|||||||
|
|
||||||
defer.returnValue({row["room_id"]: row["event_id"] for row in rows})
|
defer.returnValue({row["room_id"]: row["event_id"] for row in rows})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
|
||||||
|
def f(txn):
|
||||||
|
sql = (
|
||||||
|
"SELECT rl.room_id, rl.event_id,"
|
||||||
|
" e.topological_ordering, e.stream_ordering"
|
||||||
|
" FROM receipts_linearized AS rl"
|
||||||
|
" INNER JOIN events AS e USING (room_id, event_id)"
|
||||||
|
" WHERE rl.room_id = e.room_id"
|
||||||
|
" AND rl.event_id = e.event_id"
|
||||||
|
" AND user_id = ?"
|
||||||
|
)
|
||||||
|
txn.execute(sql, (user_id,))
|
||||||
|
return txn.fetchall()
|
||||||
|
rows = yield self.runInteraction(
|
||||||
|
"get_receipts_for_user_with_orderings", f
|
||||||
|
)
|
||||||
|
defer.returnValue({
|
||||||
|
row[0]: {
|
||||||
|
"event_id": row[1],
|
||||||
|
"topological_ordering": row[2],
|
||||||
|
"stream_ordering": row[3],
|
||||||
|
} for row in rows
|
||||||
|
})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
|
def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
|
||||||
"""Get receipts for multiple rooms for sending to clients.
|
"""Get receipts for multiple rooms for sending to clients.
|
||||||
@ -120,7 +145,7 @@ class ReceiptsStore(SQLBaseStore):
|
|||||||
|
|
||||||
defer.returnValue([ev for res in results.values() for ev in res])
|
defer.returnValue([ev for res in results.values() for ev in res])
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=3, max_entries=5000, lru=True, tree=True)
|
@cachedInlineCallbacks(num_args=3, max_entries=5000, tree=True)
|
||||||
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
|
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
|
||||||
"""Get receipts for a single room for sending to clients.
|
"""Get receipts for a single room for sending to clients.
|
||||||
|
|
||||||
|
@ -93,7 +93,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||||||
desc="add_refresh_token_to_user",
|
desc="add_refresh_token_to_user",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def register(self, user_id, token=None, password_hash=None,
|
def register(self, user_id, token=None, password_hash=None,
|
||||||
was_guest=False, make_guest=False, appservice_id=None,
|
was_guest=False, make_guest=False, appservice_id=None,
|
||||||
create_profile_with_localpart=None, admin=False):
|
create_profile_with_localpart=None, admin=False):
|
||||||
@ -115,7 +114,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||||||
Raises:
|
Raises:
|
||||||
StoreError if the user_id could not be registered.
|
StoreError if the user_id could not be registered.
|
||||||
"""
|
"""
|
||||||
yield self.runInteraction(
|
return self.runInteraction(
|
||||||
"register",
|
"register",
|
||||||
self._register,
|
self._register,
|
||||||
user_id,
|
user_id,
|
||||||
@ -127,8 +126,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||||||
create_profile_with_localpart,
|
create_profile_with_localpart,
|
||||||
admin
|
admin
|
||||||
)
|
)
|
||||||
self.get_user_by_id.invalidate((user_id,))
|
|
||||||
self.is_guest.invalidate((user_id,))
|
|
||||||
|
|
||||||
def _register(
|
def _register(
|
||||||
self,
|
self,
|
||||||
@ -210,6 +207,11 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||||||
(create_profile_with_localpart,)
|
(create_profile_with_localpart,)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.get_user_by_id, (user_id,)
|
||||||
|
)
|
||||||
|
txn.call_after(self.is_guest.invalidate, (user_id,))
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
def get_user_by_id(self, user_id):
|
def get_user_by_id(self, user_id):
|
||||||
return self._simple_select_one(
|
return self._simple_select_one(
|
||||||
@ -236,22 +238,31 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||||||
|
|
||||||
return self.runInteraction("get_users_by_id_case_insensitive", f)
|
return self.runInteraction("get_users_by_id_case_insensitive", f)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def user_set_password_hash(self, user_id, password_hash):
|
def user_set_password_hash(self, user_id, password_hash):
|
||||||
"""
|
"""
|
||||||
NB. This does *not* evict any cache because the one use for this
|
NB. This does *not* evict any cache because the one use for this
|
||||||
removes most of the entries subsequently anyway so it would be
|
removes most of the entries subsequently anyway so it would be
|
||||||
pointless. Use flush_user separately.
|
pointless. Use flush_user separately.
|
||||||
"""
|
"""
|
||||||
yield self._simple_update_one('users', {
|
def user_set_password_hash_txn(txn):
|
||||||
'name': user_id
|
self._simple_update_one_txn(
|
||||||
}, {
|
txn,
|
||||||
'password_hash': password_hash
|
'users', {
|
||||||
})
|
'name': user_id
|
||||||
self.get_user_by_id.invalidate((user_id,))
|
},
|
||||||
|
{
|
||||||
|
'password_hash': password_hash
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.get_user_by_id, (user_id,)
|
||||||
|
)
|
||||||
|
return self.runInteraction(
|
||||||
|
"user_set_password_hash", user_set_password_hash_txn
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def user_delete_access_tokens(self, user_id, except_token_ids=[],
|
def user_delete_access_tokens(self, user_id, except_token_id=None,
|
||||||
device_id=None,
|
device_id=None,
|
||||||
delete_refresh_tokens=False):
|
delete_refresh_tokens=False):
|
||||||
"""
|
"""
|
||||||
@ -259,7 +270,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): ID of user the tokens belong to
|
user_id (str): ID of user the tokens belong to
|
||||||
except_token_ids (list[str]): list of access_tokens which should
|
except_token_id (str): list of access_tokens IDs which should
|
||||||
*not* be deleted
|
*not* be deleted
|
||||||
device_id (str|None): ID of device the tokens are associated with.
|
device_id (str|None): ID of device the tokens are associated with.
|
||||||
If None, tokens associated with any device (or no device) will
|
If None, tokens associated with any device (or no device) will
|
||||||
@ -269,53 +280,45 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||||||
Returns:
|
Returns:
|
||||||
defer.Deferred:
|
defer.Deferred:
|
||||||
"""
|
"""
|
||||||
def f(txn, table, except_tokens, call_after_delete):
|
def f(txn):
|
||||||
sql = "SELECT token FROM %s WHERE user_id = ?" % table
|
keyvalues = {
|
||||||
clauses = [user_id]
|
"user_id": user_id,
|
||||||
|
}
|
||||||
if device_id is not None:
|
if device_id is not None:
|
||||||
sql += " AND device_id = ?"
|
keyvalues["device_id"] = device_id
|
||||||
clauses.append(device_id)
|
|
||||||
|
|
||||||
if except_tokens:
|
if delete_refresh_tokens:
|
||||||
sql += " AND id NOT IN (%s)" % (
|
self._simple_delete_txn(
|
||||||
",".join(["?" for _ in except_tokens]),
|
txn,
|
||||||
)
|
table="refresh_tokens",
|
||||||
clauses += except_tokens
|
keyvalues=keyvalues,
|
||||||
|
|
||||||
txn.execute(sql, clauses)
|
|
||||||
|
|
||||||
rows = txn.fetchall()
|
|
||||||
|
|
||||||
n = 100
|
|
||||||
chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)]
|
|
||||||
for chunk in chunks:
|
|
||||||
if call_after_delete:
|
|
||||||
for row in chunk:
|
|
||||||
txn.call_after(call_after_delete, (row[0],))
|
|
||||||
|
|
||||||
txn.execute(
|
|
||||||
"DELETE FROM %s WHERE token in (%s)" % (
|
|
||||||
table,
|
|
||||||
",".join(["?" for _ in chunk]),
|
|
||||||
), [r[0] for r in chunk]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# delete refresh tokens first, to stop new access tokens being
|
items = keyvalues.items()
|
||||||
# allocated while our backs are turned
|
where_clause = " AND ".join(k + " = ?" for k, _ in items)
|
||||||
if delete_refresh_tokens:
|
values = [v for _, v in items]
|
||||||
yield self.runInteraction(
|
if except_token_id:
|
||||||
"user_delete_access_tokens", f,
|
where_clause += " AND id != ?"
|
||||||
table="refresh_tokens",
|
values.append(except_token_id)
|
||||||
except_tokens=[],
|
|
||||||
call_after_delete=None,
|
txn.execute(
|
||||||
|
"SELECT token FROM access_tokens WHERE %s" % where_clause,
|
||||||
|
values
|
||||||
|
)
|
||||||
|
rows = self.cursor_to_dict(txn)
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.get_user_by_access_token, (row["token"],)
|
||||||
|
)
|
||||||
|
|
||||||
|
txn.execute(
|
||||||
|
"DELETE FROM access_tokens WHERE %s" % where_clause,
|
||||||
|
values
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.runInteraction(
|
yield self.runInteraction(
|
||||||
"user_delete_access_tokens", f,
|
"user_delete_access_tokens", f,
|
||||||
table="access_tokens",
|
|
||||||
except_tokens=except_token_ids,
|
|
||||||
call_after_delete=self.get_user_by_access_token.invalidate,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete_access_token(self, access_token):
|
def delete_access_token(self, access_token):
|
||||||
@ -328,7 +331,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
txn.call_after(self.get_user_by_access_token.invalidate, (access_token,))
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.get_user_by_access_token, (access_token,)
|
||||||
|
)
|
||||||
|
|
||||||
return self.runInteraction("delete_access_token", f)
|
return self.runInteraction("delete_access_token", f)
|
||||||
|
|
||||||
|
@ -277,7 +277,6 @@ class RoomMemberStore(SQLBaseStore):
|
|||||||
user_id, membership_list=[Membership.JOIN],
|
user_id, membership_list=[Membership.JOIN],
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def forget(self, user_id, room_id):
|
def forget(self, user_id, room_id):
|
||||||
"""Indicate that user_id wishes to discard history for room_id."""
|
"""Indicate that user_id wishes to discard history for room_id."""
|
||||||
def f(txn):
|
def f(txn):
|
||||||
@ -292,10 +291,13 @@ class RoomMemberStore(SQLBaseStore):
|
|||||||
" room_id = ?"
|
" room_id = ?"
|
||||||
)
|
)
|
||||||
txn.execute(sql, (user_id, room_id))
|
txn.execute(sql, (user_id, room_id))
|
||||||
yield self.runInteraction("forget_membership", f)
|
|
||||||
self.was_forgotten_at.invalidate_all()
|
txn.call_after(self.was_forgotten_at.invalidate_all)
|
||||||
self.who_forgot_in_room.invalidate_all()
|
txn.call_after(self.did_forget.invalidate, (user_id, room_id))
|
||||||
self.did_forget.invalidate((user_id, room_id))
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.who_forgot_in_room, (room_id,)
|
||||||
|
)
|
||||||
|
return self.runInteraction("forget_membership", f)
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=2)
|
@cachedInlineCallbacks(num_args=2)
|
||||||
def did_forget(self, user_id, room_id):
|
def did_forget(self, user_id, room_id):
|
||||||
|
23
synapse/storage/schema/delta/34/appservice_stream.sql
Normal file
23
synapse/storage/schema/delta/34/appservice_stream.sql
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
/* Copyright 2016 OpenMarket Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS appservice_stream_position(
|
||||||
|
Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
|
||||||
|
stream_ordering BIGINT,
|
||||||
|
CHECK (Lock='X')
|
||||||
|
);
|
||||||
|
|
||||||
|
INSERT INTO appservice_stream_position (stream_ordering)
|
||||||
|
SELECT COALESCE(MAX(stream_ordering), 0) FROM events;
|
46
synapse/storage/schema/delta/34/cache_stream.py
Normal file
46
synapse/storage/schema/delta/34/cache_stream.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from synapse.storage.prepare_database import get_statements
|
||||||
|
from synapse.storage.engines import PostgresEngine
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# This stream is used to notify replication slaves that some caches have
|
||||||
|
# been invalidated that they cannot infer from the other streams.
|
||||||
|
CREATE_TABLE = """
|
||||||
|
CREATE TABLE cache_invalidation_stream (
|
||||||
|
stream_id BIGINT,
|
||||||
|
cache_func TEXT,
|
||||||
|
keys TEXT[],
|
||||||
|
invalidation_ts BIGINT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX cache_invalidation_stream_id ON cache_invalidation_stream(stream_id);
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def run_create(cur, database_engine, *args, **kwargs):
|
||||||
|
if not isinstance(database_engine, PostgresEngine):
|
||||||
|
return
|
||||||
|
|
||||||
|
for statement in get_statements(CREATE_TABLE.splitlines()):
|
||||||
|
cur.execute(statement)
|
||||||
|
|
||||||
|
|
||||||
|
def run_upgrade(cur, database_engine, *args, **kwargs):
|
||||||
|
pass
|
20
synapse/storage/schema/delta/34/push_display_name_rename.sql
Normal file
20
synapse/storage/schema/delta/34/push_display_name_rename.sql
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
/* Copyright 2016 OpenMarket Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
DELETE FROM push_rules WHERE rule_id = 'global/override/.m.rule.contains_display_name';
|
||||||
|
UPDATE push_rules SET rule_id = 'global/override/.m.rule.contains_display_name' WHERE rule_id = 'global/underride/.m.rule.contains_display_name';
|
||||||
|
|
||||||
|
DELETE FROM push_rules_enable WHERE rule_id = 'global/override/.m.rule.contains_display_name';
|
||||||
|
UPDATE push_rules_enable SET rule_id = 'global/override/.m.rule.contains_display_name' WHERE rule_id = 'global/underride/.m.rule.contains_display_name';
|
32
synapse/storage/schema/delta/34/received_txn_purge.py
Normal file
32
synapse/storage/schema/delta/34/received_txn_purge.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from synapse.storage.engines import PostgresEngine
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def run_create(cur, database_engine, *args, **kwargs):
|
||||||
|
if isinstance(database_engine, PostgresEngine):
|
||||||
|
cur.execute("TRUNCATE received_transactions")
|
||||||
|
else:
|
||||||
|
cur.execute("DELETE FROM received_transactions")
|
||||||
|
|
||||||
|
cur.execute("CREATE INDEX received_transactions_ts ON received_transactions(ts)")
|
||||||
|
|
||||||
|
|
||||||
|
def run_upgrade(cur, database_engine, *args, **kwargs):
|
||||||
|
pass
|
@ -25,7 +25,7 @@ from synapse.util.caches.descriptors import cached, cachedList
|
|||||||
class SignatureStore(SQLBaseStore):
|
class SignatureStore(SQLBaseStore):
|
||||||
"""Persistence for event signatures and hashes"""
|
"""Persistence for event signatures and hashes"""
|
||||||
|
|
||||||
@cached(lru=True)
|
@cached()
|
||||||
def get_event_reference_hash(self, event_id):
|
def get_event_reference_hash(self, event_id):
|
||||||
return self._get_event_reference_hashes_txn(event_id)
|
return self._get_event_reference_hashes_txn(event_id)
|
||||||
|
|
||||||
|
@ -174,7 +174,7 @@ class StateStore(SQLBaseStore):
|
|||||||
return [r[0] for r in results]
|
return [r[0] for r in results]
|
||||||
return self.runInteraction("get_current_state_for_key", f)
|
return self.runInteraction("get_current_state_for_key", f)
|
||||||
|
|
||||||
@cached(num_args=2, lru=True, max_entries=1000)
|
@cached(num_args=2, max_entries=1000)
|
||||||
def _get_state_group_from_group(self, group, types):
|
def _get_state_group_from_group(self, group, types):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@ -272,7 +272,7 @@ class StateStore(SQLBaseStore):
|
|||||||
state_map = yield self.get_state_for_events([event_id], types)
|
state_map = yield self.get_state_for_events([event_id], types)
|
||||||
defer.returnValue(state_map[event_id])
|
defer.returnValue(state_map[event_id])
|
||||||
|
|
||||||
@cached(num_args=2, lru=True, max_entries=10000)
|
@cached(num_args=2, max_entries=10000)
|
||||||
def _get_state_group_for_event(self, room_id, event_id):
|
def _get_state_group_for_event(self, room_id, event_id):
|
||||||
return self._simple_select_one_onecol(
|
return self._simple_select_one_onecol(
|
||||||
table="event_to_state_groups",
|
table="event_to_state_groups",
|
||||||
|
@ -39,7 +39,7 @@ from ._base import SQLBaseStore
|
|||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.types import RoomStreamToken
|
from synapse.types import RoomStreamToken
|
||||||
from synapse.util.logcontext import preserve_fn
|
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@ -234,12 +234,12 @@ class StreamStore(SQLBaseStore):
|
|||||||
results = {}
|
results = {}
|
||||||
room_ids = list(room_ids)
|
room_ids = list(room_ids)
|
||||||
for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
|
for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
|
||||||
res = yield defer.gatherResults([
|
res = yield preserve_context_over_deferred(defer.gatherResults([
|
||||||
preserve_fn(self.get_room_events_stream_for_room)(
|
preserve_fn(self.get_room_events_stream_for_room)(
|
||||||
room_id, from_key, to_key, limit, order=order,
|
room_id, from_key, to_key, limit, order=order,
|
||||||
)
|
)
|
||||||
for room_id in rm_ids
|
for room_id in rm_ids
|
||||||
])
|
]))
|
||||||
results.update(dict(zip(rm_ids, res)))
|
results.update(dict(zip(rm_ids, res)))
|
||||||
|
|
||||||
defer.returnValue(results)
|
defer.returnValue(results)
|
||||||
|
@ -62,10 +62,9 @@ class TransactionStore(SQLBaseStore):
|
|||||||
self.last_transaction = {}
|
self.last_transaction = {}
|
||||||
|
|
||||||
reactor.addSystemEventTrigger("before", "shutdown", self._persist_in_mem_txns)
|
reactor.addSystemEventTrigger("before", "shutdown", self._persist_in_mem_txns)
|
||||||
hs.get_clock().looping_call(
|
self._clock.looping_call(self._persist_in_mem_txns, 1000)
|
||||||
self._persist_in_mem_txns,
|
|
||||||
1000,
|
self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000)
|
||||||
)
|
|
||||||
|
|
||||||
def get_received_txn_response(self, transaction_id, origin):
|
def get_received_txn_response(self, transaction_id, origin):
|
||||||
"""For an incoming transaction from a given origin, check if we have
|
"""For an incoming transaction from a given origin, check if we have
|
||||||
@ -127,6 +126,7 @@ class TransactionStore(SQLBaseStore):
|
|||||||
"origin": origin,
|
"origin": origin,
|
||||||
"response_code": code,
|
"response_code": code,
|
||||||
"response_json": buffer(encode_canonical_json(response_dict)),
|
"response_json": buffer(encode_canonical_json(response_dict)),
|
||||||
|
"ts": self._clock.time_msec(),
|
||||||
},
|
},
|
||||||
or_ignore=True,
|
or_ignore=True,
|
||||||
desc="set_received_txn_response",
|
desc="set_received_txn_response",
|
||||||
@ -383,3 +383,12 @@ class TransactionStore(SQLBaseStore):
|
|||||||
yield self.runInteraction("_persist_in_mem_txns", f)
|
yield self.runInteraction("_persist_in_mem_txns", f)
|
||||||
except:
|
except:
|
||||||
logger.exception("Failed to persist transactions!")
|
logger.exception("Failed to persist transactions!")
|
||||||
|
|
||||||
|
def _cleanup_transactions(self):
|
||||||
|
now = self._clock.time_msec()
|
||||||
|
month_ago = now - 30 * 24 * 60 * 60 * 1000
|
||||||
|
|
||||||
|
def _cleanup_transactions_txn(txn):
|
||||||
|
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
|
||||||
|
|
||||||
|
return self.runInteraction("_persist_in_mem_txns", _cleanup_transactions_txn)
|
||||||
|
@ -269,3 +269,10 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
|
|||||||
return "t%d-%d" % (self.topological, self.stream)
|
return "t%d-%d" % (self.topological, self.stream)
|
||||||
else:
|
else:
|
||||||
return "s%d" % (self.stream,)
|
return "s%d" % (self.stream,)
|
||||||
|
|
||||||
|
|
||||||
|
# Some arbitrary constants used for internal API enumerations. Don't rely on
|
||||||
|
# exact values; always pass or compare symbolically
|
||||||
|
class ThirdPartyEntityKind(object):
|
||||||
|
USER = 'user'
|
||||||
|
LOCATION = 'location'
|
||||||
|
@ -146,10 +146,10 @@ def concurrently_execute(func, args, limit):
|
|||||||
except StopIteration:
|
except StopIteration:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return defer.gatherResults([
|
return preserve_context_over_deferred(defer.gatherResults([
|
||||||
preserve_fn(_concurrently_execute_inner)()
|
preserve_fn(_concurrently_execute_inner)()
|
||||||
for _ in xrange(limit)
|
for _ in xrange(limit)
|
||||||
], consumeErrors=True).addErrback(unwrapFirstError)
|
], consumeErrors=True)).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
|
|
||||||
class Linearizer(object):
|
class Linearizer(object):
|
||||||
@ -181,7 +181,8 @@ class Linearizer(object):
|
|||||||
self.key_to_defer[key] = new_defer
|
self.key_to_defer[key] = new_defer
|
||||||
|
|
||||||
if current_defer:
|
if current_defer:
|
||||||
yield preserve_context_over_deferred(current_defer)
|
with PreserveLoggingContext():
|
||||||
|
yield current_defer
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _ctx_manager():
|
def _ctx_manager():
|
||||||
@ -264,7 +265,7 @@ class ReadWriteLock(object):
|
|||||||
curr_readers.clear()
|
curr_readers.clear()
|
||||||
self.key_to_current_writer[key] = new_defer
|
self.key_to_current_writer[key] = new_defer
|
||||||
|
|
||||||
yield defer.gatherResults(to_wait_on)
|
yield preserve_context_over_deferred(defer.gatherResults(to_wait_on))
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _ctx_manager():
|
def _ctx_manager():
|
||||||
|
@ -25,8 +25,7 @@ from synapse.util.logcontext import (
|
|||||||
from . import DEBUG_CACHES, register_cache
|
from . import DEBUG_CACHES, register_cache
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
from collections import namedtuple
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import functools
|
import functools
|
||||||
@ -54,16 +53,11 @@ class Cache(object):
|
|||||||
"metrics",
|
"metrics",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False):
|
def __init__(self, name, max_entries=1000, keylen=1, tree=False):
|
||||||
if lru:
|
cache_type = TreeCache if tree else dict
|
||||||
cache_type = TreeCache if tree else dict
|
self.cache = LruCache(
|
||||||
self.cache = LruCache(
|
max_size=max_entries, keylen=keylen, cache_type=cache_type
|
||||||
max_size=max_entries, keylen=keylen, cache_type=cache_type
|
)
|
||||||
)
|
|
||||||
self.max_entries = None
|
|
||||||
else:
|
|
||||||
self.cache = OrderedDict()
|
|
||||||
self.max_entries = max_entries
|
|
||||||
|
|
||||||
self.name = name
|
self.name = name
|
||||||
self.keylen = keylen
|
self.keylen = keylen
|
||||||
@ -81,8 +75,8 @@ class Cache(object):
|
|||||||
"Cache objects can only be accessed from the main thread"
|
"Cache objects can only be accessed from the main thread"
|
||||||
)
|
)
|
||||||
|
|
||||||
def get(self, key, default=_CacheSentinel):
|
def get(self, key, default=_CacheSentinel, callback=None):
|
||||||
val = self.cache.get(key, _CacheSentinel)
|
val = self.cache.get(key, _CacheSentinel, callback=callback)
|
||||||
if val is not _CacheSentinel:
|
if val is not _CacheSentinel:
|
||||||
self.metrics.inc_hits()
|
self.metrics.inc_hits()
|
||||||
return val
|
return val
|
||||||
@ -94,19 +88,15 @@ class Cache(object):
|
|||||||
else:
|
else:
|
||||||
return default
|
return default
|
||||||
|
|
||||||
def update(self, sequence, key, value):
|
def update(self, sequence, key, value, callback=None):
|
||||||
self.check_thread()
|
self.check_thread()
|
||||||
if self.sequence == sequence:
|
if self.sequence == sequence:
|
||||||
# Only update the cache if the caches sequence number matches the
|
# Only update the cache if the caches sequence number matches the
|
||||||
# number that the cache had before the SELECT was started (SYN-369)
|
# number that the cache had before the SELECT was started (SYN-369)
|
||||||
self.prefill(key, value)
|
self.prefill(key, value, callback=callback)
|
||||||
|
|
||||||
def prefill(self, key, value):
|
def prefill(self, key, value, callback=None):
|
||||||
if self.max_entries is not None:
|
self.cache.set(key, value, callback=callback)
|
||||||
while len(self.cache) >= self.max_entries:
|
|
||||||
self.cache.popitem(last=False)
|
|
||||||
|
|
||||||
self.cache[key] = value
|
|
||||||
|
|
||||||
def invalidate(self, key):
|
def invalidate(self, key):
|
||||||
self.check_thread()
|
self.check_thread()
|
||||||
@ -151,9 +141,21 @@ class CacheDescriptor(object):
|
|||||||
The wrapped function has another additional callable, called "prefill",
|
The wrapped function has another additional callable, called "prefill",
|
||||||
which can be used to insert values into the cache specifically, without
|
which can be used to insert values into the cache specifically, without
|
||||||
calling the calculation function.
|
calling the calculation function.
|
||||||
|
|
||||||
|
Cached functions can be "chained" (i.e. a cached function can call other cached
|
||||||
|
functions and get appropriately invalidated when they called caches are
|
||||||
|
invalidated) by adding a special "cache_context" argument to the function
|
||||||
|
and passing that as a kwarg to all caches called. For example::
|
||||||
|
|
||||||
|
@cachedInlineCallbacks(cache_context=True)
|
||||||
|
def foo(self, key, cache_context):
|
||||||
|
r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate)
|
||||||
|
r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
|
||||||
|
defer.returnValue(r1 + r2)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False,
|
def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
|
||||||
inlineCallbacks=False):
|
inlineCallbacks=False, cache_context=False):
|
||||||
max_entries = int(max_entries * CACHE_SIZE_FACTOR)
|
max_entries = int(max_entries * CACHE_SIZE_FACTOR)
|
||||||
|
|
||||||
self.orig = orig
|
self.orig = orig
|
||||||
@ -165,15 +167,33 @@ class CacheDescriptor(object):
|
|||||||
|
|
||||||
self.max_entries = max_entries
|
self.max_entries = max_entries
|
||||||
self.num_args = num_args
|
self.num_args = num_args
|
||||||
self.lru = lru
|
|
||||||
self.tree = tree
|
self.tree = tree
|
||||||
|
|
||||||
self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
|
all_args = inspect.getargspec(orig)
|
||||||
|
self.arg_names = all_args.args[1:num_args + 1]
|
||||||
|
|
||||||
|
if "cache_context" in all_args.args:
|
||||||
|
if not cache_context:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot have a 'cache_context' arg without setting"
|
||||||
|
" cache_context=True"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
self.arg_names.remove("cache_context")
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
elif cache_context:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot have cache_context=True without having an arg"
|
||||||
|
" named `cache_context`"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.add_cache_context = cache_context
|
||||||
|
|
||||||
if len(self.arg_names) < self.num_args:
|
if len(self.arg_names) < self.num_args:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Not enough explicit positional arguments to key off of for %r."
|
"Not enough explicit positional arguments to key off of for %r."
|
||||||
" (@cached cannot key off of *args or **kwars)"
|
" (@cached cannot key off of *args or **kwargs)"
|
||||||
% (orig.__name__,)
|
% (orig.__name__,)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -182,16 +202,29 @@ class CacheDescriptor(object):
|
|||||||
name=self.orig.__name__,
|
name=self.orig.__name__,
|
||||||
max_entries=self.max_entries,
|
max_entries=self.max_entries,
|
||||||
keylen=self.num_args,
|
keylen=self.num_args,
|
||||||
lru=self.lru,
|
|
||||||
tree=self.tree,
|
tree=self.tree,
|
||||||
)
|
)
|
||||||
|
|
||||||
@functools.wraps(self.orig)
|
@functools.wraps(self.orig)
|
||||||
def wrapped(*args, **kwargs):
|
def wrapped(*args, **kwargs):
|
||||||
|
# If we're passed a cache_context then we'll want to call its invalidate()
|
||||||
|
# whenever we are invalidated
|
||||||
|
invalidate_callback = kwargs.pop("on_invalidate", None)
|
||||||
|
|
||||||
|
# Add temp cache_context so inspect.getcallargs doesn't explode
|
||||||
|
if self.add_cache_context:
|
||||||
|
kwargs["cache_context"] = None
|
||||||
|
|
||||||
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
||||||
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
|
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
|
||||||
|
|
||||||
|
# Add our own `cache_context` to argument list if the wrapped function
|
||||||
|
# has asked for one
|
||||||
|
if self.add_cache_context:
|
||||||
|
kwargs["cache_context"] = _CacheContext(cache, cache_key)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cached_result_d = cache.get(cache_key)
|
cached_result_d = cache.get(cache_key, callback=invalidate_callback)
|
||||||
|
|
||||||
observer = cached_result_d.observe()
|
observer = cached_result_d.observe()
|
||||||
if DEBUG_CACHES:
|
if DEBUG_CACHES:
|
||||||
@ -228,7 +261,7 @@ class CacheDescriptor(object):
|
|||||||
ret.addErrback(onErr)
|
ret.addErrback(onErr)
|
||||||
|
|
||||||
ret = ObservableDeferred(ret, consumeErrors=True)
|
ret = ObservableDeferred(ret, consumeErrors=True)
|
||||||
cache.update(sequence, cache_key, ret)
|
cache.update(sequence, cache_key, ret, callback=invalidate_callback)
|
||||||
|
|
||||||
return preserve_context_over_deferred(ret.observe())
|
return preserve_context_over_deferred(ret.observe())
|
||||||
|
|
||||||
@ -297,6 +330,10 @@ class CacheListDescriptor(object):
|
|||||||
|
|
||||||
@functools.wraps(self.orig)
|
@functools.wraps(self.orig)
|
||||||
def wrapped(*args, **kwargs):
|
def wrapped(*args, **kwargs):
|
||||||
|
# If we're passed a cache_context then we'll want to call its invalidate()
|
||||||
|
# whenever we are invalidated
|
||||||
|
invalidate_callback = kwargs.pop("on_invalidate", None)
|
||||||
|
|
||||||
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
||||||
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
|
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
|
||||||
list_args = arg_dict[self.list_name]
|
list_args = arg_dict[self.list_name]
|
||||||
@ -311,7 +348,7 @@ class CacheListDescriptor(object):
|
|||||||
key[self.list_pos] = arg
|
key[self.list_pos] = arg
|
||||||
|
|
||||||
try:
|
try:
|
||||||
res = cache.get(tuple(key))
|
res = cache.get(tuple(key), callback=invalidate_callback)
|
||||||
if not res.has_succeeded():
|
if not res.has_succeeded():
|
||||||
res = res.observe()
|
res = res.observe()
|
||||||
res.addCallback(lambda r, arg: (arg, r), arg)
|
res.addCallback(lambda r, arg: (arg, r), arg)
|
||||||
@ -345,7 +382,10 @@ class CacheListDescriptor(object):
|
|||||||
|
|
||||||
key = list(keyargs)
|
key = list(keyargs)
|
||||||
key[self.list_pos] = arg
|
key[self.list_pos] = arg
|
||||||
cache.update(sequence, tuple(key), observer)
|
cache.update(
|
||||||
|
sequence, tuple(key), observer,
|
||||||
|
callback=invalidate_callback
|
||||||
|
)
|
||||||
|
|
||||||
def invalidate(f, key):
|
def invalidate(f, key):
|
||||||
cache.invalidate(key)
|
cache.invalidate(key)
|
||||||
@ -376,24 +416,29 @@ class CacheListDescriptor(object):
|
|||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
def cached(max_entries=1000, num_args=1, lru=True, tree=False):
|
class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
|
||||||
|
def invalidate(self):
|
||||||
|
self.cache.invalidate(self.key)
|
||||||
|
|
||||||
|
|
||||||
|
def cached(max_entries=1000, num_args=1, tree=False, cache_context=False):
|
||||||
return lambda orig: CacheDescriptor(
|
return lambda orig: CacheDescriptor(
|
||||||
orig,
|
orig,
|
||||||
max_entries=max_entries,
|
max_entries=max_entries,
|
||||||
num_args=num_args,
|
num_args=num_args,
|
||||||
lru=lru,
|
|
||||||
tree=tree,
|
tree=tree,
|
||||||
|
cache_context=cache_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False):
|
def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False):
|
||||||
return lambda orig: CacheDescriptor(
|
return lambda orig: CacheDescriptor(
|
||||||
orig,
|
orig,
|
||||||
max_entries=max_entries,
|
max_entries=max_entries,
|
||||||
num_args=num_args,
|
num_args=num_args,
|
||||||
lru=lru,
|
|
||||||
tree=tree,
|
tree=tree,
|
||||||
inlineCallbacks=True,
|
inlineCallbacks=True,
|
||||||
|
cache_context=cache_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,13 +30,14 @@ def enumerate_leaves(node, depth):
|
|||||||
|
|
||||||
|
|
||||||
class _Node(object):
|
class _Node(object):
|
||||||
__slots__ = ["prev_node", "next_node", "key", "value"]
|
__slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
|
||||||
|
|
||||||
def __init__(self, prev_node, next_node, key, value):
|
def __init__(self, prev_node, next_node, key, value, callbacks=set()):
|
||||||
self.prev_node = prev_node
|
self.prev_node = prev_node
|
||||||
self.next_node = next_node
|
self.next_node = next_node
|
||||||
self.key = key
|
self.key = key
|
||||||
self.value = value
|
self.value = value
|
||||||
|
self.callbacks = callbacks
|
||||||
|
|
||||||
|
|
||||||
class LruCache(object):
|
class LruCache(object):
|
||||||
@ -44,6 +45,9 @@ class LruCache(object):
|
|||||||
Least-recently-used cache.
|
Least-recently-used cache.
|
||||||
Supports del_multi only if cache_type=TreeCache
|
Supports del_multi only if cache_type=TreeCache
|
||||||
If cache_type=TreeCache, all keys must be tuples.
|
If cache_type=TreeCache, all keys must be tuples.
|
||||||
|
|
||||||
|
Can also set callbacks on objects when getting/setting which are fired
|
||||||
|
when that key gets invalidated/evicted.
|
||||||
"""
|
"""
|
||||||
def __init__(self, max_size, keylen=1, cache_type=dict):
|
def __init__(self, max_size, keylen=1, cache_type=dict):
|
||||||
cache = cache_type()
|
cache = cache_type()
|
||||||
@ -62,10 +66,10 @@ class LruCache(object):
|
|||||||
|
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
def add_node(key, value):
|
def add_node(key, value, callbacks=set()):
|
||||||
prev_node = list_root
|
prev_node = list_root
|
||||||
next_node = prev_node.next_node
|
next_node = prev_node.next_node
|
||||||
node = _Node(prev_node, next_node, key, value)
|
node = _Node(prev_node, next_node, key, value, callbacks)
|
||||||
prev_node.next_node = node
|
prev_node.next_node = node
|
||||||
next_node.prev_node = node
|
next_node.prev_node = node
|
||||||
cache[key] = node
|
cache[key] = node
|
||||||
@ -88,23 +92,41 @@ class LruCache(object):
|
|||||||
prev_node.next_node = next_node
|
prev_node.next_node = next_node
|
||||||
next_node.prev_node = prev_node
|
next_node.prev_node = prev_node
|
||||||
|
|
||||||
|
for cb in node.callbacks:
|
||||||
|
cb()
|
||||||
|
node.callbacks.clear()
|
||||||
|
|
||||||
@synchronized
|
@synchronized
|
||||||
def cache_get(key, default=None):
|
def cache_get(key, default=None, callback=None):
|
||||||
node = cache.get(key, None)
|
node = cache.get(key, None)
|
||||||
if node is not None:
|
if node is not None:
|
||||||
move_node_to_front(node)
|
move_node_to_front(node)
|
||||||
|
if callback:
|
||||||
|
node.callbacks.add(callback)
|
||||||
return node.value
|
return node.value
|
||||||
else:
|
else:
|
||||||
return default
|
return default
|
||||||
|
|
||||||
@synchronized
|
@synchronized
|
||||||
def cache_set(key, value):
|
def cache_set(key, value, callback=None):
|
||||||
node = cache.get(key, None)
|
node = cache.get(key, None)
|
||||||
if node is not None:
|
if node is not None:
|
||||||
|
if value != node.value:
|
||||||
|
for cb in node.callbacks:
|
||||||
|
cb()
|
||||||
|
node.callbacks.clear()
|
||||||
|
|
||||||
|
if callback:
|
||||||
|
node.callbacks.add(callback)
|
||||||
|
|
||||||
move_node_to_front(node)
|
move_node_to_front(node)
|
||||||
node.value = value
|
node.value = value
|
||||||
else:
|
else:
|
||||||
add_node(key, value)
|
if callback:
|
||||||
|
callbacks = set([callback])
|
||||||
|
else:
|
||||||
|
callbacks = set()
|
||||||
|
add_node(key, value, callbacks)
|
||||||
if len(cache) > max_size:
|
if len(cache) > max_size:
|
||||||
todelete = list_root.prev_node
|
todelete = list_root.prev_node
|
||||||
delete_node(todelete)
|
delete_node(todelete)
|
||||||
@ -148,6 +170,9 @@ class LruCache(object):
|
|||||||
def cache_clear():
|
def cache_clear():
|
||||||
list_root.next_node = list_root
|
list_root.next_node = list_root
|
||||||
list_root.prev_node = list_root
|
list_root.prev_node = list_root
|
||||||
|
for node in cache.values():
|
||||||
|
for cb in node.callbacks:
|
||||||
|
cb()
|
||||||
cache.clear()
|
cache.clear()
|
||||||
|
|
||||||
@synchronized
|
@synchronized
|
||||||
|
@ -64,6 +64,9 @@ class TreeCache(object):
|
|||||||
self.size -= cnt
|
self.size -= cnt
|
||||||
return popped
|
return popped
|
||||||
|
|
||||||
|
def values(self):
|
||||||
|
return [e.value for e in self.root.values()]
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.size
|
return self.size
|
||||||
|
|
||||||
|
@ -297,12 +297,13 @@ def preserve_context_over_fn(fn, *args, **kwargs):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def preserve_context_over_deferred(deferred):
|
def preserve_context_over_deferred(deferred, context=None):
|
||||||
"""Given a deferred wrap it such that any callbacks added later to it will
|
"""Given a deferred wrap it such that any callbacks added later to it will
|
||||||
be invoked with the current context.
|
be invoked with the current context.
|
||||||
"""
|
"""
|
||||||
current_context = LoggingContext.current_context()
|
if context is None:
|
||||||
d = _PreservingContextDeferred(current_context)
|
context = LoggingContext.current_context()
|
||||||
|
d = _PreservingContextDeferred(context)
|
||||||
deferred.chainDeferred(d)
|
deferred.chainDeferred(d)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
@ -316,8 +317,13 @@ def preserve_fn(f):
|
|||||||
|
|
||||||
def g(*args, **kwargs):
|
def g(*args, **kwargs):
|
||||||
with PreserveLoggingContext(current):
|
with PreserveLoggingContext(current):
|
||||||
return f(*args, **kwargs)
|
res = f(*args, **kwargs)
|
||||||
|
if isinstance(res, defer.Deferred):
|
||||||
|
return preserve_context_over_deferred(
|
||||||
|
res, context=LoggingContext.sentinel
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return res
|
||||||
return g
|
return g
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,10 +13,12 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
|
from functools import wraps
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
@ -47,6 +49,18 @@ block_db_txn_duration = metrics.register_distribution(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def measure_func(name):
|
||||||
|
def wrapper(func):
|
||||||
|
@wraps(func)
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def measured_func(self, *args, **kwargs):
|
||||||
|
with Measure(self.clock, name):
|
||||||
|
r = yield func(self, *args, **kwargs)
|
||||||
|
defer.returnValue(r)
|
||||||
|
return measured_func
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
class Measure(object):
|
class Measure(object):
|
||||||
__slots__ = [
|
__slots__ = [
|
||||||
"clock", "name", "start_context", "start", "new_context", "ru_utime",
|
"clock", "name", "start_context", "start", "new_context", "ru_utime",
|
||||||
@ -64,7 +78,6 @@ class Measure(object):
|
|||||||
self.start = self.clock.time_msec()
|
self.start = self.clock.time_msec()
|
||||||
self.start_context = LoggingContext.current_context()
|
self.start_context = LoggingContext.current_context()
|
||||||
if not self.start_context:
|
if not self.start_context:
|
||||||
logger.warn("Entered Measure without log context: %s", self.name)
|
|
||||||
self.start_context = LoggingContext("Measure")
|
self.start_context = LoggingContext("Measure")
|
||||||
self.start_context.__enter__()
|
self.start_context.__enter__()
|
||||||
self.created_context = True
|
self.created_context = True
|
||||||
@ -74,7 +87,7 @@ class Measure(object):
|
|||||||
self.db_txn_duration = self.start_context.db_txn_duration
|
self.db_txn_duration = self.start_context.db_txn_duration
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
if exc_type is not None or not self.start_context:
|
if isinstance(exc_type, Exception) or not self.start_context:
|
||||||
return
|
return
|
||||||
|
|
||||||
duration = self.clock.time_msec() - self.start
|
duration = self.clock.time_msec() - self.start
|
||||||
@ -85,7 +98,7 @@ class Measure(object):
|
|||||||
if context != self.start_context:
|
if context != self.start_context:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Context has unexpectedly changed from '%s' to '%s'. (%r)",
|
"Context has unexpectedly changed from '%s' to '%s'. (%r)",
|
||||||
context, self.start_context, self.name
|
self.start_context, context, self.name
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ from twisted.internet import defer
|
|||||||
|
|
||||||
from synapse.api.constants import Membership, EventTypes
|
from synapse.api.constants import Membership, EventTypes
|
||||||
|
|
||||||
from synapse.util.logcontext import preserve_fn
|
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@ -55,12 +55,12 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
|
|||||||
given events
|
given events
|
||||||
events ([synapse.events.EventBase]): list of events to filter
|
events ([synapse.events.EventBase]): list of events to filter
|
||||||
"""
|
"""
|
||||||
forgotten = yield defer.gatherResults([
|
forgotten = yield preserve_context_over_deferred(defer.gatherResults([
|
||||||
preserve_fn(store.who_forgot_in_room)(
|
preserve_fn(store.who_forgot_in_room)(
|
||||||
room_id,
|
room_id,
|
||||||
)
|
)
|
||||||
for room_id in frozenset(e.room_id for e in events)
|
for room_id in frozenset(e.room_id for e in events)
|
||||||
], consumeErrors=True)
|
], consumeErrors=True))
|
||||||
|
|
||||||
# Set of membership event_ids that have been forgotten
|
# Set of membership event_ids that have been forgotten
|
||||||
event_id_forgotten = frozenset(
|
event_id_forgotten = frozenset(
|
||||||
|
@ -14,6 +14,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from synapse.appservice import ApplicationService
|
from synapse.appservice import ApplicationService
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
@ -42,20 +44,25 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||||||
type="m.something", room_id="!foo:bar", sender="@someone:somewhere"
|
type="m.something", room_id="!foo:bar", sender="@someone:somewhere"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.store = Mock()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def test_regex_user_id_prefix_match(self):
|
def test_regex_user_id_prefix_match(self):
|
||||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
self.service.namespaces[ApplicationService.NS_USERS].append(
|
||||||
_regex("@irc_.*")
|
_regex("@irc_.*")
|
||||||
)
|
)
|
||||||
self.event.sender = "@irc_foobar:matrix.org"
|
self.event.sender = "@irc_foobar:matrix.org"
|
||||||
self.assertTrue(self.service.is_interested(self.event))
|
self.assertTrue((yield self.service.is_interested(self.event)))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def test_regex_user_id_prefix_no_match(self):
|
def test_regex_user_id_prefix_no_match(self):
|
||||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
self.service.namespaces[ApplicationService.NS_USERS].append(
|
||||||
_regex("@irc_.*")
|
_regex("@irc_.*")
|
||||||
)
|
)
|
||||||
self.event.sender = "@someone_else:matrix.org"
|
self.event.sender = "@someone_else:matrix.org"
|
||||||
self.assertFalse(self.service.is_interested(self.event))
|
self.assertFalse((yield self.service.is_interested(self.event)))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def test_regex_room_member_is_checked(self):
|
def test_regex_room_member_is_checked(self):
|
||||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
self.service.namespaces[ApplicationService.NS_USERS].append(
|
||||||
_regex("@irc_.*")
|
_regex("@irc_.*")
|
||||||
@ -63,30 +70,36 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||||||
self.event.sender = "@someone_else:matrix.org"
|
self.event.sender = "@someone_else:matrix.org"
|
||||||
self.event.type = "m.room.member"
|
self.event.type = "m.room.member"
|
||||||
self.event.state_key = "@irc_foobar:matrix.org"
|
self.event.state_key = "@irc_foobar:matrix.org"
|
||||||
self.assertTrue(self.service.is_interested(self.event))
|
self.assertTrue((yield self.service.is_interested(self.event)))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def test_regex_room_id_match(self):
|
def test_regex_room_id_match(self):
|
||||||
self.service.namespaces[ApplicationService.NS_ROOMS].append(
|
self.service.namespaces[ApplicationService.NS_ROOMS].append(
|
||||||
_regex("!some_prefix.*some_suffix:matrix.org")
|
_regex("!some_prefix.*some_suffix:matrix.org")
|
||||||
)
|
)
|
||||||
self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
|
self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
|
||||||
self.assertTrue(self.service.is_interested(self.event))
|
self.assertTrue((yield self.service.is_interested(self.event)))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def test_regex_room_id_no_match(self):
|
def test_regex_room_id_no_match(self):
|
||||||
self.service.namespaces[ApplicationService.NS_ROOMS].append(
|
self.service.namespaces[ApplicationService.NS_ROOMS].append(
|
||||||
_regex("!some_prefix.*some_suffix:matrix.org")
|
_regex("!some_prefix.*some_suffix:matrix.org")
|
||||||
)
|
)
|
||||||
self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
|
self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
|
||||||
self.assertFalse(self.service.is_interested(self.event))
|
self.assertFalse((yield self.service.is_interested(self.event)))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def test_regex_alias_match(self):
|
def test_regex_alias_match(self):
|
||||||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
||||||
_regex("#irc_.*:matrix.org")
|
_regex("#irc_.*:matrix.org")
|
||||||
)
|
)
|
||||||
self.assertTrue(self.service.is_interested(
|
self.store.get_aliases_for_room.return_value = [
|
||||||
self.event,
|
"#irc_foobar:matrix.org", "#athing:matrix.org"
|
||||||
aliases_for_event=["#irc_foobar:matrix.org", "#athing:matrix.org"]
|
]
|
||||||
))
|
self.store.get_users_in_room.return_value = []
|
||||||
|
self.assertTrue((yield self.service.is_interested(
|
||||||
|
self.event, self.store
|
||||||
|
)))
|
||||||
|
|
||||||
def test_non_exclusive_alias(self):
|
def test_non_exclusive_alias(self):
|
||||||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
||||||
@ -136,15 +149,20 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||||||
"!irc_foobar:matrix.org"
|
"!irc_foobar:matrix.org"
|
||||||
))
|
))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def test_regex_alias_no_match(self):
|
def test_regex_alias_no_match(self):
|
||||||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
||||||
_regex("#irc_.*:matrix.org")
|
_regex("#irc_.*:matrix.org")
|
||||||
)
|
)
|
||||||
self.assertFalse(self.service.is_interested(
|
self.store.get_aliases_for_room.return_value = [
|
||||||
self.event,
|
"#xmpp_foobar:matrix.org", "#athing:matrix.org"
|
||||||
aliases_for_event=["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
|
]
|
||||||
))
|
self.store.get_users_in_room.return_value = []
|
||||||
|
self.assertFalse((yield self.service.is_interested(
|
||||||
|
self.event, self.store
|
||||||
|
)))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def test_regex_multiple_matches(self):
|
def test_regex_multiple_matches(self):
|
||||||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
||||||
_regex("#irc_.*:matrix.org")
|
_regex("#irc_.*:matrix.org")
|
||||||
@ -153,53 +171,13 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||||||
_regex("@irc_.*")
|
_regex("@irc_.*")
|
||||||
)
|
)
|
||||||
self.event.sender = "@irc_foobar:matrix.org"
|
self.event.sender = "@irc_foobar:matrix.org"
|
||||||
self.assertTrue(self.service.is_interested(
|
self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"]
|
||||||
self.event,
|
self.store.get_users_in_room.return_value = []
|
||||||
aliases_for_event=["#irc_barfoo:matrix.org"]
|
self.assertTrue((yield self.service.is_interested(
|
||||||
))
|
self.event, self.store
|
||||||
|
)))
|
||||||
def test_restrict_to_rooms(self):
|
|
||||||
self.service.namespaces[ApplicationService.NS_ROOMS].append(
|
|
||||||
_regex("!flibble_.*:matrix.org")
|
|
||||||
)
|
|
||||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
|
||||||
_regex("@irc_.*")
|
|
||||||
)
|
|
||||||
self.event.sender = "@irc_foobar:matrix.org"
|
|
||||||
self.event.room_id = "!wibblewoo:matrix.org"
|
|
||||||
self.assertFalse(self.service.is_interested(
|
|
||||||
self.event,
|
|
||||||
restrict_to=ApplicationService.NS_ROOMS
|
|
||||||
))
|
|
||||||
|
|
||||||
def test_restrict_to_aliases(self):
|
|
||||||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
|
||||||
_regex("#xmpp_.*:matrix.org")
|
|
||||||
)
|
|
||||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
|
||||||
_regex("@irc_.*")
|
|
||||||
)
|
|
||||||
self.event.sender = "@irc_foobar:matrix.org"
|
|
||||||
self.assertFalse(self.service.is_interested(
|
|
||||||
self.event,
|
|
||||||
restrict_to=ApplicationService.NS_ALIASES,
|
|
||||||
aliases_for_event=["#irc_barfoo:matrix.org"]
|
|
||||||
))
|
|
||||||
|
|
||||||
def test_restrict_to_senders(self):
|
|
||||||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
|
||||||
_regex("#xmpp_.*:matrix.org")
|
|
||||||
)
|
|
||||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
|
||||||
_regex("@irc_.*")
|
|
||||||
)
|
|
||||||
self.event.sender = "@xmpp_foobar:matrix.org"
|
|
||||||
self.assertFalse(self.service.is_interested(
|
|
||||||
self.event,
|
|
||||||
restrict_to=ApplicationService.NS_USERS,
|
|
||||||
aliases_for_event=["#xmpp_barfoo:matrix.org"]
|
|
||||||
))
|
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def test_interested_in_self(self):
|
def test_interested_in_self(self):
|
||||||
# make sure invites get through
|
# make sure invites get through
|
||||||
self.service.sender = "@appservice:name"
|
self.service.sender = "@appservice:name"
|
||||||
@ -211,20 +189,21 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||||||
"membership": "invite"
|
"membership": "invite"
|
||||||
}
|
}
|
||||||
self.event.state_key = self.service.sender
|
self.event.state_key = self.service.sender
|
||||||
self.assertTrue(self.service.is_interested(self.event))
|
self.assertTrue((yield self.service.is_interested(self.event)))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def test_member_list_match(self):
|
def test_member_list_match(self):
|
||||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
self.service.namespaces[ApplicationService.NS_USERS].append(
|
||||||
_regex("@irc_.*")
|
_regex("@irc_.*")
|
||||||
)
|
)
|
||||||
join_list = [
|
self.store.get_users_in_room.return_value = [
|
||||||
"@alice:here",
|
"@alice:here",
|
||||||
"@irc_fo:here", # AS user
|
"@irc_fo:here", # AS user
|
||||||
"@bob:here",
|
"@bob:here",
|
||||||
]
|
]
|
||||||
|
self.store.get_aliases_for_room.return_value = []
|
||||||
|
|
||||||
self.event.sender = "@xmpp_foobar:matrix.org"
|
self.event.sender = "@xmpp_foobar:matrix.org"
|
||||||
self.assertTrue(self.service.is_interested(
|
self.assertTrue((yield self.service.is_interested(
|
||||||
event=self.event,
|
event=self.event, store=self.store
|
||||||
member_list=join_list
|
)))
|
||||||
))
|
|
||||||
|
@ -193,7 +193,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.txn_ctrl = Mock()
|
self.txn_ctrl = Mock()
|
||||||
self.queuer = _ServiceQueuer(self.txn_ctrl)
|
self.queuer = _ServiceQueuer(self.txn_ctrl, MockClock())
|
||||||
|
|
||||||
def test_send_single_event_no_queue(self):
|
def test_send_single_event_no_queue(self):
|
||||||
# Expect the event to be sent immediately.
|
# Expect the event to be sent immediately.
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from .. import unittest
|
from .. import unittest
|
||||||
|
from tests.utils import MockClock
|
||||||
|
|
||||||
from synapse.handlers.appservice import ApplicationServicesHandler
|
from synapse.handlers.appservice import ApplicationServicesHandler
|
||||||
|
|
||||||
@ -32,6 +33,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||||||
hs.get_datastore = Mock(return_value=self.mock_store)
|
hs.get_datastore = Mock(return_value=self.mock_store)
|
||||||
hs.get_application_service_api = Mock(return_value=self.mock_as_api)
|
hs.get_application_service_api = Mock(return_value=self.mock_as_api)
|
||||||
hs.get_application_service_scheduler = Mock(return_value=self.mock_scheduler)
|
hs.get_application_service_scheduler = Mock(return_value=self.mock_scheduler)
|
||||||
|
hs.get_clock.return_value = MockClock()
|
||||||
self.handler = ApplicationServicesHandler(hs)
|
self.handler = ApplicationServicesHandler(hs)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -51,8 +53,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||||||
type="m.room.message",
|
type="m.room.message",
|
||||||
room_id="!foo:bar"
|
room_id="!foo:bar"
|
||||||
)
|
)
|
||||||
|
self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
|
||||||
self.mock_as_api.push = Mock()
|
self.mock_as_api.push = Mock()
|
||||||
yield self.handler.notify_interested_services(event)
|
yield self.handler.notify_interested_services(0)
|
||||||
self.mock_scheduler.submit_event_for_as.assert_called_once_with(
|
self.mock_scheduler.submit_event_for_as.assert_called_once_with(
|
||||||
interested_service, event
|
interested_service, event
|
||||||
)
|
)
|
||||||
@ -72,7 +75,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.mock_as_api.push = Mock()
|
self.mock_as_api.push = Mock()
|
||||||
self.mock_as_api.query_user = Mock()
|
self.mock_as_api.query_user = Mock()
|
||||||
yield self.handler.notify_interested_services(event)
|
self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
|
||||||
|
yield self.handler.notify_interested_services(0)
|
||||||
self.mock_as_api.query_user.assert_called_once_with(
|
self.mock_as_api.query_user.assert_called_once_with(
|
||||||
services[0], user_id
|
services[0], user_id
|
||||||
)
|
)
|
||||||
@ -94,7 +98,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.mock_as_api.push = Mock()
|
self.mock_as_api.push = Mock()
|
||||||
self.mock_as_api.query_user = Mock()
|
self.mock_as_api.query_user = Mock()
|
||||||
yield self.handler.notify_interested_services(event)
|
self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
|
||||||
|
yield self.handler.notify_interested_services(0)
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
self.mock_as_api.query_user.called,
|
self.mock_as_api.query_user.called,
|
||||||
"query_user called when it shouldn't have been."
|
"query_user called when it shouldn't have been."
|
||||||
@ -108,11 +113,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
room_id = "!alpha:bet"
|
room_id = "!alpha:bet"
|
||||||
servers = ["aperture"]
|
servers = ["aperture"]
|
||||||
interested_service = self._mkservice(is_interested=True)
|
interested_service = self._mkservice_alias(is_interested_in_alias=True)
|
||||||
services = [
|
services = [
|
||||||
self._mkservice(is_interested=False),
|
self._mkservice_alias(is_interested_in_alias=False),
|
||||||
interested_service,
|
interested_service,
|
||||||
self._mkservice(is_interested=False)
|
self._mkservice_alias(is_interested_in_alias=False)
|
||||||
]
|
]
|
||||||
|
|
||||||
self.mock_store.get_app_services = Mock(return_value=services)
|
self.mock_store.get_app_services = Mock(return_value=services)
|
||||||
@ -135,3 +140,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||||||
service.token = "mock_service_token"
|
service.token = "mock_service_token"
|
||||||
service.url = "mock_service_url"
|
service.url = "mock_service_url"
|
||||||
return service
|
return service
|
||||||
|
|
||||||
|
def _mkservice_alias(self, is_interested_in_alias):
|
||||||
|
service = Mock()
|
||||||
|
service.is_interested_in_alias = Mock(return_value=is_interested_in_alias)
|
||||||
|
service.token = "mock_service_token"
|
||||||
|
service.url = "mock_service_url"
|
||||||
|
return service
|
||||||
|
@ -14,11 +14,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import pymacaroons
|
import pymacaroons
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
import synapse
|
||||||
|
import synapse.api.errors
|
||||||
from synapse.handlers.auth import AuthHandler
|
from synapse.handlers.auth import AuthHandler
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.utils import setup_test_homeserver
|
from tests.utils import setup_test_homeserver
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
|
|
||||||
class AuthHandlers(object):
|
class AuthHandlers(object):
|
||||||
@ -31,11 +33,12 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.hs = yield setup_test_homeserver(handlers=None)
|
self.hs = yield setup_test_homeserver(handlers=None)
|
||||||
self.hs.handlers = AuthHandlers(self.hs)
|
self.hs.handlers = AuthHandlers(self.hs)
|
||||||
|
self.auth_handler = self.hs.handlers.auth_handler
|
||||||
|
|
||||||
def test_token_is_a_macaroon(self):
|
def test_token_is_a_macaroon(self):
|
||||||
self.hs.config.macaroon_secret_key = "this key is a huge secret"
|
self.hs.config.macaroon_secret_key = "this key is a huge secret"
|
||||||
|
|
||||||
token = self.hs.handlers.auth_handler.generate_access_token("some_user")
|
token = self.auth_handler.generate_access_token("some_user")
|
||||||
# Check that we can parse the thing with pymacaroons
|
# Check that we can parse the thing with pymacaroons
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(token)
|
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||||
# The most basic of sanity checks
|
# The most basic of sanity checks
|
||||||
@ -46,7 +49,7 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
self.hs.config.macaroon_secret_key = "this key is a massive secret"
|
self.hs.config.macaroon_secret_key = "this key is a massive secret"
|
||||||
self.hs.clock.now = 5000
|
self.hs.clock.now = 5000
|
||||||
|
|
||||||
token = self.hs.handlers.auth_handler.generate_access_token("a_user")
|
token = self.auth_handler.generate_access_token("a_user")
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(token)
|
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||||
|
|
||||||
def verify_gen(caveat):
|
def verify_gen(caveat):
|
||||||
@ -67,3 +70,46 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
v.satisfy_general(verify_type)
|
v.satisfy_general(verify_type)
|
||||||
v.satisfy_general(verify_expiry)
|
v.satisfy_general(verify_expiry)
|
||||||
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
||||||
|
|
||||||
|
def test_short_term_login_token_gives_user_id(self):
|
||||||
|
self.hs.clock.now = 1000
|
||||||
|
|
||||||
|
token = self.auth_handler.generate_short_term_login_token(
|
||||||
|
"a_user", 5000
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
"a_user",
|
||||||
|
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||||
|
token
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# when we advance the clock, the token should be rejected
|
||||||
|
self.hs.clock.now = 6000
|
||||||
|
with self.assertRaises(synapse.api.errors.AuthError):
|
||||||
|
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||||
|
token
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_short_term_login_token_cannot_replace_user_id(self):
|
||||||
|
token = self.auth_handler.generate_short_term_login_token(
|
||||||
|
"a_user", 5000
|
||||||
|
)
|
||||||
|
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
"a_user",
|
||||||
|
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||||
|
macaroon.serialize()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# add another "user_id" caveat, which might allow us to override the
|
||||||
|
# user_id.
|
||||||
|
macaroon.add_first_party_caveat("user_id = b_user")
|
||||||
|
|
||||||
|
with self.assertRaises(synapse.api.errors.AuthError):
|
||||||
|
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||||
|
macaroon.serialize()
|
||||||
|
)
|
||||||
|
@ -17,6 +17,8 @@
|
|||||||
from tests import unittest
|
from tests import unittest
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from mock import Mock
|
||||||
|
|
||||||
from synapse.util.async import ObservableDeferred
|
from synapse.util.async import ObservableDeferred
|
||||||
|
|
||||||
from synapse.util.caches.descriptors import Cache, cached
|
from synapse.util.caches.descriptors import Cache, cached
|
||||||
@ -72,7 +74,7 @@ class CacheTestCase(unittest.TestCase):
|
|||||||
cache.get(3)
|
cache.get(3)
|
||||||
|
|
||||||
def test_eviction_lru(self):
|
def test_eviction_lru(self):
|
||||||
cache = Cache("test", max_entries=2, lru=True)
|
cache = Cache("test", max_entries=2)
|
||||||
|
|
||||||
cache.prefill(1, "one")
|
cache.prefill(1, "one")
|
||||||
cache.prefill(2, "two")
|
cache.prefill(2, "two")
|
||||||
@ -199,3 +201,115 @@ class CacheDecoratorTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEquals(a.func("foo").result, d.result)
|
self.assertEquals(a.func("foo").result, d.result)
|
||||||
self.assertEquals(callcount[0], 0)
|
self.assertEquals(callcount[0], 0)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_invalidate_context(self):
|
||||||
|
callcount = [0]
|
||||||
|
callcount2 = [0]
|
||||||
|
|
||||||
|
class A(object):
|
||||||
|
@cached()
|
||||||
|
def func(self, key):
|
||||||
|
callcount[0] += 1
|
||||||
|
return key
|
||||||
|
|
||||||
|
@cached(cache_context=True)
|
||||||
|
def func2(self, key, cache_context):
|
||||||
|
callcount2[0] += 1
|
||||||
|
return self.func(key, on_invalidate=cache_context.invalidate)
|
||||||
|
|
||||||
|
a = A()
|
||||||
|
yield a.func2("foo")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 1)
|
||||||
|
self.assertEquals(callcount2[0], 1)
|
||||||
|
|
||||||
|
a.func.invalidate(("foo",))
|
||||||
|
yield a.func("foo")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 2)
|
||||||
|
self.assertEquals(callcount2[0], 1)
|
||||||
|
|
||||||
|
yield a.func2("foo")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 2)
|
||||||
|
self.assertEquals(callcount2[0], 2)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_eviction_context(self):
|
||||||
|
callcount = [0]
|
||||||
|
callcount2 = [0]
|
||||||
|
|
||||||
|
class A(object):
|
||||||
|
@cached(max_entries=2)
|
||||||
|
def func(self, key):
|
||||||
|
callcount[0] += 1
|
||||||
|
return key
|
||||||
|
|
||||||
|
@cached(cache_context=True)
|
||||||
|
def func2(self, key, cache_context):
|
||||||
|
callcount2[0] += 1
|
||||||
|
return self.func(key, on_invalidate=cache_context.invalidate)
|
||||||
|
|
||||||
|
a = A()
|
||||||
|
yield a.func2("foo")
|
||||||
|
yield a.func2("foo2")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 2)
|
||||||
|
self.assertEquals(callcount2[0], 2)
|
||||||
|
|
||||||
|
yield a.func("foo3")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 3)
|
||||||
|
self.assertEquals(callcount2[0], 2)
|
||||||
|
|
||||||
|
yield a.func2("foo")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 4)
|
||||||
|
self.assertEquals(callcount2[0], 3)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_double_get(self):
|
||||||
|
callcount = [0]
|
||||||
|
callcount2 = [0]
|
||||||
|
|
||||||
|
class A(object):
|
||||||
|
@cached()
|
||||||
|
def func(self, key):
|
||||||
|
callcount[0] += 1
|
||||||
|
return key
|
||||||
|
|
||||||
|
@cached(cache_context=True)
|
||||||
|
def func2(self, key, cache_context):
|
||||||
|
callcount2[0] += 1
|
||||||
|
return self.func(key, on_invalidate=cache_context.invalidate)
|
||||||
|
|
||||||
|
a = A()
|
||||||
|
a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
|
||||||
|
|
||||||
|
yield a.func2("foo")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 1)
|
||||||
|
self.assertEquals(callcount2[0], 1)
|
||||||
|
|
||||||
|
a.func2.invalidate(("foo",))
|
||||||
|
self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
|
||||||
|
|
||||||
|
yield a.func2("foo")
|
||||||
|
a.func2.invalidate(("foo",))
|
||||||
|
self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 1)
|
||||||
|
self.assertEquals(callcount2[0], 2)
|
||||||
|
|
||||||
|
a.func.invalidate(("foo",))
|
||||||
|
self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
|
||||||
|
yield a.func("foo")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 2)
|
||||||
|
self.assertEquals(callcount2[0], 2)
|
||||||
|
|
||||||
|
yield a.func2("foo")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 2)
|
||||||
|
self.assertEquals(callcount2[0], 3)
|
||||||
|
@ -15,7 +15,9 @@
|
|||||||
|
|
||||||
from . import unittest
|
from . import unittest
|
||||||
|
|
||||||
from synapse.rest.media.v1.preview_url_resource import summarize_paragraphs
|
from synapse.rest.media.v1.preview_url_resource import (
|
||||||
|
summarize_paragraphs, decode_and_calc_og
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PreviewTestCase(unittest.TestCase):
|
class PreviewTestCase(unittest.TestCase):
|
||||||
@ -137,3 +139,79 @@ class PreviewTestCase(unittest.TestCase):
|
|||||||
" of old wooden houses in Northern Norway, the oldest house dating from"
|
" of old wooden houses in Northern Norway, the oldest house dating from"
|
||||||
" 1789. The Arctic Cathedral, a modern church…"
|
" 1789. The Arctic Cathedral, a modern church…"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PreviewUrlTestCase(unittest.TestCase):
|
||||||
|
def test_simple(self):
|
||||||
|
html = """
|
||||||
|
<html>
|
||||||
|
<head><title>Foo</title></head>
|
||||||
|
<body>
|
||||||
|
Some text.
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
|
|
||||||
|
og = decode_and_calc_og(html, "http://example.com/test.html")
|
||||||
|
|
||||||
|
self.assertEquals(og, {
|
||||||
|
"og:title": "Foo",
|
||||||
|
"og:description": "Some text."
|
||||||
|
})
|
||||||
|
|
||||||
|
def test_comment(self):
|
||||||
|
html = """
|
||||||
|
<html>
|
||||||
|
<head><title>Foo</title></head>
|
||||||
|
<body>
|
||||||
|
<!-- HTML comment -->
|
||||||
|
Some text.
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
|
|
||||||
|
og = decode_and_calc_og(html, "http://example.com/test.html")
|
||||||
|
|
||||||
|
self.assertEquals(og, {
|
||||||
|
"og:title": "Foo",
|
||||||
|
"og:description": "Some text."
|
||||||
|
})
|
||||||
|
|
||||||
|
def test_comment2(self):
|
||||||
|
html = """
|
||||||
|
<html>
|
||||||
|
<head><title>Foo</title></head>
|
||||||
|
<body>
|
||||||
|
Some text.
|
||||||
|
<!-- HTML comment -->
|
||||||
|
Some more text.
|
||||||
|
<p>Text</p>
|
||||||
|
More text
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
|
|
||||||
|
og = decode_and_calc_og(html, "http://example.com/test.html")
|
||||||
|
|
||||||
|
self.assertEquals(og, {
|
||||||
|
"og:title": "Foo",
|
||||||
|
"og:description": "Some text.\n\nSome more text.\n\nText\n\nMore text"
|
||||||
|
})
|
||||||
|
|
||||||
|
def test_script(self):
|
||||||
|
html = """
|
||||||
|
<html>
|
||||||
|
<head><title>Foo</title></head>
|
||||||
|
<body>
|
||||||
|
<script> (function() {})() </script>
|
||||||
|
Some text.
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
|
|
||||||
|
og = decode_and_calc_og(html, "http://example.com/test.html")
|
||||||
|
|
||||||
|
self.assertEquals(og, {
|
||||||
|
"og:title": "Foo",
|
||||||
|
"og:description": "Some text."
|
||||||
|
})
|
||||||
|
@ -19,6 +19,8 @@ from .. import unittest
|
|||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
from synapse.util.caches.treecache import TreeCache
|
from synapse.util.caches.treecache import TreeCache
|
||||||
|
|
||||||
|
from mock import Mock
|
||||||
|
|
||||||
|
|
||||||
class LruCacheTestCase(unittest.TestCase):
|
class LruCacheTestCase(unittest.TestCase):
|
||||||
|
|
||||||
@ -48,6 +50,8 @@ class LruCacheTestCase(unittest.TestCase):
|
|||||||
self.assertEquals(cache.get("key"), 1)
|
self.assertEquals(cache.get("key"), 1)
|
||||||
self.assertEquals(cache.setdefault("key", 2), 1)
|
self.assertEquals(cache.setdefault("key", 2), 1)
|
||||||
self.assertEquals(cache.get("key"), 1)
|
self.assertEquals(cache.get("key"), 1)
|
||||||
|
cache["key"] = 2 # Make sure overriding works.
|
||||||
|
self.assertEquals(cache.get("key"), 2)
|
||||||
|
|
||||||
def test_pop(self):
|
def test_pop(self):
|
||||||
cache = LruCache(1)
|
cache = LruCache(1)
|
||||||
@ -79,3 +83,152 @@ class LruCacheTestCase(unittest.TestCase):
|
|||||||
cache["key"] = 1
|
cache["key"] = 1
|
||||||
cache.clear()
|
cache.clear()
|
||||||
self.assertEquals(len(cache), 0)
|
self.assertEquals(len(cache), 0)
|
||||||
|
|
||||||
|
|
||||||
|
class LruCacheCallbacksTestCase(unittest.TestCase):
|
||||||
|
def test_get(self):
|
||||||
|
m = Mock()
|
||||||
|
cache = LruCache(1)
|
||||||
|
|
||||||
|
cache.set("key", "value")
|
||||||
|
self.assertFalse(m.called)
|
||||||
|
|
||||||
|
cache.get("key", callback=m)
|
||||||
|
self.assertFalse(m.called)
|
||||||
|
|
||||||
|
cache.get("key", "value")
|
||||||
|
self.assertFalse(m.called)
|
||||||
|
|
||||||
|
cache.set("key", "value2")
|
||||||
|
self.assertEquals(m.call_count, 1)
|
||||||
|
|
||||||
|
cache.set("key", "value")
|
||||||
|
self.assertEquals(m.call_count, 1)
|
||||||
|
|
||||||
|
def test_multi_get(self):
|
||||||
|
m = Mock()
|
||||||
|
cache = LruCache(1)
|
||||||
|
|
||||||
|
cache.set("key", "value")
|
||||||
|
self.assertFalse(m.called)
|
||||||
|
|
||||||
|
cache.get("key", callback=m)
|
||||||
|
self.assertFalse(m.called)
|
||||||
|
|
||||||
|
cache.get("key", callback=m)
|
||||||
|
self.assertFalse(m.called)
|
||||||
|
|
||||||
|
cache.set("key", "value2")
|
||||||
|
self.assertEquals(m.call_count, 1)
|
||||||
|
|
||||||
|
cache.set("key", "value")
|
||||||
|
self.assertEquals(m.call_count, 1)
|
||||||
|
|
||||||
|
def test_set(self):
|
||||||
|
m = Mock()
|
||||||
|
cache = LruCache(1)
|
||||||
|
|
||||||
|
cache.set("key", "value", m)
|
||||||
|
self.assertFalse(m.called)
|
||||||
|
|
||||||
|
cache.set("key", "value")
|
||||||
|
self.assertFalse(m.called)
|
||||||
|
|
||||||
|
cache.set("key", "value2")
|
||||||
|
self.assertEquals(m.call_count, 1)
|
||||||
|
|
||||||
|
cache.set("key", "value")
|
||||||
|
self.assertEquals(m.call_count, 1)
|
||||||
|
|
||||||
|
def test_pop(self):
|
||||||
|
m = Mock()
|
||||||
|
cache = LruCache(1)
|
||||||
|
|
||||||
|
cache.set("key", "value", m)
|
||||||
|
self.assertFalse(m.called)
|
||||||
|
|
||||||
|
cache.pop("key")
|
||||||
|
self.assertEquals(m.call_count, 1)
|
||||||
|
|
||||||
|
cache.set("key", "value")
|
||||||
|
self.assertEquals(m.call_count, 1)
|
||||||
|
|
||||||
|
cache.pop("key")
|
||||||
|
self.assertEquals(m.call_count, 1)
|
||||||
|
|
||||||
|
def test_del_multi(self):
|
||||||
|
m1 = Mock()
|
||||||
|
m2 = Mock()
|
||||||
|
m3 = Mock()
|
||||||
|
m4 = Mock()
|
||||||
|
cache = LruCache(4, 2, cache_type=TreeCache)
|
||||||
|
|
||||||
|
cache.set(("a", "1"), "value", m1)
|
||||||
|
cache.set(("a", "2"), "value", m2)
|
||||||
|
cache.set(("b", "1"), "value", m3)
|
||||||
|
cache.set(("b", "2"), "value", m4)
|
||||||
|
|
||||||
|
self.assertEquals(m1.call_count, 0)
|
||||||
|
self.assertEquals(m2.call_count, 0)
|
||||||
|
self.assertEquals(m3.call_count, 0)
|
||||||
|
self.assertEquals(m4.call_count, 0)
|
||||||
|
|
||||||
|
cache.del_multi(("a",))
|
||||||
|
|
||||||
|
self.assertEquals(m1.call_count, 1)
|
||||||
|
self.assertEquals(m2.call_count, 1)
|
||||||
|
self.assertEquals(m3.call_count, 0)
|
||||||
|
self.assertEquals(m4.call_count, 0)
|
||||||
|
|
||||||
|
def test_clear(self):
|
||||||
|
m1 = Mock()
|
||||||
|
m2 = Mock()
|
||||||
|
cache = LruCache(5)
|
||||||
|
|
||||||
|
cache.set("key1", "value", m1)
|
||||||
|
cache.set("key2", "value", m2)
|
||||||
|
|
||||||
|
self.assertEquals(m1.call_count, 0)
|
||||||
|
self.assertEquals(m2.call_count, 0)
|
||||||
|
|
||||||
|
cache.clear()
|
||||||
|
|
||||||
|
self.assertEquals(m1.call_count, 1)
|
||||||
|
self.assertEquals(m2.call_count, 1)
|
||||||
|
|
||||||
|
def test_eviction(self):
|
||||||
|
m1 = Mock(name="m1")
|
||||||
|
m2 = Mock(name="m2")
|
||||||
|
m3 = Mock(name="m3")
|
||||||
|
cache = LruCache(2)
|
||||||
|
|
||||||
|
cache.set("key1", "value", m1)
|
||||||
|
cache.set("key2", "value", m2)
|
||||||
|
|
||||||
|
self.assertEquals(m1.call_count, 0)
|
||||||
|
self.assertEquals(m2.call_count, 0)
|
||||||
|
self.assertEquals(m3.call_count, 0)
|
||||||
|
|
||||||
|
cache.set("key3", "value", m3)
|
||||||
|
|
||||||
|
self.assertEquals(m1.call_count, 1)
|
||||||
|
self.assertEquals(m2.call_count, 0)
|
||||||
|
self.assertEquals(m3.call_count, 0)
|
||||||
|
|
||||||
|
cache.set("key3", "value")
|
||||||
|
|
||||||
|
self.assertEquals(m1.call_count, 1)
|
||||||
|
self.assertEquals(m2.call_count, 0)
|
||||||
|
self.assertEquals(m3.call_count, 0)
|
||||||
|
|
||||||
|
cache.get("key2")
|
||||||
|
|
||||||
|
self.assertEquals(m1.call_count, 1)
|
||||||
|
self.assertEquals(m2.call_count, 0)
|
||||||
|
self.assertEquals(m3.call_count, 0)
|
||||||
|
|
||||||
|
cache.set("key1", "value", m1)
|
||||||
|
|
||||||
|
self.assertEquals(m1.call_count, 1)
|
||||||
|
self.assertEquals(m2.call_count, 0)
|
||||||
|
self.assertEquals(m3.call_count, 1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user