Merge branch 'develop' of github.com:matrix-org/synapse into erikj/purge_state_groups

This commit is contained in:
Erik Johnston 2018-10-26 16:22:45 +01:00
commit 03e634dad4
115 changed files with 3799 additions and 1683 deletions

View File

@ -1,12 +1,27 @@
sudo: false sudo: false
language: python language: python
# tell travis to cache ~/.cache/pip cache:
cache: pip directories:
# we only bother to cache the wheels; parts of the http cache get
# invalidated every build (because they get served with a max-age of 600
# seconds), which means that we end up re-uploading the whole cache for
# every build, which is time-consuming In any case, it's not obvious that
# downloading the cache from S3 would be much faster than downloading the
# originals from pypi.
#
- $HOME/.cache/pip/wheels
before_script: # don't clone the whole repo history, one commit will do
- git remote set-branches --add origin develop git:
- git fetch origin develop depth: 1
# only build branches we care about (PRs are built seperately)
branches:
only:
- master
- develop
- /^release-v/
matrix: matrix:
fast_finish: true fast_finish: true
@ -14,8 +29,8 @@ matrix:
- python: 2.7 - python: 2.7
env: TOX_ENV=packaging env: TOX_ENV=packaging
- python: 2.7 - python: 3.6
env: TOX_ENV=pep8 env: TOX_ENV="pep8,check_isort"
- python: 2.7 - python: 2.7
env: TOX_ENV=py27 env: TOX_ENV=py27
@ -39,11 +54,14 @@ matrix:
services: services:
- postgresql - postgresql
- python: 3.6 - # we only need to check for the newsfragment if it's a PR build
env: TOX_ENV=check_isort if: type = pull_request
python: 3.6
- python: 3.6
env: TOX_ENV=check-newsfragment env: TOX_ENV=check-newsfragment
script:
- git remote set-branches --add origin develop
- git fetch origin develop
- tox -e $TOX_ENV
install: install:
- pip install tox - pip install tox

View File

@ -11,10 +11,6 @@ If you have email notifications enabled, you should ensure that
have installed customised templates, or leave it unset to use the default have installed customised templates, or leave it unset to use the default
templates. templates.
The configuration parser will try to detect the situation where
`email.template_dir` is incorrectly set to `res/templates` and do the right
thing, but will warn about this.
Synapse 0.33.7rc2 (2018-10-17) Synapse 0.33.7rc2 (2018-10-17)
============================== ==============================

View File

@ -174,6 +174,12 @@ Alternatively, Andreas Peters (previously Silvio Fricke) has contributed a
Dockerfile to automate a synapse server in a single Docker image, at Dockerfile to automate a synapse server in a single Docker image, at
https://hub.docker.com/r/avhost/docker-matrix/tags/ https://hub.docker.com/r/avhost/docker-matrix/tags/
Slavi Pantaleev has created an Ansible playbook,
which installs the offical Docker image of Matrix Synapse
along with many other Matrix-related services (Postgres database, riot-web, coturn, mxisd, SSL support, etc.).
For more details, see
https://github.com/spantaleev/matrix-docker-ansible-deploy
Configuring Synapse Configuring Synapse
------------------- -------------------
@ -651,7 +657,8 @@ Using a reverse proxy with Synapse
It is recommended to put a reverse proxy such as It is recommended to put a reverse proxy such as
`nginx <https://nginx.org/en/docs/http/ngx_http_proxy_module.html>`_, `nginx <https://nginx.org/en/docs/http/ngx_http_proxy_module.html>`_,
`Apache <https://httpd.apache.org/docs/current/mod/mod_proxy_http.html>`_ or `Apache <https://httpd.apache.org/docs/current/mod/mod_proxy_http.html>`_,
`Caddy <https://caddyserver.com/docs/proxy>`_ or
`HAProxy <https://www.haproxy.org/>`_ in front of Synapse. One advantage of `HAProxy <https://www.haproxy.org/>`_ in front of Synapse. One advantage of
doing so is that it means that you can expose the default https port (443) to doing so is that it means that you can expose the default https port (443) to
Matrix clients without needing to run Synapse with root privileges. Matrix clients without needing to run Synapse with root privileges.
@ -682,7 +689,15 @@ so an example nginx configuration might look like::
} }
} }
and an example apache configuration may look like:: an example Caddy configuration might look like::
matrix.example.com {
proxy /_matrix http://localhost:8008 {
transparent
}
}
and an example Apache configuration might look like::
<VirtualHost *:443> <VirtualHost *:443>
SSLEngine on SSLEngine on

View File

@ -61,10 +61,6 @@ If you have email notifications enabled, you should ensure that
have installed customised templates, or leave it unset to use the default have installed customised templates, or leave it unset to use the default
templates. templates.
The configuration parser will try to detect the situation where
``email.template_dir`` is incorrectly set to ``res/templates`` and do the right
thing, but will warn about this.
Upgrading to v0.27.3 Upgrading to v0.27.3
==================== ====================

1
changelog.d/3698.misc Normal file
View File

@ -0,0 +1 @@
Add information about the [matrix-docker-ansible-deploy](https://github.com/spantaleev/matrix-docker-ansible-deploy) playbook

1
changelog.d/3786.misc Normal file
View File

@ -0,0 +1 @@
Add initial implementation of new state resolution algorithm

1
changelog.d/3969.bugfix Normal file
View File

@ -0,0 +1 @@
Fix HTTP error response codes for federated group requests.

1
changelog.d/3975.feature Normal file
View File

@ -0,0 +1 @@
Servers with auto-join rooms will now automatically create those rooms when the first user registers

1
changelog.d/4011.misc Normal file
View File

@ -0,0 +1 @@
Reduce database load when fetching state groups

1
changelog.d/4051.feature Normal file
View File

@ -0,0 +1 @@
Add config option to control alias creation

1
changelog.d/4063.misc Normal file
View File

@ -0,0 +1 @@
Refactor room alias creation code

1
changelog.d/4068.bugfix Normal file
View File

@ -0,0 +1 @@
Fix bug which prevented email notifications from being sent unless an absolute path was given for `email_templates`.

1
changelog.d/4068.misc Normal file
View File

@ -0,0 +1 @@
Make the Python scripts in the top-level scripts folders meet pep8 and pass flake8.

1
changelog.d/4072.misc Normal file
View File

@ -0,0 +1 @@
The README now contains example for the Caddy web server. Contributed by steamp0rt.

1
changelog.d/4073.misc Normal file
View File

@ -0,0 +1 @@
Add psutil as an explicit dependency

1
changelog.d/4074.bugfix Normal file
View File

@ -0,0 +1 @@
Correctly account for cpu usage by background threads

1
changelog.d/4075.misc Normal file
View File

@ -0,0 +1 @@
Clean up threading and logcontexts in pushers

1
changelog.d/4076.misc Normal file
View File

@ -0,0 +1 @@
Correctly manage logcontexts during startup to fix some "Unexpected logging context" warnings

1
changelog.d/4077.misc Normal file
View File

@ -0,0 +1 @@
Give some more things logcontexts

2
changelog.d/4081.bugfix Normal file
View File

@ -0,0 +1,2 @@
Fix race condition where config defined reserved users were not being added to
the monthly active user list prior to the homeserver reactor firing up

1
changelog.d/4082.misc Normal file
View File

@ -0,0 +1 @@
Clean up some bits of code which were flagged by the linter

1
changelog.d/4083.bugfix Normal file
View File

@ -0,0 +1 @@
Fix bug which prevented backslashes being used in event field filters

1
changelog.d/4085.feature Normal file
View File

@ -0,0 +1 @@
The register_new_matrix_user script is now ported to Python 3.

1
changelog.d/4089.feature Normal file
View File

@ -0,0 +1 @@
Configure Docker image to listen on both ipv4 and ipv6.

View File

@ -21,7 +21,7 @@ listeners:
{% if not SYNAPSE_NO_TLS %} {% if not SYNAPSE_NO_TLS %}
- -
port: 8448 port: 8448
bind_addresses: ['0.0.0.0'] bind_addresses: ['::']
type: http type: http
tls: true tls: true
x_forwarded: false x_forwarded: false
@ -34,7 +34,7 @@ listeners:
- port: 8008 - port: 8008
tls: false tls: false
bind_addresses: ['0.0.0.0'] bind_addresses: ['::']
type: http type: http
x_forwarded: false x_forwarded: false

View File

@ -1,21 +1,20 @@
from synapse.events import FrozenEvent from __future__ import print_function
from synapse.api.auth import Auth
from mock import Mock
import argparse import argparse
import itertools import itertools
import json import json
import sys import sys
from mock import Mock
from synapse.api.auth import Auth
from synapse.events import FrozenEvent
def check_auth(auth, auth_chain, events): def check_auth(auth, auth_chain, events):
auth_chain.sort(key=lambda e: e.depth) auth_chain.sort(key=lambda e: e.depth)
auth_map = { auth_map = {e.event_id: e for e in auth_chain}
e.event_id: e
for e in auth_chain
}
create_events = {} create_events = {}
for e in auth_chain: for e in auth_chain:
@ -25,31 +24,26 @@ def check_auth(auth, auth_chain, events):
for e in itertools.chain(auth_chain, events): for e in itertools.chain(auth_chain, events):
auth_events_list = [auth_map[i] for i, _ in e.auth_events] auth_events_list = [auth_map[i] for i, _ in e.auth_events]
auth_events = { auth_events = {(e.type, e.state_key): e for e in auth_events_list}
(e.type, e.state_key): e
for e in auth_events_list
}
auth_events[("m.room.create", "")] = create_events[e.room_id] auth_events[("m.room.create", "")] = create_events[e.room_id]
try: try:
auth.check(e, auth_events=auth_events) auth.check(e, auth_events=auth_events)
except Exception as ex: except Exception as ex:
print "Failed:", e.event_id, e.type, e.state_key print("Failed:", e.event_id, e.type, e.state_key)
print "Auth_events:", auth_events print("Auth_events:", auth_events)
print ex print(ex)
print json.dumps(e.get_dict(), sort_keys=True, indent=4) print(json.dumps(e.get_dict(), sort_keys=True, indent=4))
# raise # raise
print "Success:", e.event_id, e.type, e.state_key print("Success:", e.event_id, e.type, e.state_key)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
'json', 'json', nargs='?', type=argparse.FileType('r'), default=sys.stdin
nargs='?',
type=argparse.FileType('r'),
default=sys.stdin,
) )
args = parser.parse_args() args = parser.parse_args()

View File

@ -1,10 +1,15 @@
from synapse.crypto.event_signing import *
from unpaddedbase64 import encode_base64
import argparse import argparse
import hashlib import hashlib
import sys
import json import json
import logging
import sys
from unpaddedbase64 import encode_base64
from synapse.crypto.event_signing import (
check_event_content_hash,
compute_event_reference_hash,
)
class dictobj(dict): class dictobj(dict):
@ -24,27 +29,26 @@ class dictobj(dict):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("input_json", nargs="?", type=argparse.FileType('r'), parser.add_argument(
default=sys.stdin) "input_json", nargs="?", type=argparse.FileType('r'), default=sys.stdin
)
args = parser.parse_args() args = parser.parse_args()
logging.basicConfig() logging.basicConfig()
event_json = dictobj(json.load(args.input_json)) event_json = dictobj(json.load(args.input_json))
algorithms = { algorithms = {"sha256": hashlib.sha256}
"sha256": hashlib.sha256,
}
for alg_name in event_json.hashes: for alg_name in event_json.hashes:
if check_event_content_hash(event_json, algorithms[alg_name]): if check_event_content_hash(event_json, algorithms[alg_name]):
print "PASS content hash %s" % (alg_name,) print("PASS content hash %s" % (alg_name,))
else: else:
print "FAIL content hash %s" % (alg_name,) print("FAIL content hash %s" % (alg_name,))
for algorithm in algorithms.values(): for algorithm in algorithms.values():
name, h_bytes = compute_event_reference_hash(event_json, algorithm) name, h_bytes = compute_event_reference_hash(event_json, algorithm)
print "Reference hash %s: %s" % (name, encode_base64(h_bytes)) print("Reference hash %s: %s" % (name, encode_base64(h_bytes)))
if __name__=="__main__":
if __name__ == "__main__":
main() main()

View File

@ -1,15 +1,15 @@
from signedjson.sign import verify_signed_json import argparse
import json
import logging
import sys
import urllib2
import dns.resolver
from signedjson.key import decode_verify_key_bytes, write_signing_keys from signedjson.key import decode_verify_key_bytes, write_signing_keys
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
import urllib2
import json
import sys
import dns.resolver
import pprint
import argparse
import logging
def get_targets(server_name): def get_targets(server_name):
if ":" in server_name: if ":" in server_name:
@ -23,6 +23,7 @@ def get_targets(server_name):
except dns.resolver.NXDOMAIN: except dns.resolver.NXDOMAIN:
yield (server_name, 8448) yield (server_name, 8448)
def get_server_keys(server_name, target, port): def get_server_keys(server_name, target, port):
url = "https://%s:%i/_matrix/key/v1" % (target, port) url = "https://%s:%i/_matrix/key/v1" % (target, port)
keys = json.load(urllib2.urlopen(url)) keys = json.load(urllib2.urlopen(url))
@ -33,12 +34,14 @@ def get_server_keys(server_name, target, port):
verify_keys[key_id] = verify_key verify_keys[key_id] = verify_key
return verify_keys return verify_keys
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("signature_name") parser.add_argument("signature_name")
parser.add_argument("input_json", nargs="?", type=argparse.FileType('r'), parser.add_argument(
default=sys.stdin) "input_json", nargs="?", type=argparse.FileType('r'), default=sys.stdin
)
args = parser.parse_args() args = parser.parse_args()
logging.basicConfig() logging.basicConfig()
@ -48,24 +51,23 @@ def main():
for target, port in get_targets(server_name): for target, port in get_targets(server_name):
try: try:
keys = get_server_keys(server_name, target, port) keys = get_server_keys(server_name, target, port)
print "Using keys from https://%s:%s/_matrix/key/v1" % (target, port) print("Using keys from https://%s:%s/_matrix/key/v1" % (target, port))
write_signing_keys(sys.stdout, keys.values()) write_signing_keys(sys.stdout, keys.values())
break break
except: except Exception:
logging.exception("Error talking to %s:%s", target, port) logging.exception("Error talking to %s:%s", target, port)
json_to_check = json.load(args.input_json) json_to_check = json.load(args.input_json)
print "Checking JSON:" print("Checking JSON:")
for key_id in json_to_check["signatures"][args.signature_name]: for key_id in json_to_check["signatures"][args.signature_name]:
try: try:
key = keys[key_id] key = keys[key_id]
verify_signed_json(json_to_check, args.signature_name, key) verify_signed_json(json_to_check, args.signature_name, key)
print "PASS %s" % (key_id,) print("PASS %s" % (key_id,))
except: except Exception:
logging.exception("Check for key %s failed" % (key_id,)) logging.exception("Check for key %s failed" % (key_id,))
print "FAIL %s" % (key_id,) print("FAIL %s" % (key_id,))
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@ -1,13 +1,21 @@
import hashlib
import json
import sys
import time
import six
import psycopg2 import psycopg2
import yaml import yaml
import sys from canonicaljson import encode_canonical_json
import json
import time
import hashlib
from unpaddedbase64 import encode_base64
from signedjson.key import read_signing_keys from signedjson.key import read_signing_keys
from signedjson.sign import sign_json from signedjson.sign import sign_json
from canonicaljson import encode_canonical_json from unpaddedbase64 import encode_base64
if six.PY2:
db_type = six.moves.builtins.buffer
else:
db_type = memoryview
def select_v1_keys(connection): def select_v1_keys(connection):
@ -39,7 +47,9 @@ def select_v2_json(connection):
cursor.close() cursor.close()
results = {} results = {}
for server_name, key_id, key_json in rows: for server_name, key_id, key_json in rows:
results.setdefault(server_name, {})[key_id] = json.loads(str(key_json).decode("utf-8")) results.setdefault(server_name, {})[key_id] = json.loads(
str(key_json).decode("utf-8")
)
return results return results
@ -47,10 +57,7 @@ def convert_v1_to_v2(server_name, valid_until, keys, certificate):
return { return {
"old_verify_keys": {}, "old_verify_keys": {},
"server_name": server_name, "server_name": server_name,
"verify_keys": { "verify_keys": {key_id: {"key": key} for key_id, key in keys.items()},
key_id: {"key": key}
for key_id, key in keys.items()
},
"valid_until_ts": valid_until, "valid_until_ts": valid_until,
"tls_fingerprints": [fingerprint(certificate)], "tls_fingerprints": [fingerprint(certificate)],
} }
@ -65,7 +72,7 @@ def rows_v2(server, json):
valid_until = json["valid_until_ts"] valid_until = json["valid_until_ts"]
key_json = encode_canonical_json(json) key_json = encode_canonical_json(json)
for key_id in json["verify_keys"]: for key_id in json["verify_keys"]:
yield (server, key_id, "-", valid_until, valid_until, buffer(key_json)) yield (server, key_id, "-", valid_until, valid_until, db_type(key_json))
def main(): def main():
@ -87,7 +94,7 @@ def main():
result = {} result = {}
for server in keys: for server in keys:
if not server in json: if server not in json:
v2_json = convert_v1_to_v2( v2_json = convert_v1_to_v2(
server, valid_until, keys[server], certificates[server] server, valid_until, keys[server], certificates[server]
) )
@ -96,10 +103,7 @@ def main():
yaml.safe_dump(result, sys.stdout, default_flow_style=False) yaml.safe_dump(result, sys.stdout, default_flow_style=False)
rows = list( rows = list(row for server, json in result.items() for row in rows_v2(server, json))
row for server, json in result.items()
for row in rows_v2(server, json)
)
cursor = connection.cursor() cursor = connection.cursor()
cursor.executemany( cursor.executemany(
@ -107,7 +111,7 @@ def main():
" server_name, key_id, from_server," " server_name, key_id, from_server,"
" ts_added_ms, ts_valid_until_ms, key_json" " ts_added_ms, ts_valid_until_ms, key_json"
") VALUES (%s, %s, %s, %s, %s, %s)", ") VALUES (%s, %s, %s, %s, %s, %s)",
rows rows,
) )
connection.commit() connection.commit()

View File

@ -1,8 +1,16 @@
#! /usr/bin/python #! /usr/bin/python
from __future__ import print_function
import argparse
import ast import ast
import os
import re
import sys
import yaml import yaml
class DefinitionVisitor(ast.NodeVisitor): class DefinitionVisitor(ast.NodeVisitor):
def __init__(self): def __init__(self):
super(DefinitionVisitor, self).__init__() super(DefinitionVisitor, self).__init__()
@ -42,15 +50,18 @@ def non_empty(defs):
functions = {name: non_empty(f) for name, f in defs['def'].items()} functions = {name: non_empty(f) for name, f in defs['def'].items()}
classes = {name: non_empty(f) for name, f in defs['class'].items()} classes = {name: non_empty(f) for name, f in defs['class'].items()}
result = {} result = {}
if functions: result['def'] = functions if functions:
if classes: result['class'] = classes result['def'] = functions
if classes:
result['class'] = classes
names = defs['names'] names = defs['names']
uses = [] uses = []
for name in names.get('Load', ()): for name in names.get('Load', ()):
if name not in names.get('Param', ()) and name not in names.get('Store', ()): if name not in names.get('Param', ()) and name not in names.get('Store', ()):
uses.append(name) uses.append(name)
uses.extend(defs['attrs']) uses.extend(defs['attrs'])
if uses: result['uses'] = uses if uses:
result['uses'] = uses
result['names'] = names result['names'] = names
result['attrs'] = defs['attrs'] result['attrs'] = defs['attrs']
return result return result
@ -95,7 +106,6 @@ def used_names(prefix, item, defs, names):
if __name__ == '__main__': if __name__ == '__main__':
import sys, os, argparse, re
parser = argparse.ArgumentParser(description='Find definitions.') parser = argparse.ArgumentParser(description='Find definitions.')
parser.add_argument( parser.add_argument(
@ -105,24 +115,28 @@ if __name__ == '__main__':
"--ignore", action="append", metavar="REGEXP", help="Ignore a pattern" "--ignore", action="append", metavar="REGEXP", help="Ignore a pattern"
) )
parser.add_argument( parser.add_argument(
"--pattern", action="append", metavar="REGEXP", "--pattern", action="append", metavar="REGEXP", help="Search for a pattern"
help="Search for a pattern"
) )
parser.add_argument( parser.add_argument(
"directories", nargs='+', metavar="DIR", "directories",
help="Directories to search for definitions" nargs='+',
metavar="DIR",
help="Directories to search for definitions",
) )
parser.add_argument( parser.add_argument(
"--referrers", default=0, type=int, "--referrers",
help="Include referrers up to the given depth" default=0,
type=int,
help="Include referrers up to the given depth",
) )
parser.add_argument( parser.add_argument(
"--referred", default=0, type=int, "--referred",
help="Include referred down to the given depth" default=0,
type=int,
help="Include referred down to the given depth",
) )
parser.add_argument( parser.add_argument(
"--format", default="yaml", "--format", default="yaml", help="Output format, one of 'yaml' or 'dot'"
help="Output format, one of 'yaml' or 'dot'"
) )
args = parser.parse_args() args = parser.parse_args()
@ -162,7 +176,7 @@ if __name__ == '__main__':
for used_by in entry.get("used", ()): for used_by in entry.get("used", ()):
referrers.add(used_by) referrers.add(used_by)
for name, definition in names.items(): for name, definition in names.items():
if not name in referrers: if name not in referrers:
continue continue
if ignore and any(pattern.match(name) for pattern in ignore): if ignore and any(pattern.match(name) for pattern in ignore):
continue continue
@ -176,7 +190,7 @@ if __name__ == '__main__':
for uses in entry.get("uses", ()): for uses in entry.get("uses", ()):
referred.add(uses) referred.add(uses)
for name, definition in names.items(): for name, definition in names.items():
if not name in referred: if name not in referred:
continue continue
if ignore and any(pattern.match(name) for pattern in ignore): if ignore and any(pattern.match(name) for pattern in ignore):
continue continue
@ -185,12 +199,12 @@ if __name__ == '__main__':
if args.format == 'yaml': if args.format == 'yaml':
yaml.dump(result, sys.stdout, default_flow_style=False) yaml.dump(result, sys.stdout, default_flow_style=False)
elif args.format == 'dot': elif args.format == 'dot':
print "digraph {" print("digraph {")
for name, entry in result.items(): for name, entry in result.items():
print name print(name)
for used_by in entry.get("used", ()): for used_by in entry.get("used", ()):
if used_by in result: if used_by in result:
print used_by, "->", name print(used_by, "->", name)
print "}" print("}")
else: else:
raise ValueError("Unknown format %r" % (args.format)) raise ValueError("Unknown format %r" % (args.format))

View File

@ -1,8 +1,11 @@
#!/usr/bin/env python2 #!/usr/bin/env python2
import pymacaroons from __future__ import print_function
import sys import sys
import pymacaroons
if len(sys.argv) == 1: if len(sys.argv) == 1:
sys.stderr.write("usage: %s macaroon [key]\n" % (sys.argv[0],)) sys.stderr.write("usage: %s macaroon [key]\n" % (sys.argv[0],))
sys.exit(1) sys.exit(1)
@ -11,14 +14,14 @@ macaroon_string = sys.argv[1]
key = sys.argv[2] if len(sys.argv) > 2 else None key = sys.argv[2] if len(sys.argv) > 2 else None
macaroon = pymacaroons.Macaroon.deserialize(macaroon_string) macaroon = pymacaroons.Macaroon.deserialize(macaroon_string)
print macaroon.inspect() print(macaroon.inspect())
print "" print("")
verifier = pymacaroons.Verifier() verifier = pymacaroons.Verifier()
verifier.satisfy_general(lambda c: True) verifier.satisfy_general(lambda c: True)
try: try:
verifier.verify(macaroon, key) verifier.verify(macaroon, key)
print "Signature is correct" print("Signature is correct")
except Exception as e: except Exception as e:
print str(e) print(str(e))

