Merge branch 'develop' of github.com:matrix-org/synapse into erikj/client_apis_move

This commit is contained in:
Erik Johnston 2018-07-23 13:21:15 +01:00
commit 0b0b24cb82
81 changed files with 3717 additions and 3297 deletions

2470
CHANGES.md Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -2,6 +2,7 @@ include synctl
include LICENSE include LICENSE
include VERSION include VERSION
include *.rst include *.rst
include *.md
include demo/README include demo/README
include demo/demo.tls.dh include demo/demo.tls.dh
include demo/*.py include demo/*.py

View File

@ -71,7 +71,7 @@ We'd like to invite you to join #matrix:matrix.org (via
https://matrix.org/docs/projects/try-matrix-now.html), run a homeserver, take a look https://matrix.org/docs/projects/try-matrix-now.html), run a homeserver, take a look
at the `Matrix spec <https://matrix.org/docs/spec>`_, and experiment with the at the `Matrix spec <https://matrix.org/docs/spec>`_, and experiment with the
`APIs <https://matrix.org/docs/api>`_ and `Client SDKs `APIs <https://matrix.org/docs/api>`_ and `Client SDKs
<http://matrix.org/docs/projects/try-matrix-now.html#client-sdks>`_. <https://matrix.org/docs/projects/try-matrix-now.html#client-sdks>`_.
Thanks for using Matrix! Thanks for using Matrix!
@ -283,7 +283,7 @@ Connecting to Synapse from a client
The easiest way to try out your new Synapse installation is by connecting to it The easiest way to try out your new Synapse installation is by connecting to it
from a web client. The easiest option is probably the one at from a web client. The easiest option is probably the one at
http://riot.im/app. You will need to specify a "Custom server" when you log on https://riot.im/app. You will need to specify a "Custom server" when you log on
or register: set this to ``https://domain.tld`` if you setup a reverse proxy or register: set this to ``https://domain.tld`` if you setup a reverse proxy
following the recommended setup, or ``https://localhost:8448`` - remember to specify the following the recommended setup, or ``https://localhost:8448`` - remember to specify the
port (``:8448``) if not ``:443`` unless you changed the configuration. (Leave the identity port (``:8448``) if not ``:443`` unless you changed the configuration. (Leave the identity
@ -329,7 +329,7 @@ Security Note
============= =============
Matrix serves raw user generated data in some APIs - specifically the `content Matrix serves raw user generated data in some APIs - specifically the `content
repository endpoints <http://matrix.org/docs/spec/client_server/latest.html#get-matrix-media-r0-download-servername-mediaid>`_. repository endpoints <https://matrix.org/docs/spec/client_server/latest.html#get-matrix-media-r0-download-servername-mediaid>`_.
Whilst we have tried to mitigate against possible XSS attacks (e.g. Whilst we have tried to mitigate against possible XSS attacks (e.g.
https://github.com/matrix-org/synapse/pull/1021) we recommend running https://github.com/matrix-org/synapse/pull/1021) we recommend running
@ -348,7 +348,7 @@ Platform-Specific Instructions
Debian Debian
------ ------
Matrix provides official Debian packages via apt from http://matrix.org/packages/debian/. Matrix provides official Debian packages via apt from https://matrix.org/packages/debian/.
Note that these packages do not include a client - choose one from Note that these packages do not include a client - choose one from
https://matrix.org/docs/projects/try-matrix-now.html (or build your own with one of our SDKs :) https://matrix.org/docs/projects/try-matrix-now.html (or build your own with one of our SDKs :)
@ -524,7 +524,7 @@ Troubleshooting Running
----------------------- -----------------------
If synapse fails with ``missing "sodium.h"`` crypto errors, you may need If synapse fails with ``missing "sodium.h"`` crypto errors, you may need
to manually upgrade PyNaCL, as synapse uses NaCl (http://nacl.cr.yp.to/) for to manually upgrade PyNaCL, as synapse uses NaCl (https://nacl.cr.yp.to/) for
encryption and digital signatures. encryption and digital signatures.
Unfortunately PyNACL currently has a few issues Unfortunately PyNACL currently has a few issues
(https://github.com/pyca/pynacl/issues/53) and (https://github.com/pyca/pynacl/issues/53) and
@ -672,8 +672,8 @@ useful just for development purposes. See `<demo/README>`_.
Using PostgreSQL Using PostgreSQL
================ ================
As of Synapse 0.9, `PostgreSQL <http://www.postgresql.org>`_ is supported as an As of Synapse 0.9, `PostgreSQL <https://www.postgresql.org>`_ is supported as an
alternative to the `SQLite <http://sqlite.org/>`_ database that Synapse has alternative to the `SQLite <https://sqlite.org/>`_ database that Synapse has
traditionally used for convenience and simplicity. traditionally used for convenience and simplicity.
The advantages of Postgres include: The advantages of Postgres include:
@ -697,7 +697,7 @@ Using a reverse proxy with Synapse
It is recommended to put a reverse proxy such as It is recommended to put a reverse proxy such as
`nginx <https://nginx.org/en/docs/http/ngx_http_proxy_module.html>`_, `nginx <https://nginx.org/en/docs/http/ngx_http_proxy_module.html>`_,
`Apache <https://httpd.apache.org/docs/current/mod/mod_proxy_http.html>`_ or `Apache <https://httpd.apache.org/docs/current/mod/mod_proxy_http.html>`_ or
`HAProxy <http://www.haproxy.org/>`_ in front of Synapse. One advantage of `HAProxy <https://www.haproxy.org/>`_ in front of Synapse. One advantage of
doing so is that it means that you can expose the default https port (443) to doing so is that it means that you can expose the default https port (443) to
Matrix clients without needing to run Synapse with root privileges. Matrix clients without needing to run Synapse with root privileges.

View File

@ -1 +0,0 @@
Enforce the specified API for report_event

View File

@ -1 +0,0 @@
Include CPU time from database threads in request/block metrics.

View File

@ -1 +0,0 @@
Add CPU metrics for _fetch_event_list

View File

View File

View File

View File

@ -1 +0,0 @@
Reduce database consumption when processing large numbers of receipts

1
changelog.d/3520.bugfix Normal file
View File

@ -0,0 +1 @@
Correctly announce deleted devices over federation

View File

@ -1 +0,0 @@
Cache optimisation for /sync requests

View File

View File

@ -1 +0,0 @@
Fix queued federation requests being processed in the wrong order

View File

@ -1 +0,0 @@
refactor: use parse_{string,integer} and assert's from http.servlet for deduplication

View File

View File

@ -1 +0,0 @@
check isort for each PR

View File

@ -1 +0,0 @@
Optimisation to make handling incoming federation requests more efficient.

View File

View File

@ -1 +0,0 @@
Ensure that erasure requests are correctly honoured for publicly accessible rooms when accessed over federation.

1
changelog.d/3548.bugfix Normal file
View File

@ -0,0 +1 @@
Catch failures saving metrics captured by Measure, and instead log the faulty metrics information for further analysis.

1
changelog.d/3552.misc Normal file
View File

@ -0,0 +1 @@
Release notes are now in the Markdown format.

1
changelog.d/3553.feature Normal file
View File

@ -0,0 +1 @@
Add metrics to track resource usage by background processes

1
changelog.d/3554.feature Normal file
View File

@ -0,0 +1 @@
Add `code` label to `synapse_http_server_response_time_seconds` prometheus metric

1
changelog.d/3556.feature Normal file
View File

@ -0,0 +1 @@
Add metrics to track resource usage by background processes

1
changelog.d/3559.misc Normal file
View File

@ -0,0 +1 @@
add config for pep8

1
changelog.d/3570.bugfix Normal file
View File

@ -0,0 +1 @@
Fix potential stack overflow and deadlock under heavy load

1
changelog.d/3571.misc Normal file
View File

@ -0,0 +1 @@
Merge Linearizer and Limiter

1
changelog.d/3572.misc Normal file
View File

@ -0,0 +1 @@
Merge Linearizer and Limiter

View File

@ -0,0 +1,63 @@
Shared-Secret Registration
==========================
This API allows for the creation of users in an administrative and
non-interactive way. This is generally used for bootstrapping a Synapse
instance with administrator accounts.
To authenticate yourself to the server, you will need both the shared secret
(``registration_shared_secret`` in the homeserver configuration), and a
one-time nonce. If the registration shared secret is not configured, this API
is not enabled.
To fetch the nonce, you need to request one from the API::
> GET /_matrix/client/r0/admin/register
< {"nonce": "thisisanonce"}
Once you have the nonce, you can make a ``POST`` to the same URL with a JSON
body containing the nonce, username, password, whether they are an admin
(optional, False by default), and a HMAC digest of the content.
As an example::
> POST /_matrix/client/r0/admin/register
> {
"nonce": "thisisanonce",
"username": "pepper_roni",
"password": "pizza",
"admin": true,
"mac": "mac_digest_here"
}
< {
"access_token": "token_here",
"user_id": "@pepper_roni@test",
"home_server": "test",
"device_id": "device_id_here"
}
The MAC is the hex digest output of the HMAC-SHA1 algorithm, with the key being
the shared secret and the content being the nonce, user, password, and either
the string "admin" or "notadmin", each separated by NULs. For an example of
generation in Python::
import hmac, hashlib
def generate_mac(nonce, user, password, admin=False):
mac = hmac.new(
key=shared_secret,
digestmod=hashlib.sha1,
)
mac.update(nonce.encode('utf8'))
mac.update(b"\x00")
mac.update(user.encode('utf8'))
mac.update(b"\x00")
mac.update(password.encode('utf8'))
mac.update(b"\x00")
mac.update(b"admin" if admin else b"notadmin")
return mac.hexdigest()

View File

@ -1,5 +1,30 @@
[tool.towncrier] [tool.towncrier]
package = "synapse" package = "synapse"
filename = "CHANGES.rst" filename = "CHANGES.md"
directory = "changelog.d" directory = "changelog.d"
issue_format = "`#{issue} <https://github.com/matrix-org/synapse/issues/{issue}>`_" issue_format = "[\\#{issue}](https://github.com/matrix-org/synapse/issues/{issue}>)"
[[tool.towncrier.type]]
directory = "feature"
name = "Features"
showcontent = true
[[tool.towncrier.type]]
directory = "bugfix"
name = "Bugfixes"
showcontent = true
[[tool.towncrier.type]]
directory = "doc"
name = "Improved Documentation"
showcontent = true
[[tool.towncrier.type]]
directory = "removal"
name = "Deprecations and Removals"
showcontent = true
[[tool.towncrier.type]]
directory = "misc"
name = "Internal Changes"
showcontent = true

View File

@ -26,11 +26,37 @@ import yaml
def request_registration(user, password, server_location, shared_secret, admin=False): def request_registration(user, password, server_location, shared_secret, admin=False):
req = urllib2.Request(
"%s/_matrix/client/r0/admin/register" % (server_location,),
headers={'Content-Type': 'application/json'}
)
try:
if sys.version_info[:3] >= (2, 7, 9):
# As of version 2.7.9, urllib2 now checks SSL certs
import ssl
f = urllib2.urlopen(req, context=ssl.SSLContext(ssl.PROTOCOL_SSLv23))
else:
f = urllib2.urlopen(req)
body = f.read()
f.close()
nonce = json.loads(body)["nonce"]
except urllib2.HTTPError as e:
print "ERROR! Received %d %s" % (e.code, e.reason,)
if 400 <= e.code < 500:
if e.info().type == "application/json":
resp = json.load(e)
if "error" in resp:
print resp["error"]
sys.exit(1)
mac = hmac.new( mac = hmac.new(
key=shared_secret, key=shared_secret,
digestmod=hashlib.sha1, digestmod=hashlib.sha1,
) )
mac.update(nonce)
mac.update("\x00")
mac.update(user) mac.update(user)
mac.update("\x00") mac.update("\x00")
mac.update(password) mac.update(password)
@ -40,10 +66,10 @@ def request_registration(user, password, server_location, shared_secret, admin=F
mac = mac.hexdigest() mac = mac.hexdigest()
data = { data = {
"user": user, "nonce": nonce,
"username": user,
"password": password, "password": password,
"mac": mac, "mac": mac,
"type": "org.matrix.login.shared_secret",
"admin": admin, "admin": admin,
} }
@ -52,7 +78,7 @@ def request_registration(user, password, server_location, shared_secret, admin=F
print "Sending registration request..." print "Sending registration request..."
req = urllib2.Request( req = urllib2.Request(
"%s/_matrix/client/api/v1/register" % (server_location,), "%s/_matrix/client/r0/admin/register" % (server_location,),
data=json.dumps(data), data=json.dumps(data),
headers={'Content-Type': 'application/json'} headers={'Content-Type': 'application/json'}
) )

View File

@ -14,12 +14,17 @@ ignore =
pylint.cfg pylint.cfg
tox.ini tox.ini
[flake8] [pep8]
max-line-length = 90 max-line-length = 90
# W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it. # W503 requires that binary operators be at the end, not start, of lines. Erik
# E203 is contrary to PEP8. # doesn't like it. E203 is contrary to PEP8.
ignore = W503,E203 ignore = W503,E203
[flake8]
# note that flake8 inherits the "ignore" settings from "pep8" (because it uses
# pep8 to do those checks), but not the "max-line-length" setting
max-line-length = 90
[isort] [isort]
line_length = 89 line_length = 89
not_skip = __init__.py not_skip = __init__.py
@ -31,3 +36,4 @@ known_compat = mock,six
known_twisted=twisted,OpenSSL known_twisted=twisted,OpenSSL
multi_line_output=3 multi_line_output=3
include_trailing_comma=true include_trailing_comma=true
combine_as_imports=true

View File

@ -17,4 +17,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.32.2" __version__ = "0.33.0"

View File

@ -18,6 +18,8 @@ import logging
import os import os
import sys import sys
from six import iteritems
from twisted.application import service from twisted.application import service
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.web.resource import EncodingResourceWrapper, NoResource from twisted.web.resource import EncodingResourceWrapper, NoResource
@ -442,7 +444,7 @@ def run(hs):
stats["total_nonbridged_users"] = total_nonbridged_users stats["total_nonbridged_users"] = total_nonbridged_users
daily_user_type_results = yield hs.get_datastore().count_daily_user_type() daily_user_type_results = yield hs.get_datastore().count_daily_user_type()
for name, count in daily_user_type_results.iteritems(): for name, count in iteritems(daily_user_type_results):
stats["daily_user_type_" + name] = count stats["daily_user_type_" + name] = count
room_count = yield hs.get_datastore().get_room_count() room_count = yield hs.get_datastore().get_room_count()
@ -453,7 +455,7 @@ def run(hs):
stats["daily_messages"] = yield hs.get_datastore().count_daily_messages() stats["daily_messages"] = yield hs.get_datastore().count_daily_messages()
r30_results = yield hs.get_datastore().count_r30_users() r30_results = yield hs.get_datastore().count_r30_users()
for name, count in r30_results.iteritems(): for name, count in iteritems(r30_results):
stats["r30_users_" + name] = count stats["r30_users_" + name] = count
daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages() daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()

View File

@ -25,6 +25,8 @@ import subprocess
import sys import sys
import time import time
from six import iteritems
import yaml import yaml
SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"] SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"]
@ -173,7 +175,7 @@ def main():
os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor) os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor)
cache_factors = config.get("synctl_cache_factors", {}) cache_factors = config.get("synctl_cache_factors", {})
for cache_name, factor in cache_factors.iteritems(): for cache_name, factor in iteritems(cache_factors):
os.environ["SYNAPSE_CACHE_FACTOR_" + cache_name.upper()] = str(factor) os.environ["SYNAPSE_CACHE_FACTOR_" + cache_name.upper()] = str(factor)
worker_configfiles = [] worker_configfiles = []

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from six import iteritems
from frozendict import frozendict from frozendict import frozendict
from twisted.internet import defer from twisted.internet import defer
@ -159,7 +161,7 @@ def _encode_state_dict(state_dict):
return [ return [
(etype, state_key, v) (etype, state_key, v)
for (etype, state_key), v in state_dict.iteritems() for (etype, state_key), v in iteritems(state_dict)
] ]

View File

@ -30,7 +30,8 @@ from synapse.metrics import (
sent_edus_counter, sent_edus_counter,
sent_transactions_counter, sent_transactions_counter,
) )
from synapse.util import PreserveLoggingContext, logcontext from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util import logcontext
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
@ -165,10 +166,11 @@ class TransactionQueue(object):
if self._is_processing: if self._is_processing:
return return
# fire off a processing loop in the background. It's likely it will # fire off a processing loop in the background
# outlast the current request, so run it in the sentinel logcontext. run_as_background_process(
with PreserveLoggingContext(): "process_event_queue_for_federation",
self._process_event_queue_loop() self._process_event_queue_loop,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _process_event_queue_loop(self): def _process_event_queue_loop(self):
@ -432,14 +434,11 @@ class TransactionQueue(object):
logger.debug("TX [%s] Starting transaction loop", destination) logger.debug("TX [%s] Starting transaction loop", destination)
# Drop the logcontext before starting the transaction. It doesn't run_as_background_process(
# really make sense to log all the outbound transactions against "federation_transaction_transmission_loop",
# whatever path led us to this point: that's pretty arbitrary really. self._transaction_transmission_loop,
# destination,
# (this also means we can fire off _perform_transaction without )
# yielding)
with logcontext.PreserveLoggingContext():
self._transaction_transmission_loop(destination)
@defer.inlineCallbacks @defer.inlineCallbacks
def _transaction_transmission_loop(self, destination): def _transaction_transmission_loop(self, destination):

View File

@ -21,8 +21,8 @@ import logging
import sys import sys
import six import six
from six import iteritems from six import iteritems, itervalues
from six.moves import http_client from six.moves import http_client, zip
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json from signedjson.sign import verify_signed_json
@ -731,7 +731,7 @@ class FederationHandler(BaseHandler):
""" """
joined_users = [ joined_users = [
(state_key, int(event.depth)) (state_key, int(event.depth))
for (e_type, state_key), event in state.iteritems() for (e_type, state_key), event in iteritems(state)
if e_type == EventTypes.Member if e_type == EventTypes.Member
and event.membership == Membership.JOIN and event.membership == Membership.JOIN
] ]
@ -748,7 +748,7 @@ class FederationHandler(BaseHandler):
except Exception: except Exception:
pass pass
return sorted(joined_domains.iteritems(), key=lambda d: d[1]) return sorted(joined_domains.items(), key=lambda d: d[1])
curr_domains = get_domains_from_state(curr_state) curr_domains = get_domains_from_state(curr_state)
@ -811,7 +811,7 @@ class FederationHandler(BaseHandler):
tried_domains = set(likely_domains) tried_domains = set(likely_domains)
tried_domains.add(self.server_name) tried_domains.add(self.server_name)
event_ids = list(extremities.iterkeys()) event_ids = list(extremities.keys())
logger.debug("calling resolve_state_groups in _maybe_backfill") logger.debug("calling resolve_state_groups in _maybe_backfill")
resolve = logcontext.preserve_fn( resolve = logcontext.preserve_fn(
@ -827,15 +827,15 @@ class FederationHandler(BaseHandler):
states = dict(zip(event_ids, [s.state for s in states])) states = dict(zip(event_ids, [s.state for s in states]))
state_map = yield self.store.get_events( state_map = yield self.store.get_events(
[e_id for ids in states.itervalues() for e_id in ids.itervalues()], [e_id for ids in itervalues(states) for e_id in itervalues(ids)],
get_prev_content=False get_prev_content=False
) )
states = { states = {
key: { key: {
k: state_map[e_id] k: state_map[e_id]
for k, e_id in state_dict.iteritems() for k, e_id in iteritems(state_dict)
if e_id in state_map if e_id in state_map
} for key, state_dict in states.iteritems() } for key, state_dict in iteritems(states)
} }
for e_id, _ in sorted_extremeties_tuple: for e_id, _ in sorted_extremeties_tuple:
@ -1515,7 +1515,7 @@ class FederationHandler(BaseHandler):
yield self.store.persist_events( yield self.store.persist_events(
[ [
(ev_info["event"], context) (ev_info["event"], context)
for ev_info, context in itertools.izip(event_infos, contexts) for ev_info, context in zip(event_infos, contexts)
], ],
backfilled=backfilled, backfilled=backfilled,
) )

View File

@ -32,7 +32,7 @@ from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.replication.http.send_event import send_event_to_master from synapse.replication.http.send_event import send_event_to_master
from synapse.types import RoomAlias, UserID from synapse.types import RoomAlias, UserID
from synapse.util.async import Limiter from synapse.util.async import Linearizer
from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.logcontext import run_in_background from synapse.util.logcontext import run_in_background
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
@ -180,7 +180,7 @@ class EventCreationHandler(object):
# We arbitrarily limit concurrent event creation for a room to 5. # We arbitrarily limit concurrent event creation for a room to 5.
# This is to stop us from diverging history *too* much. # This is to stop us from diverging history *too* much.
self.limiter = Limiter(max_count=5) self.limiter = Linearizer(max_count=5, name="room_event_creation_limit")
self.action_generator = hs.get_action_generator() self.action_generator = hs.get_action_generator()

View File

@ -26,9 +26,11 @@ from OpenSSL.SSL import VERIFY_NONE
from twisted.internet import defer, protocol, reactor, ssl, task from twisted.internet import defer, protocol, reactor, ssl, task
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.web._newclient import ResponseDone from twisted.web._newclient import ResponseDone
from twisted.web.client import Agent, BrowserLikeRedirectAgent, ContentDecoderAgent
from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer
from twisted.web.client import ( from twisted.web.client import (
Agent,
BrowserLikeRedirectAgent,
ContentDecoderAgent,
FileBodyProducer as TwistedFileBodyProducer,
GzipDecoder, GzipDecoder,
HTTPConnectionPool, HTTPConnectionPool,
PartialDownloadError, PartialDownloadError,

View File

@ -38,7 +38,8 @@ outgoing_responses_counter = Counter(
) )
response_timer = Histogram( response_timer = Histogram(
"synapse_http_server_response_time_seconds", "sec", ["method", "servlet", "tag"] "synapse_http_server_response_time_seconds", "sec",
["method", "servlet", "tag", "code"],
) )
response_ru_utime = Counter( response_ru_utime = Counter(
@ -171,11 +172,13 @@ class RequestMetrics(object):
) )
return return
outgoing_responses_counter.labels(request.method, str(request.code)).inc() response_code = str(request.code)
outgoing_responses_counter.labels(request.method, response_code).inc()
response_count.labels(request.method, self.name, tag).inc() response_count.labels(request.method, self.name, tag).inc()
response_timer.labels(request.method, self.name, tag).observe( response_timer.labels(request.method, self.name, tag, response_code).observe(
time_sec - self.start time_sec - self.start
) )

View File

@ -0,0 +1,179 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import six
from prometheus_client.core import REGISTRY, Counter, GaugeMetricFamily
from twisted.internet import defer
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
_background_process_start_count = Counter(
"synapse_background_process_start_count",
"Number of background processes started",
["name"],
)
# we set registry=None in all of these to stop them getting registered with
# the default registry. Instead we collect them all via the CustomCollector,
# which ensures that we can update them before they are collected.
#
_background_process_ru_utime = Counter(
"synapse_background_process_ru_utime_seconds",
"User CPU time used by background processes, in seconds",
["name"],
registry=None,
)
_background_process_ru_stime = Counter(
"synapse_background_process_ru_stime_seconds",
"System CPU time used by background processes, in seconds",
["name"],
registry=None,
)
_background_process_db_txn_count = Counter(
"synapse_background_process_db_txn_count",
"Number of database transactions done by background processes",
["name"],
registry=None,
)
_background_process_db_txn_duration = Counter(
"synapse_background_process_db_txn_duration_seconds",
("Seconds spent by background processes waiting for database "
"transactions, excluding scheduling time"),
["name"],
registry=None,
)
_background_process_db_sched_duration = Counter(
"synapse_background_process_db_sched_duration_seconds",
"Seconds spent by background processes waiting for database connections",
["name"],
registry=None,
)
# map from description to a counter, so that we can name our logcontexts
# incrementally. (It actually duplicates _background_process_start_count, but
# it's much simpler to do so than to try to combine them.)
_background_process_counts = dict() # type: dict[str, int]
# map from description to the currently running background processes.
#
# it's kept as a dict of sets rather than a big set so that we can keep track
# of process descriptions that no longer have any active processes.
_background_processes = dict() # type: dict[str, set[_BackgroundProcess]]
class _Collector(object):
"""A custom metrics collector for the background process metrics.
Ensures that all of the metrics are up-to-date with any in-flight processes
before they are returned.
"""
def collect(self):
background_process_in_flight_count = GaugeMetricFamily(
"synapse_background_process_in_flight_count",
"Number of background processes in flight",
labels=["name"],
)
for desc, processes in six.iteritems(_background_processes):
background_process_in_flight_count.add_metric(
(desc,), len(processes),
)
for process in processes:
process.update_metrics()
yield background_process_in_flight_count
# now we need to run collect() over each of the static Counters, and
# yield each metric they return.
for m in (
_background_process_ru_utime,
_background_process_ru_stime,
_background_process_db_txn_count,
_background_process_db_txn_duration,
_background_process_db_sched_duration,
):
for r in m.collect():
yield r
REGISTRY.register(_Collector())
class _BackgroundProcess(object):
def __init__(self, desc, ctx):
self.desc = desc
self._context = ctx
self._reported_stats = None
def update_metrics(self):
"""Updates the metrics with values from this process."""
new_stats = self._context.get_resource_usage()
if self._reported_stats is None:
diff = new_stats
else:
diff = new_stats - self._reported_stats
self._reported_stats = new_stats
_background_process_ru_utime.labels(self.desc).inc(diff.ru_utime)
_background_process_ru_stime.labels(self.desc).inc(diff.ru_stime)
_background_process_db_txn_count.labels(self.desc).inc(
diff.db_txn_count,
)
_background_process_db_txn_duration.labels(self.desc).inc(
diff.db_txn_duration_sec,
)
_background_process_db_sched_duration.labels(self.desc).inc(
diff.db_sched_duration_sec,
)
def run_as_background_process(desc, func, *args, **kwargs):
"""Run the given function in its own logcontext, with resource metrics
This should be used to wrap processes which are fired off to run in the
background, instead of being associated with a particular request.
Args:
desc (str): a description for this background process type
func: a function, which may return a Deferred
args: positional args for func
kwargs: keyword args for func
Returns: None
"""
@defer.inlineCallbacks
def run():
count = _background_process_counts.get(desc, 0)
_background_process_counts[desc] = count + 1
_background_process_start_count.labels(desc).inc()
with LoggingContext(desc) as context:
context.request = "%s-%i" % (desc, count)
proc = _BackgroundProcess(desc, context)
_background_processes.setdefault(desc, set()).add(proc)
try:
yield func(*args, **kwargs)
finally:
proc.update_metrics()
_background_processes[desc].remove(proc)
with PreserveLoggingContext():
run()

View File

@ -274,7 +274,7 @@ class Notifier(object):
logger.exception("Error notifying application services of event") logger.exception("Error notifying application services of event")
def on_new_event(self, stream_key, new_token, users=[], rooms=[]): def on_new_event(self, stream_key, new_token, users=[], rooms=[]):
""" Used to inform listeners that something has happend event wise. """ Used to inform listeners that something has happened event wise.
Will wake up all listeners for the given users and rooms. Will wake up all listeners for the given users and rooms.
""" """

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,13 +14,24 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from six import PY3
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.rest.client import versions from synapse.rest.client import versions
from synapse.rest.client.v1 import admin, directory, events, initial_sync from synapse.rest.client.v1 import (
from synapse.rest.client.v1 import login as v1_login admin,
from synapse.rest.client.v1 import logout, presence, profile, push_rule, pusher directory,
from synapse.rest.client.v1 import register as v1_register events,
from synapse.rest.client.v1 import room, voip initial_sync,
login as v1_login,
logout,
presence,
profile,
push_rule,
pusher,
room,
voip,
)
from synapse.rest.client.v2_alpha import ( from synapse.rest.client.v2_alpha import (
account, account,
account_data, account_data,
@ -42,6 +54,11 @@ from synapse.rest.client.v2_alpha import (
user_directory, user_directory,
) )
if not PY3:
from synapse.rest.client.v1_only import (
register as v1_register,
)
class ClientRestResource(JsonResource): class ClientRestResource(JsonResource):
"""A resource for version 1 of the matrix client API.""" """A resource for version 1 of the matrix client API."""
@ -54,14 +71,22 @@ class ClientRestResource(JsonResource):
def register_servlets(client_resource, hs): def register_servlets(client_resource, hs):
versions.register_servlets(client_resource) versions.register_servlets(client_resource)
# "v1" if not PY3:
room.register_servlets(hs, client_resource) # "v1" (Python 2 only)
v1_register.register_servlets(hs, client_resource)
# Deprecated in r0
initial_sync.register_servlets(hs, client_resource)
room.register_deprecated_servlets(hs, client_resource)
# Partially deprecated in r0
events.register_servlets(hs, client_resource) events.register_servlets(hs, client_resource)
v1_register.register_servlets(hs, client_resource)
# "v1" + "r0"
room.register_servlets(hs, client_resource)
v1_login.register_servlets(hs, client_resource) v1_login.register_servlets(hs, client_resource)
profile.register_servlets(hs, client_resource) profile.register_servlets(hs, client_resource)
presence.register_servlets(hs, client_resource) presence.register_servlets(hs, client_resource)
initial_sync.register_servlets(hs, client_resource)
directory.register_servlets(hs, client_resource) directory.register_servlets(hs, client_resource)
voip.register_servlets(hs, client_resource) voip.register_servlets(hs, client_resource)
admin.register_servlets(hs, client_resource) admin.register_servlets(hs, client_resource)

View File

@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import hashlib
import hmac
import logging import logging
from six.moves import http_client from six.moves import http_client
@ -63,6 +65,125 @@ class UsersRestServlet(ClientV1RestServlet):
defer.returnValue((200, ret)) defer.returnValue((200, ret))
class UserRegisterServlet(ClientV1RestServlet):
"""
Attributes:
NONCE_TIMEOUT (int): Seconds until a generated nonce won't be accepted
nonces (dict[str, int]): The nonces that we will accept. A dict of
nonce to the time it was generated, in int seconds.
"""
PATTERNS = client_path_patterns("/admin/register")
NONCE_TIMEOUT = 60
def __init__(self, hs):
super(UserRegisterServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
self.reactor = hs.get_reactor()
self.nonces = {}
self.hs = hs
def _clear_old_nonces(self):
"""
Clear out old nonces that are older than NONCE_TIMEOUT.
"""
now = int(self.reactor.seconds())
for k, v in list(self.nonces.items()):
if now - v > self.NONCE_TIMEOUT:
del self.nonces[k]
def on_GET(self, request):
"""
Generate a new nonce.
"""
self._clear_old_nonces()
nonce = self.hs.get_secrets().token_hex(64)
self.nonces[nonce] = int(self.reactor.seconds())
return (200, {"nonce": nonce.encode('ascii')})
@defer.inlineCallbacks
def on_POST(self, request):
self._clear_old_nonces()
if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled")
body = parse_json_object_from_request(request)
if "nonce" not in body:
raise SynapseError(
400, "nonce must be specified", errcode=Codes.BAD_JSON,
)
nonce = body["nonce"]
if nonce not in self.nonces:
raise SynapseError(
400, "unrecognised nonce",
)
# Delete the nonce, so it can't be reused, even if it's invalid
del self.nonces[nonce]
if "username" not in body:
raise SynapseError(
400, "username must be specified", errcode=Codes.BAD_JSON,
)
else:
if (not isinstance(body['username'], str) or len(body['username']) > 512):
raise SynapseError(400, "Invalid username")
username = body["username"].encode("utf-8")
if b"\x00" in username:
raise SynapseError(400, "Invalid username")
if "password" not in body:
raise SynapseError(
400, "password must be specified", errcode=Codes.BAD_JSON,
)
else:
if (not isinstance(body['password'], str) or len(body['password']) > 512):
raise SynapseError(400, "Invalid password")
password = body["password"].encode("utf-8")
if b"\x00" in password:
raise SynapseError(400, "Invalid password")
admin = body.get("admin", None)
got_mac = body["mac"]
want_mac = hmac.new(
key=self.hs.config.registration_shared_secret.encode(),
digestmod=hashlib.sha1,
)
want_mac.update(nonce)
want_mac.update(b"\x00")
want_mac.update(username)
want_mac.update(b"\x00")
want_mac.update(password)
want_mac.update(b"\x00")
want_mac.update(b"admin" if admin else b"notadmin")
want_mac = want_mac.hexdigest()
if not hmac.compare_digest(want_mac, got_mac):
raise SynapseError(
403, "HMAC incorrect",
)
# Reuse the parts of RegisterRestServlet to reduce code duplication
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
register = RegisterRestServlet(self.hs)
(user_id, _) = yield register.registration_handler.register(
localpart=username.lower(), password=password, admin=bool(admin),
generate_token=False,
)
result = yield register._create_registration_details(user_id, body)
defer.returnValue((200, result))
class WhoisRestServlet(ClientV1RestServlet): class WhoisRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)") PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)")
@ -614,3 +735,4 @@ def register_servlets(hs, http_server):
ShutdownRoomRestServlet(hs).register(http_server) ShutdownRoomRestServlet(hs).register(http_server)
QuarantineMediaInRoom(hs).register(http_server) QuarantineMediaInRoom(hs).register(http_server)
ListMediaInRoom(hs).register(http_server) ListMediaInRoom(hs).register(http_server)
UserRegisterServlet(hs).register(http_server)

View File

@ -830,10 +830,13 @@ def register_servlets(hs, http_server):
RoomSendEventRestServlet(hs).register(http_server) RoomSendEventRestServlet(hs).register(http_server)
PublicRoomListRestServlet(hs).register(http_server) PublicRoomListRestServlet(hs).register(http_server)
RoomStateRestServlet(hs).register(http_server) RoomStateRestServlet(hs).register(http_server)
RoomInitialSyncRestServlet(hs).register(http_server)
RoomRedactEventRestServlet(hs).register(http_server) RoomRedactEventRestServlet(hs).register(http_server)
RoomTypingRestServlet(hs).register(http_server) RoomTypingRestServlet(hs).register(http_server)
SearchRestServlet(hs).register(http_server) SearchRestServlet(hs).register(http_server)
JoinedRoomsRestServlet(hs).register(http_server) JoinedRoomsRestServlet(hs).register(http_server)
RoomEventServlet(hs).register(http_server) RoomEventServlet(hs).register(http_server)
RoomEventContextServlet(hs).register(http_server) RoomEventContextServlet(hs).register(http_server)
def register_deprecated_servlets(hs, http_server):
RoomInitialSyncRestServlet(hs).register(http_server)

View File

@ -0,0 +1,3 @@
"""
REST APIs that are only used in v1 (the legacy API).
"""

View File

@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains base REST classes for constructing client v1 servlets.
"""
import re
from synapse.api.urls import CLIENT_PREFIX
def v1_only_client_path_patterns(path_regex, include_in_unstable=True):
"""Creates a regex compiled client path with the correct client path
prefix.
Args:
path_regex (str): The regex string to match. This should NOT have a ^
as this will be prefixed.
Returns:
list of SRE_Pattern
"""
patterns = [re.compile("^" + CLIENT_PREFIX + path_regex)]
if include_in_unstable:
unstable_prefix = CLIENT_PREFIX.replace("/api/v1", "/unstable")
patterns.append(re.compile("^" + unstable_prefix + path_regex))
return patterns

View File

@ -24,9 +24,10 @@ import synapse.util.stringutils as stringutils
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
from synapse.rest.client.v1.base import ClientV1RestServlet
from synapse.types import create_requester from synapse.types import create_requester
from .base import ClientV1RestServlet, client_path_patterns from .base import v1_only_client_path_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -49,7 +50,7 @@ class RegisterRestServlet(ClientV1RestServlet):
handler doesn't have a concept of multi-stages or sessions. handler doesn't have a concept of multi-stages or sessions.
""" """
PATTERNS = client_path_patterns("/register$", releases=(), include_in_unstable=False) PATTERNS = v1_only_client_path_patterns("/register$", include_in_unstable=False)
def __init__(self, hs): def __init__(self, hs):
""" """
@ -379,7 +380,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
"""Handles user creation via a server-to-server interface """Handles user creation via a server-to-server interface
""" """
PATTERNS = client_path_patterns("/createUser$", releases=()) PATTERNS = v1_only_client_path_patterns("/createUser$")
def __init__(self, hs): def __init__(self, hs):
super(CreateUserRestServlet, self).__init__(hs) super(CreateUserRestServlet, self).__init__(hs)

42
synapse/secrets.py Normal file
View File

@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Injectable secrets module for Synapse.
See https://docs.python.org/3/library/secrets.html#module-secrets for the API
used in Python 3.6, and the API emulated in Python 2.7.
"""
import six
if six.PY3:
import secrets
def Secrets():
return secrets
else:
import os
import binascii
class Secrets(object):
def token_bytes(self, nbytes=32):
return os.urandom(nbytes)
def token_hex(self, nbytes=32):
return binascii.hexlify(self.token_bytes(nbytes))

View File

@ -75,6 +75,7 @@ from synapse.rest.media.v1.media_repository import (
MediaRepository, MediaRepository,
MediaRepositoryResource, MediaRepositoryResource,
) )
from synapse.secrets import Secrets
from synapse.server_notices.server_notices_manager import ServerNoticesManager from synapse.server_notices.server_notices_manager import ServerNoticesManager
from synapse.server_notices.server_notices_sender import ServerNoticesSender from synapse.server_notices.server_notices_sender import ServerNoticesSender
from synapse.server_notices.worker_server_notices_sender import WorkerServerNoticesSender from synapse.server_notices.worker_server_notices_sender import WorkerServerNoticesSender
@ -159,6 +160,7 @@ class HomeServer(object):
'groups_server_handler', 'groups_server_handler',
'groups_attestation_signing', 'groups_attestation_signing',
'groups_attestation_renewer', 'groups_attestation_renewer',
'secrets',
'spam_checker', 'spam_checker',
'room_member_handler', 'room_member_handler',
'federation_registry', 'federation_registry',
@ -409,6 +411,9 @@ class HomeServer(object):
def build_groups_attestation_renewer(self): def build_groups_attestation_renewer(self):
return GroupAttestionRenewer(self) return GroupAttestionRenewer(self)
def build_secrets(self):
return Secrets()
def build_spam_checker(self): def build_spam_checker(self):
return SpamChecker(self) return SpamChecker(self)

View File

@ -18,7 +18,7 @@ import hashlib
import logging import logging
from collections import namedtuple from collections import namedtuple
from six import iteritems, itervalues from six import iteritems, iterkeys, itervalues
from frozendict import frozendict from frozendict import frozendict
@ -647,7 +647,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
for event_id in event_ids for event_id in event_ids
) )
if event_map is not None: if event_map is not None:
needed_events -= set(event_map.iterkeys()) needed_events -= set(iterkeys(event_map))
logger.info("Asking for %d conflicted events", len(needed_events)) logger.info("Asking for %d conflicted events", len(needed_events))
@ -668,7 +668,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
new_needed_events = set(itervalues(auth_events)) new_needed_events = set(itervalues(auth_events))
new_needed_events -= needed_events new_needed_events -= needed_events
if event_map is not None: if event_map is not None:
new_needed_events -= set(event_map.iterkeys()) new_needed_events -= set(iterkeys(event_map))
logger.info("Asking for %d auth events", len(new_needed_events)) logger.info("Asking for %d auth events", len(new_needed_events))

View File

@ -344,7 +344,7 @@ class SQLBaseStore(object):
parent_context = LoggingContext.current_context() parent_context = LoggingContext.current_context()
if parent_context == LoggingContext.sentinel: if parent_context == LoggingContext.sentinel:
logger.warn( logger.warn(
"Running db txn from sentinel context: metrics will be lost", "Starting db connection from sentinel context: metrics will be lost",
) )
parent_context = None parent_context = None

View File

@ -19,6 +19,8 @@ from canonicaljson import json
from twisted.internet import defer from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from . import engines from . import engines
from ._base import SQLBaseStore from ._base import SQLBaseStore
@ -87,10 +89,14 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_handlers = {} self._background_update_handlers = {}
self._all_done = False self._all_done = False
@defer.inlineCallbacks
def start_doing_background_updates(self): def start_doing_background_updates(self):
logger.info("Starting background schema updates") run_as_background_process(
"background_updates", self._run_background_updates,
)
@defer.inlineCallbacks
def _run_background_updates(self):
logger.info("Starting background schema updates")
while True: while True:
yield self.hs.get_clock().sleep( yield self.hs.get_clock().sleep(
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.) self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.)

View File

@ -19,6 +19,7 @@ from six import iteritems
from twisted.internet import defer from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches import CACHE_SIZE_FACTOR
from . import background_updates from . import background_updates
@ -93,10 +94,16 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
self._batch_row_update[key] = (user_agent, device_id, now) self._batch_row_update[key] = (user_agent, device_id, now)
def _update_client_ips_batch(self): def _update_client_ips_batch(self):
to_update = self._batch_row_update def update():
self._batch_row_update = {} to_update = self._batch_row_update
return self.runInteraction( self._batch_row_update = {}
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update return self.runInteraction(
"_update_client_ips_batch", self._update_client_ips_batch_txn,
to_update,
)
run_as_background_process(
"update_client_ips", update,
) )
def _update_client_ips_batch_txn(self, txn, to_update): def _update_client_ips_batch_txn(self, txn, to_update):

View File

@ -248,17 +248,31 @@ class DeviceStore(SQLBaseStore):
def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id, def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
content, stream_id): content, stream_id):
self._simple_upsert_txn( if content.get("deleted"):
txn, self._simple_delete_txn(
table="device_lists_remote_cache", txn,
keyvalues={ table="device_lists_remote_cache",
"user_id": user_id, keyvalues={
"device_id": device_id, "user_id": user_id,
}, "device_id": device_id,
values={ },
"content": json.dumps(content), )
}
) txn.call_after(
self.device_id_exists_cache.invalidate, (user_id, device_id,)
)
else:
self._simple_upsert_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
values={
"content": json.dumps(content),
}
)
txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,)) txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,))
txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,)) txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
@ -366,7 +380,7 @@ class DeviceStore(SQLBaseStore):
now_stream_id = max(stream_id for stream_id in itervalues(query_map)) now_stream_id = max(stream_id for stream_id in itervalues(query_map))
devices = self._get_e2e_device_keys_txn( devices = self._get_e2e_device_keys_txn(
txn, query_map.keys(), include_all_devices=True txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
) )
prev_sent_id_sql = """ prev_sent_id_sql = """
@ -393,12 +407,15 @@ class DeviceStore(SQLBaseStore):
prev_id = stream_id prev_id = stream_id
key_json = device.get("key_json", None) if device is not None:
if key_json: key_json = device.get("key_json", None)
result["keys"] = json.loads(key_json) if key_json:
device_display_name = device.get("device_display_name", None) result["keys"] = json.loads(key_json)
if device_display_name: device_display_name = device.get("device_display_name", None)
result["device_display_name"] = device_display_name if device_display_name:
result["device_display_name"] = device_display_name
else:
result["deleted"] = True
results.append(result) results.append(result)

View File

@ -64,12 +64,18 @@ class EndToEndKeyStore(SQLBaseStore):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def get_e2e_device_keys(self, query_list, include_all_devices=False): def get_e2e_device_keys(
self, query_list, include_all_devices=False,
include_deleted_devices=False,
):
"""Fetch a list of device keys. """Fetch a list of device keys.
Args: Args:
query_list(list): List of pairs of user_ids and device_ids. query_list(list): List of pairs of user_ids and device_ids.
include_all_devices (bool): whether to include entries for devices include_all_devices (bool): whether to include entries for devices
that don't have device keys that don't have device keys
include_deleted_devices (bool): whether to include null entries for
devices which no longer exist (but were in the query_list).
This option only takes effect if include_all_devices is true.
Returns: Returns:
Dict mapping from user-id to dict mapping from device_id to Dict mapping from user-id to dict mapping from device_id to
dict containing "key_json", "device_display_name". dict containing "key_json", "device_display_name".
@ -79,7 +85,7 @@ class EndToEndKeyStore(SQLBaseStore):
results = yield self.runInteraction( results = yield self.runInteraction(
"get_e2e_device_keys", self._get_e2e_device_keys_txn, "get_e2e_device_keys", self._get_e2e_device_keys_txn,
query_list, include_all_devices, query_list, include_all_devices, include_deleted_devices,
) )
for user_id, device_keys in iteritems(results): for user_id, device_keys in iteritems(results):
@ -88,10 +94,19 @@ class EndToEndKeyStore(SQLBaseStore):
defer.returnValue(results) defer.returnValue(results)
def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices): def _get_e2e_device_keys_txn(
self, txn, query_list, include_all_devices=False,
include_deleted_devices=False,
):
query_clauses = [] query_clauses = []
query_params = [] query_params = []
if include_all_devices is False:
include_deleted_devices = False
if include_deleted_devices:
deleted_devices = set(query_list)
for (user_id, device_id) in query_list: for (user_id, device_id) in query_list:
query_clause = "user_id = ?" query_clause = "user_id = ?"
query_params.append(user_id) query_params.append(user_id)
@ -119,8 +134,14 @@ class EndToEndKeyStore(SQLBaseStore):
result = {} result = {}
for row in rows: for row in rows:
if include_deleted_devices:
deleted_devices.remove((row["user_id"], row["device_id"]))
result.setdefault(row["user_id"], {})[row["device_id"]] = row result.setdefault(row["user_id"], {})[row["device_id"]] = row
if include_deleted_devices:
for user_id, device_id in deleted_devices:
result.setdefault(user_id, {})[device_id] = None
return result return result
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -33,12 +33,13 @@ from synapse.api.errors import SynapseError
# these are only included to make the type annotations work # these are only included to make the type annotations work
from synapse.events import EventBase # noqa: F401 from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.events_worker import EventsWorkerStore from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import RoomStreamToken, get_domain_from_id from synapse.types import RoomStreamToken, get_domain_from_id
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable from synapse.util.logcontext import make_deferred_yieldable
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -155,11 +156,8 @@ class _EventPeristenceQueue(object):
self._event_persist_queues[room_id] = queue self._event_persist_queues[room_id] = queue
self._currently_persisting_rooms.discard(room_id) self._currently_persisting_rooms.discard(room_id)
# set handle_queue_loop off on the background. We don't want to # set handle_queue_loop off in the background
# attribute work done in it to the current request, so we drop the run_as_background_process("persist_events", handle_queue_loop)
# logcontext altogether.
with PreserveLoggingContext():
handle_queue_loop()
def _get_drainining_queue(self, room_id): def _get_drainining_queue(self, room_id):
queue = self._event_persist_queues.setdefault(room_id, deque()) queue = self._event_persist_queues.setdefault(room_id, deque())

View File

@ -25,6 +25,7 @@ from synapse.events import EventBase # noqa: F401
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.events.snapshot import EventContext # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.logcontext import ( from synapse.util.logcontext import (
LoggingContext, LoggingContext,
PreserveLoggingContext, PreserveLoggingContext,
@ -322,10 +323,11 @@ class EventsWorkerStore(SQLBaseStore):
should_start = False should_start = False
if should_start: if should_start:
with PreserveLoggingContext(): run_as_background_process(
self.runWithConnection( "fetch_events",
self._do_fetch self.runWithConnection,
) self._do_fetch,
)
logger.debug("Loading %d events", len(events)) logger.debug("Loading %d events", len(events))
with PreserveLoggingContext(): with PreserveLoggingContext():

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections
import logging import logging
from contextlib import contextmanager from contextlib import contextmanager
@ -156,54 +157,72 @@ def concurrently_execute(func, args, limit):
class Linearizer(object): class Linearizer(object):
"""Linearizes access to resources based on a key. Useful to ensure only one """Limits concurrent access to resources based on a key. Useful to ensure
thing is happening at a time on a given resource. only a few things happen at a time on a given resource.
Example: Example:
with (yield linearizer.queue("test_key")): with (yield limiter.queue("test_key")):
# do some work. # do some work.
""" """
def __init__(self, name=None, clock=None): def __init__(self, name=None, max_count=1, clock=None):
"""
Args:
max_count(int): The maximum number of concurrent accesses
"""
if name is None: if name is None:
self.name = id(self) self.name = id(self)
else: else:
self.name = name self.name = name
self.key_to_defer = {}
if not clock: if not clock:
from twisted.internet import reactor from twisted.internet import reactor
clock = Clock(reactor) clock = Clock(reactor)
self._clock = clock self._clock = clock
self.max_count = max_count
# key_to_defer is a map from the key to a 2 element list where
# the first element is the number of things executing, and
# the second element is an OrderedDict, where the keys are deferreds for the
# things blocked from executing.
self.key_to_defer = {}
@defer.inlineCallbacks @defer.inlineCallbacks
def queue(self, key): def queue(self, key):
# If there is already a deferred in the queue, we pull it out so that entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()])
# we can wait on it later.
# Then we replace it with a deferred that we resolve *after* the
# context manager has exited.
# We only return the context manager after the previous deferred has
# resolved.
# This all has the net effect of creating a chain of deferreds that
# wait for the previous deferred before starting their work.
current_defer = self.key_to_defer.get(key)
new_defer = defer.Deferred() # If the number of things executing is greater than the maximum
self.key_to_defer[key] = new_defer # then add a deferred to the list of blocked items
# When on of the things currently executing finishes it will callback
# this item so that it can continue executing.
if entry[0] >= self.max_count:
new_defer = defer.Deferred()
entry[1][new_defer] = 1
if current_defer:
logger.info( logger.info(
"Waiting to acquire linearizer lock %r for key %r", self.name, key "Waiting to acquire linearizer lock %r for key %r", self.name, key,
) )
try: try:
with PreserveLoggingContext(): yield make_deferred_yieldable(new_defer)
yield current_defer except Exception as e:
except Exception: if isinstance(e, CancelledError):
logger.exception("Unexpected exception in Linearizer") logger.info(
"Cancelling wait for linearizer lock %r for key %r",
self.name, key,
)
else:
logger.warn(
"Unexpected exception waiting for linearizer lock %r for key %r",
self.name, key,
)
logger.info("Acquired linearizer lock %r for key %r", self.name, # we just have to take ourselves back out of the queue.
key) del entry[1][new_defer]
raise
logger.info("Acquired linearizer lock %r for key %r", self.name, key)
entry[0] += 1
# if the code holding the lock completes synchronously, then it # if the code holding the lock completes synchronously, then it
# will recursively run the next claimant on the list. That can # will recursively run the next claimant on the list. That can
@ -213,15 +232,15 @@ class Linearizer(object):
# In order to break the cycle, we add a cheeky sleep(0) here to # In order to break the cycle, we add a cheeky sleep(0) here to
# ensure that we fall back to the reactor between each iteration. # ensure that we fall back to the reactor between each iteration.
# #
# (There's no particular need for it to happen before we return # (This needs to happen while we hold the lock, and the context manager's exit
# the context manager, but it needs to happen while we hold the # code must be synchronous, so this is the only sensible place.)
# lock, and the context manager's exit code must be synchronous,
# so actually this is the only sensible place.
yield self._clock.sleep(0) yield self._clock.sleep(0)
else: else:
logger.info("Acquired uncontended linearizer lock %r for key %r", logger.info(
self.name, key) "Acquired uncontended linearizer lock %r for key %r", self.name, key,
)
entry[0] += 1
@contextmanager @contextmanager
def _ctx_manager(): def _ctx_manager():
@ -229,73 +248,15 @@ class Linearizer(object):
yield yield
finally: finally:
logger.info("Releasing linearizer lock %r for key %r", self.name, key) logger.info("Releasing linearizer lock %r for key %r", self.name, key)
with PreserveLoggingContext():
new_defer.callback(None)
current_d = self.key_to_defer.get(key)
if current_d is new_defer:
self.key_to_defer.pop(key, None)
defer.returnValue(_ctx_manager())
class Limiter(object):
"""Limits concurrent access to resources based on a key. Useful to ensure
only a few thing happen at a time on a given resource.
Example:
with (yield limiter.queue("test_key")):
# do some work.
"""
def __init__(self, max_count):
"""
Args:
max_count(int): The maximum number of concurrent access
"""
self.max_count = max_count
# key_to_defer is a map from the key to a 2 element list where
# the first element is the number of things executing
# the second element is a list of deferreds for the things blocked from
# executing.
self.key_to_defer = {}
@defer.inlineCallbacks
def queue(self, key):
entry = self.key_to_defer.setdefault(key, [0, []])
# If the number of things executing is greater than the maximum
# then add a deferred to the list of blocked items
# When on of the things currently executing finishes it will callback
# this item so that it can continue executing.
if entry[0] >= self.max_count:
new_defer = defer.Deferred()
entry[1].append(new_defer)
logger.info("Waiting to acquire limiter lock for key %r", key)
with PreserveLoggingContext():
yield new_defer
logger.info("Acquired limiter lock for key %r", key)
else:
logger.info("Acquired uncontended limiter lock for key %r", key)
entry[0] += 1
@contextmanager
def _ctx_manager():
try:
yield
finally:
logger.info("Releasing limiter lock for key %r", key)
# We've finished executing so check if there are any things # We've finished executing so check if there are any things
# blocked waiting to execute and start one of them # blocked waiting to execute and start one of them
entry[0] -= 1 entry[0] -= 1
if entry[1]: if entry[1]:
next_def = entry[1].pop(0) (next_def, _) = entry[1].popitem(last=False)
# we need to run the next thing in the sentinel context.
with PreserveLoggingContext(): with PreserveLoggingContext():
next_def.callback(None) next_def.callback(None)
elif entry[0] == 0: elif entry[0] == 0:

View File

@ -16,6 +16,7 @@
import logging import logging
from collections import OrderedDict from collections import OrderedDict
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.caches import register_cache from synapse.util.caches import register_cache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -63,7 +64,10 @@ class ExpiringCache(object):
return return
def f(): def f():
self._prune_cache() run_as_background_process(
"prune_cache_%s" % self._cache_name,
self._prune_cache,
)
self._clock.looping_call(f, self._expiry_ms / 2) self._clock.looping_call(f, self._expiry_ms / 2)

View File

@ -17,20 +17,18 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.util import unwrapFirstError from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def user_left_room(distributor, user, room_id): def user_left_room(distributor, user, room_id):
with PreserveLoggingContext(): distributor.fire("user_left_room", user=user, room_id=room_id)
distributor.fire("user_left_room", user=user, room_id=room_id)
def user_joined_room(distributor, user, room_id): def user_joined_room(distributor, user, room_id):
with PreserveLoggingContext(): distributor.fire("user_joined_room", user=user, room_id=room_id)
distributor.fire("user_joined_room", user=user, room_id=room_id)
class Distributor(object): class Distributor(object):
@ -44,9 +42,7 @@ class Distributor(object):
model will do for today. model will do for today.
""" """
def __init__(self, suppress_failures=True): def __init__(self):
self.suppress_failures = suppress_failures
self.signals = {} self.signals = {}
self.pre_registration = {} self.pre_registration = {}
@ -56,7 +52,6 @@ class Distributor(object):
self.signals[name] = Signal( self.signals[name] = Signal(
name, name,
suppress_failures=self.suppress_failures,
) )
if name in self.pre_registration: if name in self.pre_registration:
@ -75,10 +70,18 @@ class Distributor(object):
self.pre_registration[name].append(observer) self.pre_registration[name].append(observer)
def fire(self, name, *args, **kwargs): def fire(self, name, *args, **kwargs):
"""Dispatches the given signal to the registered observers.
Runs the observers as a background process. Does not return a deferred.
"""
if name not in self.signals: if name not in self.signals:
raise KeyError("%r does not have a signal named %s" % (self, name)) raise KeyError("%r does not have a signal named %s" % (self, name))
return self.signals[name].fire(*args, **kwargs) run_as_background_process(
name,
self.signals[name].fire,
*args, **kwargs
)
class Signal(object): class Signal(object):
@ -91,9 +94,8 @@ class Signal(object):
method into all of the observers. method into all of the observers.
""" """
def __init__(self, name, suppress_failures): def __init__(self, name):
self.name = name self.name = name
self.suppress_failures = suppress_failures
self.observers = [] self.observers = []
def observe(self, observer): def observe(self, observer):
@ -103,7 +105,6 @@ class Signal(object):
Each observer callable may return a Deferred.""" Each observer callable may return a Deferred."""
self.observers.append(observer) self.observers.append(observer)
@defer.inlineCallbacks
def fire(self, *args, **kwargs): def fire(self, *args, **kwargs):
"""Invokes every callable in the observer list, passing in the args and """Invokes every callable in the observer list, passing in the args and
kwargs. Exceptions thrown by observers are logged but ignored. It is kwargs. Exceptions thrown by observers are logged but ignored. It is
@ -121,22 +122,17 @@ class Signal(object):
failure.type, failure.type,
failure.value, failure.value,
failure.getTracebackObject())) failure.getTracebackObject()))
if not self.suppress_failures:
return failure
return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb) return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
with PreserveLoggingContext(): deferreds = [
deferreds = [ run_in_background(do, o)
do(observer) for o in self.observers
for observer in self.observers ]
]
res = yield defer.gatherResults( return make_deferred_yieldable(defer.gatherResults(
deferreds, consumeErrors=True deferreds, consumeErrors=True,
).addErrback(unwrapFirstError) ))
defer.returnValue(res)
def __repr__(self): def __repr__(self):
return "<Signal name=%r>" % (self.name,) return "<Signal name=%r>" % (self.name,)

View File

@ -99,6 +99,17 @@ class ContextResourceUsage(object):
self.db_sched_duration_sec = 0 self.db_sched_duration_sec = 0
self.evt_db_fetch_count = 0 self.evt_db_fetch_count = 0
def __repr__(self):
return ("<ContextResourceUsage ru_stime='%r', ru_utime='%r', "
"db_txn_count='%r', db_txn_duration_sec='%r', "
"db_sched_duration_sec='%r', evt_db_fetch_count='%r'>") % (
self.ru_stime,
self.ru_utime,
self.db_txn_count,
self.db_txn_duration_sec,
self.db_sched_duration_sec,
self.evt_db_fetch_count,)
def __iadd__(self, other): def __iadd__(self, other):
"""Add another ContextResourceUsage's stats to this one's. """Add another ContextResourceUsage's stats to this one's.

View File

@ -104,12 +104,19 @@ class Measure(object):
logger.warn("Expected context. (%r)", self.name) logger.warn("Expected context. (%r)", self.name)
return return
usage = context.get_resource_usage() - self.start_usage current = context.get_resource_usage()
block_ru_utime.labels(self.name).inc(usage.ru_utime) usage = current - self.start_usage
block_ru_stime.labels(self.name).inc(usage.ru_stime) try:
block_db_txn_count.labels(self.name).inc(usage.db_txn_count) block_ru_utime.labels(self.name).inc(usage.ru_utime)
block_db_txn_duration.labels(self.name).inc(usage.db_txn_duration_sec) block_ru_stime.labels(self.name).inc(usage.ru_stime)
block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec) block_db_txn_count.labels(self.name).inc(usage.db_txn_count)
block_db_txn_duration.labels(self.name).inc(usage.db_txn_duration_sec)
block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec)
except ValueError:
logger.warn(
"Failed to save metrics! OLD: %r, NEW: %r",
self.start_usage, current
)
if self.created_context: if self.created_context:
self.start_context.__exit__(exc_type, exc_val, exc_tb) self.start_context.__exit__(exc_type, exc_val, exc_tb)

View File

@ -12,11 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import itertools
import logging import logging
import operator import operator
import six from six import iteritems, itervalues
from six.moves import map
from twisted.internet import defer from twisted.internet import defer
@ -221,7 +222,7 @@ def filter_events_for_client(store, user_id, events, is_peeking=False,
return event return event
# check each event: gives an iterable[None|EventBase] # check each event: gives an iterable[None|EventBase]
filtered_events = itertools.imap(allowed, events) filtered_events = map(allowed, events)
# remove the None entries # remove the None entries
filtered_events = filter(operator.truth, filtered_events) filtered_events = filter(operator.truth, filtered_events)
@ -261,7 +262,7 @@ def filter_events_for_server(store, server_name, events):
# membership states for the requesting server to determine # membership states for the requesting server to determine
# if the server is either in the room or has been invited # if the server is either in the room or has been invited
# into the room. # into the room.
for ev in state.itervalues(): for ev in itervalues(state):
if ev.type != EventTypes.Member: if ev.type != EventTypes.Member:
continue continue
try: try:
@ -295,7 +296,7 @@ def filter_events_for_server(store, server_name, events):
) )
visibility_ids = set() visibility_ids = set()
for sids in event_to_state_ids.itervalues(): for sids in itervalues(event_to_state_ids):
hist = sids.get((EventTypes.RoomHistoryVisibility, "")) hist = sids.get((EventTypes.RoomHistoryVisibility, ""))
if hist: if hist:
visibility_ids.add(hist) visibility_ids.add(hist)
@ -308,7 +309,7 @@ def filter_events_for_server(store, server_name, events):
event_map = yield store.get_events(visibility_ids) event_map = yield store.get_events(visibility_ids)
all_open = all( all_open = all(
e.content.get("history_visibility") in (None, "shared", "world_readable") e.content.get("history_visibility") in (None, "shared", "world_readable")
for e in event_map.itervalues() for e in itervalues(event_map)
) )
if all_open: if all_open:
@ -346,7 +347,7 @@ def filter_events_for_server(store, server_name, events):
# #
state_key_to_event_id_set = { state_key_to_event_id_set = {
e e
for key_to_eid in six.itervalues(event_to_state_ids) for key_to_eid in itervalues(event_to_state_ids)
for e in key_to_eid.items() for e in key_to_eid.items()
} }
@ -369,10 +370,10 @@ def filter_events_for_server(store, server_name, events):
event_to_state = { event_to_state = {
e_id: { e_id: {
key: event_map[inner_e_id] key: event_map[inner_e_id]
for key, inner_e_id in key_to_eid.iteritems() for key, inner_e_id in iteritems(key_to_eid)
if inner_e_id in event_map if inner_e_id in event_map
} }
for e_id, key_to_eid in event_to_state_ids.iteritems() for e_id, key_to_eid in iteritems(event_to_state_ids)
} }
defer.returnValue([ defer.returnValue([

View File

@ -0,0 +1,305 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import hmac
import json
from mock import Mock
from synapse.http.server import JsonResource
from synapse.rest.client.v1.admin import register_servlets
from synapse.util import Clock
from tests import unittest
from tests.server import (
ThreadedMemoryReactorClock,
make_request,
render,
setup_test_homeserver,
)
class UserRegisterTestCase(unittest.TestCase):
def setUp(self):
self.clock = ThreadedMemoryReactorClock()
self.hs_clock = Clock(self.clock)
self.url = "/_matrix/client/r0/admin/register"
self.registration_handler = Mock()
self.identity_handler = Mock()
self.login_handler = Mock()
self.device_handler = Mock()
self.device_handler.check_device_registered = Mock(return_value="FAKE")
self.datastore = Mock(return_value=Mock())
self.datastore.get_current_state_deltas = Mock(return_value=[])
self.secrets = Mock()
self.hs = setup_test_homeserver(
http_client=None, clock=self.hs_clock, reactor=self.clock
)
self.hs.config.registration_shared_secret = u"shared"
self.hs.get_media_repository = Mock()
self.hs.get_deactivate_account_handler = Mock()
self.resource = JsonResource(self.hs)
register_servlets(self.hs, self.resource)
def test_disabled(self):
"""
If there is no shared secret, registration through this method will be
prevented.
"""
self.hs.config.registration_shared_secret = None
request, channel = make_request("POST", self.url, b'{}')
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
'Shared secret registration is not enabled', channel.json_body["error"]
)
def test_get_nonce(self):
"""
Calling GET on the endpoint will return a randomised nonce, using the
homeserver's secrets provider.
"""
secrets = Mock()
secrets.token_hex = Mock(return_value="abcd")
self.hs.get_secrets = Mock(return_value=secrets)
request, channel = make_request("GET", self.url)
render(request, self.resource, self.clock)
self.assertEqual(channel.json_body, {"nonce": "abcd"})
def test_expired_nonce(self):
"""
Calling GET on the endpoint will return a randomised nonce, which will
only last for SALT_TIMEOUT (60s).
"""
request, channel = make_request("GET", self.url)
render(request, self.resource, self.clock)
nonce = channel.json_body["nonce"]
# 59 seconds
self.clock.advance(59)
body = json.dumps({"nonce": nonce})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('username must be specified', channel.json_body["error"])
# 61 seconds
self.clock.advance(2)
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('unrecognised nonce', channel.json_body["error"])
def test_register_incorrect_nonce(self):
"""
Only the provided nonce can be used, as it's checked in the MAC.
"""
request, channel = make_request("GET", self.url)
render(request, self.resource, self.clock)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin")
want_mac = want_mac.hexdigest()
body = json.dumps(
{
"nonce": nonce,
"username": "bob",
"password": "abc123",
"admin": True,
"mac": want_mac,
}
).encode('utf8')
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("HMAC incorrect", channel.json_body["error"])
def test_register_correct_nonce(self):
"""
When the correct nonce is provided, and the right key is provided, the
user is registered.
"""
request, channel = make_request("GET", self.url)
render(request, self.resource, self.clock)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin")
want_mac = want_mac.hexdigest()
body = json.dumps(
{
"nonce": nonce,
"username": "bob",
"password": "abc123",
"admin": True,
"mac": want_mac,
}
).encode('utf8')
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["user_id"])
def test_nonce_reuse(self):
"""
A valid unrecognised nonce.
"""
request, channel = make_request("GET", self.url)
render(request, self.resource, self.clock)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin")
want_mac = want_mac.hexdigest()
body = json.dumps(
{
"nonce": nonce,
"username": "bob",
"password": "abc123",
"admin": True,
"mac": want_mac,
}
).encode('utf8')
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["user_id"])
# Now, try and reuse it
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('unrecognised nonce', channel.json_body["error"])
def test_missing_parts(self):
"""
Synapse will complain if you don't give nonce, username, password, and
mac. Admin is optional. Additional checks are done for length and
type.
"""
def nonce():
request, channel = make_request("GET", self.url)
render(request, self.resource, self.clock)
return channel.json_body["nonce"]
#
# Nonce check
#
# Must be present
body = json.dumps({})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('nonce must be specified', channel.json_body["error"])
#
# Username checks
#
# Must be present
body = json.dumps({"nonce": nonce()})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('username must be specified', channel.json_body["error"])
# Must be a string
body = json.dumps({"nonce": nonce(), "username": 1234})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid username', channel.json_body["error"])
# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": b"abcd\x00"})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid username', channel.json_body["error"])
# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid username', channel.json_body["error"])
#
# Username checks
#
# Must be present
body = json.dumps({"nonce": nonce(), "username": "a"})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('password must be specified', channel.json_body["error"])
# Must be a string
body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid password', channel.json_body["error"])
# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "a", "password": b"abcd\x00"})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid password', channel.json_body["error"])
# Super long
body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid password', channel.json_body["error"])

View File

@ -14,100 +14,30 @@
# limitations under the License. # limitations under the License.
""" Tests REST events for /events paths.""" """ Tests REST events for /events paths."""
from mock import Mock, NonCallableMock from mock import Mock, NonCallableMock
from six import PY3
# twisted imports
from twisted.internet import defer from twisted.internet import defer
import synapse.rest.client.v1.events
import synapse.rest.client.v1.register
import synapse.rest.client.v1.room
from tests import unittest
from ....utils import MockHttpResource, setup_test_homeserver from ....utils import MockHttpResource, setup_test_homeserver
from .utils import RestTestCase from .utils import RestTestCase
PATH_PREFIX = "/_matrix/client/api/v1" PATH_PREFIX = "/_matrix/client/api/v1"
class EventStreamPaginationApiTestCase(unittest.TestCase):
""" Tests event streaming query parameters and start/end keys used in the
Pagination stream API. """
user_id = "sid1"
def setUp(self):
# configure stream and inject items
pass
def tearDown(self):
pass
def TODO_test_long_poll(self):
# stream from 'end' key, send (self+other) message, expect message.
# stream from 'END', send (self+other) message, expect message.
# stream from 'end' key, send (self+other) topic, expect topic.
# stream from 'END', send (self+other) topic, expect topic.
# stream from 'end' key, send (self+other) invite, expect invite.
# stream from 'END', send (self+other) invite, expect invite.
pass
def TODO_test_stream_forward(self):
# stream from START, expect injected items
# stream from 'start' key, expect same content
# stream from 'end' key, expect nothing
# stream from 'END', expect nothing
# The following is needed for cases where content is removed e.g. you
# left a room, so the token you're streaming from is > the one that
# would be returned naturally from START>END.
# stream from very new token (higher than end key), expect same token
# returned as end key
pass
def TODO_test_limits(self):
# stream from a key, expect limit_num items
# stream from START, expect limit_num items
pass
def TODO_test_range(self):
# stream from key to key, expect X items
# stream from key to END, expect X items
# stream from START to key, expect X items
# stream from START to END, expect all items
pass
def TODO_test_direction(self):
# stream from END to START and fwds, expect newest first
# stream from END to START and bwds, expect oldest first
# stream from START to END and fwds, expect oldest first
# stream from START to END and bwds, expect newest first
pass
class EventStreamPermissionsTestCase(RestTestCase): class EventStreamPermissionsTestCase(RestTestCase):
""" Tests event streaming (GET /events). """ """ Tests event streaming (GET /events). """
if PY3:
skip = "Skip on Py3 until ported to use not V1 only register."
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
import synapse.rest.client.v1.events
import synapse.rest.client.v1_only.register
import synapse.rest.client.v1.room
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
@ -125,7 +55,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
synapse.rest.client.v1.register.register_servlets(hs, self.mock_resource) synapse.rest.client.v1_only.register.register_servlets(hs, self.mock_resource)
synapse.rest.client.v1.events.register_servlets(hs, self.mock_resource) synapse.rest.client.v1.events.register_servlets(hs, self.mock_resource)
synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)

View File

@ -16,11 +16,12 @@
import json import json
from mock import Mock from mock import Mock
from six import PY3
from twisted.test.proto_helpers import MemoryReactorClock from twisted.test.proto_helpers import MemoryReactorClock
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.rest.client.v1.register import register_servlets from synapse.rest.client.v1_only.register import register_servlets
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
@ -31,6 +32,8 @@ class CreateUserServletTestCase(unittest.TestCase):
""" """
Tests for CreateUserRestServlet. Tests for CreateUserRestServlet.
""" """
if PY3:
skip = "Not ported to Python 3."
def setUp(self): def setUp(self):
self.registration_handler = Mock() self.registration_handler = Mock()

View File

@ -20,7 +20,6 @@ import json
from mock import Mock, NonCallableMock from mock import Mock, NonCallableMock
from six.moves.urllib import parse as urlparse from six.moves.urllib import parse as urlparse
# twisted imports
from twisted.internet import defer from twisted.internet import defer
import synapse.rest.client.v1.room import synapse.rest.client.v1.room
@ -86,6 +85,7 @@ class RoomBase(unittest.TestCase):
self.resource = JsonResource(self.hs) self.resource = JsonResource(self.hs)
synapse.rest.client.v1.room.register_servlets(self.hs, self.resource) synapse.rest.client.v1.room.register_servlets(self.hs, self.resource)
synapse.rest.client.v1.room.register_deprecated_servlets(self.hs, self.resource)
self.helper = RestHelper(self.hs, self.resource, self.user_id) self.helper = RestHelper(self.hs, self.resource, self.user_id)

View File

@ -21,8 +21,12 @@ from synapse.types import UserID
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import ThreadedMemoryReactorClock as MemoryReactorClock from tests.server import (
from tests.server import make_request, setup_test_homeserver, wait_until_result ThreadedMemoryReactorClock as MemoryReactorClock,
make_request,
setup_test_homeserver,
wait_until_result,
)
PATH_PREFIX = "/_matrix/client/v2_alpha" PATH_PREFIX = "/_matrix/client/v2_alpha"

View File

@ -20,8 +20,12 @@ from synapse.types import UserID
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import ThreadedMemoryReactorClock as MemoryReactorClock from tests.server import (
from tests.server import make_request, setup_test_homeserver, wait_until_result ThreadedMemoryReactorClock as MemoryReactorClock,
make_request,
setup_test_homeserver,
wait_until_result,
)
PATH_PREFIX = "/_matrix/client/v2_alpha" PATH_PREFIX = "/_matrix/client/v2_alpha"

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -15,8 +16,6 @@
from mock import Mock, patch from mock import Mock, patch
from twisted.internet import defer
from synapse.util.distributor import Distributor from synapse.util.distributor import Distributor
from . import unittest from . import unittest
@ -27,38 +26,15 @@ class DistributorTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.dist = Distributor() self.dist = Distributor()
@defer.inlineCallbacks
def test_signal_dispatch(self): def test_signal_dispatch(self):
self.dist.declare("alert") self.dist.declare("alert")
observer = Mock() observer = Mock()
self.dist.observe("alert", observer) self.dist.observe("alert", observer)
d = self.dist.fire("alert", 1, 2, 3) self.dist.fire("alert", 1, 2, 3)
yield d
self.assertTrue(d.called)
observer.assert_called_with(1, 2, 3) observer.assert_called_with(1, 2, 3)
@defer.inlineCallbacks
def test_signal_dispatch_deferred(self):
self.dist.declare("whine")
d_inner = defer.Deferred()
def observer():
return d_inner
self.dist.observe("whine", observer)
d_outer = self.dist.fire("whine")
self.assertFalse(d_outer.called)
d_inner.callback(None)
yield d_outer
self.assertTrue(d_outer.called)
@defer.inlineCallbacks
def test_signal_catch(self): def test_signal_catch(self):
self.dist.declare("alarm") self.dist.declare("alarm")
@ -71,9 +47,7 @@ class DistributorTestCase(unittest.TestCase):
with patch( with patch(
"synapse.util.distributor.logger", spec=["warning"] "synapse.util.distributor.logger", spec=["warning"]
) as mock_logger: ) as mock_logger:
d = self.dist.fire("alarm", "Go") self.dist.fire("alarm", "Go")
yield d
self.assertTrue(d.called)
observers[0].assert_called_once_with("Go") observers[0].assert_called_once_with("Go")
observers[1].assert_called_once_with("Go") observers[1].assert_called_once_with("Go")
@ -83,34 +57,12 @@ class DistributorTestCase(unittest.TestCase):
mock_logger.warning.call_args[0][0], str mock_logger.warning.call_args[0][0], str
) )
@defer.inlineCallbacks
def test_signal_catch_no_suppress(self):
# Gut-wrenching
self.dist.suppress_failures = False
self.dist.declare("whail")
class MyException(Exception):
pass
@defer.inlineCallbacks
def observer():
raise MyException("Oopsie")
self.dist.observe("whail", observer)
d = self.dist.fire("whail")
yield self.assertFailure(d, MyException)
self.dist.suppress_failures = True
@defer.inlineCallbacks
def test_signal_prereg(self): def test_signal_prereg(self):
observer = Mock() observer = Mock()
self.dist.observe("flare", observer) self.dist.observe("flare", observer)
self.dist.declare("flare") self.dist.declare("flare")
yield self.dist.fire("flare", 4, 5) self.dist.fire("flare", 4, 5)
observer.assert_called_with(4, 5) observer.assert_called_with(4, 5)

View File

@ -137,7 +137,6 @@ class MessageAcceptTests(unittest.TestCase):
) )
self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
@unittest.DEBUG
def test_cant_hide_past_history(self): def test_cant_hide_past_history(self):
""" """
If you send a message, you must be able to provide the direct If you send a message, you must be able to provide the direct
@ -178,7 +177,7 @@ class MessageAcceptTests(unittest.TestCase):
for x, y in d.items() for x, y in d.items()
if x == ("m.room.member", "@us:test") if x == ("m.room.member", "@us:test")
], ],
"auth_chain_ids": d.values(), "auth_chain_ids": list(d.values()),
} }
) )

View File

@ -1,70 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.util.async import Limiter
from tests import unittest
class LimiterTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_limiter(self):
limiter = Limiter(3)
key = object()
d1 = limiter.queue(key)
cm1 = yield d1
d2 = limiter.queue(key)
cm2 = yield d2
d3 = limiter.queue(key)
cm3 = yield d3
d4 = limiter.queue(key)
self.assertFalse(d4.called)
d5 = limiter.queue(key)
self.assertFalse(d5.called)
with cm1:
self.assertFalse(d4.called)
self.assertFalse(d5.called)
self.assertTrue(d4.called)
self.assertFalse(d5.called)
with cm3:
self.assertFalse(d5.called)
self.assertTrue(d5.called)
with cm2:
pass
with (yield d4):
pass
with (yield d5):
pass
d6 = limiter.queue(key)
with (yield d6):
pass

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd # Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -16,6 +17,7 @@
from six.moves import range from six.moves import range
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.defer import CancelledError
from synapse.util import Clock, logcontext from synapse.util import Clock, logcontext
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
@ -65,3 +67,79 @@ class LinearizerTestCase(unittest.TestCase):
func(i) func(i)
return func(1000) return func(1000)
@defer.inlineCallbacks
def test_multiple_entries(self):
limiter = Linearizer(max_count=3)
key = object()
d1 = limiter.queue(key)
cm1 = yield d1
d2 = limiter.queue(key)
cm2 = yield d2
d3 = limiter.queue(key)
cm3 = yield d3
d4 = limiter.queue(key)
self.assertFalse(d4.called)
d5 = limiter.queue(key)
self.assertFalse(d5.called)
with cm1:
self.assertFalse(d4.called)
self.assertFalse(d5.called)
cm4 = yield d4
self.assertFalse(d5.called)
with cm3:
self.assertFalse(d5.called)
cm5 = yield d5
with cm2:
pass
with cm4:
pass
with cm5:
pass
d6 = limiter.queue(key)
with (yield d6):
pass
@defer.inlineCallbacks
def test_cancellation(self):
linearizer = Linearizer()
key = object()
d1 = linearizer.queue(key)
cm1 = yield d1
d2 = linearizer.queue(key)
self.assertFalse(d2.called)
d3 = linearizer.queue(key)
self.assertFalse(d3.called)
d2.cancel()
with cm1:
pass
self.assertTrue(d2.called)
try:
yield d2
self.fail("Expected d2 to raise CancelledError")
except CancelledError:
pass
with (yield d3):
pass

View File

@ -71,6 +71,8 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None
config.user_directory_search_all_users = False config.user_directory_search_all_users = False
config.user_consent_server_notice_content = None config.user_consent_server_notice_content = None
config.block_events_without_consent_error = None config.block_events_without_consent_error = None
config.media_storage_providers = []
config.auto_join_rooms = []
# disable user directory updates, because they get done in the # disable user directory updates, because they get done in the
# background, which upsets the test runner. # background, which upsets the test runner.
@ -136,6 +138,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None
database_engine=db_engine, database_engine=db_engine,
room_list_handler=object(), room_list_handler=object(),
tls_server_context_factory=Mock(), tls_server_context_factory=Mock(),
reactor=reactor,
**kargs **kargs
) )