View File

@ -18,21 +18,21 @@
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import base64
import json
import sys
from urlparse import urlparse, urlunparse from urlparse import urlparse, urlunparse
import nacl.signing import nacl.signing
import json
import base64
import requests import requests
import sys
from requests.adapters import HTTPAdapter
import srvlookup import srvlookup
import yaml import yaml
from requests.adapters import HTTPAdapter
# uncomment the following to enable debug logging of http requests # uncomment the following to enable debug logging of http requests
#from httplib import HTTPConnection # from httplib import HTTPConnection
#HTTPConnection.debuglevel = 1 # HTTPConnection.debuglevel = 1
def encode_base64(input_bytes): def encode_base64(input_bytes):
"""Encode bytes as a base64 string without any padding.""" """Encode bytes as a base64 string without any padding."""
@ -58,15 +58,15 @@ def decode_base64(input_string):
def encode_canonical_json(value): def encode_canonical_json(value):
return json.dumps( return json.dumps(
value, value,
# Encode code-points outside of ASCII as UTF-8 rather than \u escapes # Encode code-points outside of ASCII as UTF-8 rather than \u escapes
ensure_ascii=False, ensure_ascii=False,
# Remove unecessary white space. # Remove unecessary white space.
separators=(',',':'), separators=(',', ':'),
# Sort the keys of dictionaries. # Sort the keys of dictionaries.
sort_keys=True, sort_keys=True,
# Encode the resulting unicode as UTF-8 bytes. # Encode the resulting unicode as UTF-8 bytes.
).encode("UTF-8") ).encode("UTF-8")
def sign_json(json_object, signing_key, signing_name): def sign_json(json_object, signing_key, signing_name):
@ -88,6 +88,7 @@ def sign_json(json_object, signing_key, signing_name):
NACL_ED25519 = "ed25519" NACL_ED25519 = "ed25519"
def decode_signing_key_base64(algorithm, version, key_base64): def decode_signing_key_base64(algorithm, version, key_base64):
"""Decode a base64 encoded signing key """Decode a base64 encoded signing key
Args: Args:
@ -143,14 +144,12 @@ def request_json(method, origin_name, origin_key, destination, path, content):
authorization_headers = [] authorization_headers = []
for key, sig in signed_json["signatures"][origin_name].items(): for key, sig in signed_json["signatures"][origin_name].items():
header = "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % ( header = "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (origin_name, key, sig)
origin_name, key, sig,
)
authorization_headers.append(bytes(header)) authorization_headers.append(bytes(header))
print ("Authorization: %s" % header, file=sys.stderr) print("Authorization: %s" % header, file=sys.stderr)
dest = "matrix://%s%s" % (destination, path) dest = "matrix://%s%s" % (destination, path)
print ("Requesting %s" % dest, file=sys.stderr) print("Requesting %s" % dest, file=sys.stderr)
s = requests.Session() s = requests.Session()
s.mount("matrix://", MatrixConnectionAdapter()) s.mount("matrix://", MatrixConnectionAdapter())
@ -158,10 +157,7 @@ def request_json(method, origin_name, origin_key, destination, path, content):
result = s.request( result = s.request(
method=method, method=method,
url=dest, url=dest,
headers={ headers={"Host": destination, "Authorization": authorization_headers[0]},
"Host": destination,
"Authorization": authorization_headers[0]
},
verify=False, verify=False,
data=content, data=content,
) )
@ -171,50 +167,50 @@ def request_json(method, origin_name, origin_key, destination, path, content):
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description= description="Signs and sends a federation request to a matrix homeserver"
"Signs and sends a federation request to a matrix homeserver",
) )
parser.add_argument( parser.add_argument(
"-N", "--server-name", "-N",
"--server-name",
help="Name to give as the local homeserver. If unspecified, will be " help="Name to give as the local homeserver. If unspecified, will be "
"read from the config file.", "read from the config file.",
) )
parser.add_argument( parser.add_argument(
"-k", "--signing-key-path", "-k",
"--signing-key-path",
help="Path to the file containing the private ed25519 key to sign the " help="Path to the file containing the private ed25519 key to sign the "
"request with.", "request with.",
) )
parser.add_argument( parser.add_argument(
"-c", "--config", "-c",
"--config",
default="homeserver.yaml", default="homeserver.yaml",
help="Path to server config file. Ignored if --server-name and " help="Path to server config file. Ignored if --server-name and "
"--signing-key-path are both given.", "--signing-key-path are both given.",
) )
parser.add_argument( parser.add_argument(
"-d", "--destination", "-d",
"--destination",
default="matrix.org", default="matrix.org",
help="name of the remote homeserver. We will do SRV lookups and " help="name of the remote homeserver. We will do SRV lookups and "
"connect appropriately.", "connect appropriately.",
) )
parser.add_argument( parser.add_argument(
"-X", "--method", "-X",
"--method",
help="HTTP method to use for the request. Defaults to GET if --data is" help="HTTP method to use for the request. Defaults to GET if --data is"
"unspecified, POST if it is." "unspecified, POST if it is.",
) )
parser.add_argument( parser.add_argument("--body", help="Data to send as the body of the HTTP request")
"--body",
help="Data to send as the body of the HTTP request"
)
parser.add_argument( parser.add_argument(
"path", "path", help="request path. We will add '/_matrix/federation/v1/' to this."
help="request path. We will add '/_matrix/federation/v1/' to this."
) )
args = parser.parse_args() args = parser.parse_args()
@ -227,13 +223,15 @@ def main():
result = request_json( result = request_json(
args.method, args.method,
args.server_name, key, args.destination, args.server_name,
key,
args.destination,
"/_matrix/federation/v1/" + args.path, "/_matrix/federation/v1/" + args.path,
content=args.body, content=args.body,
) )
json.dump(result, sys.stdout) json.dump(result, sys.stdout)
print ("") print("")
def read_args_from_config(args): def read_args_from_config(args):
@ -253,7 +251,7 @@ class MatrixConnectionAdapter(HTTPAdapter):
return s, 8448 return s, 8448
if ":" in s: if ":" in s:
out = s.rsplit(":",1) out = s.rsplit(":", 1)
try: try:
port = int(out[1]) port = int(out[1])
except ValueError: except ValueError:
@ -263,7 +261,7 @@ class MatrixConnectionAdapter(HTTPAdapter):
try: try:
srv = srvlookup.lookup("matrix", "tcp", s)[0] srv = srvlookup.lookup("matrix", "tcp", s)[0]
return srv.host, srv.port return srv.host, srv.port
except: except Exception:
return s, 8448 return s, 8448
def get_connection(self, url, proxies=None): def get_connection(self, url, proxies=None):
@ -272,10 +270,9 @@ class MatrixConnectionAdapter(HTTPAdapter):
(host, port) = self.lookup(parsed.netloc) (host, port) = self.lookup(parsed.netloc)
netloc = "%s:%d" % (host, port) netloc = "%s:%d" % (host, port)
print("Connecting to %s" % (netloc,), file=sys.stderr) print("Connecting to %s" % (netloc,), file=sys.stderr)
url = urlunparse(( url = urlunparse(
"https", netloc, parsed.path, parsed.params, parsed.query, ("https", netloc, parsed.path, parsed.params, parsed.query, parsed.fragment)
parsed.fragment, )
))
return super(MatrixConnectionAdapter, self).get_connection(url, proxies) return super(MatrixConnectionAdapter, self).get_connection(url, proxies)

View File

@ -1,23 +1,31 @@
from synapse.storage.pdu import PduStore from __future__ import print_function
from synapse.storage.signatures import SignatureStore
from synapse.storage._base import SQLBaseStore
from synapse.federation.units import Pdu
from synapse.crypto.event_signing import (
add_event_pdu_content_hash, compute_pdu_event_reference_hash
)
from synapse.api.events.utils import prune_pdu
from unpaddedbase64 import encode_base64, decode_base64
from canonicaljson import encode_canonical_json
import sqlite3 import sqlite3
import sys import sys
from unpaddedbase64 import decode_base64, encode_base64
from synapse.crypto.event_signing import (
add_event_pdu_content_hash,
compute_pdu_event_reference_hash,
)
from synapse.federation.units import Pdu
from synapse.storage._base import SQLBaseStore
from synapse.storage.pdu import PduStore
from synapse.storage.signatures import SignatureStore
class Store(object): class Store(object):
_get_pdu_tuples = PduStore.__dict__["_get_pdu_tuples"] _get_pdu_tuples = PduStore.__dict__["_get_pdu_tuples"]
_get_pdu_content_hashes_txn = SignatureStore.__dict__["_get_pdu_content_hashes_txn"] _get_pdu_content_hashes_txn = SignatureStore.__dict__["_get_pdu_content_hashes_txn"]
_get_prev_pdu_hashes_txn = SignatureStore.__dict__["_get_prev_pdu_hashes_txn"] _get_prev_pdu_hashes_txn = SignatureStore.__dict__["_get_prev_pdu_hashes_txn"]
_get_pdu_origin_signatures_txn = SignatureStore.__dict__["_get_pdu_origin_signatures_txn"] _get_pdu_origin_signatures_txn = SignatureStore.__dict__[
"_get_pdu_origin_signatures_txn"
]
_store_pdu_content_hash_txn = SignatureStore.__dict__["_store_pdu_content_hash_txn"] _store_pdu_content_hash_txn = SignatureStore.__dict__["_store_pdu_content_hash_txn"]
_store_pdu_reference_hash_txn = SignatureStore.__dict__["_store_pdu_reference_hash_txn"] _store_pdu_reference_hash_txn = SignatureStore.__dict__[
"_store_pdu_reference_hash_txn"
]
_store_prev_pdu_hash_txn = SignatureStore.__dict__["_store_prev_pdu_hash_txn"] _store_prev_pdu_hash_txn = SignatureStore.__dict__["_store_prev_pdu_hash_txn"]
_simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"] _simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"]
@ -26,9 +34,7 @@ store = Store()
def select_pdus(cursor): def select_pdus(cursor):
cursor.execute( cursor.execute("SELECT pdu_id, origin FROM pdus ORDER BY depth ASC")
"SELECT pdu_id, origin FROM pdus ORDER BY depth ASC"
)
ids = cursor.fetchall() ids = cursor.fetchall()
@ -41,23 +47,30 @@ def select_pdus(cursor):
for pdu in pdus: for pdu in pdus:
try: try:
if pdu.prev_pdus: if pdu.prev_pdus:
print "PROCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus print("PROCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus)
for pdu_id, origin, hashes in pdu.prev_pdus: for pdu_id, origin, hashes in pdu.prev_pdus:
ref_alg, ref_hsh = reference_hashes[(pdu_id, origin)] ref_alg, ref_hsh = reference_hashes[(pdu_id, origin)]
hashes[ref_alg] = encode_base64(ref_hsh) hashes[ref_alg] = encode_base64(ref_hsh)
store._store_prev_pdu_hash_txn(cursor, pdu.pdu_id, pdu.origin, pdu_id, origin, ref_alg, ref_hsh) store._store_prev_pdu_hash_txn(
print "SUCCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus cursor, pdu.pdu_id, pdu.origin, pdu_id, origin, ref_alg, ref_hsh
)
print("SUCCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus)
pdu = add_event_pdu_content_hash(pdu) pdu = add_event_pdu_content_hash(pdu)
ref_alg, ref_hsh = compute_pdu_event_reference_hash(pdu) ref_alg, ref_hsh = compute_pdu_event_reference_hash(pdu)
reference_hashes[(pdu.pdu_id, pdu.origin)] = (ref_alg, ref_hsh) reference_hashes[(pdu.pdu_id, pdu.origin)] = (ref_alg, ref_hsh)
store._store_pdu_reference_hash_txn(cursor, pdu.pdu_id, pdu.origin, ref_alg, ref_hsh) store._store_pdu_reference_hash_txn(
cursor, pdu.pdu_id, pdu.origin, ref_alg, ref_hsh
)
for alg, hsh_base64 in pdu.hashes.items(): for alg, hsh_base64 in pdu.hashes.items():
print alg, hsh_base64 print(alg, hsh_base64)
store._store_pdu_content_hash_txn(cursor, pdu.pdu_id, pdu.origin, alg, decode_base64(hsh_base64)) store._store_pdu_content_hash_txn(
cursor, pdu.pdu_id, pdu.origin, alg, decode_base64(hsh_base64)
)
except Exception:
print("FAILED_", pdu.pdu_id, pdu.origin, pdu.prev_pdus)
except:
print "FAILED_", pdu.pdu_id, pdu.origin, pdu.prev_pdus
def main(): def main():
conn = sqlite3.connect(sys.argv[1]) conn = sqlite3.connect(sys.argv[1])
@ -65,5 +78,6 @@ def main():
select_pdus(cursor) select_pdus(cursor)
conn.commit() conn.commit()
if __name__=='__main__':
if __name__ == '__main__':
main() main()

View File

@ -1,18 +1,17 @@
#! /usr/bin/python #! /usr/bin/python
import ast
import argparse import argparse
import ast
import os import os
import sys import sys
import yaml import yaml
PATTERNS_V1 = [] PATTERNS_V1 = []
PATTERNS_V2 = [] PATTERNS_V2 = []
RESULT = { RESULT = {"v1": PATTERNS_V1, "v2": PATTERNS_V2}
"v1": PATTERNS_V1,
"v2": PATTERNS_V2,
}
class CallVisitor(ast.NodeVisitor): class CallVisitor(ast.NodeVisitor):
def visit_Call(self, node): def visit_Call(self, node):
@ -21,7 +20,6 @@ class CallVisitor(ast.NodeVisitor):
else: else:
return return
if name == "client_path_patterns": if name == "client_path_patterns":
PATTERNS_V1.append(node.args[0].s) PATTERNS_V1.append(node.args[0].s)
elif name == "client_v2_patterns": elif name == "client_v2_patterns":
@ -42,8 +40,10 @@ def find_patterns_in_file(filepath):
parser = argparse.ArgumentParser(description='Find url patterns.') parser = argparse.ArgumentParser(description='Find url patterns.')
parser.add_argument( parser.add_argument(
"directories", nargs='+', metavar="DIR", "directories",
help="Directories to search for definitions" nargs='+',
metavar="DIR",
help="Directories to search for definitions",
) )
args = parser.parse_args() args = parser.parse_args()

View File

@ -1,8 +1,9 @@
import requests
import collections import collections
import json
import sys import sys
import time import time
import json
import requests
Entry = collections.namedtuple("Entry", "name position rows") Entry = collections.namedtuple("Entry", "name position rows")
@ -30,11 +31,11 @@ def parse_response(content):
def replicate(server, streams): def replicate(server, streams):
return parse_response(requests.get( return parse_response(
server + "/_synapse/replication", requests.get(
verify=False, server + "/_synapse/replication", verify=False, params=streams
params=streams ).content
).content) )
def main(): def main():
@ -45,16 +46,16 @@ def main():
try: try:
streams = { streams = {
row.name: row.position row.name: row.position
for row in replicate(server, {"streams":"-1"})["streams"].rows for row in replicate(server, {"streams": "-1"})["streams"].rows
} }
except requests.exceptions.ConnectionError as e: except requests.exceptions.ConnectionError:
time.sleep(0.1) time.sleep(0.1)
while True: while True:
try: try:
results = replicate(server, streams) results = replicate(server, streams)
except: except Exception:
sys.stdout.write("connection_lost("+ repr(streams) + ")\n") sys.stdout.write("connection_lost(" + repr(streams) + ")\n")
break break
for update in results.values(): for update in results.values():
for row in update.rows: for row in update.rows:
@ -62,6 +63,5 @@ def main():
streams[update.name] = update.position streams[update.name] = update.position
if __name__ == '__main__':
if __name__=='__main__':
main() main()

View File

@ -1,12 +1,10 @@
#!/usr/bin/env python #!/usr/bin/env python
import argparse import argparse
import getpass
import sys import sys
import bcrypt import bcrypt
import getpass
import yaml import yaml
bcrypt_rounds=12 bcrypt_rounds=12
@ -52,4 +50,3 @@ if __name__ == "__main__":
password = prompt_for_pass() password = prompt_for_pass()
print bcrypt.hashpw(password + password_pepper, bcrypt.gensalt(bcrypt_rounds)) print bcrypt.hashpw(password + password_pepper, bcrypt.gensalt(bcrypt_rounds))

View File

@ -36,12 +36,9 @@ from __future__ import print_function
import argparse import argparse
import logging import logging
import sys
import os import os
import shutil import shutil
import sys
from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.rest.media.v1.filepath import MediaFilePaths
@ -77,24 +74,23 @@ def move_media(origin_server, file_id, src_paths, dest_paths):
if not os.path.exists(original_file): if not os.path.exists(original_file):
logger.warn( logger.warn(
"Original for %s/%s (%s) does not exist", "Original for %s/%s (%s) does not exist",
origin_server, file_id, original_file, origin_server,
file_id,
original_file,
) )
else: else:
mkdir_and_move( mkdir_and_move(
original_file, original_file, dest_paths.remote_media_filepath(origin_server, file_id)
dest_paths.remote_media_filepath(origin_server, file_id),
) )
# now look for thumbnails # now look for thumbnails
original_thumb_dir = src_paths.remote_media_thumbnail_dir( original_thumb_dir = src_paths.remote_media_thumbnail_dir(origin_server, file_id)
origin_server, file_id,
)
if not os.path.exists(original_thumb_dir): if not os.path.exists(original_thumb_dir):
return return
mkdir_and_move( mkdir_and_move(
original_thumb_dir, original_thumb_dir,
dest_paths.remote_media_thumbnail_dir(origin_server, file_id) dest_paths.remote_media_thumbnail_dir(origin_server, file_id),
) )
@ -109,24 +105,16 @@ def mkdir_and_move(original_file, dest_file):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description=__doc__, description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
formatter_class = argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"-v", action='store_true', help='enable debug logging')
parser.add_argument(
"src_repo",
help="Path to source content repo",
)
parser.add_argument(
"dest_repo",
help="Path to source content repo",
) )
parser.add_argument("-v", action='store_true', help='enable debug logging')
parser.add_argument("src_repo", help="Path to source content repo")
parser.add_argument("dest_repo", help="Path to source content repo")
args = parser.parse_args() args = parser.parse_args()
logging_config = { logging_config = {
"level": logging.DEBUG if args.v else logging.INFO, "level": logging.DEBUG if args.v else logging.INFO,
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s" "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
} }
logging.basicConfig(**logging_config) logging.basicConfig(**logging_config)

View File

@ -14,197 +14,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function
import argparse from synapse._scripts.register_new_matrix_user import main
import getpass
import hashlib
import hmac
import json
import sys
import urllib2
import yaml
def request_registration(user, password, server_location, shared_secret, admin=False):
req = urllib2.Request(
"%s/_matrix/client/r0/admin/register" % (server_location,),
headers={'Content-Type': 'application/json'}
)
try:
if sys.version_info[:3] >= (2, 7, 9):
# As of version 2.7.9, urllib2 now checks SSL certs
import ssl
f = urllib2.urlopen(req, context=ssl.SSLContext(ssl.PROTOCOL_SSLv23))
else:
f = urllib2.urlopen(req)
body = f.read()
f.close()
nonce = json.loads(body)["nonce"]
except urllib2.HTTPError as e:
print "ERROR! Received %d %s" % (e.code, e.reason,)
if 400 <= e.code < 500:
if e.info().type == "application/json":
resp = json.load(e)
if "error" in resp:
print resp["error"]
sys.exit(1)
mac = hmac.new(
key=shared_secret,
digestmod=hashlib.sha1,
)
mac.update(nonce)
mac.update("\x00")
mac.update(user)
mac.update("\x00")
mac.update(password)
mac.update("\x00")
mac.update("admin" if admin else "notadmin")
mac = mac.hexdigest()
data = {
"nonce": nonce,
"username": user,
"password": password,
"mac": mac,
"admin": admin,
}
server_location = server_location.rstrip("/")
print "Sending registration request..."
req = urllib2.Request(
"%s/_matrix/client/r0/admin/register" % (server_location,),
data=json.dumps(data),
headers={'Content-Type': 'application/json'}
)
try:
if sys.version_info[:3] >= (2, 7, 9):
# As of version 2.7.9, urllib2 now checks SSL certs
import ssl
f = urllib2.urlopen(req, context=ssl.SSLContext(ssl.PROTOCOL_SSLv23))
else:
f = urllib2.urlopen(req)
f.read()
f.close()
print "Success."
except urllib2.HTTPError as e:
print "ERROR! Received %d %s" % (e.code, e.reason,)
if 400 <= e.code < 500:
if e.info().type == "application/json":
resp = json.load(e)
if "error" in resp:
print resp["error"]
sys.exit(1)
def register_new_user(user, password, server_location, shared_secret, admin):
if not user:
try:
default_user = getpass.getuser()
except:
default_user = None
if default_user:
user = raw_input("New user localpart [%s]: " % (default_user,))
if not user:
user = default_user
else:
user = raw_input("New user localpart: ")
if not user:
print "Invalid user name"
sys.exit(1)
if not password:
password = getpass.getpass("Password: ")
if not password:
print "Password cannot be blank."
sys.exit(1)
confirm_password = getpass.getpass("Confirm password: ")
if password != confirm_password:
print "Passwords do not match"
sys.exit(1)
if admin is None:
admin = raw_input("Make admin [no]: ")
if admin in ("y", "yes", "true"):
admin = True
else:
admin = False
request_registration(user, password, server_location, shared_secret, bool(admin))
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( main()
description="Used to register new users with a given home server when"
" registration has been disabled. The home server must be"
" configured with the 'registration_shared_secret' option"
" set.",
)
parser.add_argument(
"-u", "--user",
default=None,
help="Local part of the new user. Will prompt if omitted.",
)
parser.add_argument(
"-p", "--password",
default=None,
help="New password for user. Will prompt if omitted.",
)
admin_group = parser.add_mutually_exclusive_group()
admin_group.add_argument(
"-a", "--admin",
action="store_true",
help="Register new user as an admin. Will prompt if --no-admin is not set either.",
)
admin_group.add_argument(
"--no-admin",
action="store_true",
help="Register new user as a regular user. Will prompt if --admin is not set either.",
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"-c", "--config",
type=argparse.FileType('r'),
help="Path to server config file. Used to read in shared secret.",
)
group.add_argument(
"-k", "--shared-secret",
help="Shared secret as defined in server config file.",
)
parser.add_argument(
"server_url",
default="https://localhost:8448",
nargs='?',
help="URL to use to talk to the home server. Defaults to "
" 'https://localhost:8448'.",
)
args = parser.parse_args()
if "config" in args and args.config:
config = yaml.safe_load(args.config)
secret = config.get("registration_shared_secret", None)
if not secret:
print "No 'registration_shared_secret' defined in config."
sys.exit(1)
else:
secret = args.shared_secret
admin = None
if args.admin or args.no_admin:
admin = args.admin
register_new_user(args.user, args.password, args.server_url, secret, admin)

View File

@ -15,23 +15,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer, reactor
from twisted.enterprise import adbapi
from synapse.storage._base import LoggingTransaction, SQLBaseStore
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
import argparse import argparse
import curses import curses
import logging import logging
import sys import sys
import time import time
import traceback import traceback
import yaml
from six import string_types from six import string_types
import yaml
from twisted.enterprise import adbapi
from twisted.internet import defer, reactor
from synapse.storage._base import LoggingTransaction, SQLBaseStore
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
logger = logging.getLogger("synapse_port_db") logger = logging.getLogger("synapse_port_db")
@ -105,6 +105,7 @@ class Store(object):
*All* database interactions should go through this object. *All* database interactions should go through this object.
""" """
def __init__(self, db_pool, engine): def __init__(self, db_pool, engine):
self.db_pool = db_pool self.db_pool = db_pool
self.database_engine = engine self.database_engine = engine
@ -135,7 +136,8 @@ class Store(object):
txn = conn.cursor() txn = conn.cursor()
return func( return func(
LoggingTransaction(txn, desc, self.database_engine, [], []), LoggingTransaction(txn, desc, self.database_engine, [], []),
*args, **kwargs *args,
**kwargs
) )
except self.database_engine.module.DatabaseError as e: except self.database_engine.module.DatabaseError as e:
if self.database_engine.is_deadlock(e): if self.database_engine.is_deadlock(e):
@ -158,22 +160,20 @@ class Store(object):
def r(txn): def r(txn):
txn.execute(sql, args) txn.execute(sql, args)
return txn.fetchall() return txn.fetchall()
return self.runInteraction("execute_sql", r) return self.runInteraction("execute_sql", r)
def insert_many_txn(self, txn, table, headers, rows): def insert_many_txn(self, txn, table, headers, rows):
sql = "INSERT INTO %s (%s) VALUES (%s)" % ( sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table, table,
", ".join(k for k in headers), ", ".join(k for k in headers),
", ".join("%s" for _ in headers) ", ".join("%s" for _ in headers),
) )
try: try:
txn.executemany(sql, rows) txn.executemany(sql, rows)
except: except Exception:
logger.exception( logger.exception("Failed to insert: %s", table)
"Failed to insert: %s",
table,
)
raise raise
@ -206,7 +206,7 @@ class Porter(object):
"table_name": table, "table_name": table,
"forward_rowid": 1, "forward_rowid": 1,
"backward_rowid": 0, "backward_rowid": 0,
} },
) )
forward_chunk = 1 forward_chunk = 1
@ -221,10 +221,10 @@ class Porter(object):
table, forward_chunk, backward_chunk table, forward_chunk, backward_chunk
) )
else: else:
def delete_all(txn): def delete_all(txn):
txn.execute( txn.execute(
"DELETE FROM port_from_sqlite3 WHERE table_name = %s", "DELETE FROM port_from_sqlite3 WHERE table_name = %s", (table,)
(table,)
) )
txn.execute("TRUNCATE %s CASCADE" % (table,)) txn.execute("TRUNCATE %s CASCADE" % (table,))
@ -232,11 +232,7 @@ class Porter(object):
yield self.postgres_store._simple_insert( yield self.postgres_store._simple_insert(
table="port_from_sqlite3", table="port_from_sqlite3",
values={ values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0},
"table_name": table,
"forward_rowid": 1,
"backward_rowid": 0,
}
) )
forward_chunk = 1 forward_chunk = 1
@ -251,12 +247,16 @@ class Porter(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_table(self, table, postgres_size, table_size, forward_chunk, def handle_table(
backward_chunk): self, table, postgres_size, table_size, forward_chunk, backward_chunk
):
logger.info( logger.info(
"Table %s: %i/%i (rows %i-%i) already ported", "Table %s: %i/%i (rows %i-%i) already ported",
table, postgres_size, table_size, table,
backward_chunk+1, forward_chunk-1, postgres_size,
table_size,
backward_chunk + 1,
forward_chunk - 1,
) )
if not table_size: if not table_size:
@ -271,7 +271,9 @@ class Porter(object):
return return
if table in ( if table in (
"user_directory", "user_directory_search", "users_who_share_rooms", "user_directory",
"user_directory_search",
"users_who_share_rooms",
"users_in_pubic_room", "users_in_pubic_room",
): ):
# We don't port these tables, as they're a faff and we can regenreate # We don't port these tables, as they're a faff and we can regenreate
@ -283,37 +285,35 @@ class Porter(object):
# We need to make sure there is a single row, `(X, null), as that is # We need to make sure there is a single row, `(X, null), as that is
# what synapse expects to be there. # what synapse expects to be there.
yield self.postgres_store._simple_insert( yield self.postgres_store._simple_insert(
table=table, table=table, values={"stream_id": None}
values={"stream_id": None},
) )
self.progress.update(table, table_size) # Mark table as done self.progress.update(table, table_size) # Mark table as done
return return
forward_select = ( forward_select = (
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?" "SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?" % (table,)
% (table,)
) )
backward_select = ( backward_select = (
"SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?" "SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?" % (table,)
% (table,)
) )
do_forward = [True] do_forward = [True]
do_backward = [True] do_backward = [True]
while True: while True:
def r(txn): def r(txn):
forward_rows = [] forward_rows = []
backward_rows = [] backward_rows = []
if do_forward[0]: if do_forward[0]:
txn.execute(forward_select, (forward_chunk, self.batch_size,)) txn.execute(forward_select, (forward_chunk, self.batch_size))
forward_rows = txn.fetchall() forward_rows = txn.fetchall()
if not forward_rows: if not forward_rows:
do_forward[0] = False do_forward[0] = False
if do_backward[0]: if do_backward[0]:
txn.execute(backward_select, (backward_chunk, self.batch_size,)) txn.execute(backward_select, (backward_chunk, self.batch_size))
backward_rows = txn.fetchall() backward_rows = txn.fetchall()
if not backward_rows: if not backward_rows:
do_backward[0] = False do_backward[0] = False
@ -325,9 +325,7 @@ class Porter(object):
return headers, forward_rows, backward_rows return headers, forward_rows, backward_rows
headers, frows, brows = yield self.sqlite_store.runInteraction( headers, frows, brows = yield self.sqlite_store.runInteraction("select", r)
"select", r
)
if frows or brows: if frows or brows:
if frows: if frows:
@ -339,9 +337,7 @@ class Porter(object):
rows = self._convert_rows(table, headers, rows) rows = self._convert_rows(table, headers, rows)
def insert(txn): def insert(txn):
self.postgres_store.insert_many_txn( self.postgres_store.insert_many_txn(txn, table, headers[1:], rows)
txn, table, headers[1:], rows
)
self.postgres_store._simple_update_one_txn( self.postgres_store._simple_update_one_txn(
txn, txn,
@ -362,8 +358,9 @@ class Porter(object):
return return
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_search_table(self, postgres_size, table_size, forward_chunk, def handle_search_table(
backward_chunk): self, postgres_size, table_size, forward_chunk, backward_chunk
):
select = ( select = (
"SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering" "SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
" FROM event_search as es" " FROM event_search as es"
@ -373,8 +370,9 @@ class Porter(object):
) )
while True: while True:
def r(txn): def r(txn):
txn.execute(select, (forward_chunk, self.batch_size,)) txn.execute(select, (forward_chunk, self.batch_size))
rows = txn.fetchall() rows = txn.fetchall()
headers = [column[0] for column in txn.description] headers = [column[0] for column in txn.description]
@ -402,18 +400,21 @@ class Porter(object):
else: else:
rows_dict.append(d) rows_dict.append(d)
txn.executemany(sql, [ txn.executemany(
( sql,
row["event_id"], [
row["room_id"], (
row["key"], row["event_id"],
row["sender"], row["room_id"],
row["value"], row["key"],
row["origin_server_ts"], row["sender"],
row["stream_ordering"], row["value"],
) row["origin_server_ts"],
for row in rows_dict row["stream_ordering"],
]) )
for row in rows_dict
],
)
self.postgres_store._simple_update_one_txn( self.postgres_store._simple_update_one_txn(
txn, txn,
@ -437,7 +438,8 @@ class Porter(object):
def setup_db(self, db_config, database_engine): def setup_db(self, db_config, database_engine):
db_conn = database_engine.module.connect( db_conn = database_engine.module.connect(
**{ **{
k: v for k, v in db_config.get("args", {}).items() k: v
for k, v in db_config.get("args", {}).items()
if not k.startswith("cp_") if not k.startswith("cp_")
} }
) )
@ -450,13 +452,11 @@ class Porter(object):
def run(self): def run(self):
try: try:
sqlite_db_pool = adbapi.ConnectionPool( sqlite_db_pool = adbapi.ConnectionPool(
self.sqlite_config["name"], self.sqlite_config["name"], **self.sqlite_config["args"]
**self.sqlite_config["args"]
) )
postgres_db_pool = adbapi.ConnectionPool( postgres_db_pool = adbapi.ConnectionPool(
self.postgres_config["name"], self.postgres_config["name"], **self.postgres_config["args"]
**self.postgres_config["args"]
) )
sqlite_engine = create_engine(sqlite_config) sqlite_engine = create_engine(sqlite_config)
@ -465,9 +465,7 @@ class Porter(object):
self.sqlite_store = Store(sqlite_db_pool, sqlite_engine) self.sqlite_store = Store(sqlite_db_pool, sqlite_engine)
self.postgres_store = Store(postgres_db_pool, postgres_engine) self.postgres_store = Store(postgres_db_pool, postgres_engine)
yield self.postgres_store.execute( yield self.postgres_store.execute(postgres_engine.check_database)
postgres_engine.check_database
)
# Step 1. Set up databases. # Step 1. Set up databases.
self.progress.set_state("Preparing SQLite3") self.progress.set_state("Preparing SQLite3")
@ -477,6 +475,7 @@ class Porter(object):
self.setup_db(postgres_config, postgres_engine) self.setup_db(postgres_config, postgres_engine)
self.progress.set_state("Creating port tables") self.progress.set_state("Creating port tables")
def create_port_table(txn): def create_port_table(txn):
txn.execute( txn.execute(
"CREATE TABLE IF NOT EXISTS port_from_sqlite3 (" "CREATE TABLE IF NOT EXISTS port_from_sqlite3 ("
@ -501,10 +500,9 @@ class Porter(object):
) )
try: try:
yield self.postgres_store.runInteraction( yield self.postgres_store.runInteraction("alter_table", alter_table)
"alter_table", alter_table except Exception:
) # On Error Resume Next
except Exception as e:
pass pass
yield self.postgres_store.runInteraction( yield self.postgres_store.runInteraction(
@ -514,11 +512,7 @@ class Porter(object):
# Step 2. Get tables. # Step 2. Get tables.
self.progress.set_state("Fetching tables") self.progress.set_state("Fetching tables")
sqlite_tables = yield self.sqlite_store._simple_select_onecol( sqlite_tables = yield self.sqlite_store._simple_select_onecol(
table="sqlite_master", table="sqlite_master", keyvalues={"type": "table"}, retcol="name"
keyvalues={
"type": "table",
},
retcol="name",
) )
postgres_tables = yield self.postgres_store._simple_select_onecol( postgres_tables = yield self.postgres_store._simple_select_onecol(
@ -545,18 +539,14 @@ class Porter(object):
# Step 4. Do the copying. # Step 4. Do the copying.
self.progress.set_state("Copying to postgres") self.progress.set_state("Copying to postgres")
yield defer.gatherResults( yield defer.gatherResults(
[ [self.handle_table(*res) for res in setup_res], consumeErrors=True
self.handle_table(*res)
for res in setup_res
],
consumeErrors=True,
) )
# Step 5. Do final post-processing # Step 5. Do final post-processing
yield self._setup_state_group_id_seq() yield self._setup_state_group_id_seq()
self.progress.done() self.progress.done()
except: except Exception:
global end_error_exec_info global end_error_exec_info
end_error_exec_info = sys.exc_info() end_error_exec_info = sys.exc_info()
logger.exception("") logger.exception("")
@ -566,9 +556,7 @@ class Porter(object):
def _convert_rows(self, table, headers, rows): def _convert_rows(self, table, headers, rows):
bool_col_names = BOOLEAN_COLUMNS.get(table, []) bool_col_names = BOOLEAN_COLUMNS.get(table, [])
bool_cols = [ bool_cols = [i for i, h in enumerate(headers) if h in bool_col_names]
i for i, h in enumerate(headers) if h in bool_col_names
]
class BadValueException(Exception): class BadValueException(Exception):
pass pass
@ -577,18 +565,21 @@ class Porter(object):
if j in bool_cols: if j in bool_cols:
return bool(col) return bool(col)
elif isinstance(col, string_types) and "\0" in col: elif isinstance(col, string_types) and "\0" in col:
logger.warn("DROPPING ROW: NUL value in table %s col %s: %r", table, headers[j], col) logger.warn(
raise BadValueException(); "DROPPING ROW: NUL value in table %s col %s: %r",
table,
headers[j],
col,
)
raise BadValueException()
return col return col
outrows = [] outrows = []
for i, row in enumerate(rows): for i, row in enumerate(rows):
try: try:
outrows.append(tuple( outrows.append(
conv(j, col) tuple(conv(j, col) for j, col in enumerate(row) if j > 0)
for j, col in enumerate(row) )
if j > 0
))
except BadValueException: except BadValueException:
pass pass
@ -616,9 +607,7 @@ class Porter(object):
return headers, [r for r in rows if r[ts_ind] < yesterday] return headers, [r for r in rows if r[ts_ind] < yesterday]
headers, rows = yield self.sqlite_store.runInteraction( headers, rows = yield self.sqlite_store.runInteraction("select", r)
"select", r,
)
rows = self._convert_rows("sent_transactions", headers, rows) rows = self._convert_rows("sent_transactions", headers, rows)
@ -639,7 +628,7 @@ class Porter(object):
txn.execute( txn.execute(
"SELECT rowid FROM sent_transactions WHERE ts >= ?" "SELECT rowid FROM sent_transactions WHERE ts >= ?"
" ORDER BY rowid ASC LIMIT 1", " ORDER BY rowid ASC LIMIT 1",
(yesterday,) (yesterday,),
) )
rows = txn.fetchall() rows = txn.fetchall()
@ -657,21 +646,17 @@ class Porter(object):
"table_name": "sent_transactions", "table_name": "sent_transactions",
"forward_rowid": next_chunk, "forward_rowid": next_chunk,
"backward_rowid": 0, "backward_rowid": 0,
} },
) )
def get_sent_table_size(txn): def get_sent_table_size(txn):
txn.execute( txn.execute(
"SELECT count(*) FROM sent_transactions" "SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,)
" WHERE ts >= ?",
(yesterday,)
) )
size, = txn.fetchone() size, = txn.fetchone()
return int(size) return int(size)
remaining_count = yield self.sqlite_store.execute( remaining_count = yield self.sqlite_store.execute(get_sent_table_size)
get_sent_table_size
)
total_count = remaining_count + inserted_rows total_count = remaining_count + inserted_rows
@ -680,13 +665,11 @@ class Porter(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk): def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk):
frows = yield self.sqlite_store.execute_sql( frows = yield self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk
forward_chunk,
) )
brows = yield self.sqlite_store.execute_sql( brows = yield self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), "SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk
backward_chunk,
) )
defer.returnValue(frows[0][0] + brows[0][0]) defer.returnValue(frows[0][0] + brows[0][0])
@ -694,7 +677,7 @@ class Porter(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_already_ported_count(self, table): def _get_already_ported_count(self, table):
rows = yield self.postgres_store.execute_sql( rows = yield self.postgres_store.execute_sql(
"SELECT count(*) FROM %s" % (table,), "SELECT count(*) FROM %s" % (table,)
) )
defer.returnValue(rows[0][0]) defer.returnValue(rows[0][0])
@ -717,22 +700,21 @@ class Porter(object):
def _setup_state_group_id_seq(self): def _setup_state_group_id_seq(self):
def r(txn): def r(txn):
txn.execute("SELECT MAX(id) FROM state_groups") txn.execute("SELECT MAX(id) FROM state_groups")
next_id = txn.fetchone()[0]+1 next_id = txn.fetchone()[0] + 1
txn.execute( txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,))
"ALTER SEQUENCE state_group_id_seq RESTART WITH %s",
(next_id,),
)
return self.postgres_store.runInteraction("setup_state_group_id_seq", r) return self.postgres_store.runInteraction("setup_state_group_id_seq", r)
############################################## ##############################################
###### The following is simply UI stuff ###### # The following is simply UI stuff
############################################## ##############################################
class Progress(object): class Progress(object):
"""Used to report progress of the port """Used to report progress of the port
""" """
def __init__(self): def __init__(self):
self.tables = {} self.tables = {}
@ -758,6 +740,7 @@ class Progress(object):
class CursesProgress(Progress): class CursesProgress(Progress):
"""Reports progress to a curses window """Reports progress to a curses window
""" """
def __init__(self, stdscr): def __init__(self, stdscr):
self.stdscr = stdscr self.stdscr = stdscr
@ -801,7 +784,7 @@ class CursesProgress(Progress):
duration = int(now) - int(self.start_time) duration = int(now) - int(self.start_time)
minutes, seconds = divmod(duration, 60) minutes, seconds = divmod(duration, 60)
duration_str = '%02dm %02ds' % (minutes, seconds,) duration_str = '%02dm %02ds' % (minutes, seconds)
if self.finished: if self.finished:
status = "Time spent: %s (Done!)" % (duration_str,) status = "Time spent: %s (Done!)" % (duration_str,)
@ -814,16 +797,12 @@ class CursesProgress(Progress):
est_remaining_str = '%02dm %02ds remaining' % divmod(est_remaining, 60) est_remaining_str = '%02dm %02ds remaining' % divmod(est_remaining, 60)
else: else:
est_remaining_str = "Unknown" est_remaining_str = "Unknown"
status = ( status = "Time spent: %s (est. remaining: %s)" % (
"Time spent: %s (est. remaining: %s)" duration_str,
% (duration_str, est_remaining_str,) est_remaining_str,
) )
self.stdscr.addstr( self.stdscr.addstr(0, 0, status, curses.A_BOLD)
0, 0,
status,
curses.A_BOLD,
)
max_len = max([len(t) for t in self.tables.keys()]) max_len = max([len(t) for t in self.tables.keys()])
@ -831,9 +810,7 @@ class CursesProgress(Progress):
middle_space = 1 middle_space = 1
items = self.tables.items() items = self.tables.items()
items.sort( items.sort(key=lambda i: (i[1]["perc"], i[0]))
key=lambda i: (i[1]["perc"], i[0]),
)
for i, (table, data) in enumerate(items): for i, (table, data) in enumerate(items):
if i + 2 >= rows: if i + 2 >= rows:
@ -844,9 +821,7 @@ class CursesProgress(Progress):
color = curses.color_pair(2) if perc == 100 else curses.color_pair(1) color = curses.color_pair(2) if perc == 100 else curses.color_pair(1)
self.stdscr.addstr( self.stdscr.addstr(
i + 2, left_margin + max_len - len(table), i + 2, left_margin + max_len - len(table), table, curses.A_BOLD | color
table,
curses.A_BOLD | color,
) )
size = 20 size = 20
@ -857,15 +832,13 @@ class CursesProgress(Progress):
) )
self.stdscr.addstr( self.stdscr.addstr(
i + 2, left_margin + max_len + middle_space, i + 2,
left_margin + max_len + middle_space,
"%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]), "%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]),
) )
if self.finished: if self.finished:
self.stdscr.addstr( self.stdscr.addstr(rows - 1, 0, "Press any key to exit...")
rows - 1, 0,
"Press any key to exit...",
)
self.stdscr.refresh() self.stdscr.refresh()
self.last_update = time.time() self.last_update = time.time()
@ -877,29 +850,25 @@ class CursesProgress(Progress):
def set_state(self, state): def set_state(self, state):
self.stdscr.clear() self.stdscr.clear()
self.stdscr.addstr( self.stdscr.addstr(0, 0, state + "...", curses.A_BOLD)
0, 0,
state + "...",
curses.A_BOLD,
)
self.stdscr.refresh() self.stdscr.refresh()
class TerminalProgress(Progress): class TerminalProgress(Progress):
"""Just prints progress to the terminal """Just prints progress to the terminal
""" """
def update(self, table, num_done): def update(self, table, num_done):
super(TerminalProgress, self).update(table, num_done) super(TerminalProgress, self).update(table, num_done)
data = self.tables[table] data = self.tables[table]
print "%s: %d%% (%d/%d)" % ( print(
table, data["perc"], "%s: %d%% (%d/%d)" % (table, data["perc"], data["num_done"], data["total"])
data["num_done"], data["total"],
) )
def set_state(self, state): def set_state(self, state):
print state + "..." print(state + "...")
############################################## ##############################################
@ -909,34 +878,38 @@ class TerminalProgress(Progress):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="A script to port an existing synapse SQLite database to" description="A script to port an existing synapse SQLite database to"
" a new PostgreSQL database." " a new PostgreSQL database."
) )
parser.add_argument("-v", action='store_true') parser.add_argument("-v", action='store_true')
parser.add_argument( parser.add_argument(
"--sqlite-database", required=True, "--sqlite-database",
required=True,
help="The snapshot of the SQLite database file. This must not be" help="The snapshot of the SQLite database file. This must not be"
" currently used by a running synapse server" " currently used by a running synapse server",
) )
parser.add_argument( parser.add_argument(
"--postgres-config", type=argparse.FileType('r'), required=True, "--postgres-config",
help="The database config file for the PostgreSQL database" type=argparse.FileType('r'),
required=True,
help="The database config file for the PostgreSQL database",
) )
parser.add_argument( parser.add_argument(
"--curses", action='store_true', "--curses", action='store_true', help="display a curses based progress UI"
help="display a curses based progress UI"
) )
parser.add_argument( parser.add_argument(
"--batch-size", type=int, default=1000, "--batch-size",
type=int,
default=1000,
help="The number of rows to select from the SQLite table each" help="The number of rows to select from the SQLite table each"
" iteration [default=1000]", " iteration [default=1000]",
) )
args = parser.parse_args() args = parser.parse_args()
logging_config = { logging_config = {
"level": logging.DEBUG if args.v else logging.INFO, "level": logging.DEBUG if args.v else logging.INFO,
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s" "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
} }
if args.curses: if args.curses:

View File

@ -14,17 +14,16 @@ ignore =
pylint.cfg pylint.cfg
tox.ini tox.ini
[pep8]
max-line-length = 90
# W503 requires that binary operators be at the end, not start, of lines. Erik
# doesn't like it. E203 is contrary to PEP8. E731 is silly.
ignore = W503,E203,E731
[flake8] [flake8]
# note that flake8 inherits the "ignore" settings from "pep8" (because it uses
# pep8 to do those checks), but not the "max-line-length" setting
max-line-length = 90 max-line-length = 90
ignore=W503,E203,E731
# see https://pycodestyle.readthedocs.io/en/latest/intro.html#error-codes
# for error codes. The ones we ignore are:
# W503: line break before binary operator
# W504: line break after binary operator
# E203: whitespace before ':' (which is contrary to pep8?)
# E731: do not assign a lambda expression, use a def
ignore=W503,W504,E203,E731
[isort] [isort]
line_length = 89 line_length = 89

View File

@ -1,6 +1,8 @@
#!/usr/bin/env python #!/usr/bin/env python
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2017 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# Copyright 2017-2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -86,7 +88,7 @@ setup(
name="matrix-synapse", name="matrix-synapse",
version=version, version=version,
packages=find_packages(exclude=["tests", "tests.*"]), packages=find_packages(exclude=["tests", "tests.*"]),
description="Reference Synapse Home Server", description="Reference homeserver for the Matrix decentralised comms protocol",
install_requires=dependencies['requirements'](include_conditional=True).keys(), install_requires=dependencies['requirements'](include_conditional=True).keys(),
dependency_links=dependencies["DEPENDENCY_LINKS"].values(), dependency_links=dependencies["DEPENDENCY_LINKS"].values(),
include_package_data=True, include_package_data=True,

View File

View File

@ -0,0 +1,215 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2018 New Vector
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import getpass
import hashlib
import hmac
import logging
import sys
from six.moves import input
import requests as _requests
import yaml
def request_registration(
user,
password,
server_location,
shared_secret,
admin=False,
requests=_requests,
_print=print,
exit=sys.exit,
):
url = "%s/_matrix/client/r0/admin/register" % (server_location,)
# Get the nonce
r = requests.get(url, verify=False)
if r.status_code is not 200:
_print("ERROR! Received %d %s" % (r.status_code, r.reason))
if 400 <= r.status_code < 500:
try:
_print(r.json()["error"])
except Exception:
pass
return exit(1)
nonce = r.json()["nonce"]
mac = hmac.new(key=shared_secret.encode('utf8'), digestmod=hashlib.sha1)
mac.update(nonce.encode('utf8'))
mac.update(b"\x00")
mac.update(user.encode('utf8'))
mac.update(b"\x00")
mac.update(password.encode('utf8'))
mac.update(b"\x00")
mac.update(b"admin" if admin else b"notadmin")
mac = mac.hexdigest()
data = {
"nonce": nonce,
"username": user,
"password": password,
"mac": mac,
"admin": admin,
}
_print("Sending registration request...")
r = requests.post(url, json=data, verify=False)
if r.status_code is not 200:
_print("ERROR! Received %d %s" % (r.status_code, r.reason))
if 400 <= r.status_code < 500:
try:
_print(r.json()["error"])
except Exception:
pass
return exit(1)
_print("Success!")
def register_new_user(user, password, server_location, shared_secret, admin):
if not user:
try:
default_user = getpass.getuser()
except Exception:
default_user = None
if default_user:
user = input("New user localpart [%s]: " % (default_user,))
if not user:
user = default_user
else:
user = input("New user localpart: ")
if not user:
print("Invalid user name")
sys.exit(1)
if not password:
password = getpass.getpass("Password: ")
if not password:
print("Password cannot be blank.")
sys.exit(1)
confirm_password = getpass.getpass("Confirm password: ")
if password != confirm_password:
print("Passwords do not match")
sys.exit(1)
if admin is None:
admin = input("Make admin [no]: ")
if admin in ("y", "yes", "true"):
admin = True
else:
admin = False
request_registration(user, password, server_location, shared_secret, bool(admin))
def main():
logging.captureWarnings(True)
parser = argparse.ArgumentParser(
description="Used to register new users with a given home server when"
" registration has been disabled. The home server must be"
" configured with the 'registration_shared_secret' option"
" set."
)
parser.add_argument(
"-u",
"--user",
default=None,
help="Local part of the new user. Will prompt if omitted.",
)
parser.add_argument(
"-p",
"--password",
default=None,
help="New password for user. Will prompt if omitted.",
)
admin_group = parser.add_mutually_exclusive_group()
admin_group.add_argument(
"-a",
"--admin",
action="store_true",
help=(
"Register new user as an admin. "
"Will prompt if --no-admin is not set either."
),
)
admin_group.add_argument(
"--no-admin",
action="store_true",
help=(
"Register new user as a regular user. "
"Will prompt if --admin is not set either."
),
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"-c",
"--config",
type=argparse.FileType('r'),
help="Path to server config file. Used to read in shared secret.",
)
group.add_argument(
"-k", "--shared-secret", help="Shared secret as defined in server config file."
)
parser.add_argument(
"server_url",
default="https://localhost:8448",
nargs='?',
help="URL to use to talk to the home server. Defaults to "
" 'https://localhost:8448'.",
)
args = parser.parse_args()
if "config" in args and args.config:
config = yaml.safe_load(args.config)
secret = config.get("registration_shared_secret", None)
if not secret:
print("No 'registration_shared_secret' defined in config.")
sys.exit(1)
else:
secret = args.shared_secret
admin = None
if args.admin or args.no_admin:
admin = args.admin
register_new_user(args.user, args.password, args.server_url, secret, admin)
if __name__ == "__main__":
main()

View File

@ -172,7 +172,10 @@ USER_FILTER_SCHEMA = {
# events a lot easier as we can then use a negative lookbehind # events a lot easier as we can then use a negative lookbehind
# assertion to split '\.' If we allowed \\ then it would # assertion to split '\.' If we allowed \\ then it would
# incorrectly split '\\.' See synapse.events.utils.serialize_event # incorrectly split '\\.' See synapse.events.utils.serialize_event
"pattern": "^((?!\\\).)*$" #
# Note that because this is a regular expression, we have to escape
# each backslash in the pattern.
"pattern": r"^((?!\\\\).)*$"
} }
} }
}, },

View File

@ -20,6 +20,7 @@ import sys
from six import iteritems from six import iteritems
import psutil
from prometheus_client import Gauge from prometheus_client import Gauge
from twisted.application import service from twisted.application import service
@ -502,7 +503,6 @@ def run(hs):
def performance_stats_init(): def performance_stats_init():
try: try:
import psutil
process = psutil.Process() process = psutil.Process()
# Ensure we can fetch both, and make the initial request for cpu_percent # Ensure we can fetch both, and make the initial request for cpu_percent
# so the next request will use this as the initial point. # so the next request will use this as the initial point.
@ -510,12 +510,9 @@ def run(hs):
process.cpu_percent(interval=None) process.cpu_percent(interval=None)
logger.info("report_stats can use psutil") logger.info("report_stats can use psutil")
stats_process.append(process) stats_process.append(process)
except (ImportError, AttributeError): except (AttributeError):
logger.warn( logger.warning(
"report_stats enabled but psutil is not installed or incorrect version." "Unable to read memory/cpu stats. Disabling reporting."
" Disabling reporting of memory/cpu stats."
" Ensuring psutil is available will help matrix.org track performance"
" changes across releases."
) )
def generate_user_daily_visit_stats(): def generate_user_daily_visit_stats():
@ -530,10 +527,13 @@ def run(hs):
clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000) clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000)
# monthly active user limiting functionality # monthly active user limiting functionality
clock.looping_call( def reap_monthly_active_users():
hs.get_datastore().reap_monthly_active_users, 1000 * 60 * 60 return run_as_background_process(
) "reap_monthly_active_users",
hs.get_datastore().reap_monthly_active_users() hs.get_datastore().reap_monthly_active_users,
)
clock.looping_call(reap_monthly_active_users, 1000 * 60 * 60)
reap_monthly_active_users()
@defer.inlineCallbacks @defer.inlineCallbacks
def generate_monthly_active_users(): def generate_monthly_active_users():
@ -547,12 +547,15 @@ def run(hs):
registered_reserved_users_mau_gauge.set(float(reserved_count)) registered_reserved_users_mau_gauge.set(float(reserved_count))
max_mau_gauge.set(float(hs.config.max_mau_value)) max_mau_gauge.set(float(hs.config.max_mau_value))
hs.get_datastore().initialise_reserved_users( def start_generate_monthly_active_users():
hs.config.mau_limits_reserved_threepids return run_as_background_process(
) "generate_monthly_active_users",
generate_monthly_active_users() generate_monthly_active_users,
)
start_generate_monthly_active_users()
if hs.config.limit_usage_by_mau: if hs.config.limit_usage_by_mau:
clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000) clock.looping_call(start_generate_monthly_active_users, 5 * 60 * 1000)
# End of monthly active user settings # End of monthly active user settings
if hs.config.report_stats: if hs.config.report_stats:
@ -568,7 +571,7 @@ def run(hs):
clock.call_later(5 * 60, start_phone_stats_home) clock.call_later(5 * 60, start_phone_stats_home)
if hs.config.daemonize and hs.config.print_pidfile: if hs.config.daemonize and hs.config.print_pidfile:
print (hs.config.pid_file) print(hs.config.pid_file)
_base.start_reactor( _base.start_reactor(
"synapse-homeserver", "synapse-homeserver",

View File

@ -161,11 +161,11 @@ class PusherReplicationHandler(ReplicationClientHandler):
else: else:
yield self.start_pusher(row.user_id, row.app_id, row.pushkey) yield self.start_pusher(row.user_id, row.app_id, row.pushkey)
elif stream_name == "events": elif stream_name == "events":
self.pusher_pool.on_new_notifications( yield self.pusher_pool.on_new_notifications(
token, token, token, token,
) )
elif stream_name == "receipts": elif stream_name == "receipts":
self.pusher_pool.on_new_receipts( yield self.pusher_pool.on_new_receipts(
token, token, set(row.room_id for row in rows) token, token, set(row.room_id for row in rows)
) )
except Exception: except Exception:
@ -183,7 +183,7 @@ class PusherReplicationHandler(ReplicationClientHandler):
def start_pusher(self, user_id, app_id, pushkey): def start_pusher(self, user_id, app_id, pushkey):
key = "%s:%s" % (app_id, pushkey) key = "%s:%s" % (app_id, pushkey)
logger.info("Starting pusher %r / %r", user_id, key) logger.info("Starting pusher %r / %r", user_id, key)
return self.pusher_pool._refresh_pusher(app_id, pushkey, user_id) return self.pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)
def start(config_options): def start(config_options):

View File

@ -28,7 +28,7 @@ if __name__ == "__main__":
sys.stderr.write("\n" + str(e) + "\n") sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1) sys.exit(1)
print (getattr(config, key)) print(getattr(config, key))
sys.exit(0) sys.exit(0)
else: else:
sys.stderr.write("Unknown command %r\n" % (action,)) sys.stderr.write("Unknown command %r\n" % (action,))

View File

@ -106,10 +106,7 @@ class Config(object):
@classmethod @classmethod
def check_file(cls, file_path, config_name): def check_file(cls, file_path, config_name):
if file_path is None: if file_path is None:
raise ConfigError( raise ConfigError("Missing config for %s." % (config_name,))
"Missing config for %s."
% (config_name,)
)
try: try:
os.stat(file_path) os.stat(file_path)
except OSError as e: except OSError as e:
@ -128,9 +125,7 @@ class Config(object):
if e.errno != errno.EEXIST: if e.errno != errno.EEXIST:
raise raise
if not os.path.isdir(dir_path): if not os.path.isdir(dir_path):
raise ConfigError( raise ConfigError("%s is not a directory" % (dir_path,))
"%s is not a directory" % (dir_path,)
)
return dir_path return dir_path
@classmethod @classmethod
@ -156,21 +151,20 @@ class Config(object):
return results return results
def generate_config( def generate_config(
self, self, config_dir_path, server_name, is_generating_file, report_stats=None
config_dir_path,
server_name,
is_generating_file,
report_stats=None,
): ):
default_config = "# vim:ft=yaml\n" default_config = "# vim:ft=yaml\n"
default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all( default_config += "\n\n".join(
"default_config", dedent(conf)
config_dir_path=config_dir_path, for conf in self.invoke_all(
server_name=server_name, "default_config",
is_generating_file=is_generating_file, config_dir_path=config_dir_path,
report_stats=report_stats, server_name=server_name,
)) is_generating_file=is_generating_file,
report_stats=report_stats,
)
)
config = yaml.load(default_config) config = yaml.load(default_config)
@ -178,23 +172,22 @@ class Config(object):
@classmethod @classmethod
def load_config(cls, description, argv): def load_config(cls, description, argv):
config_parser = argparse.ArgumentParser( config_parser = argparse.ArgumentParser(description=description)
description=description,
)
config_parser.add_argument( config_parser.add_argument(
"-c", "--config-path", "-c",
"--config-path",
action="append", action="append",
metavar="CONFIG_FILE", metavar="CONFIG_FILE",
help="Specify config file. Can be given multiple times and" help="Specify config file. Can be given multiple times and"
" may specify directories containing *.yaml files." " may specify directories containing *.yaml files.",
) )
config_parser.add_argument( config_parser.add_argument(
"--keys-directory", "--keys-directory",
metavar="DIRECTORY", metavar="DIRECTORY",
help="Where files such as certs and signing keys are stored when" help="Where files such as certs and signing keys are stored when"
" their location is given explicitly in the config." " their location is given explicitly in the config."
" Defaults to the directory containing the last config file", " Defaults to the directory containing the last config file",
) )
config_args = config_parser.parse_args(argv) config_args = config_parser.parse_args(argv)
@ -203,9 +196,7 @@ class Config(object):
obj = cls() obj = cls()
obj.read_config_files( obj.read_config_files(
config_files, config_files, keys_directory=config_args.keys_directory, generate_keys=False
keys_directory=config_args.keys_directory,
generate_keys=False,
) )
return obj return obj
@ -213,38 +204,38 @@ class Config(object):
def load_or_generate_config(cls, description, argv): def load_or_generate_config(cls, description, argv):
config_parser = argparse.ArgumentParser(add_help=False) config_parser = argparse.ArgumentParser(add_help=False)
config_parser.add_argument( config_parser.add_argument(
"-c", "--config-path", "-c",
"--config-path",
action="append", action="append",
metavar="CONFIG_FILE", metavar="CONFIG_FILE",
help="Specify config file. Can be given multiple times and" help="Specify config file. Can be given multiple times and"
" may specify directories containing *.yaml files." " may specify directories containing *.yaml files.",
) )
config_parser.add_argument( config_parser.add_argument(
"--generate-config", "--generate-config",
action="store_true", action="store_true",
help="Generate a config file for the server name" help="Generate a config file for the server name",
) )
config_parser.add_argument( config_parser.add_argument(
"--report-stats", "--report-stats",
action="store", action="store",
help="Whether the generated config reports anonymized usage statistics", help="Whether the generated config reports anonymized usage statistics",
choices=["yes", "no"] choices=["yes", "no"],
) )
config_parser.add_argument( config_parser.add_argument(
"--generate-keys", "--generate-keys",
action="store_true", action="store_true",
help="Generate any missing key files then exit" help="Generate any missing key files then exit",
) )
config_parser.add_argument( config_parser.add_argument(
"--keys-directory", "--keys-directory",
metavar="DIRECTORY", metavar="DIRECTORY",
help="Used with 'generate-*' options to specify where files such as" help="Used with 'generate-*' options to specify where files such as"
" certs and signing keys should be stored in, unless explicitly" " certs and signing keys should be stored in, unless explicitly"
" specified in the config." " specified in the config.",
) )
config_parser.add_argument( config_parser.add_argument(
"-H", "--server-name", "-H", "--server-name", help="The server name to generate a config file for"
help="The server name to generate a config file for"
) )
config_args, remaining_args = config_parser.parse_known_args(argv) config_args, remaining_args = config_parser.parse_known_args(argv)
@ -257,8 +248,8 @@ class Config(object):
if config_args.generate_config: if config_args.generate_config:
if config_args.report_stats is None: if config_args.report_stats is None:
config_parser.error( config_parser.error(
"Please specify either --report-stats=yes or --report-stats=no\n\n" + "Please specify either --report-stats=yes or --report-stats=no\n\n"
MISSING_REPORT_STATS_SPIEL + MISSING_REPORT_STATS_SPIEL
) )
if not config_files: if not config_files:
config_parser.error( config_parser.error(
@ -287,26 +278,32 @@ class Config(object):
config_dir_path=config_dir_path, config_dir_path=config_dir_path,
server_name=server_name, server_name=server_name,
report_stats=(config_args.report_stats == "yes"), report_stats=(config_args.report_stats == "yes"),
is_generating_file=True is_generating_file=True,
) )
obj.invoke_all("generate_files", config) obj.invoke_all("generate_files", config)
config_file.write(config_str) config_file.write(config_str)
print(( print(
"A config file has been generated in %r for server name" (
" %r with corresponding SSL keys and self-signed" "A config file has been generated in %r for server name"
" certificates. Please review this file and customise it" " %r with corresponding SSL keys and self-signed"
" to your needs." " certificates. Please review this file and customise it"
) % (config_path, server_name)) " to your needs."
)
% (config_path, server_name)
)
print( print(
"If this server name is incorrect, you will need to" "If this server name is incorrect, you will need to"
" regenerate the SSL certificates" " regenerate the SSL certificates"
) )
return return
else: else:
print(( print(
"Config file %r already exists. Generating any missing key" (
" files." "Config file %r already exists. Generating any missing key"
) % (config_path,)) " files."
)
% (config_path,)
)
generate_keys = True generate_keys = True
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -338,8 +335,7 @@ class Config(object):
return obj return obj
def read_config_files(self, config_files, keys_directory=None, def read_config_files(self, config_files, keys_directory=None, generate_keys=False):
generate_keys=False):
if not keys_directory: if not keys_directory:
keys_directory = os.path.dirname(config_files[-1]) keys_directory = os.path.dirname(config_files[-1])
@ -364,8 +360,9 @@ class Config(object):
if "report_stats" not in config: if "report_stats" not in config:
raise ConfigError( raise ConfigError(
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" + MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS
MISSING_REPORT_STATS_SPIEL + "\n"
+ MISSING_REPORT_STATS_SPIEL
) )
if generate_keys: if generate_keys:
@ -399,16 +396,16 @@ def find_config_files(search_paths):
for entry in os.listdir(config_path): for entry in os.listdir(config_path):
entry_path = os.path.join(config_path, entry) entry_path = os.path.join(config_path, entry)
if not os.path.isfile(entry_path): if not os.path.isfile(entry_path):
print ( err = "Found subdirectory in config directory: %r. IGNORING."
"Found subdirectory in config directory: %r. IGNORING." print(err % (entry_path,))
) % (entry_path, )
continue continue
if not entry.endswith(".yaml"): if not entry.endswith(".yaml"):
print ( err = (
"Found file in config directory that does not" "Found file in config directory that does not end in "
" end in '.yaml': %r. IGNORING." "'.yaml': %r. IGNORING."
) % (entry_path, ) )
print(err % (entry_path,))
continue continue
files.append(entry_path) files.append(entry_path)

View File

@ -19,19 +19,13 @@ from __future__ import print_function
import email.utils import email.utils
import logging import logging
import os import os
import sys
import textwrap
from ._base import Config import pkg_resources
from ._base import Config, ConfigError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TEMPLATE_DIR_WARNING = """\
WARNING: The email notifier is configured to look for templates in '%(template_dir)s',
but no templates could be found there. We will fall back to using the example templates;
to get rid of this warning, leave 'email.template_dir' unset.
"""
class EmailConfig(Config): class EmailConfig(Config):
def read_config(self, config): def read_config(self, config):
@ -78,20 +72,22 @@ class EmailConfig(Config):
self.email_notif_template_html = email_config["notif_template_html"] self.email_notif_template_html = email_config["notif_template_html"]
self.email_notif_template_text = email_config["notif_template_text"] self.email_notif_template_text = email_config["notif_template_text"]
self.email_template_dir = email_config.get("template_dir") template_dir = email_config.get("template_dir")
# we need an absolute path, because we change directory after starting (and
# backwards-compatibility hack # we don't yet know what auxilliary templates like mail.css we will need).
if ( # (Note that loading as package_resources with jinja.PackageLoader doesn't
self.email_template_dir == "res/templates" # work for the same reason.)
and not os.path.isfile( if not template_dir:
os.path.join(self.email_template_dir, self.email_notif_template_text) template_dir = pkg_resources.resource_filename(
'synapse', 'res/templates'
) )
): template_dir = os.path.abspath(template_dir)
t = TEMPLATE_DIR_WARNING % {
"template_dir": self.email_template_dir, for f in self.email_notif_template_text, self.email_notif_template_html:
} p = os.path.join(template_dir, f)
print(textwrap.fill(t, width=80) + "\n", file=sys.stderr) if not os.path.isfile(p):
self.email_template_dir = None raise ConfigError("Unable to find email template file %s" % (p, ))
self.email_template_dir = template_dir
self.email_notif_for_new_users = email_config.get( self.email_notif_for_new_users = email_config.get(
"notif_for_new_users", True "notif_for_new_users", True

View File

@ -31,6 +31,7 @@ from .push import PushConfig
from .ratelimiting import RatelimitConfig from .ratelimiting import RatelimitConfig
from .registration import RegistrationConfig from .registration import RegistrationConfig
from .repository import ContentRepositoryConfig from .repository import ContentRepositoryConfig
from .room_directory import RoomDirectoryConfig
from .saml2 import SAML2Config from .saml2 import SAML2Config
from .server import ServerConfig from .server import ServerConfig
from .server_notices_config import ServerNoticesConfig from .server_notices_config import ServerNoticesConfig
@ -49,7 +50,7 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
WorkerConfig, PasswordAuthProviderConfig, PushConfig, WorkerConfig, PasswordAuthProviderConfig, PushConfig,
SpamCheckerConfig, GroupsConfig, UserDirectoryConfig, SpamCheckerConfig, GroupsConfig, UserDirectoryConfig,
ConsentConfig, ConsentConfig,
ServerNoticesConfig, ServerNoticesConfig, RoomDirectoryConfig,
): ):
pass pass

View File

@ -15,10 +15,10 @@
from distutils.util import strtobool from distutils.util import strtobool
from synapse.config._base import Config, ConfigError
from synapse.types import RoomAlias
from synapse.util.stringutils import random_string_with_symbols from synapse.util.stringutils import random_string_with_symbols
from ._base import Config
class RegistrationConfig(Config): class RegistrationConfig(Config):
@ -44,6 +44,10 @@ class RegistrationConfig(Config):
) )
self.auto_join_rooms = config.get("auto_join_rooms", []) self.auto_join_rooms = config.get("auto_join_rooms", [])
for room_alias in self.auto_join_rooms:
if not RoomAlias.is_valid(room_alias):
raise ConfigError('Invalid auto_join_rooms entry %s' % (room_alias,))
self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True)
def default_config(self, **kwargs): def default_config(self, **kwargs):
registration_shared_secret = random_string_with_symbols(50) registration_shared_secret = random_string_with_symbols(50)
@ -98,6 +102,13 @@ class RegistrationConfig(Config):
# to these rooms # to these rooms
#auto_join_rooms: #auto_join_rooms:
# - "#example:example.com" # - "#example:example.com"
# Where auto_join_rooms are specified, setting this flag ensures that the
# the rooms exist by creating them when the first user on the
# homeserver registers.
# Setting to false means that if the rooms are not manually created,
# users cannot be auto-joined since they do not exist.
autocreate_auto_join_rooms: true
""" % locals() """ % locals()
def add_arguments(self, parser): def add_arguments(self, parser):

View File

@ -178,7 +178,7 @@ class ContentRepositoryConfig(Config):
def default_config(self, **kwargs): def default_config(self, **kwargs):
media_store = self.default_path("media_store") media_store = self.default_path("media_store")
uploads_path = self.default_path("uploads") uploads_path = self.default_path("uploads")
return """ return r"""
# Directory where uploaded images and attachments are stored. # Directory where uploaded images and attachments are stored.
media_store_path: "%(media_store)s" media_store_path: "%(media_store)s"

View File

@ -0,0 +1,102 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util import glob_to_regex
from ._base import Config, ConfigError
class RoomDirectoryConfig(Config):
def read_config(self, config):
alias_creation_rules = config["alias_creation_rules"]
self._alias_creation_rules = [
_AliasRule(rule)
for rule in alias_creation_rules
]
def default_config(self, config_dir_path, server_name, **kwargs):
return """
# The `alias_creation` option controls who's allowed to create aliases
# on this server.
#
# The format of this option is a list of rules that contain globs that
# match against user_id and the new alias (fully qualified with server
# name). The action in the first rule that matches is taken, which can
# currently either be "allow" or "deny".
#
# If no rules match the request is denied.
alias_creation_rules:
- user_id: "*"
alias: "*"
action: allow
"""
def is_alias_creation_allowed(self, user_id, alias):
"""Checks if the given user is allowed to create the given alias
Args:
user_id (str)
alias (str)
Returns:
boolean: True if user is allowed to crate the alias
"""
for rule in self._alias_creation_rules:
if rule.matches(user_id, alias):
return rule.action == "allow"
return False
class _AliasRule(object):
def __init__(self, rule):
action = rule["action"]
user_id = rule["user_id"]
alias = rule["alias"]
if action in ("allow", "deny"):
self.action = action
else:
raise ConfigError(
"alias_creation_rules rules can only have action of 'allow'"
" or 'deny'"
)
try:
self._user_id_regex = glob_to_regex(user_id)
self._alias_regex = glob_to_regex(alias)
except Exception as e:
raise ConfigError("Failed to parse glob into regex: %s", e)
def matches(self, user_id, alias):
"""Tests if this rule matches the given user_id and alias.
Args:
user_id (str)
alias (str)
Returns:
boolean
"""
# Note: The regexes are anchored at both ends
if not self._user_id_regex.match(user_id):
return False
if not self._alias_regex.match(alias):
return False
return True

View File

@ -55,7 +55,7 @@ def fetch_server_key(server_name, tls_client_options_factory, path=KEY_API_V1):
raise IOError("Cannot get key for %r" % server_name) raise IOError("Cannot get key for %r" % server_name)
except (ConnectError, DomainError) as e: except (ConnectError, DomainError) as e:
logger.warn("Error getting key for %r: %s", server_name, e) logger.warn("Error getting key for %r: %s", server_name, e)
except Exception as e: except Exception:
logger.exception("Error getting key for %r", server_name) logger.exception("Error getting key for %r", server_name)
raise IOError("Cannot get key for %r" % server_name) raise IOError("Cannot get key for %r" % server_name)

View File

@ -690,7 +690,7 @@ def auth_types_for_event(event):
auth_types = [] auth_types = []
auth_types.append((EventTypes.PowerLevels, "", )) auth_types.append((EventTypes.PowerLevels, "", ))
auth_types.append((EventTypes.Member, event.user_id, )) auth_types.append((EventTypes.Member, event.sender, ))
auth_types.append((EventTypes.Create, "", )) auth_types.append((EventTypes.Create, "", ))
if event.type == EventTypes.Member: if event.type == EventTypes.Member:

View File

@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import re
import six import six
from six import iteritems from six import iteritems
@ -44,6 +43,7 @@ from synapse.replication.http.federation import (
ReplicationGetQueryRestServlet, ReplicationGetQueryRestServlet,
) )
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
from synapse.util import glob_to_regex
from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.util.logcontext import nested_logging_context from synapse.util.logcontext import nested_logging_context
@ -729,22 +729,10 @@ def _acl_entry_matches(server_name, acl_entry):
if not isinstance(acl_entry, six.string_types): if not isinstance(acl_entry, six.string_types):
logger.warn("Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)) logger.warn("Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry))
return False return False
regex = _glob_to_regex(acl_entry) regex = glob_to_regex(acl_entry)
return regex.match(server_name) return regex.match(server_name)
def _glob_to_regex(glob):
res = ''
for c in glob:
if c == '*':
res = res + '.*'
elif c == '?':
res = res + '.'
else:
res = res + re.escape(c)
return re.compile(res + "\\Z", re.IGNORECASE)
class FederationHandlerRegistry(object): class FederationHandlerRegistry(object):
"""Allows classes to register themselves as handlers for a given EDU or """Allows classes to register themselves as handlers for a given EDU or
query type for incoming federation traffic. query type for incoming federation traffic.
@ -800,7 +788,7 @@ class FederationHandlerRegistry(object):
yield handler(origin, content) yield handler(origin, content)
except SynapseError as e: except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e) logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception as e: except Exception:
logger.exception("Failed to handle edu %r", edu_type) logger.exception("Failed to handle edu %r", edu_type)
def on_query(self, query_type, args): def on_query(self, query_type, args):

View File

@ -22,7 +22,7 @@ import bcrypt
import pymacaroons import pymacaroons
from canonicaljson import json from canonicaljson import json
from twisted.internet import defer, threads from twisted.internet import defer
from twisted.web.client import PartialDownloadError from twisted.web.client import PartialDownloadError
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
@ -37,8 +37,8 @@ from synapse.api.errors import (
) )
from synapse.module_api import ModuleApi from synapse.module_api import ModuleApi
from synapse.types import UserID from synapse.types import UserID
from synapse.util import logcontext
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logcontext import make_deferred_yieldable
from ._base import BaseHandler from ._base import BaseHandler
@ -884,11 +884,7 @@ class AuthHandler(BaseHandler):
bcrypt.gensalt(self.bcrypt_rounds), bcrypt.gensalt(self.bcrypt_rounds),
).decode('ascii') ).decode('ascii')
return make_deferred_yieldable( return logcontext.defer_to_thread(self.hs.get_reactor(), _do_hash)
threads.deferToThreadPool(
self.hs.get_reactor(), self.hs.get_reactor().getThreadPool(), _do_hash
),
)
def validate_hash(self, password, stored_hash): def validate_hash(self, password, stored_hash):
"""Validates that self.hash(password) == stored_hash. """Validates that self.hash(password) == stored_hash.
@ -913,13 +909,7 @@ class AuthHandler(BaseHandler):
if not isinstance(stored_hash, bytes): if not isinstance(stored_hash, bytes):
stored_hash = stored_hash.encode('ascii') stored_hash = stored_hash.encode('ascii')
return make_deferred_yieldable( return logcontext.defer_to_thread(self.hs.get_reactor(), _do_validate_hash)
threads.deferToThreadPool(
self.hs.get_reactor(),
self.hs.get_reactor().getThreadPool(),
_do_validate_hash,
),
)
else: else:
return defer.succeed(False) return defer.succeed(False)

View File

@ -17,8 +17,8 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
from synapse.util.logcontext import run_in_background
from ._base import BaseHandler from ._base import BaseHandler
@ -121,7 +121,7 @@ class DeactivateAccountHandler(BaseHandler):
None None
""" """
if not self._user_parter_running: if not self._user_parter_running:
run_in_background(self._user_parter_loop) run_as_background_process("user_parter_loop", self._user_parter_loop)
@defer.inlineCallbacks @defer.inlineCallbacks
def _user_parter_loop(self): def _user_parter_loop(self):

View File

@ -43,6 +43,7 @@ class DirectoryHandler(BaseHandler):
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.appservice_handler = hs.get_application_service_handler() self.appservice_handler = hs.get_application_service_handler()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.config = hs.config
self.federation = hs.get_federation_client() self.federation = hs.get_federation_client()
hs.get_federation_registry().register_query_handler( hs.get_federation_registry().register_query_handler(
@ -80,42 +81,68 @@ class DirectoryHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def create_association(self, user_id, room_alias, room_id, servers=None): def create_association(self, requester, room_alias, room_id, servers=None,
# association creation for human users send_event=True):
# TODO(erikj): Do user auth. """Attempt to create a new alias
if not self.spam_checker.user_may_create_room_alias(user_id, room_alias): Args:
raise SynapseError( requester (Requester)
403, "This user is not permitted to create this alias", room_alias (RoomAlias)
) room_id (str)
servers (list[str]|None): List of servers that others servers
should try and join via
send_event (bool): Whether to send an updated m.room.aliases event
can_create = yield self.can_modify_alias( Returns:
room_alias, Deferred
user_id=user_id """
)
if not can_create: user_id = requester.user.to_string()
raise SynapseError(
400, "This alias is reserved by an application service.", service = requester.app_service
errcode=Codes.EXCLUSIVE if service:
if not service.is_interested_in_alias(room_alias.to_string()):
raise SynapseError(
400, "This application service has not reserved"
" this kind of alias.", errcode=Codes.EXCLUSIVE
)
else:
if not self.spam_checker.user_may_create_room_alias(user_id, room_alias):
raise AuthError(
403, "This user is not permitted to create this alias",
)
if not self.config.is_alias_creation_allowed(user_id, room_alias.to_string()):
# Lets just return a generic message, as there may be all sorts of
# reasons why we said no. TODO: Allow configurable error messages
# per alias creation rule?
raise SynapseError(
403, "Not allowed to create alias",
)
can_create = yield self.can_modify_alias(
room_alias,
user_id=user_id
) )
if not can_create:
raise AuthError(
400, "This alias is reserved by an application service.",
errcode=Codes.EXCLUSIVE
)
yield self._create_association(room_alias, room_id, servers, creator=user_id) yield self._create_association(room_alias, room_id, servers, creator=user_id)
if send_event:
@defer.inlineCallbacks yield self.send_room_alias_update_event(
def create_appservice_association(self, service, room_alias, room_id, requester,
servers=None): room_id
if not service.is_interested_in_alias(room_alias.to_string()):
raise SynapseError(
400, "This application service has not reserved"
" this kind of alias.", errcode=Codes.EXCLUSIVE
) )
# association creation for app services
yield self._create_association(room_alias, room_id, servers)
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_association(self, requester, user_id, room_alias): def delete_association(self, requester, room_alias):
# association deletion for human users # association deletion for human users
user_id = requester.user.to_string()
try: try:
can_delete = yield self._user_can_delete_alias(room_alias, user_id) can_delete = yield self._user_can_delete_alias(room_alias, user_id)
except StoreError as e: except StoreError as e:
@ -143,7 +170,6 @@ class DirectoryHandler(BaseHandler):
try: try:
yield self.send_room_alias_update_event( yield self.send_room_alias_update_event(
requester, requester,
requester.user.to_string(),
room_id room_id
) )
@ -261,7 +287,7 @@ class DirectoryHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def send_room_alias_update_event(self, requester, user_id, room_id): def send_room_alias_update_event(self, requester, room_id):
aliases = yield self.store.get_aliases_for_room(room_id) aliases = yield self.store.get_aliases_for_room(room_id)
yield self.event_creation_handler.create_and_send_nonmember_event( yield self.event_creation_handler.create_and_send_nonmember_event(
@ -270,7 +296,7 @@ class DirectoryHandler(BaseHandler):
"type": EventTypes.Aliases, "type": EventTypes.Aliases,
"state_key": self.hs.hostname, "state_key": self.hs.hostname,
"room_id": room_id, "room_id": room_id,
"sender": user_id, "sender": requester.user.to_string(),
"content": {"aliases": aliases}, "content": {"aliases": aliases},
}, },
ratelimit=False ratelimit=False

View File

@ -53,7 +53,7 @@ from synapse.replication.http.federation import (
ReplicationFederationSendEventsRestServlet, ReplicationFederationSendEventsRestServlet,
) )
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import resolve_events_with_factory from synapse.state import StateResolutionStore, resolve_events_with_store
from synapse.types import UserID, get_domain_from_id from synapse.types import UserID, get_domain_from_id
from synapse.util import logcontext, unwrapFirstError from synapse.util import logcontext, unwrapFirstError
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
@ -384,24 +384,24 @@ class FederationHandler(BaseHandler):
for x in remote_state: for x in remote_state:
event_map[x.event_id] = x event_map[x.event_id] = x
# Resolve any conflicting state
@defer.inlineCallbacks
def fetch(ev_ids):
fetched = yield self.store.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
)
# add any events we fetch here to the `event_map` so that we
# can use them to build the state event list below.
event_map.update(fetched)
defer.returnValue(fetched)
room_version = yield self.store.get_room_version(room_id) room_version = yield self.store.get_room_version(room_id)
state_map = yield resolve_events_with_factory( state_map = yield resolve_events_with_store(
room_version, state_maps, event_map, fetch, room_version, state_maps, event_map,
state_res_store=StateResolutionStore(self.store),
) )
# we need to give _process_received_pdu the actual state events # We need to give _process_received_pdu the actual state events
# rather than event ids, so generate that now. # rather than event ids, so generate that now.
# First though we need to fetch all the events that are in
# state_map, so we can build up the state below.
evs = yield self.store.get_events(
list(state_map.values()),
get_prev_content=False,
check_redacted=False,
)
event_map.update(evs)
state = [ state = [
event_map[e] for e in six.itervalues(state_map) event_map[e] for e in six.itervalues(state_map)
] ]
@ -2520,7 +2520,7 @@ class FederationHandler(BaseHandler):
if not backfilled: # Never notify for backfilled events if not backfilled: # Never notify for backfilled events
for event, _ in event_and_contexts: for event, _ in event_and_contexts:
self._notify_persisted_event(event, max_stream_id) yield self._notify_persisted_event(event, max_stream_id)
def _notify_persisted_event(self, event, max_stream_id): def _notify_persisted_event(self, event, max_stream_id):
"""Checks to see if notifier/pushers should be notified about the """Checks to see if notifier/pushers should be notified about the
@ -2553,7 +2553,7 @@ class FederationHandler(BaseHandler):
extra_users=extra_users extra_users=extra_users
) )
self.pusher_pool.on_new_notifications( return self.pusher_pool.on_new_notifications(
event_stream_id, max_stream_id, event_stream_id, max_stream_id,
) )

View File

@ -20,7 +20,7 @@ from six import iteritems
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import HttpResponseException, SynapseError
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -37,9 +37,23 @@ def _create_rerouter(func_name):
) )
else: else:
destination = get_domain_from_id(group_id) destination = get_domain_from_id(group_id)
return getattr(self.transport_client, func_name)( d = getattr(self.transport_client, func_name)(
destination, group_id, *args, **kwargs destination, group_id, *args, **kwargs
) )
# Capture errors returned by the remote homeserver and
# re-throw specific errors as SynapseErrors. This is so
# when the remote end responds with things like 403 Not
# In Group, we can communicate that to the client instead
# of a 500.
def h(failure):
failure.trap(HttpResponseException)
e = failure.value
if e.code == 403:
raise e.to_synapse_error()
return failure
d.addErrback(h)
return d
return f return f

View File

@ -156,7 +156,7 @@ class InitialSyncHandler(BaseHandler):
room_end_token = "s%d" % (event.stream_ordering,) room_end_token = "s%d" % (event.stream_ordering,)
deferred_room_state = run_in_background( deferred_room_state = run_in_background(
self.store.get_state_for_events, self.store.get_state_for_events,
[event.event_id], None, [event.event_id],
) )
deferred_room_state.addCallback( deferred_room_state.addCallback(
lambda states: states[event.event_id] lambda states: states[event.event_id]
@ -301,7 +301,7 @@ class InitialSyncHandler(BaseHandler):
def _room_initial_sync_parted(self, user_id, room_id, pagin_config, def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
membership, member_event_id, is_peeking): membership, member_event_id, is_peeking):
room_state = yield self.store.get_state_for_events( room_state = yield self.store.get_state_for_events(
[member_event_id], None [member_event_id],
) )
room_state = room_state[member_event_id] room_state = room_state[member_event_id]

View File

@ -35,6 +35,7 @@ from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, UserID from synapse.types import RoomAlias, UserID
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.frozenutils import frozendict_json_encoder
@ -80,7 +81,7 @@ class MessageHandler(object):
elif membership == Membership.LEAVE: elif membership == Membership.LEAVE:
key = (event_type, state_key) key = (event_type, state_key)
room_state = yield self.store.get_state_for_events( room_state = yield self.store.get_state_for_events(
[membership_event_id], [key] [membership_event_id], StateFilter.from_types([key])
) )
data = room_state[membership_event_id].get(key) data = room_state[membership_event_id].get(key)
@ -88,7 +89,7 @@ class MessageHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_events( def get_state_events(
self, user_id, room_id, types=None, filtered_types=None, self, user_id, room_id, state_filter=StateFilter.all(),
at_token=None, is_guest=False, at_token=None, is_guest=False,
): ):
"""Retrieve all state events for a given room. If the user is """Retrieve all state events for a given room. If the user is
@ -100,13 +101,8 @@ class MessageHandler(object):
Args: Args:
user_id(str): The user requesting state events. user_id(str): The user requesting state events.
room_id(str): The room ID to get all state events from. room_id(str): The room ID to get all state events from.
types(list[(str, str|None)]|None): List of (type, state_key) tuples state_filter (StateFilter): The state filter used to fetch state
which are used to filter the state fetched. If `state_key` is None, from the database.
all events are returned of the given type.
May be None, which matches any key.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
at_token(StreamToken|None): the stream token of the at which we are requesting at_token(StreamToken|None): the stream token of the at which we are requesting
the stats. If the user is not allowed to view the state as of that the stats. If the user is not allowed to view the state as of that
stream token, we raise a 403 SynapseError. If None, returns the current stream token, we raise a 403 SynapseError. If None, returns the current
@ -139,7 +135,7 @@ class MessageHandler(object):
event = last_events[0] event = last_events[0]
if visible_events: if visible_events:
room_state = yield self.store.get_state_for_events( room_state = yield self.store.get_state_for_events(
[event.event_id], types, filtered_types=filtered_types, [event.event_id], state_filter=state_filter,
) )
room_state = room_state[event.event_id] room_state = room_state[event.event_id]
else: else:
@ -158,12 +154,12 @@ class MessageHandler(object):
if membership == Membership.JOIN: if membership == Membership.JOIN:
state_ids = yield self.store.get_filtered_current_state_ids( state_ids = yield self.store.get_filtered_current_state_ids(
room_id, types, filtered_types=filtered_types, room_id, state_filter=state_filter,
) )
room_state = yield self.store.get_events(state_ids.values()) room_state = yield self.store.get_events(state_ids.values())
elif membership == Membership.LEAVE: elif membership == Membership.LEAVE:
room_state = yield self.store.get_state_for_events( room_state = yield self.store.get_state_for_events(
[membership_event_id], types, filtered_types=filtered_types, [membership_event_id], state_filter=state_filter,
) )
room_state = room_state[membership_event_id] room_state = room_state[membership_event_id]
@ -779,7 +775,7 @@ class EventCreationHandler(object):
event, context=context event, context=context
) )
self.pusher_pool.on_new_notifications( yield self.pusher_pool.on_new_notifications(
event_stream_id, max_stream_id, event_stream_id, max_stream_id,
) )

View File

@ -21,6 +21,7 @@ from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.storage.state import StateFilter
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
from synapse.util.async_helpers import ReadWriteLock from synapse.util.async_helpers import ReadWriteLock
from synapse.util.logcontext import run_in_background from synapse.util.logcontext import run_in_background
@ -255,16 +256,14 @@ class PaginationHandler(object):
if event_filter and event_filter.lazy_load_members(): if event_filter and event_filter.lazy_load_members():
# TODO: remove redundant members # TODO: remove redundant members
types = [ # FIXME: we also care about invite targets etc.
(EventTypes.Member, state_key) state_filter = StateFilter.from_types(
for state_key in set( (EventTypes.Member, event.sender)
event.sender # FIXME: we also care about invite targets etc. for event in events
for event in events )
)
]
state_ids = yield self.store.get_state_ids_for_event( state_ids = yield self.store.get_state_ids_for_event(
events[0].event_id, types=types, events[0].event_id, state_filter=state_filter,
) )
if state_ids: if state_ids:

View File

@ -119,7 +119,7 @@ class ReceiptsHandler(BaseHandler):
"receipt_key", max_batch_id, rooms=affected_room_ids "receipt_key", max_batch_id, rooms=affected_room_ids
) )
# Note that the min here shouldn't be relied upon to be accurate. # Note that the min here shouldn't be relied upon to be accurate.
self.hs.get_pusherpool().on_new_receipts( yield self.hs.get_pusherpool().on_new_receipts(
min_batch_id, max_batch_id, affected_room_ids, min_batch_id, max_batch_id, affected_room_ids,
) )

View File

@ -50,6 +50,7 @@ class RegistrationHandler(BaseHandler):
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler() self.user_directory_handler = hs.get_user_directory_handler()
self.room_creation_handler = self.hs.get_room_creation_handler()
self.captcha_client = CaptchaServerHttpClient(hs) self.captcha_client = CaptchaServerHttpClient(hs)
self._next_generated_user_id = None self._next_generated_user_id = None
@ -220,9 +221,36 @@ class RegistrationHandler(BaseHandler):
# auto-join the user to any rooms we're supposed to dump them into # auto-join the user to any rooms we're supposed to dump them into
fake_requester = create_requester(user_id) fake_requester = create_requester(user_id)
# try to create the room if we're the first user on the server
should_auto_create_rooms = False
if self.hs.config.autocreate_auto_join_rooms:
count = yield self.store.count_all_users()
should_auto_create_rooms = count == 1
for r in self.hs.config.auto_join_rooms: for r in self.hs.config.auto_join_rooms:
try: try:
yield self._join_user_to_room(fake_requester, r) if should_auto_create_rooms:
room_alias = RoomAlias.from_string(r)
if self.hs.hostname != room_alias.domain:
logger.warning(
'Cannot create room alias %s, '
'it does not match server domain',
r,
)
else:
# create room expects the localpart of the room alias
room_alias_localpart = room_alias.localpart
yield self.room_creation_handler.create_room(
fake_requester,
config={
"preset": "public_chat",
"room_alias_name": room_alias_localpart
},
ratelimit=False,
)
else:
yield self._join_user_to_room(fake_requester, r)
except Exception as e: except Exception as e:
logger.error("Failed to join new user to %r: %r", r, e) logger.error("Failed to join new user to %r: %r", r, e)

View File

@ -33,6 +33,7 @@ from synapse.api.constants import (
RoomCreationPreset, RoomCreationPreset,
) )
from synapse.api.errors import AuthError, Codes, StoreError, SynapseError from synapse.api.errors import AuthError, Codes, StoreError, SynapseError
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID
from synapse.util import stringutils from synapse.util import stringutils
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
@ -190,10 +191,11 @@ class RoomCreationHandler(BaseHandler):
if room_alias: if room_alias:
directory_handler = self.hs.get_handlers().directory_handler directory_handler = self.hs.get_handlers().directory_handler
yield directory_handler.create_association( yield directory_handler.create_association(
user_id=user_id, requester=requester,
room_id=room_id, room_id=room_id,
room_alias=room_alias, room_alias=room_alias,
servers=[self.hs.hostname], servers=[self.hs.hostname],
send_event=False,
) )
preset_config = config.get( preset_config = config.get(
@ -289,7 +291,7 @@ class RoomCreationHandler(BaseHandler):
if room_alias: if room_alias:
result["room_alias"] = room_alias.to_string() result["room_alias"] = room_alias.to_string()
yield directory_handler.send_room_alias_update_event( yield directory_handler.send_room_alias_update_event(
requester, user_id, room_id requester, room_id
) )
defer.returnValue(result) defer.returnValue(result)
@ -488,23 +490,24 @@ class RoomContextHandler(object):
else: else:
last_event_id = event_id last_event_id = event_id
types = None
filtered_types = None
if event_filter and event_filter.lazy_load_members(): if event_filter and event_filter.lazy_load_members():
members = set(ev.sender for ev in itertools.chain( state_filter = StateFilter.from_lazy_load_member_list(
results["events_before"], ev.sender
(results["event"],), for ev in itertools.chain(
results["events_after"], results["events_before"],
)) (results["event"],),
filtered_types = [EventTypes.Member] results["events_after"],
types = [(EventTypes.Member, member) for member in members] )
)
else:
state_filter = StateFilter.all()
# XXX: why do we return the state as of the last event rather than the # XXX: why do we return the state as of the last event rather than the
# first? Shouldn't we be consistent with /sync? # first? Shouldn't we be consistent with /sync?
# https://github.com/matrix-org/matrix-doc/issues/687 # https://github.com/matrix-org/matrix-doc/issues/687
state = yield self.store.get_state_for_events( state = yield self.store.get_state_for_events(
[last_event_id], types, filtered_types=filtered_types, [last_event_id], state_filter=state_filter,
) )
results["state"] = list(state[last_event_id].values()) results["state"] = list(state[last_event_id].values())

View File

@ -27,6 +27,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.push.clientformat import format_push_rules_for_user from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.roommember import MemberSummary from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
from synapse.util.async_helpers import concurrently_execute from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
@ -469,25 +470,20 @@ class SyncHandler(object):
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_after_event(self, event, types=None, filtered_types=None): def get_state_after_event(self, event, state_filter=StateFilter.all()):
""" """
Get the room state after the given event Get the room state after the given event
Args: Args:
event(synapse.events.EventBase): event of interest event(synapse.events.EventBase): event of interest
types(list[(str, str|None)]|None): List of (type, state_key) tuples state_filter (StateFilter): The state filter used to fetch state
which are used to filter the state fetched. If `state_key` is None, from the database.
all events are returned of the given type.
May be None, which matches any key.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns: Returns:
A Deferred map from ((type, state_key)->Event) A Deferred map from ((type, state_key)->Event)
""" """
state_ids = yield self.store.get_state_ids_for_event( state_ids = yield self.store.get_state_ids_for_event(
event.event_id, types, filtered_types=filtered_types, event.event_id, state_filter=state_filter,
) )
if event.is_state(): if event.is_state():
state_ids = state_ids.copy() state_ids = state_ids.copy()
@ -495,18 +491,14 @@ class SyncHandler(object):
defer.returnValue(state_ids) defer.returnValue(state_ids)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_at(self, room_id, stream_position, types=None, filtered_types=None): def get_state_at(self, room_id, stream_position, state_filter=StateFilter.all()):
""" Get the room state at a particular stream position """ Get the room state at a particular stream position
Args: Args:
room_id(str): room for which to get state room_id(str): room for which to get state
stream_position(StreamToken): point at which to get state stream_position(StreamToken): point at which to get state
types(list[(str, str|None)]|None): List of (type, state_key) tuples state_filter (StateFilter): The state filter used to fetch state
which are used to filter the state fetched. If `state_key` is None, from the database.
all events are returned of the given type.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns: Returns:
A Deferred map from ((type, state_key)->Event) A Deferred map from ((type, state_key)->Event)
@ -522,7 +514,7 @@ class SyncHandler(object):
if last_events: if last_events:
last_event = last_events[-1] last_event = last_events[-1]
state = yield self.get_state_after_event( state = yield self.get_state_after_event(
last_event, types, filtered_types=filtered_types, last_event, state_filter=state_filter,
) )
else: else:
@ -563,10 +555,11 @@ class SyncHandler(object):
last_event = last_events[-1] last_event = last_events[-1]
state_ids = yield self.store.get_state_ids_for_event( state_ids = yield self.store.get_state_ids_for_event(
last_event.event_id, [ last_event.event_id,
state_filter=StateFilter.from_types([
(EventTypes.Name, ''), (EventTypes.Name, ''),
(EventTypes.CanonicalAlias, ''), (EventTypes.CanonicalAlias, ''),
] ]),
) )
# this is heavily cached, thus: fast. # this is heavily cached, thus: fast.
@ -717,8 +710,7 @@ class SyncHandler(object):
with Measure(self.clock, "compute_state_delta"): with Measure(self.clock, "compute_state_delta"):
types = None members_to_fetch = None
filtered_types = None
lazy_load_members = sync_config.filter_collection.lazy_load_members() lazy_load_members = sync_config.filter_collection.lazy_load_members()
include_redundant_members = ( include_redundant_members = (
@ -729,16 +721,21 @@ class SyncHandler(object):
# We only request state for the members needed to display the # We only request state for the members needed to display the
# timeline: # timeline:
types = [ members_to_fetch = set(
(EventTypes.Member, state_key) event.sender # FIXME: we also care about invite targets etc.
for state_key in set( for event in batch.events
event.sender # FIXME: we also care about invite targets etc. )
for event in batch.events
)
]
# only apply the filtering to room members if full_state:
filtered_types = [EventTypes.Member] # always make sure we LL ourselves so we know we're in the room
# (if we are) to fix https://github.com/vector-im/riot-web/issues/7209
# We only need apply this on full state syncs given we disabled
# LL for incr syncs in #3840.
members_to_fetch.add(sync_config.user.to_string())
state_filter = StateFilter.from_lazy_load_member_list(members_to_fetch)
else:
state_filter = StateFilter.all()
timeline_state = { timeline_state = {
(event.type, event.state_key): event.event_id (event.type, event.state_key): event.event_id
@ -746,28 +743,19 @@ class SyncHandler(object):
} }
if full_state: if full_state:
if lazy_load_members:
# always make sure we LL ourselves so we know we're in the room
# (if we are) to fix https://github.com/vector-im/riot-web/issues/7209
# We only need apply this on full state syncs given we disabled
# LL for incr syncs in #3840.
types.append((EventTypes.Member, sync_config.user.to_string()))
if batch: if batch:
current_state_ids = yield self.store.get_state_ids_for_event( current_state_ids = yield self.store.get_state_ids_for_event(
batch.events[-1].event_id, types=types, batch.events[-1].event_id, state_filter=state_filter,
filtered_types=filtered_types,
) )
state_ids = yield self.store.get_state_ids_for_event( state_ids = yield self.store.get_state_ids_for_event(
batch.events[0].event_id, types=types, batch.events[0].event_id, state_filter=state_filter,
filtered_types=filtered_types,
) )
else: else:
current_state_ids = yield self.get_state_at( current_state_ids = yield self.get_state_at(
room_id, stream_position=now_token, types=types, room_id, stream_position=now_token,
filtered_types=filtered_types, state_filter=state_filter,
) )
state_ids = current_state_ids state_ids = current_state_ids
@ -781,8 +769,7 @@ class SyncHandler(object):
) )
elif batch.limited: elif batch.limited:
state_at_timeline_start = yield self.store.get_state_ids_for_event( state_at_timeline_start = yield self.store.get_state_ids_for_event(
batch.events[0].event_id, types=types, batch.events[0].event_id, state_filter=state_filter,
filtered_types=filtered_types,
) )
# for now, we disable LL for gappy syncs - see # for now, we disable LL for gappy syncs - see
@ -797,17 +784,15 @@ class SyncHandler(object):
# members to just be ones which were timeline senders, which then ensures # members to just be ones which were timeline senders, which then ensures
# all of the rest get included in the state block (if we need to know # all of the rest get included in the state block (if we need to know
# about them). # about them).
types = None state_filter = StateFilter.all()
filtered_types = None
state_at_previous_sync = yield self.get_state_at( state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token, types=types, room_id, stream_position=since_token,
filtered_types=filtered_types, state_filter=state_filter,
) )
current_state_ids = yield self.store.get_state_ids_for_event( current_state_ids = yield self.store.get_state_ids_for_event(
batch.events[-1].event_id, types=types, batch.events[-1].event_id, state_filter=state_filter,
filtered_types=filtered_types,
) )
state_ids = _calculate_state( state_ids = _calculate_state(
@ -821,7 +806,7 @@ class SyncHandler(object):
else: else:
state_ids = {} state_ids = {}
if lazy_load_members: if lazy_load_members:
if types and batch.events: if members_to_fetch and batch.events:
# We're returning an incremental sync, with no # We're returning an incremental sync, with no
# "gap" since the previous sync, so normally there would be # "gap" since the previous sync, so normally there would be
# no state to return. # no state to return.
@ -831,8 +816,12 @@ class SyncHandler(object):
# timeline here, and then dedupe any redundant ones below. # timeline here, and then dedupe any redundant ones below.
state_ids = yield self.store.get_state_ids_for_event( state_ids = yield self.store.get_state_ids_for_event(
batch.events[0].event_id, types=types, batch.events[0].event_id,
filtered_types=None, # we only want members! # we only want members!
state_filter=StateFilter.from_types(
(EventTypes.Member, member)
for member in members_to_fetch
),
) )
if lazy_load_members and not include_redundant_members: if lazy_load_members and not include_redundant_members:

View File

@ -20,6 +20,7 @@ from six import iteritems
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.roommember import ProfileInfo from synapse.storage.roommember import ProfileInfo
from synapse.types import get_localpart_from_id from synapse.types import get_localpart_from_id
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -98,7 +99,6 @@ class UserDirectoryHandler(object):
""" """
return self.store.search_user_dir(user_id, search_term, limit) return self.store.search_user_dir(user_id, search_term, limit)
@defer.inlineCallbacks
def notify_new_event(self): def notify_new_event(self):
"""Called when there may be more deltas to process """Called when there may be more deltas to process
""" """
@ -108,11 +108,15 @@ class UserDirectoryHandler(object):
if self._is_processing: if self._is_processing:
return return
@defer.inlineCallbacks
def process():
try:
yield self._unsafe_process()
finally:
self._is_processing = False
self._is_processing = True self._is_processing = True
try: run_as_background_process("user_directory.notify_new_event", process)
yield self._unsafe_process()
finally:
self._is_processing = False
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_local_profile_change(self, user_id, profile): def handle_local_profile_change(self, user_id, profile):

View File

@ -230,7 +230,7 @@ class MatrixFederationHttpClient(object):
Returns: Returns:
Deferred: resolves with the http response object on success. Deferred: resolves with the http response object on success.
Fails with ``HTTPRequestException``: if we get an HTTP response Fails with ``HttpResponseException``: if we get an HTTP response
code >= 300. code >= 300.
Fails with ``NotRetryingDestination`` if we are not yet ready Fails with ``NotRetryingDestination`` if we are not yet ready
@ -480,7 +480,7 @@ class MatrixFederationHttpClient(object):
Deferred: Succeeds when we get a 2xx HTTP response. The result Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body. will be the decoded JSON body.
Fails with ``HTTPRequestException`` if we get an HTTP response Fails with ``HttpResponseException`` if we get an HTTP response
code >= 300. code >= 300.
Fails with ``NotRetryingDestination`` if we are not yet ready Fails with ``NotRetryingDestination`` if we are not yet ready
@ -534,7 +534,7 @@ class MatrixFederationHttpClient(object):
Deferred: Succeeds when we get a 2xx HTTP response. The result Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body. will be the decoded JSON body.
Fails with ``HTTPRequestException`` if we get an HTTP response Fails with ``HttpResponseException`` if we get an HTTP response
code >= 300. code >= 300.
Fails with ``NotRetryingDestination`` if we are not yet ready Fails with ``NotRetryingDestination`` if we are not yet ready
@ -589,7 +589,7 @@ class MatrixFederationHttpClient(object):
Deferred: Succeeds when we get a 2xx HTTP response. The result Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body. will be the decoded JSON body.
Fails with ``HTTPRequestException`` if we get an HTTP response Fails with ``HttpResponseException`` if we get an HTTP response
code >= 300. code >= 300.
Fails with ``NotRetryingDestination`` if we are not yet ready Fails with ``NotRetryingDestination`` if we are not yet ready
@ -640,7 +640,7 @@ class MatrixFederationHttpClient(object):
Deferred: Succeeds when we get a 2xx HTTP response. The result Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body. will be the decoded JSON body.
Fails with ``HTTPRequestException`` if we get an HTTP response Fails with ``HttpResponseException`` if we get an HTTP response
code >= 300. code >= 300.
Fails with ``NotRetryingDestination`` if we are not yet ready Fails with ``NotRetryingDestination`` if we are not yet ready
@ -684,7 +684,7 @@ class MatrixFederationHttpClient(object):
Deferred: resolves with an (int,dict) tuple of the file length and Deferred: resolves with an (int,dict) tuple of the file length and
a dict of the response headers. a dict of the response headers.
Fails with ``HTTPRequestException`` if we get an HTTP response code Fails with ``HttpResponseException`` if we get an HTTP response code
>= 300 >= 300
Fails with ``NotRetryingDestination`` if we are not yet ready Fails with ``NotRetryingDestination`` if we are not yet ready

View File

@ -18,8 +18,7 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.error import AlreadyCalled, AlreadyCancelled from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from synapse.util.logcontext import LoggingContext from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -71,18 +70,11 @@ class EmailPusher(object):
# See httppusher # See httppusher
self.max_stream_ordering = None self.max_stream_ordering = None
self.processing = False self._is_processing = False
@defer.inlineCallbacks
def on_started(self): def on_started(self):
if self.mailer is not None: if self.mailer is not None:
try: self._start_processing()
self.throttle_params = yield self.store.get_throttle_params_by_room(
self.pusher_id
)
yield self._process()
except Exception:
logger.exception("Error starting email pusher")
def on_stop(self): def on_stop(self):
if self.timed_call: if self.timed_call:
@ -92,43 +84,52 @@ class EmailPusher(object):
pass pass
self.timed_call = None self.timed_call = None
@defer.inlineCallbacks
def on_new_notifications(self, min_stream_ordering, max_stream_ordering): def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering) self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering)
yield self._process() self._start_processing()
def on_new_receipts(self, min_stream_id, max_stream_id): def on_new_receipts(self, min_stream_id, max_stream_id):
# We could wake up and cancel the timer but there tend to be quite a # We could wake up and cancel the timer but there tend to be quite a
# lot of read receipts so it's probably less work to just let the # lot of read receipts so it's probably less work to just let the
# timer fire # timer fire
return defer.succeed(None) pass
@defer.inlineCallbacks
def on_timer(self): def on_timer(self):
self.timed_call = None self.timed_call = None
yield self._process() self._start_processing()
def _start_processing(self):
if self._is_processing:
return
run_as_background_process("emailpush.process", self._process)
@defer.inlineCallbacks @defer.inlineCallbacks
def _process(self): def _process(self):
if self.processing: # we should never get here if we are already processing
return assert not self._is_processing
with LoggingContext("emailpush._process"): try:
with Measure(self.clock, "emailpush._process"): self._is_processing = True
if self.throttle_params is None:
# this is our first loop: load up the throttle params
self.throttle_params = yield self.store.get_throttle_params_by_room(
self.pusher_id
)
# if the max ordering changes while we're running _unsafe_process,
# call it again, and so on until we've caught up.
while True:
starting_max_ordering = self.max_stream_ordering
try: try:
self.processing = True yield self._unsafe_process()
# if the max ordering changes while we're running _unsafe_process, except Exception:
# call it again, and so on until we've caught up. logger.exception("Exception processing notifs")
while True: if self.max_stream_ordering == starting_max_ordering:
starting_max_ordering = self.max_stream_ordering break
try: finally:
yield self._unsafe_process() self._is_processing = False
except Exception:
logger.exception("Exception processing notifs")
if self.max_stream_ordering == starting_max_ordering:
break
finally:
self.processing = False
@defer.inlineCallbacks @defer.inlineCallbacks
def _unsafe_process(self): def _unsafe_process(self):

View File

@ -22,9 +22,8 @@ from prometheus_client import Counter
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.error import AlreadyCalled, AlreadyCancelled from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import PusherConfigException from synapse.push import PusherConfigException
from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure
from . import push_rule_evaluator, push_tools from . import push_rule_evaluator, push_tools
@ -61,7 +60,7 @@ class HttpPusher(object):
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.failing_since = pusherdict['failing_since'] self.failing_since = pusherdict['failing_since']
self.timed_call = None self.timed_call = None
self.processing = False self._is_processing = False
# This is the highest stream ordering we know it's safe to process. # This is the highest stream ordering we know it's safe to process.
# When new events arrive, we'll be given a window of new events: we # When new events arrive, we'll be given a window of new events: we
@ -92,34 +91,27 @@ class HttpPusher(object):
self.data_minus_url.update(self.data) self.data_minus_url.update(self.data)
del self.data_minus_url['url'] del self.data_minus_url['url']
@defer.inlineCallbacks
def on_started(self): def on_started(self):
try: self._start_processing()
yield self._process()
except Exception:
logger.exception("Error starting http pusher")
@defer.inlineCallbacks
def on_new_notifications(self, min_stream_ordering, max_stream_ordering): def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering or 0) self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering or 0)
yield self._process() self._start_processing()
@defer.inlineCallbacks
def on_new_receipts(self, min_stream_id, max_stream_id): def on_new_receipts(self, min_stream_id, max_stream_id):
# Note that the min here shouldn't be relied upon to be accurate. # Note that the min here shouldn't be relied upon to be accurate.
# We could check the receipts are actually m.read receipts here, # We could check the receipts are actually m.read receipts here,
# but currently that's the only type of receipt anyway... # but currently that's the only type of receipt anyway...
with LoggingContext("push.on_new_receipts"): run_as_background_process("http_pusher.on_new_receipts", self._update_badge)
with Measure(self.clock, "push.on_new_receipts"):
badge = yield push_tools.get_badge_count(
self.hs.get_datastore(), self.user_id
)
yield self._send_badge(badge)
@defer.inlineCallbacks @defer.inlineCallbacks
def _update_badge(self):
badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
yield self._send_badge(badge)
def on_timer(self): def on_timer(self):
yield self._process() self._start_processing()
def on_stop(self): def on_stop(self):
if self.timed_call: if self.timed_call:
@ -129,27 +121,31 @@ class HttpPusher(object):
pass pass
self.timed_call = None self.timed_call = None
@defer.inlineCallbacks def _start_processing(self):
def _process(self): if self._is_processing:
if self.processing:
return return
with LoggingContext("push._process"): run_as_background_process("httppush.process", self._process)
with Measure(self.clock, "push._process"):
@defer.inlineCallbacks
def _process(self):
# we should never get here if we are already processing
assert not self._is_processing
try:
self._is_processing = True
# if the max ordering changes while we're running _unsafe_process,
# call it again, and so on until we've caught up.
while True:
starting_max_ordering = self.max_stream_ordering
try: try:
self.processing = True yield self._unsafe_process()
# if the max ordering changes while we're running _unsafe_process, except Exception:
# call it again, and so on until we've caught up. logger.exception("Exception processing notifs")
while True: if self.max_stream_ordering == starting_max_ordering:
starting_max_ordering = self.max_stream_ordering break
try: finally:
yield self._unsafe_process() self._is_processing = False
except Exception:
logger.exception("Exception processing notifs")
if self.max_stream_ordering == starting_max_ordering:
break
finally:
self.processing = False
@defer.inlineCallbacks @defer.inlineCallbacks
def _unsafe_process(self): def _unsafe_process(self):

View File

@ -526,12 +526,8 @@ def load_jinja2_templates(config):
Returns: Returns:
(notif_template_html, notif_template_text) (notif_template_html, notif_template_text)
""" """
logger.info("loading jinja2") logger.info("loading email templates from '%s'", config.email_template_dir)
loader = jinja2.FileSystemLoader(config.email_template_dir)
if config.email_template_dir:
loader = jinja2.FileSystemLoader(config.email_template_dir)
else:
loader = jinja2.PackageLoader('synapse', 'res/templates')
env = jinja2.Environment(loader=loader) env = jinja2.Environment(loader=loader)
env.filters["format_ts"] = format_ts_filter env.filters["format_ts"] = format_ts_filter
env.filters["mxc_to_http"] = _create_mxc_to_http_filter(config) env.filters["mxc_to_http"] = _create_mxc_to_http_filter(config)

View File

@ -20,24 +20,39 @@ from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push.pusher import PusherFactory from synapse.push.pusher import PusherFactory
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PusherPool: class PusherPool:
"""
The pusher pool. This is responsible for dispatching notifications of new events to
the http and email pushers.
It provides three methods which are designed to be called by the rest of the
application: `start`, `on_new_notifications`, and `on_new_receipts`: each of these
delegates to each of the relevant pushers.
Note that it is expected that each pusher will have its own 'processing' loop which
will send out the notifications in the background, rather than blocking until the
notifications are sent; accordingly Pusher.on_started, Pusher.on_new_notifications and
Pusher.on_new_receipts are not expected to return deferreds.
"""
def __init__(self, _hs): def __init__(self, _hs):
self.hs = _hs self.hs = _hs
self.pusher_factory = PusherFactory(_hs) self.pusher_factory = PusherFactory(_hs)
self.start_pushers = _hs.config.start_pushers self._should_start_pushers = _hs.config.start_pushers
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
self.pushers = {} self.pushers = {}
@defer.inlineCallbacks
def start(self): def start(self):
pushers = yield self.store.get_all_pushers() """Starts the pushers off in a background process.
self._start_pushers(pushers) """
if not self._should_start_pushers:
logger.info("Not starting pushers because they are disabled in the config")
return
run_as_background_process("start_pushers", self._start_pushers)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_pusher(self, user_id, access_token, kind, app_id, def add_pusher(self, user_id, access_token, kind, app_id,
@ -86,7 +101,7 @@ class PusherPool:
last_stream_ordering=last_stream_ordering, last_stream_ordering=last_stream_ordering,
profile_tag=profile_tag, profile_tag=profile_tag,
) )
yield self._refresh_pusher(app_id, pushkey, user_id) yield self.start_pusher_by_id(app_id, pushkey, user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey, def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey,
@ -123,45 +138,23 @@ class PusherPool:
p['app_id'], p['pushkey'], p['user_name'], p['app_id'], p['pushkey'], p['user_name'],
) )
def on_new_notifications(self, min_stream_id, max_stream_id):
run_as_background_process(
"on_new_notifications",
self._on_new_notifications, min_stream_id, max_stream_id,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _on_new_notifications(self, min_stream_id, max_stream_id): def on_new_notifications(self, min_stream_id, max_stream_id):
try: try:
users_affected = yield self.store.get_push_action_users_in_range( users_affected = yield self.store.get_push_action_users_in_range(
min_stream_id, max_stream_id min_stream_id, max_stream_id
) )
deferreds = []
for u in users_affected: for u in users_affected:
if u in self.pushers: if u in self.pushers:
for p in self.pushers[u].values(): for p in self.pushers[u].values():
deferreds.append( p.on_new_notifications(min_stream_id, max_stream_id)
run_in_background(
p.on_new_notifications,
min_stream_id, max_stream_id,
)
)
yield make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True),
)
except Exception: except Exception:
logger.exception("Exception in pusher on_new_notifications") logger.exception("Exception in pusher on_new_notifications")
def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
run_as_background_process(
"on_new_receipts",
self._on_new_receipts, min_stream_id, max_stream_id, affected_room_ids,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
try: try:
# Need to subtract 1 from the minimum because the lower bound here # Need to subtract 1 from the minimum because the lower bound here
# is not inclusive # is not inclusive
@ -171,26 +164,20 @@ class PusherPool:
# This returns a tuple, user_id is at index 3 # This returns a tuple, user_id is at index 3
users_affected = set([r[3] for r in updated_receipts]) users_affected = set([r[3] for r in updated_receipts])
deferreds = []
for u in users_affected: for u in users_affected:
if u in self.pushers: if u in self.pushers:
for p in self.pushers[u].values(): for p in self.pushers[u].values():
deferreds.append( p.on_new_receipts(min_stream_id, max_stream_id)
run_in_background(
p.on_new_receipts,
min_stream_id, max_stream_id,
)
)
yield make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True),
)
except Exception: except Exception:
logger.exception("Exception in pusher on_new_receipts") logger.exception("Exception in pusher on_new_receipts")
@defer.inlineCallbacks @defer.inlineCallbacks
def _refresh_pusher(self, app_id, pushkey, user_id): def start_pusher_by_id(self, app_id, pushkey, user_id):
"""Look up the details for the given pusher, and start it"""
if not self._should_start_pushers:
return
resultlist = yield self.store.get_pushers_by_app_id_and_pushkey( resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(
app_id, pushkey app_id, pushkey
) )
@ -201,34 +188,50 @@ class PusherPool:
p = r p = r
if p: if p:
self._start_pusher(p)
self._start_pushers([p]) @defer.inlineCallbacks
def _start_pushers(self):
"""Start all the pushers
def _start_pushers(self, pushers): Returns:
if not self.start_pushers: Deferred
logger.info("Not starting pushers because they are disabled in the config") """
return pushers = yield self.store.get_all_pushers()
logger.info("Starting %d pushers", len(pushers)) logger.info("Starting %d pushers", len(pushers))
for pusherdict in pushers: for pusherdict in pushers:
try: self._start_pusher(pusherdict)
p = self.pusher_factory.create_pusher(pusherdict)
except Exception:
logger.exception("Couldn't start a pusher: caught Exception")
continue
if p:
appid_pushkey = "%s:%s" % (
pusherdict['app_id'],
pusherdict['pushkey'],
)
byuser = self.pushers.setdefault(pusherdict['user_name'], {})
if appid_pushkey in byuser:
byuser[appid_pushkey].on_stop()
byuser[appid_pushkey] = p
run_in_background(p.on_started)
logger.info("Started pushers") logger.info("Started pushers")
def _start_pusher(self, pusherdict):
"""Start the given pusher
Args:
pusherdict (dict):
Returns:
None
"""
try:
p = self.pusher_factory.create_pusher(pusherdict)
except Exception:
logger.exception("Couldn't start a pusher: caught Exception")
return
if not p:
return
appid_pushkey = "%s:%s" % (
pusherdict['app_id'],
pusherdict['pushkey'],
)
byuser = self.pushers.setdefault(pusherdict['user_name'], {})
if appid_pushkey in byuser:
byuser[appid_pushkey].on_stop()
byuser[appid_pushkey] = p
p.on_started()
@defer.inlineCallbacks @defer.inlineCallbacks
def remove_pusher(self, app_id, pushkey, user_id): def remove_pusher(self, app_id, pushkey, user_id):
appid_pushkey = "%s:%s" % (app_id, pushkey) appid_pushkey = "%s:%s" % (app_id, pushkey)

View File

@ -53,6 +53,7 @@ REQUIREMENTS = {
"pillow>=3.1.2": ["PIL"], "pillow>=3.1.2": ["PIL"],
"pydenticon>=0.2": ["pydenticon"], "pydenticon>=0.2": ["pydenticon"],
"sortedcontainers>=1.4.4": ["sortedcontainers"], "sortedcontainers>=1.4.4": ["sortedcontainers"],
"psutil>=2.0.0": ["psutil>=2.0.0"],
"pysaml2>=3.0.0": ["saml2"], "pysaml2>=3.0.0": ["saml2"],
"pymacaroons-pynacl>=0.9.3": ["pymacaroons"], "pymacaroons-pynacl>=0.9.3": ["pymacaroons"],
"msgpack-python>=0.4.2": ["msgpack"], "msgpack-python>=0.4.2": ["msgpack"],
@ -79,9 +80,6 @@ CONDITIONAL_REQUIREMENTS = {
"matrix-synapse-ldap3": { "matrix-synapse-ldap3": {
"matrix-synapse-ldap3>=0.1": ["ldap_auth_provider"], "matrix-synapse-ldap3>=0.1": ["ldap_auth_provider"],
}, },
"psutil": {
"psutil>=2.0.0": ["psutil>=2.0.0"],
},
"postgres": { "postgres": {
"psycopg2>=2.6": ["psycopg2"] "psycopg2>=2.6": ["psycopg2"]
} }

View File

@ -74,38 +74,11 @@ class ClientDirectoryServer(ClientV1RestServlet):
if room is None: if room is None:
raise SynapseError(400, "Room does not exist") raise SynapseError(400, "Room does not exist")
dir_handler = self.handlers.directory_handler requester = yield self.auth.get_user_by_req(request)
try: yield self.handlers.directory_handler.create_association(
# try to auth as a user requester, room_alias, room_id, servers
requester = yield self.auth.get_user_by_req(request) )
try:
user_id = requester.user.to_string()
yield dir_handler.create_association(
user_id, room_alias, room_id, servers
)
yield dir_handler.send_room_alias_update_event(
requester,
user_id,
room_id
)
except SynapseError as e:
raise e
except Exception:
logger.exception("Failed to create association")
raise
except AuthError:
# try to auth as an application service
service = yield self.auth.get_appservice_by_req(request)
yield dir_handler.create_appservice_association(
service, room_alias, room_id, servers
)
logger.info(
"Application service at %s created alias %s pointing to %s",
service.url,
room_alias.to_string(),
room_id
)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -135,7 +108,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
room_alias = RoomAlias.from_string(room_alias) room_alias = RoomAlias.from_string(room_alias)
yield dir_handler.delete_association( yield dir_handler.delete_association(
requester, user.to_string(), room_alias requester, room_alias
) )
logger.info( logger.info(

View File

@ -33,6 +33,7 @@ from synapse.http.servlet import (
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
) )
from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
@ -409,7 +410,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
room_id=room_id, room_id=room_id,
user_id=requester.user.to_string(), user_id=requester.user.to_string(),
at_token=at_token, at_token=at_token,
types=[(EventTypes.Member, None)], state_filter=StateFilter.from_types([(EventTypes.Member, None)]),
) )
chunk = [] chunk = []

View File

@ -99,7 +99,7 @@ class AuthRestServlet(RestServlet):
cannot be handled in the normal flow (with requests to the same endpoint). cannot be handled in the normal flow (with requests to the same endpoint).
Current use is for web fallback auth. Current use is for web fallback auth.
""" """
PATTERNS = client_v2_patterns("/auth/(?P<stagetype>[\w\.]*)/fallback/web") PATTERNS = client_v2_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
def __init__(self, hs): def __init__(self, hs):
super(AuthRestServlet, self).__init__() super(AuthRestServlet, self).__init__()

View File

@ -25,7 +25,7 @@ from six.moves.urllib import parse as urlparse
import twisted.internet.error import twisted.internet.error
import twisted.web.http import twisted.web.http
from twisted.internet import defer, threads from twisted.internet import defer
from twisted.web.resource import Resource from twisted.web.resource import Resource
from synapse.api.errors import ( from synapse.api.errors import (
@ -36,8 +36,8 @@ from synapse.api.errors import (
) )
from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util import logcontext
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.logcontext import make_deferred_yieldable
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
from synapse.util.stringutils import is_ascii, random_string from synapse.util.stringutils import is_ascii, random_string
@ -492,10 +492,11 @@ class MediaRepository(object):
)) ))
thumbnailer = Thumbnailer(input_path) thumbnailer = Thumbnailer(input_path)
t_byte_source = yield make_deferred_yieldable(threads.deferToThread( t_byte_source = yield logcontext.defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail, self._generate_thumbnail,
thumbnailer, t_width, t_height, t_method, t_type thumbnailer, t_width, t_height, t_method, t_type
)) )
if t_byte_source: if t_byte_source:
try: try:
@ -534,10 +535,11 @@ class MediaRepository(object):
)) ))
thumbnailer = Thumbnailer(input_path) thumbnailer = Thumbnailer(input_path)
t_byte_source = yield make_deferred_yieldable(threads.deferToThread( t_byte_source = yield logcontext.defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail, self._generate_thumbnail,
thumbnailer, t_width, t_height, t_method, t_type thumbnailer, t_width, t_height, t_method, t_type
)) )
if t_byte_source: if t_byte_source:
try: try:
@ -620,15 +622,17 @@ class MediaRepository(object):
for (t_width, t_height, t_type), t_method in iteritems(thumbnails): for (t_width, t_height, t_type), t_method in iteritems(thumbnails):
# Generate the thumbnail # Generate the thumbnail
if t_method == "crop": if t_method == "crop":
t_byte_source = yield make_deferred_yieldable(threads.deferToThread( t_byte_source = yield logcontext.defer_to_thread(
self.hs.get_reactor(),
thumbnailer.crop, thumbnailer.crop,
t_width, t_height, t_type, t_width, t_height, t_type,
)) )
elif t_method == "scale": elif t_method == "scale":
t_byte_source = yield make_deferred_yieldable(threads.deferToThread( t_byte_source = yield logcontext.defer_to_thread(
self.hs.get_reactor(),
thumbnailer.scale, thumbnailer.scale,
t_width, t_height, t_type, t_width, t_height, t_type,
)) )
else: else:
logger.error("Unrecognized method: %r", t_method) logger.error("Unrecognized method: %r", t_method)
continue continue

View File

@ -21,9 +21,10 @@ import sys
import six import six
from twisted.internet import defer, threads from twisted.internet import defer
from twisted.protocols.basic import FileSender from twisted.protocols.basic import FileSender
from synapse.util import logcontext
from synapse.util.file_consumer import BackgroundFileConsumer from synapse.util.file_consumer import BackgroundFileConsumer
from synapse.util.logcontext import make_deferred_yieldable from synapse.util.logcontext import make_deferred_yieldable
@ -64,9 +65,10 @@ class MediaStorage(object):
with self.store_into_file(file_info) as (f, fname, finish_cb): with self.store_into_file(file_info) as (f, fname, finish_cb):
# Write to the main repository # Write to the main repository
yield make_deferred_yieldable(threads.deferToThread( yield logcontext.defer_to_thread(
self.hs.get_reactor(),
_write_file_synchronously, source, f, _write_file_synchronously, source, f,
)) )
yield finish_cb() yield finish_cb()
defer.returnValue(fname) defer.returnValue(fname)

View File

@ -674,7 +674,7 @@ def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
# This splits the paragraph into words, but keeping the # This splits the paragraph into words, but keeping the
# (preceeding) whitespace intact so we can easily concat # (preceeding) whitespace intact so we can easily concat
# words back together. # words back together.
for match in re.finditer("\s*\S+", description): for match in re.finditer(r"\s*\S+", description):
word = match.group() word = match.group()
# Keep adding words while the total length is less than # Keep adding words while the total length is less than

View File

@ -17,9 +17,10 @@ import logging
import os import os
import shutil import shutil
from twisted.internet import defer, threads from twisted.internet import defer
from synapse.config._base import Config from synapse.config._base import Config
from synapse.util import logcontext
from synapse.util.logcontext import run_in_background from synapse.util.logcontext import run_in_background
from .media_storage import FileResponder from .media_storage import FileResponder
@ -120,7 +121,8 @@ class FileStorageProviderBackend(StorageProvider):
if not os.path.exists(dirname): if not os.path.exists(dirname):
os.makedirs(dirname) os.makedirs(dirname)
return threads.deferToThread( return logcontext.defer_to_thread(
self.hs.get_reactor(),
shutil.copyfile, primary_fname, backup_fname, shutil.copyfile, primary_fname, backup_fname,
) )

View File

@ -207,6 +207,7 @@ class HomeServer(object):
logger.info("Setting up.") logger.info("Setting up.")
with self.get_db_conn() as conn: with self.get_db_conn() as conn:
self.datastore = self.DATASTORE_CLASS(conn, self) self.datastore = self.DATASTORE_CLASS(conn, self)
conn.commit()
logger.info("Finished setting up.") logger.info("Finished setting up.")
def get_reactor(self): def get_reactor(self):

View File

@ -19,13 +19,14 @@ from collections import namedtuple
from six import iteritems, itervalues from six import iteritems, itervalues
import attr
from frozendict import frozendict from frozendict import frozendict
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, RoomVersions from synapse.api.constants import EventTypes, RoomVersions
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.state import v1 from synapse.state import v1, v2
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches import get_cache_factor_for from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
@ -372,15 +373,10 @@ class StateHandler(object):
result = yield self._state_resolution_handler.resolve_state_groups( result = yield self._state_resolution_handler.resolve_state_groups(
room_id, room_version, state_groups_ids, None, room_id, room_version, state_groups_ids, None,
self._state_map_factory, state_res_store=StateResolutionStore(self.store),
) )
defer.returnValue(result) defer.returnValue(result)
def _state_map_factory(self, ev_ids):
return self.store.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def resolve_events(self, room_version, state_sets, event): def resolve_events(self, room_version, state_sets, event):
logger.info( logger.info(
@ -398,10 +394,10 @@ class StateHandler(object):
} }
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_factory( new_state = yield resolve_events_with_store(
room_version, state_set_ids, room_version, state_set_ids,
event_map=state_map, event_map=state_map,
state_map_factory=self._state_map_factory state_res_store=StateResolutionStore(self.store),
) )
new_state = { new_state = {
@ -436,7 +432,7 @@ class StateResolutionHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def resolve_state_groups( def resolve_state_groups(
self, room_id, room_version, state_groups_ids, event_map, state_map_factory, self, room_id, room_version, state_groups_ids, event_map, state_res_store,
): ):
"""Resolves conflicts between a set of state groups """Resolves conflicts between a set of state groups
@ -454,9 +450,11 @@ class StateResolutionHandler(object):
a dict from event_id to event, for any events that we happen to a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be have in flight (eg, those currently being persisted). This will be
used as a starting point fof finding the state we need; any missing used as a starting point fof finding the state we need; any missing
events will be requested via state_map_factory. events will be requested via state_res_store.
If None, all events will be fetched via state_map_factory. If None, all events will be fetched via state_res_store.
state_res_store (StateResolutionStore)
Returns: Returns:
Deferred[_StateCacheEntry]: resolved state Deferred[_StateCacheEntry]: resolved state
@ -480,10 +478,10 @@ class StateResolutionHandler(object):
# start by assuming we won't have any conflicted state, and build up the new # start by assuming we won't have any conflicted state, and build up the new
# state map by iterating through the state groups. If we discover a conflict, # state map by iterating through the state groups. If we discover a conflict,
# we give up and instead use `resolve_events_with_factory`. # we give up and instead use `resolve_events_with_store`.
# #
# XXX: is this actually worthwhile, or should we just let # XXX: is this actually worthwhile, or should we just let
# resolve_events_with_factory do it? # resolve_events_with_store do it?
new_state = {} new_state = {}
conflicted_state = False conflicted_state = False
for st in itervalues(state_groups_ids): for st in itervalues(state_groups_ids):
@ -498,11 +496,11 @@ class StateResolutionHandler(object):
if conflicted_state: if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id) logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_factory( new_state = yield resolve_events_with_store(
room_version, room_version,
list(itervalues(state_groups_ids)), list(itervalues(state_groups_ids)),
event_map=event_map, event_map=event_map,
state_map_factory=state_map_factory, state_res_store=state_res_store,
) )
# if the new state matches any of the input state groups, we can # if the new state matches any of the input state groups, we can
@ -583,7 +581,7 @@ def _make_state_cache_entry(
) )
def resolve_events_with_factory(room_version, state_sets, event_map, state_map_factory): def resolve_events_with_store(room_version, state_sets, event_map, state_res_store):
""" """
Args: Args:
room_version(str): Version of the room room_version(str): Version of the room
@ -599,17 +597,19 @@ def resolve_events_with_factory(room_version, state_sets, event_map, state_map_f
If None, all events will be fetched via state_map_factory. If None, all events will be fetched via state_map_factory.
state_map_factory(func): will be called state_res_store (StateResolutionStore)
with a list of event_ids that are needed, and should return with
a Deferred of dict of event_id to event.
Returns Returns
Deferred[dict[(str, str), str]]: Deferred[dict[(str, str), str]]:
a map from (type, state_key) to event_id. a map from (type, state_key) to event_id.
""" """
if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,): if room_version == RoomVersions.V1:
return v1.resolve_events_with_factory( return v1.resolve_events_with_store(
state_sets, event_map, state_map_factory, state_sets, event_map, state_res_store.get_events,
)
elif room_version == RoomVersions.VDH_TEST:
return v2.resolve_events_with_store(
state_sets, event_map, state_res_store,
) )
else: else:
# This should only happen if we added a version but forgot to add it to # This should only happen if we added a version but forgot to add it to
@ -617,3 +617,54 @@ def resolve_events_with_factory(room_version, state_sets, event_map, state_map_f
raise Exception( raise Exception(
"No state resolution algorithm defined for version %r" % (room_version,) "No state resolution algorithm defined for version %r" % (room_version,)
) )
@attr.s
class StateResolutionStore(object):
"""Interface that allows state resolution algorithms to access the database
in well defined way.
Args:
store (DataStore)
"""
store = attr.ib()
def get_events(self, event_ids, allow_rejected=False):
"""Get events from the database
Args:
event_ids (list): The event_ids of the events to fetch
allow_rejected (bool): If True return rejected events.
Returns:
Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
"""
return self.store.get_events(
event_ids,
check_redacted=False,
get_prev_content=False,
allow_rejected=allow_rejected,
)
def get_auth_chain(self, event_ids):
"""Gets the full auth chain for a set of events (including rejected
events).
Includes the given event IDs in the result.
Note that:
1. All events must be state events.
2. For v1 rooms this may not have the full auth chain in the
presence of rejected events
Args:
event_ids (list): The event IDs of the events to fetch the auth
chain for. Must be state events.
Returns:
Deferred[list[str]]: List of event IDs of the auth chain.
"""
return self.store.get_auth_chain_ids(event_ids, include_given=True)

View File

@ -31,7 +31,7 @@ POWER_KEY = (EventTypes.PowerLevels, "")
@defer.inlineCallbacks @defer.inlineCallbacks
def resolve_events_with_factory(state_sets, event_map, state_map_factory): def resolve_events_with_store(state_sets, event_map, state_map_factory):
""" """
Args: Args:
state_sets(list): List of dicts of (type, state_key) -> event_id, state_sets(list): List of dicts of (type, state_key) -> event_id,

544
synapse/state/v2.py Normal file
View File

@ -0,0 +1,544 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import heapq
import itertools
import logging
from six import iteritems, itervalues
from twisted.internet import defer
from synapse import event_auth
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
logger = logging.getLogger(__name__)
@defer.inlineCallbacks
def resolve_events_with_store(state_sets, event_map, state_res_store):
"""Resolves the state using the v2 state resolution algorithm
Args:
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
event_map(dict[str,FrozenEvent]|None):
a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be
used as a starting point fof finding the state we need; any missing
events will be requested via state_res_store.
If None, all events will be fetched via state_res_store.
state_res_store (StateResolutionStore)
Returns
Deferred[dict[(str, str), str]]:
a map from (type, state_key) to event_id.
"""
logger.debug("Computing conflicted state")
# First split up the un/conflicted state
unconflicted_state, conflicted_state = _seperate(state_sets)
if not conflicted_state:
defer.returnValue(unconflicted_state)
logger.debug("%d conflicted state entries", len(conflicted_state))
logger.debug("Calculating auth chain difference")
# Also fetch all auth events that appear in only some of the state sets'
# auth chains.
auth_diff = yield _get_auth_chain_difference(
state_sets, event_map, state_res_store,
)
full_conflicted_set = set(itertools.chain(
itertools.chain.from_iterable(itervalues(conflicted_state)),
auth_diff,
))
events = yield state_res_store.get_events([
eid for eid in full_conflicted_set
if eid not in event_map
], allow_rejected=True)
event_map.update(events)
full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map)
logger.debug("%d full_conflicted_set entries", len(full_conflicted_set))
# Get and sort all the power events (kicks/bans/etc)
power_events = (
eid for eid in full_conflicted_set
if _is_power_event(event_map[eid])
)
sorted_power_events = yield _reverse_topological_power_sort(
power_events,
event_map,
state_res_store,
full_conflicted_set,
)
logger.debug("sorted %d power events", len(sorted_power_events))
# Now sequentially auth each one
resolved_state = yield _iterative_auth_checks(
sorted_power_events, unconflicted_state, event_map,
state_res_store,
)
logger.debug("resolved power events")
# OK, so we've now resolved the power events. Now sort the remaining
# events using the mainline of the resolved power level.
leftover_events = [
ev_id
for ev_id in full_conflicted_set
if ev_id not in sorted_power_events
]
logger.debug("sorting %d remaining events", len(leftover_events))
pl = resolved_state.get((EventTypes.PowerLevels, ""), None)
leftover_events = yield _mainline_sort(
leftover_events, pl, event_map, state_res_store,
)
logger.debug("resolving remaining events")
resolved_state = yield _iterative_auth_checks(
leftover_events, resolved_state, event_map,
state_res_store,
)
logger.debug("resolved")
# We make sure that unconflicted state always still applies.
resolved_state.update(unconflicted_state)
logger.debug("done")
defer.returnValue(resolved_state)
@defer.inlineCallbacks
def _get_power_level_for_sender(event_id, event_map, state_res_store):
"""Return the power level of the sender of the given event according to
their auth events.
Args:
event_id (str)
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
Returns:
Deferred[int]
"""
event = yield _get_event(event_id, event_map, state_res_store)
pl = None
for aid, _ in event.auth_events:
aev = yield _get_event(aid, event_map, state_res_store)
if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
pl = aev
break
if pl is None:
# Couldn't find power level. Check if they're the creator of the room
for aid, _ in event.auth_events:
aev = yield _get_event(aid, event_map, state_res_store)
if (aev.type, aev.state_key) == (EventTypes.Create, ""):
if aev.content.get("creator") == event.sender:
defer.returnValue(100)
break
defer.returnValue(0)
level = pl.content.get("users", {}).get(event.sender)
if level is None:
level = pl.content.get("users_default", 0)
if level is None:
defer.returnValue(0)
else:
defer.returnValue(int(level))
@defer.inlineCallbacks
def _get_auth_chain_difference(state_sets, event_map, state_res_store):
"""Compare the auth chains of each state set and return the set of events
that only appear in some but not all of the auth chains.
Args:
state_sets (list)
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
Returns:
Deferred[set[str]]: Set of event IDs
"""
common = set(itervalues(state_sets[0])).intersection(
*(itervalues(s) for s in state_sets[1:])
)
auth_sets = []
for state_set in state_sets:
auth_ids = set(
eid
for key, eid in iteritems(state_set)
if (key[0] in (
EventTypes.Member,
EventTypes.ThirdPartyInvite,
) or key in (
(EventTypes.PowerLevels, ''),
(EventTypes.Create, ''),
(EventTypes.JoinRules, ''),
)) and eid not in common
)
auth_chain = yield state_res_store.get_auth_chain(auth_ids)
auth_ids.update(auth_chain)
auth_sets.append(auth_ids)
intersection = set(auth_sets[0]).intersection(*auth_sets[1:])
union = set().union(*auth_sets)
defer.returnValue(union - intersection)
def _seperate(state_sets):
"""Return the unconflicted and conflicted state. This is different than in
the original algorithm, as this defines a key to be conflicted if one of
the state sets doesn't have that key.
Args:
state_sets (list)
Returns:
tuple[dict, dict]: A tuple of unconflicted and conflicted state. The
conflicted state dict is a map from type/state_key to set of event IDs
"""
unconflicted_state = {}
conflicted_state = {}
for key in set(itertools.chain.from_iterable(state_sets)):
event_ids = set(state_set.get(key) for state_set in state_sets)
if len(event_ids) == 1:
unconflicted_state[key] = event_ids.pop()
else:
event_ids.discard(None)
conflicted_state[key] = event_ids
return unconflicted_state, conflicted_state
def _is_power_event(event):
"""Return whether or not the event is a "power event", as defined by the
v2 state resolution algorithm
Args:
event (FrozenEvent)
Returns:
boolean
"""
if (event.type, event.state_key) in (
(EventTypes.PowerLevels, ""),
(EventTypes.JoinRules, ""),
(EventTypes.Create, ""),
):
return True
if event.type == EventTypes.Member:
if event.membership in ('leave', 'ban'):
return event.sender != event.state_key
return False
@defer.inlineCallbacks
def _add_event_and_auth_chain_to_graph(graph, event_id, event_map,
state_res_store, auth_diff):
"""Helper function for _reverse_topological_power_sort that add the event
and its auth chain (that is in the auth diff) to the graph
Args:
graph (dict[str, set[str]]): A map from event ID to the events auth
event IDs
event_id (str): Event to add to the graph
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
auth_diff (set[str]): Set of event IDs that are in the auth difference.
"""
state = [event_id]
while state:
eid = state.pop()
graph.setdefault(eid, set())
event = yield _get_event(eid, event_map, state_res_store)
for aid, _ in event.auth_events:
if aid in auth_diff:
if aid not in graph:
state.append(aid)
graph.setdefault(eid, set()).add(aid)
@defer.inlineCallbacks
def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_diff):
"""Returns a list of the event_ids sorted by reverse topological ordering,
and then by power level and origin_server_ts
Args:
event_ids (list[str]): The events to sort
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
auth_diff (set[str]): Set of event IDs that are in the auth difference.
Returns:
Deferred[list[str]]: The sorted list
"""
graph = {}
for event_id in event_ids:
yield _add_event_and_auth_chain_to_graph(
graph, event_id, event_map, state_res_store, auth_diff,
)
event_to_pl = {}
for event_id in graph:
pl = yield _get_power_level_for_sender(event_id, event_map, state_res_store)
event_to_pl[event_id] = pl
def _get_power_order(event_id):
ev = event_map[event_id]
pl = event_to_pl[event_id]
return -pl, ev.origin_server_ts, event_id
# Note: graph is modified during the sort
it = lexicographical_topological_sort(
graph,
key=_get_power_order,
)
sorted_events = list(it)
defer.returnValue(sorted_events)
@defer.inlineCallbacks
def _iterative_auth_checks(event_ids, base_state, event_map, state_res_store):
"""Sequentially apply auth checks to each event in given list, updating the
state as it goes along.
Args:
event_ids (list[str]): Ordered list of events to apply auth checks to
base_state (dict[tuple[str, str], str]): The set of state to start with
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
Returns:
Deferred[dict[tuple[str, str], str]]: Returns the final updated state
"""
resolved_state = base_state.copy()
for event_id in event_ids:
event = event_map[event_id]
auth_events = {}
for aid, _ in event.auth_events:
ev = yield _get_event(aid, event_map, state_res_store)
if ev.rejected_reason is None:
auth_events[(ev.type, ev.state_key)] = ev
for key in event_auth.auth_types_for_event(event):
if key in resolved_state:
ev_id = resolved_state[key]
ev = yield _get_event(ev_id, event_map, state_res_store)
if ev.rejected_reason is None:
auth_events[key] = event_map[ev_id]
try:
event_auth.check(
event, auth_events,
do_sig_check=False,
do_size_check=False
)
resolved_state[(event.type, event.state_key)] = event_id
except AuthError:
pass
defer.returnValue(resolved_state)
@defer.inlineCallbacks
def _mainline_sort(event_ids, resolved_power_event_id, event_map,
state_res_store):
"""Returns a sorted list of event_ids sorted by mainline ordering based on
the given event resolved_power_event_id
Args:
event_ids (list[str]): Events to sort
resolved_power_event_id (str): The final resolved power level event ID
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
Returns:
Deferred[list[str]]: The sorted list
"""
mainline = []
pl = resolved_power_event_id
while pl:
mainline.append(pl)
pl_ev = yield _get_event(pl, event_map, state_res_store)
auth_events = pl_ev.auth_events
pl = None
for aid, _ in auth_events:
ev = yield _get_event(aid, event_map, state_res_store)
if (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""):
pl = aid
break
mainline_map = {ev_id: i + 1 for i, ev_id in enumerate(reversed(mainline))}
event_ids = list(event_ids)
order_map = {}
for ev_id in event_ids:
depth = yield _get_mainline_depth_for_event(
event_map[ev_id], mainline_map,
event_map, state_res_store,
)
order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id)
event_ids.sort(key=lambda ev_id: order_map[ev_id])
defer.returnValue(event_ids)
@defer.inlineCallbacks
def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_store):
"""Get the mainline depths for the given event based on the mainline map
Args:
event (FrozenEvent)
mainline_map (dict[str, int]): Map from event_id to mainline depth for
events in the mainline.
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
Returns:
Deferred[int]
"""
# We do an iterative search, replacing `event with the power level in its
# auth events (if any)
while event:
depth = mainline_map.get(event.event_id)
if depth is not None:
defer.returnValue(depth)
auth_events = event.auth_events
event = None
for aid, _ in auth_events:
aev = yield _get_event(aid, event_map, state_res_store)
if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
event = aev
break
# Didn't find a power level auth event, so we just return 0
defer.returnValue(0)
@defer.inlineCallbacks
def _get_event(event_id, event_map, state_res_store):
"""Helper function to look up event in event_map, falling back to looking
it up in the store
Args:
event_id (str)
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
Returns:
Deferred[FrozenEvent]
"""
if event_id not in event_map:
events = yield state_res_store.get_events([event_id], allow_rejected=True)
event_map.update(events)
defer.returnValue(event_map[event_id])
def lexicographical_topological_sort(graph, key):
"""Performs a lexicographic reverse topological sort on the graph.
This returns a reverse topological sort (i.e. if node A references B then B
appears before A in the sort), with ties broken lexicographically based on
return value of the `key` function.
NOTE: `graph` is modified during the sort.
Args:
graph (dict[str, set[str]]): A representation of the graph where each
node is a key in the dict and its value are the nodes edges.
key (func): A function that takes a node and returns a value that is
comparable and used to order nodes
Yields:
str: The next node in the topological sort
"""
# Note, this is basically Kahn's algorithm except we look at nodes with no
# outgoing edges, c.f.
# https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
outdegree_map = graph
reverse_graph = {}
# Lists of nodes with zero out degree. Is actually a tuple of
# `(key(node), node)` so that sorting does the right thing
zero_outdegree = []
for node, edges in iteritems(graph):
if len(edges) == 0:
zero_outdegree.append((key(node), node))
reverse_graph.setdefault(node, set())
for edge in edges:
reverse_graph.setdefault(edge, set()).add(node)
# heapq is a built in implementation of a sorted queue.
heapq.heapify(zero_outdegree)
while zero_outdegree:
_, node = heapq.heappop(zero_outdegree)
for parent in reverse_graph[node]:
out = outdegree_map[parent]
out.discard(node)
if len(out) == 0:
heapq.heappush(zero_outdegree, (key(parent), parent))
yield node

View File

@ -18,7 +18,7 @@ import threading
import time import time
from six import PY2, iteritems, iterkeys, itervalues from six import PY2, iteritems, iterkeys, itervalues
from six.moves import intern, range from six.moves import builtins, intern, range
from canonicaljson import json from canonicaljson import json
from prometheus_client import Histogram from prometheus_client import Histogram
@ -1233,7 +1233,7 @@ def db_to_json(db_content):
# psycopg2 on Python 2 returns buffer objects, which we need to cast to # psycopg2 on Python 2 returns buffer objects, which we need to cast to
# bytes to decode # bytes to decode
if PY2 and isinstance(db_content, buffer): if PY2 and isinstance(db_content, builtins.buffer):
db_content = bytes(db_content) db_content = bytes(db_content)
# Decode it to a Unicode string before feeding it to json.loads, so we # Decode it to a Unicode string before feeding it to json.loads, so we

View File

@ -90,7 +90,7 @@ class DirectoryWorkerStore(SQLBaseStore):
class DirectoryStore(DirectoryWorkerStore): class DirectoryStore(DirectoryWorkerStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def create_room_alias_association(self, room_alias, room_id, servers, creator=None): def create_room_alias_association(self, room_alias, room_id, servers, creator=None):
""" Creates an associatin between a room alias and room_id/servers """ Creates an association between a room alias and room_id/servers
Args: Args:
room_alias (RoomAlias) room_alias (RoomAlias)

View File

@ -34,6 +34,7 @@ from synapse.api.errors import SynapseError
from synapse.events import EventBase # noqa: F401 from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.state import StateResolutionStore
from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.event_federation import EventFederationStore from synapse.storage.event_federation import EventFederationStore
from synapse.storage.events_worker import EventsWorkerStore from synapse.storage.events_worker import EventsWorkerStore
@ -733,11 +734,6 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
# Ok, we need to defer to the state handler to resolve our state sets. # Ok, we need to defer to the state handler to resolve our state sets.
def get_events(ev_ids):
return self.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
)
state_groups = { state_groups = {
sg: state_groups_map[sg] for sg in new_state_groups sg: state_groups_map[sg] for sg in new_state_groups
} }
@ -747,7 +743,8 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
logger.debug("calling resolve_state_groups from preserve_events") logger.debug("calling resolve_state_groups from preserve_events")
res = yield self._state_resolution_handler.resolve_state_groups( res = yield self._state_resolution_handler.resolve_state_groups(
room_id, room_version, state_groups, events_map, get_events room_id, room_version, state_groups, events_map,
state_res_store=StateResolutionStore(self)
) )
defer.returnValue((res.state, None)) defer.returnValue((res.state, None))
@ -856,6 +853,27 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
# Insert into event_to_state_groups. # Insert into event_to_state_groups.
self._store_event_state_mappings_txn(txn, events_and_contexts) self._store_event_state_mappings_txn(txn, events_and_contexts)
# We want to store event_auth mappings for rejected events, as they're
# used in state res v2.
# This is only necessary if the rejected event appears in an accepted
# event's auth chain, but its easier for now just to store them (and
# it doesn't take much storage compared to storing the entire event
# anyway).
self._simple_insert_many_txn(
txn,
table="event_auth",
values=[
{
"event_id": event.event_id,
"room_id": event.room_id,
"auth_id": auth_id,
}
for event, _ in events_and_contexts
for auth_id, _ in event.auth_events
if event.is_state()
],
)
# _store_rejected_events_txn filters out any events which were # _store_rejected_events_txn filters out any events which were
# rejected, and returns the filtered list. # rejected, and returns the filtered list.
events_and_contexts = self._store_rejected_events_txn( events_and_contexts = self._store_rejected_events_txn(
@ -1331,21 +1349,6 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
txn, event.room_id, event.redacts txn, event.room_id, event.redacts
) )
self._simple_insert_many_txn(
txn,
table="event_auth",
values=[
{
"event_id": event.event_id,
"room_id": event.room_id,
"auth_id": auth_id,
}
for event, _ in events_and_contexts
for auth_id, _ in event.auth_events
if event.is_state()
],
)
# Update the event_forward_extremities, event_backward_extremities and # Update the event_forward_extremities, event_backward_extremities and
# event_edges tables. # event_edges tables.
self._handle_mult_prev_events( self._handle_mult_prev_events(
@ -2068,7 +2071,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
for sg in remaining_state_groups: for sg in remaining_state_groups:
logger.info("[purge] de-delta-ing remaining state group %s", sg) logger.info("[purge] de-delta-ing remaining state group %s", sg)
curr_state = self._get_state_groups_from_groups_txn( curr_state = self._get_state_groups_from_groups_txn(
txn, [sg], types=None txn, [sg],
) )
curr_state = curr_state[sg] curr_state = curr_state[sg]

View File

@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
# py2 sqlite has buffer hardcoded as only binary type, so we must use it, # py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview # despite being deprecated and removed in favor of memoryview
if six.PY2: if six.PY2:
db_binary_type = buffer db_binary_type = six.moves.builtins.buffer
else: else:
db_binary_type = memoryview db_binary_type = memoryview

View File

@ -33,19 +33,29 @@ class MonthlyActiveUsersStore(SQLBaseStore):
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.hs = hs self.hs = hs
self.reserved_users = () self.reserved_users = ()
# Do not add more reserved users than the total allowable number
self._initialise_reserved_users(
dbconn.cursor(),
hs.config.mau_limits_reserved_threepids[:self.hs.config.max_mau_value],
)
@defer.inlineCallbacks def _initialise_reserved_users(self, txn, threepids):
def initialise_reserved_users(self, threepids): """Ensures that reserved threepids are accounted for in the MAU table, should
store = self.hs.get_datastore() be called on start up.
Args:
txn (cursor):
threepids (list[dict]): List of threepid dicts to reserve
"""
reserved_user_list = [] reserved_user_list = []
# Do not add more reserved users than the total allowable number for tp in threepids:
for tp in threepids[:self.hs.config.max_mau_value]: user_id = self.get_user_id_by_threepid_txn(
user_id = yield store.get_user_id_by_threepid( txn,
tp["medium"], tp["address"] tp["medium"], tp["address"]
) )
if user_id: if user_id:
yield self.upsert_monthly_active_user(user_id) self.upsert_monthly_active_user_txn(txn, user_id)
reserved_user_list.append(user_id) reserved_user_list.append(user_id)
else: else:
logger.warning( logger.warning(
@ -55,8 +65,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def reap_monthly_active_users(self): def reap_monthly_active_users(self):
""" """Cleans out monthly active user table to ensure that no stale
Cleans out monthly active user table to ensure that no stale
entries exist. entries exist.
Returns: Returns:
@ -165,19 +174,44 @@ class MonthlyActiveUsersStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def upsert_monthly_active_user(self, user_id): def upsert_monthly_active_user(self, user_id):
"""Updates or inserts the user into the monthly active user table, which
is used to track the current MAU usage of the server
Args:
user_id (str): user to add/update
""" """
Updates or inserts monthly active user member is_insert = yield self.runInteraction(
Arguments: "upsert_monthly_active_user", self.upsert_monthly_active_user_txn,
user_id (str): user to add/update user_id
Deferred[bool]: True if a new entry was created, False if an )
existing one was updated.
if is_insert:
self.user_last_seen_monthly_active.invalidate((user_id,))
self.get_monthly_active_count.invalidate(())
def upsert_monthly_active_user_txn(self, txn, user_id):
"""Updates or inserts monthly active user member
Note that, after calling this method, it will generally be necessary
to invalidate the caches on user_last_seen_monthly_active and
get_monthly_active_count. We can't do that here, because we are running
in a database thread rather than the main thread, and we can't call
txn.call_after because txn may not be a LoggingTransaction.
Args:
txn (cursor):
user_id (str): user to add/update
Returns:
bool: True if a new entry was created, False if an
existing one was updated.
""" """
# Am consciously deciding to lock the table on the basis that is ought # Am consciously deciding to lock the table on the basis that is ought
# never be a big table and alternative approaches (batching multiple # never be a big table and alternative approaches (batching multiple
# upserts into a single txn) introduced a lot of extra complexity. # upserts into a single txn) introduced a lot of extra complexity.
# See https://github.com/matrix-org/synapse/issues/3854 for more # See https://github.com/matrix-org/synapse/issues/3854 for more
is_insert = yield self._simple_upsert( is_insert = self._simple_upsert_txn(
desc="upsert_monthly_active_user", txn,
table="monthly_active_users", table="monthly_active_users",
keyvalues={ keyvalues={
"user_id": user_id, "user_id": user_id,
@ -186,9 +220,8 @@ class MonthlyActiveUsersStore(SQLBaseStore):
"timestamp": int(self._clock.time_msec()), "timestamp": int(self._clock.time_msec()),
}, },
) )
if is_insert:
self.user_last_seen_monthly_active.invalidate((user_id,)) return is_insert
self.get_monthly_active_count.invalidate(())
@cached(num_args=1) @cached(num_args=1)
def user_last_seen_monthly_active(self, user_id): def user_last_seen_monthly_active(self, user_id):

View File

@ -29,7 +29,7 @@ from ._base import SQLBaseStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if six.PY2: if six.PY2:
db_binary_type = buffer db_binary_type = six.moves.builtins.buffer
else: else:
db_binary_type = memoryview db_binary_type = memoryview

View File

@ -474,17 +474,44 @@ class RegistrationStore(RegistrationWorkerStore,
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_id_by_threepid(self, medium, address): def get_user_id_by_threepid(self, medium, address):
ret = yield self._simple_select_one( """Returns user id from threepid
Args:
medium (str): threepid medium e.g. email
address (str): threepid address e.g. me@example.com
Returns:
Deferred[str|None]: user id or None if no user id/threepid mapping exists
"""
user_id = yield self.runInteraction(
"get_user_id_by_threepid", self.get_user_id_by_threepid_txn,
medium, address
)
defer.returnValue(user_id)
def get_user_id_by_threepid_txn(self, txn, medium, address):
"""Returns user id from threepid
Args:
txn (cursor):
medium (str): threepid medium e.g. email
address (str): threepid address e.g. me@example.com
Returns:
str|None: user id or None if no user id/threepid mapping exists
"""
ret = self._simple_select_one_txn(
txn,
"user_threepids", "user_threepids",
{ {
"medium": medium, "medium": medium,
"address": address "address": address
}, },
['user_id'], True, 'get_user_id_by_threepid' ['user_id'], True
) )
if ret: if ret:
defer.returnValue(ret['user_id']) return ret['user_id']
defer.returnValue(None) return None
def user_delete_threepid(self, user_id, medium, address): def user_delete_threepid(self, user_id, medium, address):
return self._simple_delete( return self._simple_delete(
@ -567,7 +594,7 @@ class RegistrationStore(RegistrationWorkerStore,
def _find_next_generated_user_id(txn): def _find_next_generated_user_id(txn):
txn.execute("SELECT name FROM users") txn.execute("SELECT name FROM users")
regex = re.compile("^@(\d+):") regex = re.compile(r"^@(\d+):")
found = set() found = set()

View File

@ -27,7 +27,7 @@ from ._base import SQLBaseStore
# py2 sqlite has buffer hardcoded as only binary type, so we must use it, # py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview # despite being deprecated and removed in favor of memoryview
if six.PY2: if six.PY2:
db_binary_type = buffer db_binary_type = six.moves.builtins.buffer
else: else:
db_binary_type = memoryview db_binary_type = memoryview

File diff suppressed because it is too large Load Diff

View File

@ -30,7 +30,7 @@ from ._base import SQLBaseStore, db_to_json
# py2 sqlite has buffer hardcoded as only binary type, so we must use it, # py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview # despite being deprecated and removed in favor of memoryview
if six.PY2: if six.PY2:
db_binary_type = buffer db_binary_type = six.moves.builtins.buffer
else: else:
db_binary_type = memoryview db_binary_type = memoryview

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import re
from itertools import islice from itertools import islice
import attr import attr
@ -138,3 +139,27 @@ def log_failure(failure, msg, consumeErrors=True):
if not consumeErrors: if not consumeErrors:
return failure return failure
def glob_to_regex(glob):
"""Converts a glob to a compiled regex object.
The regex is anchored at the beginning and end of the string.
Args:
glob (str)
Returns:
re.RegexObject
"""
res = ''
for c in glob:
if c == '*':
res = res + '.*'
elif c == '?':
res = res + '.'
else:
res = res + re.escape(c)
# \A anchors at start of string, \Z at end of string
return re.compile(r"\A" + res + r"\Z", re.IGNORECASE)

View File

@ -15,6 +15,8 @@
import logging import logging
from six import integer_types
from sortedcontainers import SortedDict from sortedcontainers import SortedDict
from synapse.util import caches from synapse.util import caches
@ -47,7 +49,7 @@ class StreamChangeCache(object):
def has_entity_changed(self, entity, stream_pos): def has_entity_changed(self, entity, stream_pos):
"""Returns True if the entity may have been updated since stream_pos """Returns True if the entity may have been updated since stream_pos
""" """
assert type(stream_pos) is int or type(stream_pos) is long assert type(stream_pos) in integer_types
if stream_pos < self._earliest_known_stream_pos: if stream_pos < self._earliest_known_stream_pos:
self.metrics.inc_misses() self.metrics.inc_misses()

View File

@ -25,7 +25,7 @@ See doc/log_contexts.rst for details on how this works.
import logging import logging
import threading import threading
from twisted.internet import defer from twisted.internet import defer, threads
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -562,58 +562,76 @@ def _set_context_cb(result, context):
return result return result
# modules to ignore in `logcontext_tracer` def defer_to_thread(reactor, f, *args, **kwargs):
_to_ignore = [
"synapse.util.logcontext",
"synapse.http.server",
"synapse.storage._base",
"synapse.util.async_helpers",
]
def logcontext_tracer(frame, event, arg):
"""A tracer that logs whenever a logcontext "unexpectedly" changes within
a function. Probably inaccurate.
Use by calling `sys.settrace(logcontext_tracer)` in the main thread.
""" """
if event == 'call': Calls the function `f` using a thread from the reactor's default threadpool and
name = frame.f_globals["__name__"] returns the result as a Deferred.
if name.startswith("synapse"):
if name == "synapse.util.logcontext":
if frame.f_code.co_name in ["__enter__", "__exit__"]:
tracer = frame.f_back.f_trace
if tracer:
tracer.just_changed = True
tracer = frame.f_trace Creates a new logcontext for `f`, which is created as a child of the current
if tracer: logcontext (so its CPU usage metrics will get attributed to the current
return tracer logcontext). `f` should preserve the logcontext it is given.
if not any(name.startswith(ig) for ig in _to_ignore): The result deferred follows the Synapse logcontext rules: you should `yield`
return LineTracer() on it.
Args:
reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread
the Deferred will be invoked, and whose threadpool we should use for the
function.
Normally this will be hs.get_reactor().
f (callable): The function to call.
args: positional arguments to pass to f.
kwargs: keyword arguments to pass to f.
Returns:
Deferred: A Deferred which fires a callback with the result of `f`, or an
errback if `f` throws an exception.
"""
return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs)
class LineTracer(object): def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
__slots__ = ["context", "just_changed"] """
A wrapper for twisted.internet.threads.deferToThreadpool, which handles
logcontexts correctly.
def __init__(self): Calls the function `f` using a thread from the given threadpool and returns
self.context = LoggingContext.current_context() the result as a Deferred.
self.just_changed = False
def __call__(self, frame, event, arg): Creates a new logcontext for `f`, which is created as a child of the current
if event in 'line': logcontext (so its CPU usage metrics will get attributed to the current
if self.just_changed: logcontext). `f` should preserve the logcontext it is given.
self.context = LoggingContext.current_context()
self.just_changed = False
else:
c = LoggingContext.current_context()
if c != self.context:
logger.info(
"Context changed! %s -> %s, %s, %s",
self.context, c,
frame.f_code.co_filename, frame.f_lineno
)
self.context = c
return self The result deferred follows the Synapse logcontext rules: you should `yield`
on it.
Args:
reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread
the Deferred will be invoked. Normally this will be hs.get_reactor().
threadpool (twisted.python.threadpool.ThreadPool): The threadpool to use for
running `f`. Normally this will be hs.get_reactor().getThreadPool().
f (callable): The function to call.
args: positional arguments to pass to f.
kwargs: keyword arguments to pass to f.
Returns:
Deferred: A Deferred which fires a callback with the result of `f`, or an
errback if `f` throws an exception.
"""
logcontext = LoggingContext.current_context()
def g():
with LoggingContext(parent_context=logcontext):
return f(*args, **kwargs)
return make_deferred_yieldable(
threads.deferToThreadPool(reactor, threadpool, g)
)

View File

@ -23,6 +23,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.storage.state import StateFilter
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -72,7 +73,7 @@ def filter_events_for_client(store, user_id, events, is_peeking=False,
) )
event_id_to_state = yield store.get_state_for_events( event_id_to_state = yield store.get_state_for_events(
frozenset(e.event_id for e in events), frozenset(e.event_id for e in events),
types=types, state_filter=StateFilter.from_types(types),
) )
ignore_dict_content = yield store.get_global_account_data_by_type_for_user( ignore_dict_content = yield store.get_global_account_data_by_type_for_user(
@ -273,8 +274,8 @@ def filter_events_for_server(store, server_name, events):
# need to check membership (as we know the server is in the room). # need to check membership (as we know the server is in the room).
event_to_state_ids = yield store.get_state_ids_for_events( event_to_state_ids = yield store.get_state_ids_for_events(
frozenset(e.event_id for e in events), frozenset(e.event_id for e in events),
types=( state_filter=StateFilter.from_types(
(EventTypes.RoomHistoryVisibility, ""), types=((EventTypes.RoomHistoryVisibility, ""),),
) )
) )
@ -314,9 +315,11 @@ def filter_events_for_server(store, server_name, events):
# of the history vis and membership state at those events. # of the history vis and membership state at those events.
event_to_state_ids = yield store.get_state_ids_for_events( event_to_state_ids = yield store.get_state_ids_for_events(
frozenset(e.event_id for e in events), frozenset(e.event_id for e in events),
types=( state_filter=StateFilter.from_types(
(EventTypes.RoomHistoryVisibility, ""), types=(
(EventTypes.Member, None), (EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, None),
),
) )
) )

Some files were not shown because too many files have changed in this diff Show More