Merge branch 'develop' of github.com:matrix-org/synapse into release-v0.13.0

This commit is contained in:
Erik Johnston 2016-02-10 14:12:48 +00:00
commit e66d0bd03a
91 changed files with 1796 additions and 1101 deletions

View File

@ -51,3 +51,6 @@ Steven Hammerton <steven.hammerton at openmarket.com>
Mads Robin Christensen <mads at v42 dot dk> Mads Robin Christensen <mads at v42 dot dk>
* CentOS 7 installation instructions. * CentOS 7 installation instructions.
Florent Violleau <floviolleau at gmail dot com>
* Add Raspberry Pi installation instructions and general troubleshooting items

View File

@ -125,6 +125,15 @@ Installing prerequisites on Mac OS X::
sudo easy_install pip sudo easy_install pip
sudo pip install virtualenv sudo pip install virtualenv
Installing prerequisites on Raspbian::
sudo apt-get install build-essential python2.7-dev libffi-dev \
python-pip python-setuptools sqlite3 \
libssl-dev python-virtualenv libjpeg-dev
sudo pip install --upgrade pip
sudo pip install --upgrade ndg-httpsclient
sudo pip install --upgrade virtualenv
To install the synapse homeserver run:: To install the synapse homeserver run::
virtualenv -p python2.7 ~/.synapse virtualenv -p python2.7 ~/.synapse
@ -310,6 +319,18 @@ may need to manually upgrade it::
sudo pip install --upgrade pip sudo pip install --upgrade pip
Installing may fail with ``Could not find any downloads that satisfy the requirement pymacaroons-pynacl (from matrix-synapse==0.12.0)``.
You can fix this by manually upgrading pip and virtualenv::
sudo pip install --upgrade virtualenv
You can next rerun ``virtualenv -p python2.7 synapse`` to update the virtual env.
Installing may fail during installing virtualenv with ``InsecurePlatformWarning: A true SSLContext object is not available. This prevents urllib3 from configuring SSL appropriately and may cause certain SSL connections to fail. For more information, see https://urllib3.readthedocs.org/en/latest/security.html#insecureplatformwarning.``
You can fix this by manually installing ndg-httpsclient::
pip install --upgrade ndg-httpsclient
Installing may fail with ``mock requires setuptools>=17.1. Aborting installation``. Installing may fail with ``mock requires setuptools>=17.1. Aborting installation``.
You can fix this by upgrading setuptools:: You can fix this by upgrading setuptools::
@ -544,4 +565,4 @@ sphinxcontrib-napoleon::
Building internal API documentation:: Building internal API documentation::
python setup.py build_sphinx python setup.py build_sphinx

24
scripts-dev/dump_macaroon.py Executable file
View File

@ -0,0 +1,24 @@
#!/usr/bin/env python2
import pymacaroons
import sys
if len(sys.argv) == 1:
sys.stderr.write("usage: %s macaroon [key]\n" % (sys.argv[0],))
sys.exit(1)
macaroon_string = sys.argv[1]
key = sys.argv[2] if len(sys.argv) > 2 else None
macaroon = pymacaroons.Macaroon.deserialize(macaroon_string)
print macaroon.inspect()
print ""
verifier = pymacaroons.Verifier()
verifier.satisfy_general(lambda c: True)
try:
verifier.verify(macaroon, key)
print "Signature is correct"
except Exception as e:
print e.message

View File

@ -0,0 +1,62 @@
#! /usr/bin/python
import ast
import argparse
import os
import sys
import yaml
PATTERNS_V1 = []
PATTERNS_V2 = []
RESULT = {
"v1": PATTERNS_V1,
"v2": PATTERNS_V2,
}
class CallVisitor(ast.NodeVisitor):
def visit_Call(self, node):
if isinstance(node.func, ast.Name):
name = node.func.id
else:
return
if name == "client_path_patterns":
PATTERNS_V1.append(node.args[0].s)
elif name == "client_v2_patterns":
PATTERNS_V2.append(node.args[0].s)
def find_patterns_in_code(input_code):
input_ast = ast.parse(input_code)
visitor = CallVisitor()
visitor.visit(input_ast)
def find_patterns_in_file(filepath):
with open(filepath) as f:
find_patterns_in_code(f.read())
parser = argparse.ArgumentParser(description='Find url patterns.')
parser.add_argument(
"directories", nargs='+', metavar="DIR",
help="Directories to search for definitions"
)
args = parser.parse_args()
for directory in args.directories:
for root, dirs, files in os.walk(directory):
for filename in files:
if filename.endswith(".py"):
filepath = os.path.join(root, filename)
find_patterns_in_file(filepath)
PATTERNS_V1.sort()
PATTERNS_V2.sort()
yaml.dump(RESULT, sys.stdout, default_flow_style=False)

View File

@ -16,3 +16,4 @@ ignore =
[flake8] [flake8]
max-line-length = 90 max-line-length = 90
ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.

View File

@ -24,6 +24,7 @@ from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import Requester, RoomID, UserID, EventID from synapse.types import Requester, RoomID, UserID, EventID
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
import logging import logging
@ -529,7 +530,8 @@ class Auth(object):
default=[""] default=[""]
)[0] )[0]
if user and access_token and ip_addr: if user and access_token and ip_addr:
self.store.insert_client_ip( preserve_context_over_fn(
self.store.insert_client_ip,
user=user, user=user,
access_token=access_token, access_token=access_token,
ip=ip_addr, ip=ip_addr,
@ -574,7 +576,7 @@ class Auth(object):
raise AuthError( raise AuthError(
403, 403,
"Application service has not registered this user" "Application service has not registered this user"
) )
defer.returnValue(user_id) defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -696,6 +698,7 @@ class Auth(object):
def _look_up_user_by_access_token(self, token): def _look_up_user_by_access_token(self, token):
ret = yield self.store.get_user_by_access_token(token) ret = yield self.store.get_user_by_access_token(token)
if not ret: if not ret:
logger.warn("Unrecognised access token - not in store: %s" % (token,))
raise AuthError( raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.", self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN errcode=Codes.UNKNOWN_TOKEN
@ -713,6 +716,7 @@ class Auth(object):
token = request.args["access_token"][0] token = request.args["access_token"][0]
service = yield self.store.get_app_service_by_token(token) service = yield self.store.get_app_service_by_token(token)
if not service: if not service:
logger.warn("Unrecognised appservice access token: %s" % (token,))
raise AuthError( raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Unrecognised access token.", "Unrecognised access token.",

View File

@ -23,5 +23,6 @@ WEB_CLIENT_PREFIX = "/_matrix/client"
CONTENT_REPO_PREFIX = "/_matrix/content" CONTENT_REPO_PREFIX = "/_matrix/content"
SERVER_KEY_PREFIX = "/_matrix/key/v1" SERVER_KEY_PREFIX = "/_matrix/key/v1"
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2" SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
MEDIA_PREFIX = "/_matrix/media/v1" MEDIA_PREFIX = "/_matrix/media/r0"
LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
APP_SERVICE_PREFIX = "/_matrix/appservice/v1" APP_SERVICE_PREFIX = "/_matrix/appservice/v1"

View File

@ -12,3 +12,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys
sys.dont_write_bytecode = True
from synapse.python_dependencies import (
check_requirements, MissingRequirementError
) # NOQA
try:
check_requirements()
except MissingRequirementError as e:
message = "\n".join([
"Missing Requirement: %s" % (e.message,),
"To install run:",
" pip install --upgrade --force \"%s\"" % (e.dependency,),
"",
])
sys.stderr.writelines(message)
sys.exit(1)

View File

@ -14,27 +14,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys import synapse
from synapse.rest import ClientRestResource
import contextlib
import logging
import os
import re
import resource
import subprocess
import sys
import time
from synapse.config._base import ConfigError
sys.dont_write_bytecode = True
from synapse.python_dependencies import ( from synapse.python_dependencies import (
check_requirements, DEPENDENCY_LINKS, MissingRequirementError check_requirements, DEPENDENCY_LINKS
) )
if __name__ == '__main__': from synapse.rest import ClientRestResource
try:
check_requirements()
except MissingRequirementError as e:
message = "\n".join([
"Missing Requirement: %s" % (e.message,),
"To install run:",
" pip install --upgrade --force \"%s\"" % (e.dependency,),
"",
])
sys.stderr.writelines(message)
sys.exit(1)
from synapse.storage.engines import create_engine, IncorrectDatabaseSetup from synapse.storage.engines import create_engine, IncorrectDatabaseSetup
from synapse.storage import are_all_users_on_domain from synapse.storage import are_all_users_on_domain
from synapse.storage.prepare_database import UpgradeDatabaseException from synapse.storage.prepare_database import UpgradeDatabaseException
@ -60,7 +56,7 @@ from synapse.rest.key.v1.server_key_resource import LocalKey
from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.api.urls import ( from synapse.api.urls import (
FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
SERVER_KEY_PREFIX, MEDIA_PREFIX, STATIC_PREFIX, SERVER_KEY_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX, STATIC_PREFIX,
SERVER_KEY_V2_PREFIX, SERVER_KEY_V2_PREFIX,
) )
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
@ -73,17 +69,6 @@ from synapse import events
from daemonize import Daemonize from daemonize import Daemonize
import synapse
import contextlib
import logging
import os
import re
import resource
import subprocess
import time
logger = logging.getLogger("synapse.app.homeserver") logger = logging.getLogger("synapse.app.homeserver")
@ -163,8 +148,10 @@ class SynapseHomeServer(HomeServer):
}) })
if name in ["media", "federation", "client"]: if name in ["media", "federation", "client"]:
media_repo = MediaRepositoryResource(self)
resources.update({ resources.update({
MEDIA_PREFIX: MediaRepositoryResource(self), MEDIA_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo,
CONTENT_REPO_PREFIX: ContentRepoResource( CONTENT_REPO_PREFIX: ContentRepoResource(
self, self.config.uploads_path, self.auth, self.content_addr self, self.config.uploads_path, self.auth, self.content_addr
), ),
@ -366,11 +353,20 @@ def setup(config_options):
Returns: Returns:
HomeServer HomeServer
""" """
config = HomeServerConfig.load_config( try:
"Synapse Homeserver", config = HomeServerConfig.load_config(
config_options, "Synapse Homeserver",
generate_section="Homeserver" config_options,
) generate_section="Homeserver"
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
if not config:
# If a config isn't returned, and an exception isn't raised, we're just
# generating config files and shouldn't try to continue.
sys.exit(0)
config.setup_logging() config.setup_logging()
@ -690,8 +686,8 @@ def run(hs):
stats["uptime_seconds"] = uptime stats["uptime_seconds"] = uptime
stats["total_users"] = yield hs.get_datastore().count_all_users() stats["total_users"] = yield hs.get_datastore().count_all_users()
all_rooms = yield hs.get_datastore().get_rooms(False) room_count = yield hs.get_datastore().get_room_count()
stats["total_room_count"] = len(all_rooms) stats["total_room_count"] = room_count
stats["daily_active_users"] = yield hs.get_datastore().count_daily_users() stats["daily_active_users"] = yield hs.get_datastore().count_daily_users()
daily_messages = yield hs.get_datastore().count_daily_messages() daily_messages = yield hs.get_datastore().count_daily_messages()
@ -713,6 +709,8 @@ def run(hs):
phone_home_task.start(60 * 60 * 24, now=False) phone_home_task.start(60 * 60 * 24, now=False)
def in_thread(): def in_thread():
# Uncomment to enable tracing of log context changes.
# sys.settrace(logcontext_tracer)
with LoggingContext("run"): with LoggingContext("run"):
change_resource_limit(hs.config.soft_file_limit) change_resource_limit(hs.config.soft_file_limit)
reactor.run() reactor.run()

View File

@ -29,7 +29,7 @@ class ApplicationServiceApi(SimpleHttpClient):
pushing. pushing.
""" """
def __init__(self, hs): def __init__(self, hs):
super(ApplicationServiceApi, self).__init__(hs) super(ApplicationServiceApi, self).__init__(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.config._base import ConfigError
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
@ -21,7 +22,11 @@ if __name__ == "__main__":
if action == "read": if action == "read":
key = sys.argv[2] key = sys.argv[2]
config = HomeServerConfig.load_config("", sys.argv[3:]) try:
config = HomeServerConfig.load_config("", sys.argv[3:])
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
print getattr(config, key) print getattr(config, key)
sys.exit(0) sys.exit(0)

View File

@ -17,7 +17,6 @@ import argparse
import errno import errno
import os import os
import yaml import yaml
import sys
from textwrap import dedent from textwrap import dedent
@ -136,13 +135,20 @@ class Config(object):
results.append(getattr(cls, name)(self, *args, **kargs)) results.append(getattr(cls, name)(self, *args, **kargs))
return results return results
def generate_config(self, config_dir_path, server_name, report_stats=None): def generate_config(
self,
config_dir_path,
server_name,
is_generating_file,
report_stats=None,
):
default_config = "# vim:ft=yaml\n" default_config = "# vim:ft=yaml\n"
default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all( default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all(
"default_config", "default_config",
config_dir_path=config_dir_path, config_dir_path=config_dir_path,
server_name=server_name, server_name=server_name,
is_generating_file=is_generating_file,
report_stats=report_stats, report_stats=report_stats,
)) ))
@ -244,8 +250,10 @@ class Config(object):
server_name = config_args.server_name server_name = config_args.server_name
if not server_name: if not server_name:
print "Must specify a server_name to a generate config for." raise ConfigError(
sys.exit(1) "Must specify a server_name to a generate config for."
" Pass -H server.name."
)
if not os.path.exists(config_dir_path): if not os.path.exists(config_dir_path):
os.makedirs(config_dir_path) os.makedirs(config_dir_path)
with open(config_path, "wb") as config_file: with open(config_path, "wb") as config_file:
@ -253,6 +261,7 @@ class Config(object):
config_dir_path=config_dir_path, config_dir_path=config_dir_path,
server_name=server_name, server_name=server_name,
report_stats=(config_args.report_stats == "yes"), report_stats=(config_args.report_stats == "yes"),
is_generating_file=True
) )
obj.invoke_all("generate_files", config) obj.invoke_all("generate_files", config)
config_file.write(config_bytes) config_file.write(config_bytes)
@ -266,7 +275,7 @@ class Config(object):
"If this server name is incorrect, you will need to" "If this server name is incorrect, you will need to"
" regenerate the SSL certificates" " regenerate the SSL certificates"
) )
sys.exit(0) return
else: else:
print ( print (
"Config file %r already exists. Generating any missing key" "Config file %r already exists. Generating any missing key"
@ -302,25 +311,25 @@ class Config(object):
specified_config.update(yaml_config) specified_config.update(yaml_config)
if "server_name" not in specified_config: if "server_name" not in specified_config:
sys.stderr.write("\n" + MISSING_SERVER_NAME + "\n") raise ConfigError(MISSING_SERVER_NAME)
sys.exit(1)
server_name = specified_config["server_name"] server_name = specified_config["server_name"]
_, config = obj.generate_config( _, config = obj.generate_config(
config_dir_path=config_dir_path, config_dir_path=config_dir_path,
server_name=server_name server_name=server_name,
is_generating_file=False,
) )
config.pop("log_config") config.pop("log_config")
config.update(specified_config) config.update(specified_config)
if "report_stats" not in config: if "report_stats" not in config:
sys.stderr.write( raise ConfigError(
"\n" + MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" + MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" +
MISSING_REPORT_STATS_SPIEL + "\n") MISSING_REPORT_STATS_SPIEL
sys.exit(1) )
if generate_keys: if generate_keys:
obj.invoke_all("generate_files", config) obj.invoke_all("generate_files", config)
sys.exit(0) return
obj.invoke_all("read_config", config) obj.invoke_all("read_config", config)

View File

@ -22,8 +22,14 @@ from signedjson.key import (
read_signing_keys, write_signing_keys, NACL_ED25519 read_signing_keys, write_signing_keys, NACL_ED25519
) )
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
from synapse.util.stringutils import random_string_with_symbols
import os import os
import hashlib
import logging
logger = logging.getLogger(__name__)
class KeyConfig(Config): class KeyConfig(Config):
@ -40,9 +46,29 @@ class KeyConfig(Config):
config["perspectives"] config["perspectives"]
) )
def default_config(self, config_dir_path, server_name, **kwargs): self.macaroon_secret_key = config.get(
"macaroon_secret_key", self.registration_shared_secret
)
if not self.macaroon_secret_key:
# Unfortunately, there are people out there that don't have this
# set. Lets just be "nice" and derive one from their secret key.
logger.warn("Config is missing missing macaroon_secret_key")
seed = self.signing_key[0].seed
self.macaroon_secret_key = hashlib.sha256(seed)
def default_config(self, config_dir_path, server_name, is_generating_file=False,
**kwargs):
base_key_name = os.path.join(config_dir_path, server_name) base_key_name = os.path.join(config_dir_path, server_name)
if is_generating_file:
macaroon_secret_key = random_string_with_symbols(50)
else:
macaroon_secret_key = None
return """\ return """\
macaroon_secret_key: "%(macaroon_secret_key)s"
## Signing Keys ## ## Signing Keys ##
# Path to the signing key to sign messages with # Path to the signing key to sign messages with

View File

@ -23,22 +23,23 @@ from distutils.util import strtobool
class RegistrationConfig(Config): class RegistrationConfig(Config):
def read_config(self, config): def read_config(self, config):
self.disable_registration = not bool( self.enable_registration = bool(
strtobool(str(config["enable_registration"])) strtobool(str(config["enable_registration"]))
) )
if "disable_registration" in config: if "disable_registration" in config:
self.disable_registration = bool( self.enable_registration = not bool(
strtobool(str(config["disable_registration"])) strtobool(str(config["disable_registration"]))
) )
self.registration_shared_secret = config.get("registration_shared_secret") self.registration_shared_secret = config.get("registration_shared_secret")
self.macaroon_secret_key = config.get("macaroon_secret_key")
self.bcrypt_rounds = config.get("bcrypt_rounds", 12) self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
self.allow_guest_access = config.get("allow_guest_access", False) self.allow_guest_access = config.get("allow_guest_access", False)
def default_config(self, **kwargs): def default_config(self, **kwargs):
registration_shared_secret = random_string_with_symbols(50) registration_shared_secret = random_string_with_symbols(50)
macaroon_secret_key = random_string_with_symbols(50)
return """\ return """\
## Registration ## ## Registration ##
@ -49,8 +50,6 @@ class RegistrationConfig(Config):
# secret, even if registration is otherwise disabled. # secret, even if registration is otherwise disabled.
registration_shared_secret: "%(registration_shared_secret)s" registration_shared_secret: "%(registration_shared_secret)s"
macaroon_secret_key: "%(macaroon_secret_key)s"
# Set the number of bcrypt rounds used to generate password hash. # Set the number of bcrypt rounds used to generate password hash.
# Larger numbers increase the work factor needed to generate the hash. # Larger numbers increase the work factor needed to generate the hash.
# The default number of rounds is 12. # The default number of rounds is 12.
@ -60,6 +59,12 @@ class RegistrationConfig(Config):
# participate in rooms hosted on this server which have been made # participate in rooms hosted on this server which have been made
# accessible to anonymous users. # accessible to anonymous users.
allow_guest_access: False allow_guest_access: False
# The list of identity servers trusted to verify third party
# identifiers by this server.
trusted_third_party_id_servers:
- matrix.org
- vector.im
""" % locals() """ % locals()
def add_arguments(self, parser): def add_arguments(self, parser):
@ -71,6 +76,6 @@ class RegistrationConfig(Config):
def read_arguments(self, args): def read_arguments(self, args):
if args.enable_registration is not None: if args.enable_registration is not None:
self.disable_registration = not bool( self.enable_registration = bool(
strtobool(str(args.enable_registration)) strtobool(str(args.enable_registration))
) )

View File

@ -18,6 +18,10 @@ from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter from synapse.util.retryutils import get_retry_limiter
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import (
preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
preserve_fn
)
from twisted.internet import defer from twisted.internet import defer
@ -142,40 +146,43 @@ class Keyring(object):
for server_name, _ in server_and_json for server_name, _ in server_and_json
} }
# We want to wait for any previous lookups to complete before with PreserveLoggingContext():
# proceeding.
wait_on_deferred = self.wait_for_previous_lookups(
[server_name for server_name, _ in server_and_json],
server_to_deferred,
)
# Actually start fetching keys. # We want to wait for any previous lookups to complete before
wait_on_deferred.addBoth( # proceeding.
lambda _: self.get_server_verify_keys(group_id_to_group, deferreds) wait_on_deferred = self.wait_for_previous_lookups(
) [server_name for server_name, _ in server_and_json],
server_to_deferred,
)
# When we've finished fetching all the keys for a given server_name, # Actually start fetching keys.
# resolve the deferred passed to `wait_for_previous_lookups` so that wait_on_deferred.addBoth(
# any lookups waiting will proceed. lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
server_to_gids = {} )
def remove_deferreds(res, server_name, group_id): # When we've finished fetching all the keys for a given server_name,
server_to_gids[server_name].discard(group_id) # resolve the deferred passed to `wait_for_previous_lookups` so that
if not server_to_gids[server_name]: # any lookups waiting will proceed.
d = server_to_deferred.pop(server_name, None) server_to_gids = {}
if d:
d.callback(None)
return res
for g_id, deferred in deferreds.items(): def remove_deferreds(res, server_name, group_id):
server_name = group_id_to_group[g_id].server_name server_to_gids[server_name].discard(group_id)
server_to_gids.setdefault(server_name, set()).add(g_id) if not server_to_gids[server_name]:
deferred.addBoth(remove_deferreds, server_name, g_id) d = server_to_deferred.pop(server_name, None)
if d:
d.callback(None)
return res
for g_id, deferred in deferreds.items():
server_name = group_id_to_group[g_id].server_name
server_to_gids.setdefault(server_name, set()).add(g_id)
deferred.addBoth(remove_deferreds, server_name, g_id)
# Pass those keys to handle_key_deferred so that the json object # Pass those keys to handle_key_deferred so that the json object
# signatures can be verified # signatures can be verified
return [ return [
handle_key_deferred( preserve_context_over_fn(
handle_key_deferred,
group_id_to_group[g_id], group_id_to_group[g_id],
deferreds[g_id], deferreds[g_id],
) )
@ -198,12 +205,13 @@ class Keyring(object):
if server_name in self.key_downloads if server_name in self.key_downloads
] ]
if wait_on: if wait_on:
yield defer.DeferredList(wait_on) with PreserveLoggingContext():
yield defer.DeferredList(wait_on)
else: else:
break break
for server_name, deferred in server_to_deferred.items(): for server_name, deferred in server_to_deferred.items():
d = ObservableDeferred(deferred) d = ObservableDeferred(preserve_context_over_deferred(deferred))
self.key_downloads[server_name] = d self.key_downloads[server_name] = d
def rm(r, server_name): def rm(r, server_name):
@ -244,12 +252,13 @@ class Keyring(object):
for group in group_id_to_group.values(): for group in group_id_to_group.values():
for key_id in group.key_ids: for key_id in group.key_ids:
if key_id in merged_results[group.server_name]: if key_id in merged_results[group.server_name]:
group_id_to_deferred[group.group_id].callback(( with PreserveLoggingContext():
group.group_id, group_id_to_deferred[group.group_id].callback((
group.server_name, group.group_id,
key_id, group.server_name,
merged_results[group.server_name][key_id], key_id,
)) merged_results[group.server_name][key_id],
))
break break
else: else:
missing_groups.setdefault( missing_groups.setdefault(
@ -504,7 +513,7 @@ class Keyring(object):
yield defer.gatherResults( yield defer.gatherResults(
[ [
self.store_keys( preserve_fn(self.store_keys)(
server_name=key_server_name, server_name=key_server_name,
from_server=server_name, from_server=server_name,
verify_keys=verify_keys, verify_keys=verify_keys,
@ -573,7 +582,7 @@ class Keyring(object):
yield defer.gatherResults( yield defer.gatherResults(
[ [
self.store.store_server_keys_json( preserve_fn(self.store.store_server_keys_json)(
server_name=server_name, server_name=server_name,
key_id=key_id, key_id=key_id,
from_server=server_name, from_server=server_name,
@ -675,7 +684,7 @@ class Keyring(object):
# TODO(markjh): Store whether the keys have expired. # TODO(markjh): Store whether the keys have expired.
yield defer.gatherResults( yield defer.gatherResults(
[ [
self.store.store_server_verify_key( preserve_fn(self.store.store_server_verify_key)(
server_name, server_name, key.time_added, key server_name, server_name, key.time_added, key
) )
for key_id, key in verify_keys.items() for key_id, key in verify_keys.items()

View File

@ -20,3 +20,4 @@ class EventContext(object):
self.current_state = current_state self.current_state = current_state
self.state_group = None self.state_group = None
self.rejected = False self.rejected = False
self.push_actions = []

View File

@ -57,7 +57,7 @@ class FederationClient(FederationBase):
cache_name="get_pdu_cache", cache_name="get_pdu_cache",
clock=self._clock, clock=self._clock,
max_len=1000, max_len=1000,
expiry_ms=120*1000, expiry_ms=120 * 1000,
reset_expiry_on_get=False, reset_expiry_on_get=False,
) )

View File

@ -126,10 +126,8 @@ class FederationServer(FederationBase):
results = [] results = []
for pdu in pdu_list: for pdu in pdu_list:
d = self._handle_new_pdu(transaction.origin, pdu)
try: try:
yield d yield self._handle_new_pdu(transaction.origin, pdu)
results.append({}) results.append({})
except FederationError as e: except FederationError as e:
self.send_failure(e, transaction.origin) self.send_failure(e, transaction.origin)

View File

@ -103,7 +103,6 @@ class TransactionQueue(object):
else: else:
return not destination.startswith("localhost") return not destination.startswith("localhost")
@defer.inlineCallbacks
def enqueue_pdu(self, pdu, destinations, order): def enqueue_pdu(self, pdu, destinations, order):
# We loop through all destinations to see whether we already have # We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus # a transaction in progress. If we do, stick it in the pending_pdus
@ -141,8 +140,6 @@ class TransactionQueue(object):
deferreds.append(deferred) deferreds.append(deferred)
yield defer.DeferredList(deferreds, consumeErrors=True)
# NO inlineCallbacks # NO inlineCallbacks
def enqueue_edu(self, edu): def enqueue_edu(self, edu):
destination = edu.destination destination = edu.destination

View File

@ -53,25 +53,10 @@ class BaseHandler(object):
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
@defer.inlineCallbacks @defer.inlineCallbacks
def _filter_events_for_clients(self, user_tuples, events): def _filter_events_for_clients(self, user_tuples, events, event_id_to_state):
""" Returns dict of user_id -> list of events that user is allowed to """ Returns dict of user_id -> list of events that user is allowed to
see. see.
""" """
# If there is only one user, just get the state for that one user,
# otherwise just get all the state.
if len(user_tuples) == 1:
types = (
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, user_tuples[0][0]),
)
else:
types = None
event_id_to_state = yield self.store.get_state_for_events(
frozenset(e.event_id for e in events),
types=types
)
forgotten = yield defer.gatherResults([ forgotten = yield defer.gatherResults([
self.store.who_forgot_in_room( self.store.who_forgot_in_room(
room_id, room_id,
@ -135,7 +120,17 @@ class BaseHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _filter_events_for_client(self, user_id, events, is_peeking=False): def _filter_events_for_client(self, user_id, events, is_peeking=False):
# Assumes that user has at some point joined the room if not is_guest. # Assumes that user has at some point joined the room if not is_guest.
res = yield self._filter_events_for_clients([(user_id, is_peeking)], events) types = (
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, user_id),
)
event_id_to_state = yield self.store.get_state_for_events(
frozenset(e.event_id for e in events),
types=types
)
res = yield self._filter_events_for_clients(
[(user_id, is_peeking)], events, event_id_to_state
)
defer.returnValue(res.get(user_id, [])) defer.returnValue(res.get(user_id, []))
def ratelimit(self, user_id): def ratelimit(self, user_id):
@ -147,7 +142,7 @@ class BaseHandler(object):
) )
if not allowed: if not allowed:
raise LimitExceededError( raise LimitExceededError(
retry_after_ms=int(1000*(time_allowed - time_now)), retry_after_ms=int(1000 * (time_allowed - time_now)),
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -269,13 +264,13 @@ class BaseHandler(object):
"You don't have permission to redact events" "You don't have permission to redact events"
) )
(event_stream_id, max_stream_id) = yield self.store.persist_event(
event, context=context
)
action_generator = ActionGenerator(self.hs) action_generator = ActionGenerator(self.hs)
yield action_generator.handle_push_actions_for_event( yield action_generator.handle_push_actions_for_event(
event, self event, context, self
)
(event_stream_id, max_stream_id) = yield self.store.persist_event(
event, context=context
) )
destinations = set() destinations = set()
@ -293,19 +288,11 @@ class BaseHandler(object):
with PreserveLoggingContext(): with PreserveLoggingContext():
# Don't block waiting on waking up all the listeners. # Don't block waiting on waking up all the listeners.
notify_d = self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, event, event_stream_id, max_stream_id,
extra_users=extra_users extra_users=extra_users
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
notify_d.addErrback(log_failure)
# If invite, remove room_state from unsigned before sending. # If invite, remove room_state from unsigned before sending.
event.unsigned.pop("invite_room_state", None) event.unsigned.pop("invite_room_state", None)

View File

@ -175,8 +175,8 @@ class DirectoryHandler(BaseHandler):
# If this server is in the list of servers, return it first. # If this server is in the list of servers, return it first.
if self.server_name in servers: if self.server_name in servers:
servers = ( servers = (
[self.server_name] [self.server_name] +
+ [s for s in servers if s != self.server_name] [s for s in servers if s != self.server_name]
) )
else: else:
servers = list(servers) servers = list(servers)

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.types import UserID from synapse.types import UserID
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.util.logcontext import preserve_context_over_fn
from ._base import BaseHandler from ._base import BaseHandler
@ -29,11 +30,17 @@ logger = logging.getLogger(__name__)
def started_user_eventstream(distributor, user): def started_user_eventstream(distributor, user):
return distributor.fire("started_user_eventstream", user) return preserve_context_over_fn(
distributor.fire,
"started_user_eventstream", user
)
def stopped_user_eventstream(distributor, user): def stopped_user_eventstream(distributor, user):
return distributor.fire("stopped_user_eventstream", user) return preserve_context_over_fn(
distributor.fire,
"stopped_user_eventstream", user
)
class EventStreamHandler(BaseHandler): class EventStreamHandler(BaseHandler):
@ -130,7 +137,7 @@ class EventStreamHandler(BaseHandler):
# Add some randomness to this value to try and mitigate against # Add some randomness to this value to try and mitigate against
# thundering herds on restart. # thundering herds on restart.
timeout = random.randint(int(timeout*0.9), int(timeout*1.1)) timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1))
events, tokens = yield self.notifier.get_events_for( events, tokens = yield self.notifier.get_events_for(
auth_user, pagin_config, timeout, auth_user, pagin_config, timeout,

View File

@ -221,19 +221,11 @@ class FederationHandler(BaseHandler):
extra_users.append(target_user) extra_users.append(target_user)
with PreserveLoggingContext(): with PreserveLoggingContext():
d = self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, event, event_stream_id, max_stream_id,
extra_users=extra_users extra_users=extra_users
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
prev_state = context.current_state.get((event.type, event.state_key)) prev_state = context.current_state.get((event.type, event.state_key))
@ -244,12 +236,6 @@ class FederationHandler(BaseHandler):
user = UserID.from_string(event.state_key) user = UserID.from_string(event.state_key)
yield user_joined_room(self.distributor, user, event.room_id) yield user_joined_room(self.distributor, user, event.room_id)
if not backfilled and not event.internal_metadata.is_outlier():
action_generator = ActionGenerator(self.hs)
yield action_generator.handle_push_actions_for_event(
event, self
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _filter_events_for_server(self, server_name, room_id, events): def _filter_events_for_server(self, server_name, room_id, events):
event_to_state = yield self.store.get_state_for_events( event_to_state = yield self.store.get_state_for_events(
@ -643,19 +629,11 @@ class FederationHandler(BaseHandler):
) )
with PreserveLoggingContext(): with PreserveLoggingContext():
d = self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, event, event_stream_id, max_stream_id,
extra_users=[joinee] extra_users=[joinee]
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
logger.debug("Finished joining %s to %s", joinee, room_id) logger.debug("Finished joining %s to %s", joinee, room_id)
finally: finally:
room_queue = self.room_queues[room_id] room_queue = self.room_queues[room_id]
@ -730,18 +708,10 @@ class FederationHandler(BaseHandler):
extra_users.append(target_user) extra_users.append(target_user)
with PreserveLoggingContext(): with PreserveLoggingContext():
d = self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, extra_users=extra_users event, event_stream_id, max_stream_id, extra_users=extra_users
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
if event.content["membership"] == Membership.JOIN: if event.content["membership"] == Membership.JOIN:
user = UserID.from_string(event.state_key) user = UserID.from_string(event.state_key)
@ -811,19 +781,11 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(event.state_key) target_user = UserID.from_string(event.state_key)
with PreserveLoggingContext(): with PreserveLoggingContext():
d = self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, event, event_stream_id, max_stream_id,
extra_users=[target_user], extra_users=[target_user],
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
defer.returnValue(event) defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -948,18 +910,10 @@ class FederationHandler(BaseHandler):
extra_users.append(target_user) extra_users.append(target_user)
with PreserveLoggingContext(): with PreserveLoggingContext():
d = self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, extra_users=extra_users event, event_stream_id, max_stream_id, extra_users=extra_users
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
new_pdu = event new_pdu = event
destinations = set() destinations = set()
@ -1113,6 +1067,12 @@ class FederationHandler(BaseHandler):
auth_events=auth_events, auth_events=auth_events,
) )
if not backfilled and not event.internal_metadata.is_outlier():
action_generator = ActionGenerator(self.hs)
yield action_generator.handle_push_actions_for_event(
event, context, self
)
event_stream_id, max_stream_id = yield self.store.persist_event( event_stream_id, max_stream_id = yield self.store.persist_event(
event, event,
context=context, context=context,

View File

@ -36,14 +36,15 @@ class IdentityHandler(BaseHandler):
self.http_client = hs.get_simple_http_client() self.http_client = hs.get_simple_http_client()
self.trusted_id_servers = set(hs.config.trusted_third_party_id_servers)
self.trust_any_id_server_just_for_testing_do_not_use = (
hs.config.use_insecure_ssl_client_just_for_testing_do_not_use
)
@defer.inlineCallbacks @defer.inlineCallbacks
def threepid_from_creds(self, creds): def threepid_from_creds(self, creds):
yield run_on_reactor() yield run_on_reactor()
# XXX: make this configurable!
# trustedIdServers = ['matrix.org', 'localhost:8090']
trustedIdServers = ['matrix.org', 'vector.im']
if 'id_server' in creds: if 'id_server' in creds:
id_server = creds['id_server'] id_server = creds['id_server']
elif 'idServer' in creds: elif 'idServer' in creds:
@ -58,10 +59,19 @@ class IdentityHandler(BaseHandler):
else: else:
raise SynapseError(400, "No client_secret in creds") raise SynapseError(400, "No client_secret in creds")
if id_server not in trustedIdServers: if id_server not in self.trusted_id_servers:
logger.warn('%s is not a trusted ID server: rejecting 3pid ' + if self.trust_any_id_server_just_for_testing_do_not_use:
'credentials', id_server) logger.warn(
defer.returnValue(None) "Trusting untrustworthy ID server %r even though it isn't"
" in the trusted id list for testing because"
" 'use_insecure_ssl_client_just_for_testing_do_not_use'"
" is set in the config",
id_server,
)
else:
logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
'credentials', id_server)
defer.returnValue(None)
data = {} data = {}
try: try:

View File

@ -34,7 +34,7 @@ metrics = synapse.metrics.get_metrics_for(__name__)
# Don't bother bumping "last active" time if it differs by less than 60 seconds # Don't bother bumping "last active" time if it differs by less than 60 seconds
LAST_ACTIVE_GRANULARITY = 60*1000 LAST_ACTIVE_GRANULARITY = 60 * 1000
# Keep no more than this number of offline serial revisions # Keep no more than this number of offline serial revisions
MAX_OFFLINE_SERIALS = 1000 MAX_OFFLINE_SERIALS = 1000
@ -378,9 +378,9 @@ class PresenceHandler(BaseHandler):
was_polling = target_user in self._user_cachemap was_polling = target_user in self._user_cachemap
if now_online and not was_polling: if now_online and not was_polling:
self.start_polling_presence(target_user, state=state) yield self.start_polling_presence(target_user, state=state)
elif not now_online and was_polling: elif not now_online and was_polling:
self.stop_polling_presence(target_user) yield self.stop_polling_presence(target_user)
# TODO(paul): perform a presence push as part of start/stop poll so # TODO(paul): perform a presence push as part of start/stop poll so
# we don't have to do this all the time # we don't have to do this all the time
@ -394,7 +394,8 @@ class PresenceHandler(BaseHandler):
if now - prev_state.state.get("last_active", 0) < LAST_ACTIVE_GRANULARITY: if now - prev_state.state.get("last_active", 0) < LAST_ACTIVE_GRANULARITY:
return return
self.changed_presencelike_data(user, {"last_active": now}) with PreserveLoggingContext():
self.changed_presencelike_data(user, {"last_active": now})
def get_joined_rooms_for_user(self, user): def get_joined_rooms_for_user(self, user):
"""Get the list of rooms a user is joined to. """Get the list of rooms a user is joined to.
@ -466,11 +467,12 @@ class PresenceHandler(BaseHandler):
local_user, room_ids=[room_id], add_to_cache=False local_user, room_ids=[room_id], add_to_cache=False
) )
self.push_update_to_local_and_remote( with PreserveLoggingContext():
observed_user=local_user, self.push_update_to_local_and_remote(
users_to_push=[user], observed_user=local_user,
statuscache=statuscache, users_to_push=[user],
) statuscache=statuscache,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def send_presence_invite(self, observer_user, observed_user): def send_presence_invite(self, observer_user, observed_user):
@ -556,7 +558,7 @@ class PresenceHandler(BaseHandler):
observer_user.localpart, observed_user.to_string() observer_user.localpart, observed_user.to_string()
) )
self.start_polling_presence( yield self.start_polling_presence(
observer_user, target_user=observed_user observer_user, target_user=observed_user
) )

View File

@ -21,7 +21,6 @@ from synapse.api.errors import (
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
) )
from ._base import BaseHandler from ._base import BaseHandler
import synapse.util.stringutils as stringutils
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.http.client import CaptchaServerHttpClient from synapse.http.client import CaptchaServerHttpClient
@ -45,6 +44,8 @@ class RegistrationHandler(BaseHandler):
self.distributor.declare("registered_user") self.distributor.declare("registered_user")
self.captcha_client = CaptchaServerHttpClient(hs) self.captcha_client = CaptchaServerHttpClient(hs)
self._next_generated_user_id = None
@defer.inlineCallbacks @defer.inlineCallbacks
def check_username(self, localpart, guest_access_token=None): def check_username(self, localpart, guest_access_token=None):
yield run_on_reactor() yield run_on_reactor()
@ -91,7 +92,7 @@ class RegistrationHandler(BaseHandler):
Args: Args:
localpart : The local part of the user ID to register. If None, localpart : The local part of the user ID to register. If None,
one will be randomly generated. one will be generated.
password (str) : The password to assign to this user so they can password (str) : The password to assign to this user so they can
login again. This can be None which means they cannot login again login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user). via a password (e.g. the user is an application service user).
@ -108,6 +109,18 @@ class RegistrationHandler(BaseHandler):
if localpart: if localpart:
yield self.check_username(localpart, guest_access_token=guest_access_token) yield self.check_username(localpart, guest_access_token=guest_access_token)
was_guest = guest_access_token is not None
if not was_guest:
try:
int(localpart)
raise RegistrationError(
400,
"Numeric user IDs are reserved for guest users."
)
except ValueError:
pass
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
@ -118,38 +131,36 @@ class RegistrationHandler(BaseHandler):
user_id=user_id, user_id=user_id,
token=token, token=token,
password_hash=password_hash, password_hash=password_hash,
was_guest=guest_access_token is not None, was_guest=was_guest,
make_guest=make_guest, make_guest=make_guest,
) )
yield registered_user(self.distributor, user) yield registered_user(self.distributor, user)
else: else:
# autogen a random user ID # autogen a sequential user ID
attempts = 0 attempts = 0
user_id = None
token = None token = None
while not user_id: user = None
while not user:
localpart = yield self._generate_user_id(attempts > 0)
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
yield self.check_user_id_is_valid(user_id)
if generate_token:
token = self.auth_handler().generate_access_token(user_id)
try: try:
localpart = self._generate_user_id()
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
yield self.check_user_id_is_valid(user_id)
if generate_token:
token = self.auth_handler().generate_access_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token, token=token,
password_hash=password_hash) password_hash=password_hash,
make_guest=make_guest
yield registered_user(self.distributor, user) )
except SynapseError: except SynapseError:
# if user id is taken, just generate another # if user id is taken, just generate another
user_id = None user_id = None
token = None token = None
attempts += 1 attempts += 1
if attempts > 5: yield registered_user(self.distributor, user)
raise RegistrationError(
500, "Cannot generate user ID.")
# We used to generate default identicons here, but nowadays # We used to generate default identicons here, but nowadays
# we want clients to generate their own as part of their branding # we want clients to generate their own as part of their branding
@ -175,7 +186,7 @@ class RegistrationHandler(BaseHandler):
token=token, token=token,
password_hash="" password_hash=""
) )
registered_user(self.distributor, user) yield registered_user(self.distributor, user)
defer.returnValue((user_id, token)) defer.returnValue((user_id, token))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -211,7 +222,7 @@ class RegistrationHandler(BaseHandler):
400, 400,
"User ID must only contain characters which do not" "User ID must only contain characters which do not"
" require URL encoding." " require URL encoding."
) )
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
@ -281,8 +292,16 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE errcode=Codes.EXCLUSIVE
) )
def _generate_user_id(self): @defer.inlineCallbacks
return "-" + stringutils.random_string(18) def _generate_user_id(self, reseed=False):
if reseed or self._next_generated_user_id is None:
self._next_generated_user_id = (
yield self.store.find_next_generated_user_id_localpart()
)
id = self._next_generated_user_id
self._next_generated_user_id += 1
defer.returnValue(str(id))
@defer.inlineCallbacks @defer.inlineCallbacks
def _validate_captcha(self, ip_addr, private_key, challenge, response): def _validate_captcha(self, ip_addr, private_key, challenge, response):

View File

@ -18,13 +18,14 @@ from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.types import UserID, RoomAlias, RoomID from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken
from synapse.api.constants import ( from synapse.api.constants import (
EventTypes, Membership, JoinRules, RoomCreationPreset, EventTypes, Membership, JoinRules, RoomCreationPreset,
) )
from synapse.api.errors import AuthError, StoreError, SynapseError, Codes from synapse.api.errors import AuthError, StoreError, SynapseError, Codes
from synapse.util import stringutils, unwrapFirstError from synapse.util import stringutils, unwrapFirstError
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.logcontext import preserve_context_over_fn
from signedjson.sign import verify_signed_json from signedjson.sign import verify_signed_json
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
@ -46,11 +47,17 @@ def collect_presencelike_data(distributor, user, content):
def user_left_room(distributor, user, room_id): def user_left_room(distributor, user, room_id):
return distributor.fire("user_left_room", user=user, room_id=room_id) return preserve_context_over_fn(
distributor.fire,
"user_left_room", user=user, room_id=room_id
)
def user_joined_room(distributor, user, room_id): def user_joined_room(distributor, user, room_id):
return distributor.fire("user_joined_room", user=user, room_id=room_id) return preserve_context_over_fn(
distributor.fire,
"user_joined_room", user=user, room_id=room_id
)
class RoomCreationHandler(BaseHandler): class RoomCreationHandler(BaseHandler):
@ -876,39 +883,71 @@ class RoomListHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_public_room_list(self): def get_public_room_list(self):
chunk = yield self.store.get_rooms(is_public=True) room_ids = yield self.store.get_public_room_ids()
room_members = yield defer.gatherResults( @defer.inlineCallbacks
[ def handle_room(room_id):
self.store.get_users_in_room(room["room_id"]) aliases = yield self.store.get_aliases_for_room(room_id)
for room in chunk if not aliases:
], defer.returnValue(None)
consumeErrors=True,
).addErrback(unwrapFirstError)
avatar_urls = yield defer.gatherResults( state = yield self.state_handler.get_current_state(room_id)
[
self.get_room_avatar_url(room["room_id"])
for room in chunk
],
consumeErrors=True,
).addErrback(unwrapFirstError)
for i, room in enumerate(chunk): result = {"aliases": aliases, "room_id": room_id}
room["num_joined_members"] = len(room_members[i])
if avatar_urls[i]: name_event = state.get((EventTypes.Name, ""), None)
room["avatar_url"] = avatar_urls[i] if name_event:
name = name_event.content.get("name", None)
if name:
result["name"] = name
topic_event = state.get((EventTypes.Topic, ""), None)
if topic_event:
topic = topic_event.content.get("topic", None)
if topic:
result["topic"] = topic
canonical_event = state.get((EventTypes.CanonicalAlias, ""), None)
if canonical_event:
canonical_alias = canonical_event.content.get("alias", None)
if canonical_alias:
result["canonical_alias"] = canonical_alias
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
visibility = None
if visibility_event:
visibility = visibility_event.content.get("history_visibility", None)
result["world_readable"] = visibility == "world_readable"
guest_event = state.get((EventTypes.GuestAccess, ""), None)
guest = None
if guest_event:
guest = guest_event.content.get("guest_access", None)
result["guest_can_join"] = guest == "can_join"
avatar_event = state.get(("m.room.avatar", ""), None)
if avatar_event:
avatar_url = avatar_event.content.get("url", None)
if avatar_url:
result["avatar_url"] = avatar_url
result["num_joined_members"] = sum(
1 for (event_type, _), ev in state.items()
if event_type == EventTypes.Member and ev.membership == Membership.JOIN
)
defer.returnValue(result)
result = []
for chunk in (room_ids[i:i + 10] for i in xrange(0, len(room_ids), 10)):
chunk_result = yield defer.gatherResults([
handle_room(room_id)
for room_id in chunk
], consumeErrors=True).addErrback(unwrapFirstError)
result.extend(v for v in chunk_result if v)
# FIXME (erikj): START is no longer a valid value # FIXME (erikj): START is no longer a valid value
defer.returnValue({"start": "START", "end": "END", "chunk": chunk}) defer.returnValue({"start": "START", "end": "END", "chunk": result})
@defer.inlineCallbacks
def get_room_avatar_url(self, room_id):
event = yield self.hs.get_state_handler().get_current_state(
room_id, "m.room.avatar"
)
if event and "url" in event.content:
defer.returnValue(event.content["url"])
class RoomContextHandler(BaseHandler): class RoomContextHandler(BaseHandler):
@ -927,7 +966,7 @@ class RoomContextHandler(BaseHandler):
Returns: Returns:
dict, or None if the event isn't found dict, or None if the event isn't found
""" """
before_limit = math.floor(limit/2.) before_limit = math.floor(limit / 2.)
after_limit = limit - before_limit after_limit = limit - before_limit
now_token = yield self.hs.get_event_sources().get_current_token() now_token = yield self.hs.get_event_sources().get_current_token()
@ -997,6 +1036,11 @@ class RoomEventSource(object):
to_key = yield self.get_current_key() to_key = yield self.get_current_key()
from_token = RoomStreamToken.parse(from_key)
if from_token.topological:
logger.warn("Stream has topological part!!!! %r", from_key)
from_key = "s%s" % (from_token.stream,)
app_service = yield self.store.get_app_service_by_user_id( app_service = yield self.store.get_app_service_by_user_id(
user.to_string() user.to_string()
) )
@ -1008,15 +1052,30 @@ class RoomEventSource(object):
limit=limit, limit=limit,
) )
else: else:
events, end_key = yield self.store.get_room_events_stream( room_events = yield self.store.get_membership_changes_for_user(
user_id=user.to_string(), user.to_string(), from_key, to_key
)
room_to_events = yield self.store.get_room_events_stream_for_rooms(
room_ids=room_ids,
from_key=from_key, from_key=from_key,
to_key=to_key, to_key=to_key,
limit=limit, limit=limit or 10,
room_ids=room_ids,
is_guest=is_guest,
) )
events = list(room_events)
events.extend(e for evs, _ in room_to_events.values() for e in evs)
events.sort(key=lambda e: e.internal_metadata.order)
if limit:
events[:] = events[:limit]
if events:
end_key = events[-1].internal_metadata.after
else:
end_key = to_key
defer.returnValue((events, end_key)) defer.returnValue((events, end_key))
def get_current_key(self, direction='f'): def get_current_key(self, direction='f'):

View File

@ -18,11 +18,14 @@ from ._base import BaseHandler
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.logcontext import LoggingContext, preserve_fn
from synapse.util.metrics import Measure
from twisted.internet import defer from twisted.internet import defer
import collections import collections
import logging import logging
import itertools
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -139,6 +142,15 @@ class SyncHandler(BaseHandler):
A Deferred SyncResult. A Deferred SyncResult.
""" """
context = LoggingContext.current_context()
if context:
if since_token is None:
context.tag = "initial_sync"
elif full_state:
context.tag = "full_state_sync"
else:
context.tag = "incremental_sync"
if timeout == 0 or since_token is None or full_state: if timeout == 0 or since_token is None or full_state:
# we are going to return immediately, so don't bother calling # we are going to return immediately, so don't bother calling
# notifier.wait_for_events. # notifier.wait_for_events.
@ -167,18 +179,6 @@ class SyncHandler(BaseHandler):
else: else:
return self.incremental_sync_with_gap(sync_config, since_token) return self.incremental_sync_with_gap(sync_config, since_token)
def last_read_event_id_for_room_and_user(self, room_id, user_id, ephemeral_by_room):
if room_id not in ephemeral_by_room:
return None
for e in ephemeral_by_room[room_id]:
if e['type'] != 'm.receipt':
continue
for receipt_event_id, val in e['content'].items():
if 'm.read' in val:
if user_id in val['m.read']:
return receipt_event_id
return None
@defer.inlineCallbacks @defer.inlineCallbacks
def full_state_sync(self, sync_config, timeline_since_token): def full_state_sync(self, sync_config, timeline_since_token):
"""Get a sync for a client which is starting without any state. """Get a sync for a client which is starting without any state.
@ -228,44 +228,51 @@ class SyncHandler(BaseHandler):
invited = [] invited = []
archived = [] archived = []
deferreds = [] deferreds = []
for event in room_list:
if event.membership == Membership.JOIN:
room_sync_deferred = self.full_state_sync_for_joined_room(
room_id=event.room_id,
sync_config=sync_config,
now_token=now_token,
timeline_since_token=timeline_since_token,
ephemeral_by_room=ephemeral_by_room,
tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
)
room_sync_deferred.addCallback(joined.append)
deferreds.append(room_sync_deferred)
elif event.membership == Membership.INVITE:
invite = yield self.store.get_event(event.event_id)
invited.append(InvitedSyncResult(
room_id=event.room_id,
invite=invite,
))
elif event.membership in (Membership.LEAVE, Membership.BAN):
leave_token = now_token.copy_and_replace(
"room_key", "s%d" % (event.stream_ordering,)
)
room_sync_deferred = self.full_state_sync_for_archived_room(
sync_config=sync_config,
room_id=event.room_id,
leave_event_id=event.event_id,
leave_token=leave_token,
timeline_since_token=timeline_since_token,
tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
)
room_sync_deferred.addCallback(archived.append)
deferreds.append(room_sync_deferred)
yield defer.gatherResults( room_list_chunks = [room_list[i:i + 10] for i in xrange(0, len(room_list), 10)]
deferreds, consumeErrors=True for room_list_chunk in room_list_chunks:
).addErrback(unwrapFirstError) for event in room_list_chunk:
if event.membership == Membership.JOIN:
room_sync_deferred = preserve_fn(
self.full_state_sync_for_joined_room
)(
room_id=event.room_id,
sync_config=sync_config,
now_token=now_token,
timeline_since_token=timeline_since_token,
ephemeral_by_room=ephemeral_by_room,
tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
)
room_sync_deferred.addCallback(joined.append)
deferreds.append(room_sync_deferred)
elif event.membership == Membership.INVITE:
invite = yield self.store.get_event(event.event_id)
invited.append(InvitedSyncResult(
room_id=event.room_id,
invite=invite,
))
elif event.membership in (Membership.LEAVE, Membership.BAN):
leave_token = now_token.copy_and_replace(
"room_key", "s%d" % (event.stream_ordering,)
)
room_sync_deferred = preserve_fn(
self.full_state_sync_for_archived_room
)(
sync_config=sync_config,
room_id=event.room_id,
leave_event_id=event.event_id,
leave_token=leave_token,
timeline_since_token=timeline_since_token,
tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
)
room_sync_deferred.addCallback(archived.append)
deferreds.append(room_sync_deferred)
yield defer.gatherResults(
deferreds, consumeErrors=True
).addErrback(unwrapFirstError)
account_data_for_user = sync_config.filter_collection.filter_account_data( account_data_for_user = sync_config.filter_collection.filter_account_data(
self.account_data_for_user(account_data) self.account_data_for_user(account_data)
@ -305,7 +312,6 @@ class SyncHandler(BaseHandler):
ephemeral_by_room=ephemeral_by_room, ephemeral_by_room=ephemeral_by_room,
tags_by_room=tags_by_room, tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room, account_data_by_room=account_data_by_room,
all_ephemeral_by_room=ephemeral_by_room,
batch=batch, batch=batch,
full_state=True, full_state=True,
) )
@ -355,50 +361,51 @@ class SyncHandler(BaseHandler):
typing events for that room. typing events for that room.
""" """
typing_key = since_token.typing_key if since_token else "0" with Measure(self.clock, "ephemeral_by_room"):
typing_key = since_token.typing_key if since_token else "0"
rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string()) rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
room_ids = [room.room_id for room in rooms] room_ids = [room.room_id for room in rooms]
typing_source = self.event_sources.sources["typing"] typing_source = self.event_sources.sources["typing"]
typing, typing_key = yield typing_source.get_new_events( typing, typing_key = yield typing_source.get_new_events(
user=sync_config.user, user=sync_config.user,
from_key=typing_key, from_key=typing_key,
limit=sync_config.filter_collection.ephemeral_limit(), limit=sync_config.filter_collection.ephemeral_limit(),
room_ids=room_ids, room_ids=room_ids,
is_guest=sync_config.is_guest, is_guest=sync_config.is_guest,
) )
now_token = now_token.copy_and_replace("typing_key", typing_key) now_token = now_token.copy_and_replace("typing_key", typing_key)
ephemeral_by_room = {} ephemeral_by_room = {}
for event in typing: for event in typing:
# we want to exclude the room_id from the event, but modifying the # we want to exclude the room_id from the event, but modifying the
# result returned by the event source is poor form (it might cache # result returned by the event source is poor form (it might cache
# the object) # the object)
room_id = event["room_id"] room_id = event["room_id"]
event_copy = {k: v for (k, v) in event.iteritems() event_copy = {k: v for (k, v) in event.iteritems()
if k != "room_id"} if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy) ephemeral_by_room.setdefault(room_id, []).append(event_copy)
receipt_key = since_token.receipt_key if since_token else "0" receipt_key = since_token.receipt_key if since_token else "0"
receipt_source = self.event_sources.sources["receipt"] receipt_source = self.event_sources.sources["receipt"]
receipts, receipt_key = yield receipt_source.get_new_events( receipts, receipt_key = yield receipt_source.get_new_events(
user=sync_config.user, user=sync_config.user,
from_key=receipt_key, from_key=receipt_key,
limit=sync_config.filter_collection.ephemeral_limit(), limit=sync_config.filter_collection.ephemeral_limit(),
room_ids=room_ids, room_ids=room_ids,
is_guest=sync_config.is_guest, is_guest=sync_config.is_guest,
) )
now_token = now_token.copy_and_replace("receipt_key", receipt_key) now_token = now_token.copy_and_replace("receipt_key", receipt_key)
for event in receipts: for event in receipts:
room_id = event["room_id"] room_id = event["room_id"]
# exclude room id, as above # exclude room id, as above
event_copy = {k: v for (k, v) in event.iteritems() event_copy = {k: v for (k, v) in event.iteritems()
if k != "room_id"} if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy) ephemeral_by_room.setdefault(room_id, []).append(event_copy)
defer.returnValue((now_token, ephemeral_by_room)) defer.returnValue((now_token, ephemeral_by_room))
@ -438,13 +445,6 @@ class SyncHandler(BaseHandler):
) )
now_token = now_token.copy_and_replace("presence_key", presence_key) now_token = now_token.copy_and_replace("presence_key", presence_key)
# We now fetch all ephemeral events for this room in order to get
# this users current read receipt. This could almost certainly be
# optimised.
_, all_ephemeral_by_room = yield self.ephemeral_by_room(
sync_config, now_token
)
now_token, ephemeral_by_room = yield self.ephemeral_by_room( now_token, ephemeral_by_room = yield self.ephemeral_by_room(
sync_config, now_token, since_token sync_config, now_token, since_token
) )
@ -478,7 +478,7 @@ class SyncHandler(BaseHandler):
) )
# Get a list of membership change events that have happened. # Get a list of membership change events that have happened.
rooms_changed = yield self.store.get_room_changes_for_user( rooms_changed = yield self.store.get_membership_changes_for_user(
user_id, since_token.room_key, now_token.room_key user_id, since_token.room_key, now_token.room_key
) )
@ -576,7 +576,6 @@ class SyncHandler(BaseHandler):
ephemeral_by_room=ephemeral_by_room, ephemeral_by_room=ephemeral_by_room,
tags_by_room=tags_by_room, tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room, account_data_by_room=account_data_by_room,
all_ephemeral_by_room=all_ephemeral_by_room,
batch=batch, batch=batch,
full_state=full_state, full_state=full_state,
) )
@ -606,58 +605,64 @@ class SyncHandler(BaseHandler):
""" """
:returns a Deferred TimelineBatch :returns a Deferred TimelineBatch
""" """
filtering_factor = 2 with Measure(self.clock, "load_filtered_recents"):
timeline_limit = sync_config.filter_collection.timeline_limit() filtering_factor = 2
load_limit = max(timeline_limit * filtering_factor, 10) timeline_limit = sync_config.filter_collection.timeline_limit()
max_repeat = 5 # Only try a few times per room, otherwise load_limit = max(timeline_limit * filtering_factor, 10)
room_key = now_token.room_key max_repeat = 5 # Only try a few times per room, otherwise
end_key = room_key room_key = now_token.room_key
end_key = room_key
limited = recents is None or newly_joined_room or timeline_limit < len(recents) if recents is None or newly_joined_room or timeline_limit < len(recents):
limited = True
if recents is not None: else:
recents = sync_config.filter_collection.filter_room_timeline(recents)
recents = yield self._filter_events_for_client(
sync_config.user.to_string(),
recents,
is_peeking=sync_config.is_guest,
)
else:
recents = []
since_key = None
if since_token and not newly_joined_room:
since_key = since_token.room_key
while limited and len(recents) < timeline_limit and max_repeat:
events, end_key = yield self.store.get_room_events_stream_for_room(
room_id,
limit=load_limit + 1,
from_key=since_key,
to_key=end_key,
)
loaded_recents = sync_config.filter_collection.filter_room_timeline(events)
loaded_recents = yield self._filter_events_for_client(
sync_config.user.to_string(),
loaded_recents,
is_peeking=sync_config.is_guest,
)
loaded_recents.extend(recents)
recents = loaded_recents
if len(events) <= load_limit:
limited = False limited = False
break
max_repeat -= 1
if len(recents) > timeline_limit: if recents is not None:
limited = True recents = sync_config.filter_collection.filter_room_timeline(recents)
recents = recents[-timeline_limit:] recents = yield self._filter_events_for_client(
room_key = recents[0].internal_metadata.before sync_config.user.to_string(),
recents,
is_peeking=sync_config.is_guest,
)
else:
recents = []
prev_batch_token = now_token.copy_and_replace( since_key = None
"room_key", room_key if since_token and not newly_joined_room:
) since_key = since_token.room_key
while limited and len(recents) < timeline_limit and max_repeat:
events, end_key = yield self.store.get_room_events_stream_for_room(
room_id,
limit=load_limit + 1,
from_key=since_key,
to_key=end_key,
)
loaded_recents = sync_config.filter_collection.filter_room_timeline(
events
)
loaded_recents = yield self._filter_events_for_client(
sync_config.user.to_string(),
loaded_recents,
is_peeking=sync_config.is_guest,
)
loaded_recents.extend(recents)
recents = loaded_recents
if len(events) <= load_limit:
limited = False
break
max_repeat -= 1
if len(recents) > timeline_limit:
limited = True
recents = recents[-timeline_limit:]
room_key = recents[0].internal_metadata.before
prev_batch_token = now_token.copy_and_replace(
"room_key", room_key
)
defer.returnValue(TimelineBatch( defer.returnValue(TimelineBatch(
events=recents, events=recents,
@ -670,37 +675,11 @@ class SyncHandler(BaseHandler):
since_token, now_token, since_token, now_token,
ephemeral_by_room, tags_by_room, ephemeral_by_room, tags_by_room,
account_data_by_room, account_data_by_room,
all_ephemeral_by_room,
batch, full_state=False): batch, full_state=False):
if full_state: state = yield self.compute_state_delta(
state = yield self.get_state_at(room_id, now_token) room_id, batch, sync_config, since_token, now_token,
full_state=full_state
elif batch.limited: )
current_state = yield self.get_state_at(room_id, now_token)
state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token
)
state = yield self.compute_state_delta(
since_token=since_token,
previous_state=state_at_previous_sync,
current_state=current_state,
)
else:
state = {
(event.type, event.state_key): event
for event in batch.events if event.is_state()
}
just_joined = yield self.check_joined_room(sync_config, state)
if just_joined:
state = yield self.get_state_at(room_id, now_token)
state = {
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(state.values())
}
account_data = self.account_data_for_room( account_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room room_id, tags_by_room, account_data_by_room
@ -726,14 +705,12 @@ class SyncHandler(BaseHandler):
if room_sync: if room_sync:
notifs = yield self.unread_notifs_for_room_id( notifs = yield self.unread_notifs_for_room_id(
room_id, sync_config, all_ephemeral_by_room room_id, sync_config
) )
if notifs is not None: if notifs is not None:
unread_notifications["notification_count"] = len(notifs) unread_notifications["notification_count"] = notifs["notify_count"]
unread_notifications["highlight_count"] = len([ unread_notifications["highlight_count"] = notifs["highlight_count"]
1 for notif in notifs if _action_has_highlight(notif["actions"])
])
logger.debug("Room sync: %r", room_sync) logger.debug("Room sync: %r", room_sync)
@ -766,30 +743,11 @@ class SyncHandler(BaseHandler):
logger.debug("Recents %r", batch) logger.debug("Recents %r", batch)
state_events_at_leave = yield self.store.get_state_for_event( state_events_delta = yield self.compute_state_delta(
leave_event_id room_id, batch, sync_config, since_token, leave_token,
full_state=full_state
) )
if not full_state:
state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token
)
state_events_delta = yield self.compute_state_delta(
since_token=since_token,
previous_state=state_at_previous_sync,
current_state=state_events_at_leave,
)
else:
state_events_delta = state_events_at_leave
state_events_delta = {
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(
state_events_delta.values()
)
}
account_data = self.account_data_for_room( account_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room room_id, tags_by_room, account_data_by_room
) )
@ -843,15 +801,19 @@ class SyncHandler(BaseHandler):
state = {} state = {}
defer.returnValue(state) defer.returnValue(state)
def compute_state_delta(self, since_token, previous_state, current_state): @defer.inlineCallbacks
""" Works out the differnce in state between the current state and the def compute_state_delta(self, room_id, batch, sync_config, since_token, now_token,
state the client got when it last performed a sync. full_state):
""" Works out the differnce in state between the start of the timeline
and the previous sync.
:param str since_token: the point we are comparing against :param str room_id
:param dict[(str,str), synapse.events.FrozenEvent] previous_state: the :param TimelineBatch batch: The timeline batch for the room that will
state to compare to be sent to the user.
:param dict[(str,str), synapse.events.FrozenEvent] current_state: the :param sync_config
new state :param str since_token: Token of the end of the previous batch. May be None.
:param str now_token: Token of the end of the current batch.
:param bool full_state: Whether to force returning the full state.
:returns A new event dictionary :returns A new event dictionary
""" """
@ -860,12 +822,53 @@ class SyncHandler(BaseHandler):
# updates even if they occured logically before the previous event. # updates even if they occured logically before the previous event.
# TODO(mjark) Check for new redactions in the state events. # TODO(mjark) Check for new redactions in the state events.
state_delta = {} with Measure(self.clock, "compute_state_delta"):
for key, event in current_state.iteritems(): if full_state:
if (key not in previous_state or if batch:
previous_state[key].event_id != event.event_id): state = yield self.store.get_state_for_event(
state_delta[key] = event batch.events[0].event_id
return state_delta )
else:
state = yield self.get_state_at(
room_id, stream_position=now_token
)
timeline_state = {
(event.type, event.state_key): event
for event in batch.events if event.is_state()
}
state = _calculate_state(
timeline_contains=timeline_state,
timeline_start=state,
previous={},
)
elif batch.limited:
state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token
)
state_at_timeline_start = yield self.store.get_state_for_event(
batch.events[0].event_id
)
timeline_state = {
(event.type, event.state_key): event
for event in batch.events if event.is_state()
}
state = _calculate_state(
timeline_contains=timeline_state,
timeline_start=state_at_timeline_start,
previous=state_at_previous_sync,
)
else:
state = {}
defer.returnValue({
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(state.values())
})
def check_joined_room(self, sync_config, state_delta): def check_joined_room(self, sync_config, state_delta):
""" """
@ -886,21 +889,24 @@ class SyncHandler(BaseHandler):
return False return False
@defer.inlineCallbacks @defer.inlineCallbacks
def unread_notifs_for_room_id(self, room_id, sync_config, ephemeral_by_room): def unread_notifs_for_room_id(self, room_id, sync_config):
last_unread_event_id = self.last_read_event_id_for_room_and_user( with Measure(self.clock, "unread_notifs_for_room_id"):
room_id, sync_config.user.to_string(), ephemeral_by_room last_unread_event_id = yield self.store.get_last_receipt_event_id_for_user(
) user_id=sync_config.user.to_string(),
room_id=room_id,
notifs = [] receipt_type="m.read"
if last_unread_event_id:
notifs = yield self.store.get_unread_event_push_actions_by_room_for_user(
room_id, sync_config.user.to_string(), last_unread_event_id
) )
defer.returnValue(notifs)
# There is no new information in this period, so your notification notifs = []
# count is whatever it was last time. if last_unread_event_id:
defer.returnValue(None) notifs = yield self.store.get_unread_event_push_actions_by_room_for_user(
room_id, sync_config.user.to_string(), last_unread_event_id
)
defer.returnValue(notifs)
# There is no new information in this period, so your notification
# count is whatever it was last time.
defer.returnValue(None)
def _action_has_highlight(actions): def _action_has_highlight(actions):
@ -912,3 +918,37 @@ def _action_has_highlight(actions):
pass pass
return False return False
def _calculate_state(timeline_contains, timeline_start, previous):
"""Works out what state to include in a sync response.
Args:
timeline_contains (dict): state in the timeline
timeline_start (dict): state at the start of the timeline
previous (dict): state at the end of the previous sync (or empty dict
if this is an initial sync)
Returns:
dict
"""
event_id_to_state = {
e.event_id: e
for e in itertools.chain(
timeline_contains.values(),
previous.values(),
timeline_start.values(),
)
}
tc_ids = set(e.event_id for e in timeline_contains.values())
p_ids = set(e.event_id for e in previous.values())
ts_ids = set(e.event_id for e in timeline_start.values())
state_ids = (ts_ids - p_ids) - tc_ids
evs = (event_id_to_state[e] for e in state_ids)
return {
(e.type, e.state_key): e
for e in evs
}

View File

@ -19,6 +19,7 @@ from ._base import BaseHandler
from synapse.api.errors import SynapseError, AuthError from synapse.api.errors import SynapseError, AuthError
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.metrics import Measure
from synapse.types import UserID from synapse.types import UserID
import logging import logging
@ -222,6 +223,7 @@ class TypingNotificationHandler(BaseHandler):
class TypingNotificationEventSource(object): class TypingNotificationEventSource(object):
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
self.clock = hs.get_clock()
self._handler = None self._handler = None
self._room_member_handler = None self._room_member_handler = None
@ -247,19 +249,20 @@ class TypingNotificationEventSource(object):
} }
def get_new_events(self, from_key, room_ids, **kwargs): def get_new_events(self, from_key, room_ids, **kwargs):
from_key = int(from_key) with Measure(self.clock, "typing.get_new_events"):
handler = self.handler() from_key = int(from_key)
handler = self.handler()
events = [] events = []
for room_id in room_ids: for room_id in room_ids:
if room_id not in handler._room_serials: if room_id not in handler._room_serials:
continue continue
if handler._room_serials[room_id] <= from_key: if handler._room_serials[room_id] <= from_key:
continue continue
events.append(self._make_event_for(room_id)) events.append(self._make_event_for(room_id))
return events, handler._latest_room_serial return events, handler._latest_room_serial
def get_current_key(self): def get_current_key(self):
return self.handler()._latest_room_serial return self.handler()._latest_room_serial

View File

@ -152,7 +152,7 @@ class MatrixFederationHttpClient(object):
return self.clock.time_bound_deferred( return self.clock.time_bound_deferred(
request_deferred, request_deferred,
time_out=timeout/1000. if timeout else 60, time_out=timeout / 1000. if timeout else 60,
) )
response = yield preserve_context_over_fn( response = yield preserve_context_over_fn(

View File

@ -41,7 +41,7 @@ metrics = synapse.metrics.get_metrics_for(__name__)
incoming_requests_counter = metrics.register_counter( incoming_requests_counter = metrics.register_counter(
"requests", "requests",
labels=["method", "servlet"], labels=["method", "servlet", "tag"],
) )
outgoing_responses_counter = metrics.register_counter( outgoing_responses_counter = metrics.register_counter(
"responses", "responses",
@ -50,23 +50,23 @@ outgoing_responses_counter = metrics.register_counter(
response_timer = metrics.register_distribution( response_timer = metrics.register_distribution(
"response_time", "response_time",
labels=["method", "servlet"] labels=["method", "servlet", "tag"]
) )
response_ru_utime = metrics.register_distribution( response_ru_utime = metrics.register_distribution(
"response_ru_utime", labels=["method", "servlet"] "response_ru_utime", labels=["method", "servlet", "tag"]
) )
response_ru_stime = metrics.register_distribution( response_ru_stime = metrics.register_distribution(
"response_ru_stime", labels=["method", "servlet"] "response_ru_stime", labels=["method", "servlet", "tag"]
) )
response_db_txn_count = metrics.register_distribution( response_db_txn_count = metrics.register_distribution(
"response_db_txn_count", labels=["method", "servlet"] "response_db_txn_count", labels=["method", "servlet", "tag"]
) )
response_db_txn_duration = metrics.register_distribution( response_db_txn_duration = metrics.register_distribution(
"response_db_txn_duration", labels=["method", "servlet"] "response_db_txn_duration", labels=["method", "servlet", "tag"]
) )
@ -99,9 +99,8 @@ def request_handler(request_handler):
request_context.request = request_id request_context.request = request_id
with request.processing(): with request.processing():
try: try:
d = request_handler(self, request) with PreserveLoggingContext(request_context):
with PreserveLoggingContext(): yield request_handler(self, request)
yield d
except CodeMessageException as e: except CodeMessageException as e:
code = e.code code = e.code
if isinstance(e, SynapseError): if isinstance(e, SynapseError):
@ -208,6 +207,9 @@ class JsonResource(HttpServer, resource.Resource):
if request.method == "OPTIONS": if request.method == "OPTIONS":
self._send_response(request, 200, {}) self._send_response(request, 200, {})
return return
start_context = LoggingContext.current_context()
# Loop through all the registered callbacks to check if the method # Loop through all the registered callbacks to check if the method
# and path regex match # and path regex match
for path_entry in self.path_regexs.get(request.method, []): for path_entry in self.path_regexs.get(request.method, []):
@ -226,7 +228,6 @@ class JsonResource(HttpServer, resource.Resource):
servlet_classname = servlet_instance.__class__.__name__ servlet_classname = servlet_instance.__class__.__name__
else: else:
servlet_classname = "%r" % callback servlet_classname = "%r" % callback
incoming_requests_counter.inc(request.method, servlet_classname)
args = [ args = [
urllib.unquote(u).decode("UTF-8") if u else u for u in m.groups() urllib.unquote(u).decode("UTF-8") if u else u for u in m.groups()
@ -237,21 +238,40 @@ class JsonResource(HttpServer, resource.Resource):
code, response = callback_return code, response = callback_return
self._send_response(request, code, response) self._send_response(request, code, response)
response_timer.inc_by(
self.clock.time_msec() - start, request.method, servlet_classname
)
try: try:
context = LoggingContext.current_context() context = LoggingContext.current_context()
tag = ""
if context:
tag = context.tag
if context != start_context:
logger.warn(
"Context have unexpectedly changed %r, %r",
context, self.start_context
)
return
incoming_requests_counter.inc(request.method, servlet_classname, tag)
response_timer.inc_by(
self.clock.time_msec() - start, request.method,
servlet_classname, tag
)
ru_utime, ru_stime = context.get_resource_usage() ru_utime, ru_stime = context.get_resource_usage()
response_ru_utime.inc_by(ru_utime, request.method, servlet_classname) response_ru_utime.inc_by(
response_ru_stime.inc_by(ru_stime, request.method, servlet_classname) ru_utime, request.method, servlet_classname, tag
)
response_ru_stime.inc_by(
ru_stime, request.method, servlet_classname, tag
)
response_db_txn_count.inc_by( response_db_txn_count.inc_by(
context.db_txn_count, request.method, servlet_classname context.db_txn_count, request.method, servlet_classname, tag
) )
response_db_txn_duration.inc_by( response_db_txn_duration.inc_by(
context.db_txn_duration, request.method, servlet_classname context.db_txn_duration, request.method, servlet_classname, tag
) )
except: except:
pass pass

View File

@ -18,10 +18,13 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor, ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import StreamToken from synapse.types import StreamToken
import synapse.metrics import synapse.metrics
from collections import namedtuple
import logging import logging
@ -71,7 +74,8 @@ class _NotifierUserStream(object):
self.current_token = current_token self.current_token = current_token
self.last_notified_ms = time_now_ms self.last_notified_ms = time_now_ms
self.notify_deferred = ObservableDeferred(defer.Deferred()) with PreserveLoggingContext():
self.notify_deferred = ObservableDeferred(defer.Deferred())
def notify(self, stream_key, stream_id, time_now_ms): def notify(self, stream_key, stream_id, time_now_ms):
"""Notify any listeners for this user of a new event from an """Notify any listeners for this user of a new event from an
@ -86,8 +90,10 @@ class _NotifierUserStream(object):
) )
self.last_notified_ms = time_now_ms self.last_notified_ms = time_now_ms
noify_deferred = self.notify_deferred noify_deferred = self.notify_deferred
self.notify_deferred = ObservableDeferred(defer.Deferred())
noify_deferred.callback(self.current_token) with PreserveLoggingContext():
self.notify_deferred = ObservableDeferred(defer.Deferred())
noify_deferred.callback(self.current_token)
def remove(self, notifier): def remove(self, notifier):
""" Remove this listener from all the indexes in the Notifier """ Remove this listener from all the indexes in the Notifier
@ -118,6 +124,11 @@ class _NotifierUserStream(object):
return _NotificationListener(self.notify_deferred.observe()) return _NotificationListener(self.notify_deferred.observe())
class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))):
def __nonzero__(self):
return bool(self.events)
class Notifier(object): class Notifier(object):
""" This class is responsible for notifying any listeners when there are """ This class is responsible for notifying any listeners when there are
new events available for it. new events available for it.
@ -177,8 +188,6 @@ class Notifier(object):
lambda: count(bool, self.appservice_to_user_streams.values()), lambda: count(bool, self.appservice_to_user_streams.values()),
) )
@log_function
@defer.inlineCallbacks
def on_new_room_event(self, event, room_stream_id, max_room_stream_id, def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
extra_users=[]): extra_users=[]):
""" Used by handlers to inform the notifier something has happened """ Used by handlers to inform the notifier something has happened
@ -192,12 +201,11 @@ class Notifier(object):
until all previous events have been persisted before notifying until all previous events have been persisted before notifying
the client streams. the client streams.
""" """
yield run_on_reactor() with PreserveLoggingContext():
self.pending_new_room_events.append((
self.pending_new_room_events.append(( room_stream_id, event, extra_users
room_stream_id, event, extra_users ))
)) self._notify_pending_new_room_events(max_room_stream_id)
self._notify_pending_new_room_events(max_room_stream_id)
def _notify_pending_new_room_events(self, max_room_stream_id): def _notify_pending_new_room_events(self, max_room_stream_id):
"""Notify for the room events that were queued waiting for a previous """Notify for the room events that were queued waiting for a previous
@ -244,31 +252,29 @@ class Notifier(object):
extra_streams=app_streams, extra_streams=app_streams,
) )
@defer.inlineCallbacks
@log_function
def on_new_event(self, stream_key, new_token, users=[], rooms=[], def on_new_event(self, stream_key, new_token, users=[], rooms=[],
extra_streams=set()): extra_streams=set()):
""" Used to inform listeners that something has happend event wise. """ Used to inform listeners that something has happend event wise.
Will wake up all listeners for the given users and rooms. Will wake up all listeners for the given users and rooms.
""" """
yield run_on_reactor() with PreserveLoggingContext():
user_streams = set() user_streams = set()
for user in users: for user in users:
user_stream = self.user_to_user_stream.get(str(user)) user_stream = self.user_to_user_stream.get(str(user))
if user_stream is not None: if user_stream is not None:
user_streams.add(user_stream) user_streams.add(user_stream)
for room in rooms: for room in rooms:
user_streams |= self.room_to_user_streams.get(room, set()) user_streams |= self.room_to_user_streams.get(room, set())
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
for user_stream in user_streams: for user_stream in user_streams:
try: try:
user_stream.notify(stream_key, new_token, time_now_ms) user_stream.notify(stream_key, new_token, time_now_ms)
except: except:
logger.exception("Failed to notify listener") logger.exception("Failed to notify listener")
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_events(self, user_id, timeout, callback, room_ids=None, def wait_for_events(self, user_id, timeout, callback, room_ids=None,
@ -301,7 +307,7 @@ class Notifier(object):
def timed_out(): def timed_out():
if listener: if listener:
listener.deferred.cancel() listener.deferred.cancel()
timer = self.clock.call_later(timeout/1000., timed_out) timer = self.clock.call_later(timeout / 1000., timed_out)
prev_token = from_token prev_token = from_token
while not result: while not result:
@ -318,7 +324,8 @@ class Notifier(object):
# that we don't miss any current_token updates. # that we don't miss any current_token updates.
prev_token = current_token prev_token = current_token
listener = user_stream.new_listener(prev_token) listener = user_stream.new_listener(prev_token)
yield listener.deferred with PreserveLoggingContext():
yield listener.deferred
except defer.CancelledError: except defer.CancelledError:
break break
@ -356,7 +363,7 @@ class Notifier(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def check_for_updates(before_token, after_token): def check_for_updates(before_token, after_token):
if not after_token.is_after(before_token): if not after_token.is_after(before_token):
defer.returnValue(None) defer.returnValue(EventStreamResult([], (from_token, from_token)))
events = [] events = []
end_token = from_token end_token = from_token
@ -369,6 +376,7 @@ class Notifier(object):
continue continue
if only_keys and name not in only_keys: if only_keys and name not in only_keys:
continue continue
new_events, new_key = yield source.get_new_events( new_events, new_key = yield source.get_new_events(
user=user, user=user,
from_key=getattr(from_token, keyname), from_key=getattr(from_token, keyname),
@ -388,10 +396,7 @@ class Notifier(object):
events.extend(new_events) events.extend(new_events)
end_token = end_token.copy_and_replace(keyname, new_key) end_token = end_token.copy_and_replace(keyname, new_key)
if events: defer.returnValue(EventStreamResult(events, (from_token, end_token)))
defer.returnValue((events, (from_token, end_token)))
else:
defer.returnValue(None)
user_id_for_stream = user.to_string() user_id_for_stream = user.to_string()
if is_peeking: if is_peeking:
@ -415,9 +420,6 @@ class Notifier(object):
from_token=from_token, from_token=from_token,
) )
if result is None:
result = ([], (from_token, from_token))
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -17,6 +17,8 @@ from twisted.internet import defer
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import StreamToken from synapse.types import StreamToken
from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure
import synapse.util.async import synapse.util.async
import push_rule_evaluator as push_rule_evaluator import push_rule_evaluator as push_rule_evaluator
@ -27,6 +29,16 @@ import random
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_NEXT_ID = 1
def _get_next_id():
global _NEXT_ID
_id = _NEXT_ID
_NEXT_ID += 1
return _id
# Pushers could now be moved to pull out of the event_push_actions table instead # Pushers could now be moved to pull out of the event_push_actions table instead
# of listening on the event stream: this would avoid them having to run the # of listening on the event stream: this would avoid them having to run the
# rules again. # rules again.
@ -57,6 +69,8 @@ class Pusher(object):
self.alive = True self.alive = True
self.badge = None self.badge = None
self.name = "Pusher-%d" % (_get_next_id(),)
# The last value of last_active_time that we saw # The last value of last_active_time that we saw
self.last_last_active_time = 0 self.last_last_active_time = 0
self.has_unread = True self.has_unread = True
@ -86,38 +100,46 @@ class Pusher(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def start(self): def start(self):
if not self.last_token: with LoggingContext(self.name):
# First-time setup: get a token to start from (we can't if not self.last_token:
# just start from no token, ie. 'now' # First-time setup: get a token to start from (we can't
# because we need the result to be reproduceable in case # just start from no token, ie. 'now'
# we fail to dispatch the push) # because we need the result to be reproduceable in case
config = PaginationConfig(from_token=None, limit='1') # we fail to dispatch the push)
chunk = yield self.evStreamHandler.get_stream( config = PaginationConfig(from_token=None, limit='1')
self.user_id, config, timeout=0, affect_presence=False chunk = yield self.evStreamHandler.get_stream(
) self.user_id, config, timeout=0, affect_presence=False
self.last_token = chunk['end']
self.store.update_pusher_last_token(
self.app_id, self.pushkey, self.user_id, self.last_token
)
logger.info("Pusher %s for user %s starting from token %s",
self.pushkey, self.user_id, self.last_token)
wait = 0
while self.alive:
try:
if wait > 0:
yield synapse.util.async.sleep(wait)
yield self.get_and_dispatch()
wait = 0
except:
if wait == 0:
wait = 1
else:
wait = min(wait * 2, 1800)
logger.exception(
"Exception in pusher loop for pushkey %s. Pausing for %ds",
self.pushkey, wait
) )
self.last_token = chunk['end']
yield self.store.update_pusher_last_token(
self.app_id, self.pushkey, self.user_id, self.last_token
)
logger.info("New pusher %s for user %s starting from token %s",
self.pushkey, self.user_id, self.last_token)
else:
logger.info(
"Old pusher %s for user %s starting",
self.pushkey, self.user_id,
)
wait = 0
while self.alive:
try:
if wait > 0:
yield synapse.util.async.sleep(wait)
with Measure(self.clock, "push"):
yield self.get_and_dispatch()
wait = 0
except:
if wait == 0:
wait = 1
else:
wait = min(wait * 2, 1800)
logger.exception(
"Exception in pusher loop for pushkey %s. Pausing for %ds",
self.pushkey, wait
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_and_dispatch(self): def get_and_dispatch(self):
@ -316,7 +338,7 @@ class Pusher(object):
r.room_id, self.user_id, last_unread_event_id r.room_id, self.user_id, last_unread_event_id
) )
) )
badge += len(notifs) badge += notifs["notify_count"]
defer.returnValue(badge) defer.returnValue(badge)

View File

@ -19,8 +19,6 @@ import bulk_push_rule_evaluator
import logging import logging
from synapse.api.constants import EventTypes
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,21 +34,15 @@ class ActionGenerator:
# tag (ie. we just need all the users). # tag (ie. we just need all the users).
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_push_actions_for_event(self, event, handler): def handle_push_actions_for_event(self, event, context, handler):
if event.type == EventTypes.Redaction and event.redacts is not None:
yield self.store.remove_push_actions_for_event_id(
event.room_id, event.redacts
)
bulk_evaluator = yield bulk_push_rule_evaluator.evaluator_for_room_id( bulk_evaluator = yield bulk_push_rule_evaluator.evaluator_for_room_id(
event.room_id, self.hs, self.store event.room_id, self.hs, self.store
) )
actions_by_user = yield bulk_evaluator.action_for_event_by_user(event, handler) actions_by_user = yield bulk_evaluator.action_for_event_by_user(
event, handler, context.current_state
yield self.store.set_push_actions_for_event_and_users(
event,
[
(uid, None, actions) for uid, actions in actions_by_user.items()
]
) )
context.push_actions = [
(uid, None, actions) for uid, actions in actions_by_user.items()
]

View File

@ -98,25 +98,21 @@ class BulkPushRuleEvaluator:
self.store = store self.store = store
@defer.inlineCallbacks @defer.inlineCallbacks
def action_for_event_by_user(self, event, handler): def action_for_event_by_user(self, event, handler, current_state):
actions_by_user = {} actions_by_user = {}
users_dict = yield self.store.are_guests(self.rules_by_user.keys()) users_dict = yield self.store.are_guests(self.rules_by_user.keys())
filtered_by_user = yield handler._filter_events_for_clients( filtered_by_user = yield handler._filter_events_for_clients(
users_dict.items(), [event] users_dict.items(), [event], {event.event_id: current_state}
) )
evaluator = PushRuleEvaluatorForEvent(event, len(self.users_in_room)) evaluator = PushRuleEvaluatorForEvent(event, len(self.users_in_room))
condition_cache = {} condition_cache = {}
member_state = yield self.store.get_state_for_event(
event.event_id,
)
display_names = {} display_names = {}
for ev in member_state.values(): for ev in current_state.values():
nm = ev.content.get("displayname", None) nm = ev.content.get("displayname", None)
if nm and ev.type == EventTypes.Member: if nm and ev.type == EventTypes.Member:
display_names[ev.state_key] = nm display_names[ev.state_key] = nm

View File

@ -304,7 +304,7 @@ def _flatten_dict(d, prefix=[], result={}):
if isinstance(value, basestring): if isinstance(value, basestring):
result[".".join(prefix + [key])] = value.lower() result[".".join(prefix + [key])] = value.lower()
elif hasattr(value, "items"): elif hasattr(value, "items"):
_flatten_dict(value, prefix=(prefix+[key]), result=result) _flatten_dict(value, prefix=(prefix + [key]), result=result)
return result return result

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from httppusher import HttpPusher from httppusher import HttpPusher
from synapse.push import PusherConfigException from synapse.push import PusherConfigException
from synapse.util.logcontext import preserve_fn
import logging import logging
@ -76,7 +77,7 @@ class PusherPool:
"Removing pusher for app id %s, pushkey %s, user %s", "Removing pusher for app id %s, pushkey %s, user %s",
app_id, pushkey, p['user_name'] app_id, pushkey, p['user_name']
) )
self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks @defer.inlineCallbacks
def remove_pushers_by_user(self, user_id): def remove_pushers_by_user(self, user_id):
@ -91,7 +92,7 @@ class PusherPool:
"Removing pusher for app id %s, pushkey %s, user %s", "Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name'] p['app_id'], p['pushkey'], p['user_name']
) )
self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks @defer.inlineCallbacks
def _add_pusher_to_store(self, user_id, access_token, profile_tag, kind, def _add_pusher_to_store(self, user_id, access_token, profile_tag, kind,
@ -110,7 +111,7 @@ class PusherPool:
lang=lang, lang=lang,
data=data, data=data,
) )
self._refresh_pusher(app_id, pushkey, user_id) yield self._refresh_pusher(app_id, pushkey, user_id)
def _create_pusher(self, pusherdict): def _create_pusher(self, pusherdict):
if pusherdict['kind'] == 'http': if pusherdict['kind'] == 'http':
@ -166,7 +167,7 @@ class PusherPool:
if fullid in self.pushers: if fullid in self.pushers:
self.pushers[fullid].stop() self.pushers[fullid].stop()
self.pushers[fullid] = p self.pushers[fullid] = p
p.start() preserve_fn(p.start)()
logger.info("Started pushers") logger.info("Started pushers")

View File

@ -89,7 +89,7 @@ class LoginRestServlet(ClientV1RestServlet):
LoginRestServlet.SAML2_TYPE): LoginRestServlet.SAML2_TYPE):
relay_state = "" relay_state = ""
if "relay_state" in login_submission: if "relay_state" in login_submission:
relay_state = "&RelayState="+urllib.quote( relay_state = "&RelayState=" + urllib.quote(
login_submission["relay_state"]) login_submission["relay_state"])
result = { result = {
"uri": "%s%s" % (self.idp_redirect_url, relay_state) "uri": "%s%s" % (self.idp_redirect_url, relay_state)

View File

@ -33,7 +33,11 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
user, user,
) )
defer.returnValue((200, {"displayname": displayname})) ret = {}
if displayname is not None:
ret["displayname"] = displayname
defer.returnValue((200, ret))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id): def on_PUT(self, request, user_id):
@ -66,7 +70,11 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
user, user,
) )
defer.returnValue((200, {"avatar_url": avatar_url})) ret = {}
if avatar_url is not None:
ret["avatar_url"] = avatar_url
defer.returnValue((200, ret))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id): def on_PUT(self, request, user_id):
@ -102,10 +110,13 @@ class ProfileRestServlet(ClientV1RestServlet):
user, user,
) )
defer.returnValue((200, { ret = {}
"displayname": displayname, if displayname is not None:
"avatar_url": avatar_url ret["displayname"] = displayname
})) if avatar_url is not None:
ret["avatar_url"] = avatar_url
defer.returnValue((200, ret))
def register_servlets(hs, http_server): def register_servlets(hs, http_server):

View File

@ -52,7 +52,7 @@ class PusherRestServlet(ClientV1RestServlet):
if i not in content: if i not in content:
missing.append(i) missing.append(i)
if len(missing): if len(missing):
raise SynapseError(400, "Missing parameters: "+','.join(missing), raise SynapseError(400, "Missing parameters: " + ','.join(missing),
errcode=Codes.MISSING_PARAM) errcode=Codes.MISSING_PARAM)
logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind']) logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind'])
@ -83,7 +83,7 @@ class PusherRestServlet(ClientV1RestServlet):
data=content['data'] data=content['data']
) )
except PusherConfigException as pce: except PusherConfigException as pce:
raise SynapseError(400, "Config Error: "+pce.message, raise SynapseError(400, "Config Error: " + pce.message,
errcode=Codes.MISSING_PARAM) errcode=Codes.MISSING_PARAM)
defer.returnValue((200, {})) defer.returnValue((200, {}))

View File

@ -38,7 +38,8 @@ logger = logging.getLogger(__name__)
if hasattr(hmac, "compare_digest"): if hasattr(hmac, "compare_digest"):
compare_digest = hmac.compare_digest compare_digest = hmac.compare_digest
else: else:
compare_digest = lambda a, b: a == b def compare_digest(a, b):
return a == b
class RegisterRestServlet(ClientV1RestServlet): class RegisterRestServlet(ClientV1RestServlet):
@ -58,7 +59,7 @@ class RegisterRestServlet(ClientV1RestServlet):
# } # }
# TODO: persistent storage # TODO: persistent storage
self.sessions = {} self.sessions = {}
self.disable_registration = hs.config.disable_registration self.enable_registration = hs.config.enable_registration
def on_GET(self, request): def on_GET(self, request):
if self.hs.config.enable_registration_captcha: if self.hs.config.enable_registration_captcha:
@ -112,7 +113,7 @@ class RegisterRestServlet(ClientV1RestServlet):
is_using_shared_secret = login_type == LoginType.SHARED_SECRET is_using_shared_secret = login_type == LoginType.SHARED_SECRET
can_register = ( can_register = (
not self.disable_registration self.enable_registration
or is_application_server or is_application_server
or is_using_shared_secret or is_using_shared_secret
) )

View File

@ -429,8 +429,6 @@ class RoomEventContext(ClientV1RestServlet):
serialize_event(event, time_now) for event in results["state"] serialize_event(event, time_now) for event in results["state"]
] ]
logger.info("Responding with %r", results)
defer.returnValue((200, results)) defer.returnValue((200, results))

View File

@ -116,9 +116,10 @@ class ThreepidRestServlet(RestServlet):
body = parse_json_dict_from_request(request) body = parse_json_dict_from_request(request)
if 'threePidCreds' not in body: threePidCreds = body.get('threePidCreds')
threePidCreds = body.get('three_pid_creds', threePidCreds)
if threePidCreds is None:
raise SynapseError(400, "Missing param", Codes.MISSING_PARAM) raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
threePidCreds = body['threePidCreds']
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()

View File

@ -57,7 +57,7 @@ class AccountDataServlet(RestServlet):
user_id, account_data_type, body user_id, account_data_type, body
) )
yield self.notifier.on_new_event( self.notifier.on_new_event(
"account_data_key", max_id, users=[user_id] "account_data_key", max_id, users=[user_id]
) )
@ -99,7 +99,7 @@ class RoomAccountDataServlet(RestServlet):
user_id, room_id, account_data_type, body user_id, room_id, account_data_type, body
) )
yield self.notifier.on_new_event( self.notifier.on_new_event(
"account_data_key", max_id, users=[user_id] "account_data_key", max_id, users=[user_id]
) )

View File

@ -34,7 +34,8 @@ from synapse.util.async import run_on_reactor
if hasattr(hmac, "compare_digest"): if hasattr(hmac, "compare_digest"):
compare_digest = hmac.compare_digest compare_digest = hmac.compare_digest
else: else:
compare_digest = lambda a, b: a == b def compare_digest(a, b):
return a == b
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -116,7 +117,7 @@ class RegisterRestServlet(RestServlet):
return return
# == Normal User Registration == (everyone else) # == Normal User Registration == (everyone else)
if self.hs.config.disable_registration: if not self.hs.config.enable_registration:
raise SynapseError(403, "Registration has been disabled") raise SynapseError(403, "Registration has been disabled")
guest_access_token = body.get("guest_access_token", None) guest_access_token = body.get("guest_access_token", None)
@ -152,6 +153,7 @@ class RegisterRestServlet(RestServlet):
desired_username = params.get("username", None) desired_username = params.get("username", None)
new_password = params.get("password", None) new_password = params.get("password", None)
guest_access_token = params.get("guest_access_token", None)
(user_id, token) = yield self.registration_handler.register( (user_id, token) = yield self.registration_handler.register(
localpart=desired_username, localpart=desired_username,

View File

@ -20,7 +20,6 @@ from synapse.http.servlet import (
) )
from synapse.handlers.sync import SyncConfig from synapse.handlers.sync import SyncConfig
from synapse.types import StreamToken from synapse.types import StreamToken
from synapse.events import FrozenEvent
from synapse.events.utils import ( from synapse.events.utils import (
serialize_event, format_event_for_client_v2_without_room_id, serialize_event, format_event_for_client_v2_without_room_id,
) )
@ -287,9 +286,6 @@ class SyncRestServlet(RestServlet):
state_dict = room.state state_dict = room.state
timeline_events = room.timeline.events timeline_events = room.timeline.events
state_dict = SyncRestServlet._rollback_state_for_timeline(
state_dict, timeline_events)
state_events = state_dict.values() state_events = state_dict.values()
serialized_state = [serialize(e) for e in state_events] serialized_state = [serialize(e) for e in state_events]
@ -314,77 +310,6 @@ class SyncRestServlet(RestServlet):
return result return result
@staticmethod
def _rollback_state_for_timeline(state, timeline):
"""
Wind the state dictionary backwards, so that it represents the
state at the start of the timeline, rather than at the end.
:param dict[(str, str), synapse.events.EventBase] state: the
state dictionary. Will be updated to the state before the timeline.
:param list[synapse.events.EventBase] timeline: the event timeline
:return: updated state dictionary
"""
result = state.copy()
for timeline_event in reversed(timeline):
if not timeline_event.is_state():
continue
event_key = (timeline_event.type, timeline_event.state_key)
logger.debug("Considering %s for removal", event_key)
state_event = result.get(event_key)
if (state_event is None or
state_event.event_id != timeline_event.event_id):
# the event in the timeline isn't present in the state
# dictionary.
#
# the most likely cause for this is that there was a fork in
# the event graph, and the state is no longer valid. Really,
# the event shouldn't be in the timeline. We're going to ignore
# it for now, however.
logger.debug("Found state event %r in timeline which doesn't "
"match state dictionary", timeline_event)
continue
prev_event_id = timeline_event.unsigned.get("replaces_state", None)
prev_content = timeline_event.unsigned.get('prev_content')
prev_sender = timeline_event.unsigned.get('prev_sender')
# Empircally it seems possible for the event to have a
# "replaces_state" key but not a prev_content or prev_sender
# markjh conjectures that it could be due to the server not
# having a copy of that event.
# If this is the case the we ignore the previous event. This will
# cause the displayname calculations on the client to be incorrect
if prev_event_id is None or not prev_content or not prev_sender:
logger.debug(
"Removing %r from the state dict, as it is missing"
" prev_content (prev_event_id=%r)",
timeline_event.event_id, prev_event_id
)
del result[event_key]
else:
logger.debug(
"Replacing %r with %r in state dict",
timeline_event.event_id, prev_event_id
)
result[event_key] = FrozenEvent({
"type": timeline_event.type,
"state_key": timeline_event.state_key,
"content": prev_content,
"sender": prev_sender,
"event_id": prev_event_id,
"room_id": timeline_event.room_id,
})
logger.debug("New value: %r", result.get(event_key))
return result
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
SyncRestServlet(hs).register(http_server) SyncRestServlet(hs).register(http_server)

View File

@ -80,7 +80,7 @@ class TagServlet(RestServlet):
max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body) max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body)
yield self.notifier.on_new_event( self.notifier.on_new_event(
"account_data_key", max_id, users=[user_id] "account_data_key", max_id, users=[user_id]
) )
@ -94,7 +94,7 @@ class TagServlet(RestServlet):
max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag) max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag)
yield self.notifier.on_new_event( self.notifier.on_new_event(
"account_data_key", max_id, users=[user_id] "account_data_key", max_id, users=[user_id]
) )

View File

@ -26,9 +26,7 @@ class VersionsRestServlet(RestServlet):
def on_GET(self, request): def on_GET(self, request):
return (200, { return (200, {
"versions": [ "versions": ["r0.0.1"]
"r0.0.1",
]
}) })

View File

@ -28,6 +28,7 @@ from twisted.protocols.basic import FileSender
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.stringutils import is_ascii from synapse.util.stringutils import is_ascii
from synapse.util.logcontext import preserve_context_over_fn
import os import os
@ -276,7 +277,8 @@ class BaseMediaResource(Resource):
) )
self._makedirs(t_path) self._makedirs(t_path)
t_len = yield threads.deferToThread( t_len = yield preserve_context_over_fn(
threads.deferToThread,
self._generate_thumbnail, self._generate_thumbnail,
input_path, t_path, t_width, t_height, t_method, t_type input_path, t_path, t_width, t_height, t_method, t_type
) )
@ -298,7 +300,8 @@ class BaseMediaResource(Resource):
) )
self._makedirs(t_path) self._makedirs(t_path)
t_len = yield threads.deferToThread( t_len = yield preserve_context_over_fn(
threads.deferToThread,
self._generate_thumbnail, self._generate_thumbnail,
input_path, t_path, t_width, t_height, t_method, t_type input_path, t_path, t_width, t_height, t_method, t_type
) )
@ -372,7 +375,7 @@ class BaseMediaResource(Resource):
media_id, t_width, t_height, t_type, t_method, t_len media_id, t_width, t_height, t_type, t_method, t_len
)) ))
yield threads.deferToThread(generate_thumbnails) yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
for l in local_thumbnails: for l in local_thumbnails:
yield self.store.store_local_thumbnail(*l) yield self.store.store_local_thumbnail(*l)
@ -445,7 +448,7 @@ class BaseMediaResource(Resource):
t_width, t_height, t_type, t_method, t_len t_width, t_height, t_type, t_method, t_len
]) ])
yield threads.deferToThread(generate_thumbnails) yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
for r in remote_thumbnails: for r in remote_thumbnails:
yield self.store.store_remote_media_thumbnail(*r) yield self.store.store_remote_media_thumbnail(*r)

View File

@ -23,7 +23,7 @@ from twisted.web.client import BrowserLikePolicyForHTTPS
from twisted.enterprise import adbapi from twisted.enterprise import adbapi
from synapse.federation import initialize_http_replication from synapse.federation import initialize_http_replication
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.notifier import Notifier from synapse.notifier import Notifier
from synapse.api.auth import Auth from synapse.api.auth import Auth
from synapse.handlers import Handlers from synapse.handlers import Handlers

View File

@ -63,7 +63,7 @@ class StateHandler(object):
cache_name="state_cache", cache_name="state_cache",
clock=self.clock, clock=self.clock,
max_len=SIZE_OF_CACHE, max_len=SIZE_OF_CACHE,
expiry_ms=EVICTION_TIMEOUT_SECONDS*1000, expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
reset_expiry_on_get=True, reset_expiry_on_get=True,
) )

View File

@ -45,9 +45,10 @@ from .search import SearchStore
from .tags import TagsStore from .tags import TagsStore
from .account_data import AccountDataStore from .account_data import AccountDataStore
from util.id_generators import IdGenerator, StreamIdGenerator from util.id_generators import IdGenerator, StreamIdGenerator
from synapse.util.caches.stream_change_cache import StreamChangeCache
import logging import logging
@ -58,7 +59,7 @@ logger = logging.getLogger(__name__)
# Number of msec of granularity to store the user IP 'last seen' time. Smaller # Number of msec of granularity to store the user IP 'last seen' time. Smaller
# times give more inserts into the database even for readonly API hits # times give more inserts into the database even for readonly API hits
# 120 seconds == 2 minutes # 120 seconds == 2 minutes
LAST_SEEN_GRANULARITY = 120*1000 LAST_SEEN_GRANULARITY = 120 * 1000
class DataStore(RoomMemberStore, RoomStore, class DataStore(RoomMemberStore, RoomStore,
@ -84,6 +85,7 @@ class DataStore(RoomMemberStore, RoomStore,
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
self.hs = hs self.hs = hs
self.database_engine = hs.database_engine
cur = db_conn.cursor() cur = db_conn.cursor()
try: try:
@ -117,8 +119,61 @@ class DataStore(RoomMemberStore, RoomStore,
self._push_rule_id_gen = IdGenerator("push_rules", "id", self) self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
events_max = self._stream_id_gen.get_max_token(None)
event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events",
entity_column="room_id",
stream_column="stream_ordering",
max_value=events_max,
)
self._events_stream_cache = StreamChangeCache(
"EventsRoomStreamChangeCache", min_event_val,
prefilled_cache=event_cache_prefill,
)
self._membership_stream_cache = StreamChangeCache(
"MembershipStreamChangeCache", events_max,
)
account_max = self._account_data_id_gen.get_max_token(None)
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max,
)
super(DataStore, self).__init__(hs) super(DataStore, self).__init__(hs)
def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value):
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
# It doesn't really matter how many we get, the StreamChangeCache will
# do the right thing to ensure it respects the max size of cache.
sql = (
"SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
" WHERE %(stream)s > ? - 100000"
" GROUP BY %(entity)s"
) % {
"table": table,
"entity": entity_column,
"stream": stream_column,
}
sql = self.database_engine.convert_param_style(sql)
txn = db_conn.cursor()
txn.execute(sql, (int(max_value),))
rows = txn.fetchall()
cache = {
row[0]: int(row[1])
for row in rows
}
if cache:
min_val = min(cache.values())
else:
min_val = max_value
return cache, min_val
@defer.inlineCallbacks @defer.inlineCallbacks
def insert_client_ip(self, user, access_token, ip, user_agent): def insert_client_ip(self, user, access_token, ip, user_agent):
now = int(self._clock.time_msec()) now = int(self._clock.time_msec())

View File

@ -15,7 +15,7 @@
import logging import logging
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache from synapse.util.caches.descriptors import Cache
import synapse.metrics import synapse.metrics
@ -185,7 +185,7 @@ class SQLBaseStore(object):
time_then = self._previous_loop_ts time_then = self._previous_loop_ts
self._previous_loop_ts = time_now self._previous_loop_ts = time_now
ratio = (curr - prev)/(time_now - time_then) ratio = (curr - prev) / (time_now - time_then)
top_three_counters = self._txn_perf_counters.interval( top_three_counters = self._txn_perf_counters.interval(
time_now - time_then, limit=3 time_now - time_then, limit=3
@ -298,10 +298,10 @@ class SQLBaseStore(object):
func, *args, **kwargs func, *args, **kwargs
) )
result = yield preserve_context_over_fn( with PreserveLoggingContext():
self._db_pool.runWithConnection, result = yield self._db_pool.runWithConnection(
inner_func, *args, **kwargs inner_func, *args, **kwargs
) )
for after_callback, after_args in after_callbacks: for after_callback, after_args in after_callbacks:
after_callback(*after_args) after_callback(*after_args)
@ -326,10 +326,10 @@ class SQLBaseStore(object):
return func(conn, *args, **kwargs) return func(conn, *args, **kwargs)
result = yield preserve_context_over_fn( with PreserveLoggingContext():
self._db_pool.runWithConnection, result = yield self._db_pool.runWithConnection(
inner_func, *args, **kwargs inner_func, *args, **kwargs
) )
defer.returnValue(result) defer.returnValue(result)
@ -643,7 +643,10 @@ class SQLBaseStore(object):
if not iterable: if not iterable:
defer.returnValue(results) defer.returnValue(results)
chunks = [iterable[i:i+batch_size] for i in xrange(0, len(iterable), batch_size)] chunks = [
iterable[i:i + batch_size]
for i in xrange(0, len(iterable), batch_size)
]
for chunk in chunks: for chunk in chunks:
rows = yield self.runInteraction( rows = yield self.runInteraction(
desc, desc,

View File

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
from twisted.internet import defer from twisted.internet import defer
import ujson as json import ujson as json
@ -24,14 +23,6 @@ logger = logging.getLogger(__name__)
class AccountDataStore(SQLBaseStore): class AccountDataStore(SQLBaseStore):
def __init__(self, hs):
super(AccountDataStore, self).__init__(hs)
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache",
self._account_data_id_gen.get_max_token(None),
max_size=10000,
)
def get_account_data_for_user(self, user_id): def get_account_data_for_user(self, user_id):
"""Get all the client account_data for a user. """Get all the client account_data for a user.
@ -166,6 +157,10 @@ class AccountDataStore(SQLBaseStore):
"content": content_json, "content": content_json,
} }
) )
txn.call_after(
self._account_data_stream_cache.entity_has_changed,
user_id, next_id,
)
self._update_max_stream_id(txn, next_id) self._update_max_stream_id(txn, next_id)
with (yield self._account_data_id_gen.get_next(self)) as next_id: with (yield self._account_data_id_gen.get_next(self)) as next_id:

View File

@ -276,7 +276,8 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
"application_services_state", "application_services_state",
dict(as_id=service.id), dict(as_id=service.id),
["state"], ["state"],
allow_none=True allow_none=True,
desc="get_appservice_state",
) )
if result: if result:
defer.returnValue(result.get("state")) defer.returnValue(result.get("state"))

View File

@ -54,7 +54,7 @@ class Sqlite3Engine(object):
def _parse_match_info(buf): def _parse_match_info(buf):
bufsize = len(buf) bufsize = len(buf)
return [struct.unpack('@I', buf[i:i+4])[0] for i in range(0, bufsize, 4)] return [struct.unpack('@I', buf[i:i + 4])[0] for i in range(0, bufsize, 4)]
def _rank(raw_match_info): def _rank(raw_match_info):

View File

@ -58,7 +58,7 @@ class EventFederationStore(SQLBaseStore):
new_front = set() new_front = set()
front_list = list(front) front_list = list(front)
chunks = [ chunks = [
front_list[x:x+100] front_list[x:x + 100]
for x in xrange(0, len(front), 100) for x in xrange(0, len(front), 100)
] ]
for chunk in chunks: for chunk in chunks:

View File

@ -24,8 +24,7 @@ logger = logging.getLogger(__name__)
class EventPushActionsStore(SQLBaseStore): class EventPushActionsStore(SQLBaseStore):
@defer.inlineCallbacks def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
def set_push_actions_for_event_and_users(self, event, tuples):
""" """
:param event: the event set actions for :param event: the event set actions for
:param tuples: list of tuples of (user_id, profile_tag, actions) :param tuples: list of tuples of (user_id, profile_tag, actions)
@ -37,21 +36,19 @@ class EventPushActionsStore(SQLBaseStore):
'event_id': event.event_id, 'event_id': event.event_id,
'user_id': uid, 'user_id': uid,
'profile_tag': profile_tag, 'profile_tag': profile_tag,
'actions': json.dumps(actions) 'actions': json.dumps(actions),
'stream_ordering': event.internal_metadata.stream_ordering,
'topological_ordering': event.depth,
'notif': 1,
'highlight': 1 if _action_has_highlight(actions) else 0,
}) })
def f(txn): for uid, _, __ in tuples:
for uid, _, __ in tuples: txn.call_after(
txn.call_after( self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
self.get_unread_event_push_actions_by_room_for_user.invalidate_many, (event.room_id, uid)
(event.room_id, uid) )
) self._simple_insert_many_txn(txn, "event_push_actions", values)
return self._simple_insert_many_txn(txn, "event_push_actions", values)
yield self.runInteraction(
"set_actions_for_event_and_users",
f,
)
@cachedInlineCallbacks(num_args=3, lru=True, tree=True) @cachedInlineCallbacks(num_args=3, lru=True, tree=True)
def get_unread_event_push_actions_by_room_for_user( def get_unread_event_push_actions_by_room_for_user(
@ -68,32 +65,34 @@ class EventPushActionsStore(SQLBaseStore):
) )
results = txn.fetchall() results = txn.fetchall()
if len(results) == 0: if len(results) == 0:
return [] return {"notify_count": 0, "highlight_count": 0}
stream_ordering = results[0][0] stream_ordering = results[0][0]
topological_ordering = results[0][1] topological_ordering = results[0][1]
sql = ( sql = (
"SELECT ea.event_id, ea.actions" "SELECT sum(notif), sum(highlight)"
" FROM event_push_actions ea, events e" " FROM event_push_actions ea"
" WHERE ea.room_id = e.room_id" " WHERE"
" AND ea.event_id = e.event_id" " user_id = ?"
" AND ea.user_id = ?" " AND room_id = ?"
" AND ea.room_id = ?"
" AND (" " AND ("
" e.topological_ordering > ?" " topological_ordering > ?"
" OR (e.topological_ordering = ? AND e.stream_ordering > ?)" " OR (topological_ordering = ? AND stream_ordering > ?)"
")" ")"
) )
txn.execute(sql, ( txn.execute(sql, (
user_id, room_id, user_id, room_id,
topological_ordering, topological_ordering, stream_ordering topological_ordering, topological_ordering, stream_ordering
) ))
) row = txn.fetchone()
return [ if row:
{"event_id": row[0], "actions": json.loads(row[1])} return {
for row in txn.fetchall() "notify_count": row[0] or 0,
] "highlight_count": row[1] or 0,
}
else:
return {"notify_count": 0, "highlight_count": 0}
ret = yield self.runInteraction( ret = yield self.runInteraction(
"get_unread_event_push_actions_by_room", "get_unread_event_push_actions_by_room",
@ -101,19 +100,24 @@ class EventPushActionsStore(SQLBaseStore):
) )
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
def remove_push_actions_for_event_id(self, room_id, event_id): # Sad that we have to blow away the cache for the whole room here
def f(txn): txn.call_after(
# Sad that we have to blow away the cache for the whole room here self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
txn.call_after( (room_id,)
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
(room_id,)
)
txn.execute(
"DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?",
(room_id, event_id)
)
yield self.runInteraction(
"remove_push_actions_for_event_id",
f
) )
txn.execute(
"DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?",
(room_id, event_id)
)
def _action_has_highlight(actions):
for action in actions:
try:
if action.get("set_tweak", None) == "highlight":
return action.get("value", True)
except AttributeError:
pass
return False

View File

@ -19,7 +19,7 @@ from twisted.internet import defer, reactor
from synapse.events import FrozenEvent, USE_FROZEN_DICTS from synapse.events import FrozenEvent, USE_FROZEN_DICTS
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.util.logcontext import preserve_context_over_deferred from synapse.util.logcontext import preserve_fn, PreserveLoggingContext
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
@ -84,7 +84,7 @@ class EventsStore(SQLBaseStore):
event.internal_metadata.stream_ordering = stream event.internal_metadata.stream_ordering = stream
chunks = [ chunks = [
events_and_contexts[x:x+100] events_and_contexts[x:x + 100]
for x in xrange(0, len(events_and_contexts), 100) for x in xrange(0, len(events_and_contexts), 100)
] ]
@ -205,23 +205,29 @@ class EventsStore(SQLBaseStore):
@log_function @log_function
def _persist_events_txn(self, txn, events_and_contexts, backfilled, def _persist_events_txn(self, txn, events_and_contexts, backfilled,
is_new_state=True): is_new_state=True):
depth_updates = {}
# Remove the any existing cache entries for the event_ids for event, context in events_and_contexts:
for event, _ in events_and_contexts: # Remove the any existing cache entries for the event_ids
txn.call_after(self._invalidate_get_event_cache, event.event_id) txn.call_after(self._invalidate_get_event_cache, event.event_id)
if not backfilled: if not backfilled:
txn.call_after( txn.call_after(
self._events_stream_cache.entity_has_changed, self._events_stream_cache.entity_has_changed,
event.room_id, event.internal_metadata.stream_ordering, event.room_id, event.internal_metadata.stream_ordering,
) )
depth_updates = {} if not event.internal_metadata.is_outlier():
for event, _ in events_and_contexts: depth_updates[event.room_id] = max(
if event.internal_metadata.is_outlier(): event.depth, depth_updates.get(event.room_id, event.depth)
continue )
depth_updates[event.room_id] = max(
event.depth, depth_updates.get(event.room_id, event.depth) if context.push_actions:
self._set_push_actions_for_event_and_users_txn(
txn, event, context.push_actions
)
if event.type == EventTypes.Redaction and event.redacts is not None:
self._remove_push_actions_for_event_id_txn(
txn, event.room_id, event.redacts
) )
for room_id, depth in depth_updates.items(): for room_id, depth in depth_updates.items():
@ -664,14 +670,16 @@ class EventsStore(SQLBaseStore):
for ids, d in lst: for ids, d in lst:
if not d.called: if not d.called:
try: try:
d.callback([ with PreserveLoggingContext():
res[i] d.callback([
for i in ids res[i]
if i in res for i in ids
]) if i in res
])
except: except:
logger.exception("Failed to callback") logger.exception("Failed to callback")
reactor.callFromThread(fire, event_list, row_dict) with PreserveLoggingContext():
reactor.callFromThread(fire, event_list, row_dict)
except Exception as e: except Exception as e:
logger.exception("do_fetch") logger.exception("do_fetch")
@ -679,10 +687,12 @@ class EventsStore(SQLBaseStore):
def fire(evs): def fire(evs):
for _, d in evs: for _, d in evs:
if not d.called: if not d.called:
d.errback(e) with PreserveLoggingContext():
d.errback(e)
if event_list: if event_list:
reactor.callFromThread(fire, event_list) with PreserveLoggingContext():
reactor.callFromThread(fire, event_list)
@defer.inlineCallbacks @defer.inlineCallbacks
def _enqueue_events(self, events, check_redacted=True, def _enqueue_events(self, events, check_redacted=True,
@ -709,18 +719,20 @@ class EventsStore(SQLBaseStore):
should_start = False should_start = False
if should_start: if should_start:
self.runWithConnection( with PreserveLoggingContext():
self._do_fetch self.runWithConnection(
) self._do_fetch
)
rows = yield preserve_context_over_deferred(events_d) with PreserveLoggingContext():
rows = yield events_d
if not allow_rejected: if not allow_rejected:
rows[:] = [r for r in rows if not r["rejects"]] rows[:] = [r for r in rows if not r["rejects"]]
res = yield defer.gatherResults( res = yield defer.gatherResults(
[ [
self._get_event_from_row( preserve_fn(self._get_event_from_row)(
row["internal_metadata"], row["json"], row["redacts"], row["internal_metadata"], row["json"], row["redacts"],
check_redacted=check_redacted, check_redacted=check_redacted,
get_prev_content=get_prev_content, get_prev_content=get_prev_content,
@ -740,7 +752,7 @@ class EventsStore(SQLBaseStore):
rows = [] rows = []
N = 200 N = 200
for i in range(1 + len(events) / N): for i in range(1 + len(events) / N):
evs = events[i*N:(i + 1)*N] evs = events[i * N:(i + 1) * N]
if not evs: if not evs:
break break
@ -755,7 +767,7 @@ class EventsStore(SQLBaseStore):
" LEFT JOIN rejections as rej USING (event_id)" " LEFT JOIN rejections as rej USING (event_id)"
" LEFT JOIN redactions as r ON e.event_id = r.redacts" " LEFT JOIN redactions as r ON e.event_id = r.redacts"
" WHERE e.event_id IN (%s)" " WHERE e.event_id IN (%s)"
) % (",".join(["?"]*len(evs)),) ) % (",".join(["?"] * len(evs)),)
txn.execute(sql, evs) txn.execute(sql, evs)
rows.extend(self.cursor_to_dict(txn)) rows.extend(self.cursor_to_dict(txn))

View File

@ -39,6 +39,7 @@ class KeyStore(SQLBaseStore):
table="server_tls_certificates", table="server_tls_certificates",
keyvalues={"server_name": server_name}, keyvalues={"server_name": server_name},
retcols=("tls_certificate",), retcols=("tls_certificate",),
desc="get_server_certificate",
) )
tls_certificate = OpenSSL.crypto.load_certificate( tls_certificate = OpenSSL.crypto.load_certificate(
OpenSSL.crypto.FILETYPE_ASN1, tls_certificate_bytes, OpenSSL.crypto.FILETYPE_ASN1, tls_certificate_bytes,

View File

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 28 SCHEMA_VERSION = 29
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))
@ -211,7 +211,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
logger.debug("applied_delta_files: %s", applied_delta_files) logger.debug("applied_delta_files: %s", applied_delta_files)
for v in range(start_ver, SCHEMA_VERSION + 1): for v in range(start_ver, SCHEMA_VERSION + 1):
logger.debug("Upgrading schema to v%d", v) logger.info("Upgrading schema to v%d", v)
delta_dir = os.path.join(dir_path, "schema", "delta", str(v)) delta_dir = os.path.join(dir_path, "schema", "delta", str(v))

View File

@ -68,8 +68,9 @@ class PresenceStore(SQLBaseStore):
for row in rows for row in rows
}) })
@defer.inlineCallbacks
def set_presence_state(self, user_localpart, new_state): def set_presence_state(self, user_localpart, new_state):
res = self._simple_update_one( res = yield self._simple_update_one(
table="presence", table="presence",
keyvalues={"user_id": user_localpart}, keyvalues={"user_id": user_localpart},
updatevalues={"state": new_state["state"], updatevalues={"state": new_state["state"],
@ -79,7 +80,7 @@ class PresenceStore(SQLBaseStore):
) )
self.get_presence_state.invalidate((user_localpart,)) self.get_presence_state.invalidate((user_localpart,))
return res defer.returnValue(res)
def allow_presence_visible(self, observed_localpart, observer_userid): def allow_presence_visible(self, observed_localpart, observer_userid):
return self._simple_insert( return self._simple_insert(

View File

@ -46,6 +46,20 @@ class ReceiptsStore(SQLBaseStore):
desc="get_receipts_for_room", desc="get_receipts_for_room",
) )
@cached(num_args=3)
def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
return self._simple_select_one_onecol(
table="receipts_linearized",
keyvalues={
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id
},
retcol="event_id",
desc="get_own_receipt_for_user",
allow_none=True,
)
@cachedInlineCallbacks(num_args=2) @cachedInlineCallbacks(num_args=2)
def get_receipts_for_user(self, user_id, receipt_type): def get_receipts_for_user(self, user_id, receipt_type):
def f(txn): def f(txn):
@ -226,6 +240,11 @@ class ReceiptsStore(SQLBaseStore):
room_id, stream_id room_id, stream_id
) )
txn.call_after(
self.get_last_receipt_event_id_for_user.invalidate,
(user_id, room_id, receipt_type)
)
# We don't want to clobber receipts for more recent events, so we # We don't want to clobber receipts for more recent events, so we
# have to compare orderings of existing receipts # have to compare orderings of existing receipts
sql = ( sql = (

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import StoreError, Codes from synapse.api.errors import StoreError, Codes
@ -134,6 +136,7 @@ class RegistrationStore(SQLBaseStore):
}, },
retcols=["name", "password_hash", "is_guest"], retcols=["name", "password_hash", "is_guest"],
allow_none=True, allow_none=True,
desc="get_user_by_id",
) )
def get_users_by_id_case_insensitive(self, user_id): def get_users_by_id_case_insensitive(self, user_id):
@ -350,3 +353,37 @@ class RegistrationStore(SQLBaseStore):
ret = yield self.runInteraction("count_users", _count_users) ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks
def find_next_generated_user_id_localpart(self):
"""
Gets the localpart of the next generated user ID.
Generated user IDs are integers, and we aim for them to be as small as
we can. Unfortunately, it's possible some of them are already taken by
existing users, and there may be gaps in the already taken range. This
function returns the start of the first allocatable gap. This is to
avoid the case of ID 10000000 being pre-allocated, so us wasting the
first (and shortest) many generated user IDs.
"""
def _find_next_generated_user_id(txn):
txn.execute("SELECT name FROM users")
rows = self.cursor_to_dict(txn)
regex = re.compile("^@(\d+):")
found = set()
for r in rows:
user_id = r["name"]
match = regex.search(user_id)
if match:
found.add(int(match.group(1)))
for i in xrange(len(found) + 1):
if i not in found:
return i
defer.returnValue((yield self.runInteraction(
"find_next_generated_user_id",
_find_next_generated_user_id
)))

View File

@ -87,90 +87,20 @@ class RoomStore(SQLBaseStore):
desc="get_public_room_ids", desc="get_public_room_ids",
) )
@defer.inlineCallbacks def get_room_count(self):
def get_rooms(self, is_public): """Retrieve a list of all rooms
"""Retrieve a list of all public rooms.
Args:
is_public (bool): True if the rooms returned should be public.
Returns:
A list of room dicts containing at least a "room_id" key, a
"topic" key if one is set, and a "name" key if one is set
""" """
def f(txn): def f(txn):
def subquery(table_name, column_name=None): sql = "SELECT count(*) FROM rooms"
column_name = column_name or table_name txn.execute(sql)
return ( row = txn.fetchone()
"SELECT %(table_name)s.event_id as event_id, " return row[0] or 0
"%(table_name)s.room_id as room_id, %(column_name)s "
"FROM %(table_name)s "
"INNER JOIN current_state_events as c "
"ON c.event_id = %(table_name)s.event_id " % {
"column_name": column_name,
"table_name": table_name,
}
)
sql = ( return self.runInteraction(
"SELECT"
" r.room_id,"
" max(n.name),"
" max(t.topic),"
" max(v.history_visibility),"
" max(g.guest_access)"
" FROM rooms AS r"
" LEFT JOIN (%(topic)s) AS t ON t.room_id = r.room_id"
" LEFT JOIN (%(name)s) AS n ON n.room_id = r.room_id"
" LEFT JOIN (%(history_visibility)s) AS v ON v.room_id = r.room_id"
" LEFT JOIN (%(guest_access)s) AS g ON g.room_id = r.room_id"
" WHERE r.is_public = ?"
" GROUP BY r.room_id" % {
"topic": subquery("topics", "topic"),
"name": subquery("room_names", "name"),
"history_visibility": subquery("history_visibility"),
"guest_access": subquery("guest_access"),
}
)
txn.execute(sql, (is_public,))
rows = txn.fetchall()
for i, row in enumerate(rows):
room_id = row[0]
aliases = self._simple_select_onecol_txn(
txn,
table="room_aliases",
keyvalues={
"room_id": room_id
},
retcol="room_alias",
)
rows[i] = list(row) + [aliases]
return rows
rows = yield self.runInteraction(
"get_rooms", f "get_rooms", f
) )
ret = [
{
"room_id": r[0],
"name": r[1],
"topic": r[2],
"world_readable": r[3] == "world_readable",
"guest_can_join": r[4] == "can_join",
"aliases": r[5],
}
for r in rows
if r[5] # We only return rooms that have at least one alias.
]
defer.returnValue(ret)
def _store_room_topic_txn(self, txn, event): def _store_room_topic_txn(self, txn, event):
if hasattr(event, "content") and "topic" in event.content: if hasattr(event, "content") and "topic" in event.content:
self._simple_insert_txn( self._simple_insert_txn(

View File

@ -58,6 +58,10 @@ class RoomMemberStore(SQLBaseStore):
txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,)) txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(
self._membership_stream_cache.entity_has_changed,
event.state_key, event.internal_metadata.stream_ordering
)
def get_room_member(self, user_id, room_id): def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member. """Retrieve the current state of a room member.

View File

@ -0,0 +1,16 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE INDEX public_room_index on rooms(is_public);

View File

@ -0,0 +1,31 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE event_push_actions ADD COLUMN topological_ordering BIGINT;
ALTER TABLE event_push_actions ADD COLUMN stream_ordering BIGINT;
ALTER TABLE event_push_actions ADD COLUMN notif SMALLINT;
ALTER TABLE event_push_actions ADD COLUMN highlight SMALLINT;
UPDATE event_push_actions SET stream_ordering = (
SELECT stream_ordering FROM events WHERE event_id = event_push_actions.event_id
), topological_ordering = (
SELECT topological_ordering FROM events WHERE event_id = event_push_actions.event_id
);
UPDATE event_push_actions SET notif = 1, highlight = 0;
CREATE INDEX event_push_actions_rm_tokens on event_push_actions(
user_id, room_id, topological_ordering, stream_ordering
);

View File

@ -171,41 +171,43 @@ class StateStore(SQLBaseStore):
events = yield self._get_events(event_ids, get_prev_content=False) events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events) defer.returnValue(events)
def _get_state_groups_from_groups(self, groups_and_types): def _get_state_groups_from_groups(self, groups, types):
"""Returns dictionary state_group -> state event ids """Returns dictionary state_group -> state event ids
Args:
groups_and_types (list): list of 2-tuple (`group`, `types`)
""" """
def f(txn): def f(txn, groups):
if types is not None:
where_clause = "AND (%s)" % (
" OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
)
else:
where_clause = ""
sql = (
"SELECT state_group, event_id FROM state_groups_state WHERE"
" state_group IN (%s) %s" % (
",".join("?" for _ in groups),
where_clause,
)
)
args = list(groups)
if types is not None:
args.extend([i for typ in types for i in typ])
txn.execute(sql, args)
rows = self.cursor_to_dict(txn)
results = {} results = {}
for group, types in groups_and_types: for row in rows:
if types is not None: results.setdefault(row["state_group"], []).append(row["event_id"])
where_clause = "AND (%s)" % (
" OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
)
else:
where_clause = ""
sql = (
"SELECT event_id FROM state_groups_state WHERE"
" state_group = ? %s"
) % (where_clause,)
args = [group]
if types is not None:
args.extend([i for typ in types for i in typ])
txn.execute(sql, args)
results[group] = [r[0] for r in txn.fetchall()]
return results return results
return self.runInteraction( chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)]
"_get_state_groups_from_groups", for chunk in chunks:
f, return self.runInteraction(
) "_get_state_groups_from_groups",
f, chunk
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_for_events(self, event_ids, types): def get_state_for_events(self, event_ids, types):
@ -264,26 +266,20 @@ class StateStore(SQLBaseStore):
) )
@cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids", @cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids",
num_args=1) num_args=1, inlineCallbacks=True)
def _get_state_group_for_events(self, event_ids): def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group """Returns mapping event_id -> state_group
""" """
def f(txn): rows = yield self._simple_select_many_batch(
results = {} table="event_to_state_groups",
for event_id in event_ids: column="event_id",
results[event_id] = self._simple_select_one_onecol_txn( iterable=event_ids,
txn, keyvalues={},
table="event_to_state_groups", retcols=("event_id", "state_group",),
keyvalues={ desc="_get_state_group_for_events",
"event_id": event_id, )
},
retcol="state_group",
allow_none=True,
)
return results defer.returnValue({row["event_id"]: row["state_group"] for row in rows})
return self.runInteraction("_get_state_group_for_events", f)
def _get_some_state_from_cache(self, group, types): def _get_some_state_from_cache(self, group, types):
"""Checks if group is in cache. See `_get_state_for_groups` """Checks if group is in cache. See `_get_state_for_groups`
@ -355,7 +351,7 @@ class StateStore(SQLBaseStore):
all events are returned. all events are returned.
""" """
results = {} results = {}
missing_groups_and_types = [] missing_groups = []
if types is not None: if types is not None:
for group in set(groups): for group in set(groups):
state_dict, missing_types, got_all = self._get_some_state_from_cache( state_dict, missing_types, got_all = self._get_some_state_from_cache(
@ -364,7 +360,7 @@ class StateStore(SQLBaseStore):
results[group] = state_dict results[group] = state_dict
if not got_all: if not got_all:
missing_groups_and_types.append((group, missing_types)) missing_groups.append(group)
else: else:
for group in set(groups): for group in set(groups):
state_dict, got_all = self._get_all_state_from_cache( state_dict, got_all = self._get_all_state_from_cache(
@ -373,9 +369,9 @@ class StateStore(SQLBaseStore):
results[group] = state_dict results[group] = state_dict
if not got_all: if not got_all:
missing_groups_and_types.append((group, None)) missing_groups.append(group)
if not missing_groups_and_types: if not missing_groups:
defer.returnValue({ defer.returnValue({
group: { group: {
type_tuple: event type_tuple: event
@ -389,7 +385,7 @@ class StateStore(SQLBaseStore):
cache_seq_num = self._state_group_cache.sequence cache_seq_num = self._state_group_cache.sequence
group_state_dict = yield self._get_state_groups_from_groups( group_state_dict = yield self._get_state_groups_from_groups(
missing_groups_and_types missing_groups, types
) )
state_events = yield self._get_events( state_events = yield self._get_events(

View File

@ -37,10 +37,9 @@ from twisted.internet import defer
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_fn
import logging import logging
@ -78,13 +77,6 @@ def upper_bound(token):
class StreamStore(SQLBaseStore): class StreamStore(SQLBaseStore):
def __init__(self, hs):
super(StreamStore, self).__init__(hs)
self._events_stream_cache = StreamChangeCache(
"EventsRoomStreamChangeCache", self._stream_id_gen.get_max_token(None)
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_appservice_room_stream(self, service, from_key, to_key, limit=0): def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
# NB this lives here instead of appservice.py so we can reuse the # NB this lives here instead of appservice.py so we can reuse the
@ -177,14 +169,14 @@ class StreamStore(SQLBaseStore):
results = {} results = {}
room_ids = list(room_ids) room_ids = list(room_ids)
for rm_ids in (room_ids[i:i+20] for i in xrange(0, len(room_ids), 20)): for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
res = yield defer.gatherResults([ res = yield defer.gatherResults([
self.get_room_events_stream_for_room( preserve_fn(self.get_room_events_stream_for_room)(
room_id, from_key, to_key, limit room_id, from_key, to_key, limit,
).addCallback(lambda r, rm: (rm, r), room_id) )
for room_id in room_ids for room_id in room_ids
]) ])
results.update(dict(res)) results.update(dict(zip(rm_ids, res)))
defer.returnValue(results) defer.returnValue(results)
@ -229,28 +221,30 @@ class StreamStore(SQLBaseStore):
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)
ret = self._get_events_txn( return rows
txn,
[r["event_id"] for r in rows],
get_prev_content=True
)
self._set_before_and_after(ret, rows, topo_order=False) rows = yield self.runInteraction("get_room_events_stream_for_room", f)
ret.reverse() ret = yield self._get_events(
[r["event_id"] for r in rows],
get_prev_content=True
)
if rows: self._set_before_and_after(ret, rows, topo_order=False)
key = "s%d" % min(r["stream_ordering"] for r in rows)
else:
# Assume we didn't get anything because there was nothing to
# get.
key = from_key
return ret, key ret.reverse()
res = yield self.runInteraction("get_room_events_stream_for_room", f)
defer.returnValue(res)
def get_room_changes_for_user(self, user_id, from_key, to_key): if rows:
key = "s%d" % min(r["stream_ordering"] for r in rows)
else:
# Assume we didn't get anything because there was nothing to
# get.
key = from_key
defer.returnValue((ret, key))
@defer.inlineCallbacks
def get_membership_changes_for_user(self, user_id, from_key, to_key):
if from_key is not None: if from_key is not None:
from_id = RoomStreamToken.parse_stream_token(from_key).stream from_id = RoomStreamToken.parse_stream_token(from_key).stream
else: else:
@ -258,7 +252,14 @@ class StreamStore(SQLBaseStore):
to_id = RoomStreamToken.parse_stream_token(to_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream
if from_key == to_key: if from_key == to_key:
return defer.succeed([]) defer.returnValue([])
if from_id:
has_changed = self._membership_stream_cache.has_entity_changed(
user_id, int(from_id)
)
if not has_changed:
defer.returnValue([])
def f(txn): def f(txn):
if from_id is not None: if from_id is not None:
@ -283,17 +284,19 @@ class StreamStore(SQLBaseStore):
txn.execute(sql, (user_id, to_id,)) txn.execute(sql, (user_id, to_id,))
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)
ret = self._get_events_txn( return rows
txn,
[r["event_id"] for r in rows],
get_prev_content=True
)
return ret rows = yield self.runInteraction("get_membership_changes_for_user", f)
return self.runInteraction("get_room_changes_for_user", f) ret = yield self._get_events(
[r["event_id"] for r in rows],
get_prev_content=True
)
self._set_before_and_after(ret, rows, topo_order=False)
defer.returnValue(ret)
@log_function
def get_room_events_stream( def get_room_events_stream(
self, self,
user_id, user_id,
@ -324,11 +327,6 @@ class StreamStore(SQLBaseStore):
" WHERE m.user_id = ? AND m.membership = 'join'" " WHERE m.user_id = ? AND m.membership = 'join'"
) )
current_room_membership_args = [user_id] current_room_membership_args = [user_id]
if room_ids:
current_room_membership_sql += " AND m.room_id in (%s)" % (
",".join(map(lambda _: "?", room_ids))
)
current_room_membership_args = [user_id] + room_ids
# We also want to get any membership events about that user, e.g. # We also want to get any membership events about that user, e.g.
# invites or leave notifications. # invites or leave notifications.
@ -567,6 +565,7 @@ class StreamStore(SQLBaseStore):
table="events", table="events",
keyvalues={"event_id": event_id}, keyvalues={"event_id": event_id},
retcols=("stream_ordering", "topological_ordering"), retcols=("stream_ordering", "topological_ordering"),
desc="get_topological_token_for_event",
).addCallback(lambda row: "t%d-%d" % ( ).addCallback(lambda row: "t%d-%d" % (
row["topological_ordering"], row["stream_ordering"],) row["topological_ordering"], row["stream_ordering"],)
) )
@ -604,6 +603,10 @@ class StreamStore(SQLBaseStore):
internal = event.internal_metadata internal = event.internal_metadata
internal.before = str(RoomStreamToken(topo, stream - 1)) internal.before = str(RoomStreamToken(topo, stream - 1))
internal.after = str(RoomStreamToken(topo, stream)) internal.after = str(RoomStreamToken(topo, stream))
internal.order = (
int(topo) if topo else 0,
int(stream),
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_events_around(self, room_id, event_id, before_limit, after_limit): def get_events_around(self, room_id, event_id, before_limit, after_limit):

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from twisted.internet import defer, reactor, task from twisted.internet import defer, reactor, task
@ -46,7 +46,7 @@ class Clock(object):
def looping_call(self, f, msec): def looping_call(self, f, msec):
l = task.LoopingCall(f) l = task.LoopingCall(f)
l.start(msec/1000.0, now=False) l.start(msec / 1000.0, now=False)
return l return l
def stop_looping_call(self, loop): def stop_looping_call(self, loop):
@ -61,10 +61,8 @@ class Clock(object):
*args: Postional arguments to pass to function. *args: Postional arguments to pass to function.
**kwargs: Key arguments to pass to function. **kwargs: Key arguments to pass to function.
""" """
current_context = LoggingContext.current_context()
def wrapped_callback(*args, **kwargs): def wrapped_callback(*args, **kwargs):
with PreserveLoggingContext(current_context): with PreserveLoggingContext():
callback(*args, **kwargs) callback(*args, **kwargs)
with PreserveLoggingContext(): with PreserveLoggingContext():

View File

@ -16,13 +16,16 @@
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from .logcontext import preserve_context_over_deferred from .logcontext import PreserveLoggingContext
@defer.inlineCallbacks
def sleep(seconds): def sleep(seconds):
d = defer.Deferred() d = defer.Deferred()
reactor.callLater(seconds, d.callback, seconds) with PreserveLoggingContext():
return preserve_context_over_deferred(d) reactor.callLater(seconds, d.callback, seconds)
res = yield d
defer.returnValue(res)
def run_on_reactor(): def run_on_reactor():
@ -54,6 +57,7 @@ class ObservableDeferred(object):
object.__setattr__(self, "_result", (True, r)) object.__setattr__(self, "_result", (True, r))
while self._observers: while self._observers:
try: try:
# TODO: Handle errors here.
self._observers.pop().callback(r) self._observers.pop().callback(r)
except: except:
pass pass
@ -63,6 +67,7 @@ class ObservableDeferred(object):
object.__setattr__(self, "_result", (False, f)) object.__setattr__(self, "_result", (False, f))
while self._observers: while self._observers:
try: try:
# TODO: Handle errors here.
self._observers.pop().errback(f) self._observers.pop().errback(f)
except: except:
pass pass

View File

@ -18,6 +18,9 @@ from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache from synapse.util.caches.treecache import TreeCache
from synapse.util.logcontext import (
PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
)
from . import caches_by_name, DEBUG_CACHES, cache_counter from . import caches_by_name, DEBUG_CACHES, cache_counter
@ -149,7 +152,7 @@ class CacheDescriptor(object):
self.lru = lru self.lru = lru
self.tree = tree self.tree = tree
self.arg_names = inspect.getargspec(orig).args[1:num_args+1] self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
if len(self.arg_names) < self.num_args: if len(self.arg_names) < self.num_args:
raise Exception( raise Exception(
@ -190,7 +193,7 @@ class CacheDescriptor(object):
defer.returnValue(cached_result) defer.returnValue(cached_result)
observer.addCallback(check_result) observer.addCallback(check_result)
return observer return preserve_context_over_deferred(observer)
except KeyError: except KeyError:
# Get the sequence number of the cache before reading from the # Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated # database so that we can tell if the cache is invalidated
@ -198,6 +201,7 @@ class CacheDescriptor(object):
sequence = self.cache.sequence sequence = self.cache.sequence
ret = defer.maybeDeferred( ret = defer.maybeDeferred(
preserve_context_over_fn,
self.function_to_call, self.function_to_call,
obj, *args, **kwargs obj, *args, **kwargs
) )
@ -211,7 +215,7 @@ class CacheDescriptor(object):
ret = ObservableDeferred(ret, consumeErrors=True) ret = ObservableDeferred(ret, consumeErrors=True)
self.cache.update(sequence, cache_key, ret) self.cache.update(sequence, cache_key, ret)
return ret.observe() return preserve_context_over_deferred(ret.observe())
wrapped.invalidate = self.cache.invalidate wrapped.invalidate = self.cache.invalidate
wrapped.invalidate_all = self.cache.invalidate_all wrapped.invalidate_all = self.cache.invalidate_all
@ -250,7 +254,7 @@ class CacheListDescriptor(object):
self.num_args = num_args self.num_args = num_args
self.list_name = list_name self.list_name = list_name
self.arg_names = inspect.getargspec(orig).args[1:num_args+1] self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
self.list_pos = self.arg_names.index(self.list_name) self.list_pos = self.arg_names.index(self.list_name)
self.cache = cache self.cache = cache
@ -299,6 +303,7 @@ class CacheListDescriptor(object):
args_to_call[self.list_name] = missing args_to_call[self.list_name] = missing
ret_d = defer.maybeDeferred( ret_d = defer.maybeDeferred(
preserve_context_over_fn,
self.function_to_call, self.function_to_call,
**args_to_call **args_to_call
) )
@ -308,7 +313,8 @@ class CacheListDescriptor(object):
# We need to create deferreds for each arg in the list so that # We need to create deferreds for each arg in the list so that
# we can insert the new deferred into the cache. # we can insert the new deferred into the cache.
for arg in missing: for arg in missing:
observer = ret_d.observe() with PreserveLoggingContext():
observer = ret_d.observe()
observer.addCallback(lambda r, arg: r.get(arg, None), arg) observer.addCallback(lambda r, arg: r.get(arg, None), arg)
observer = ObservableDeferred(observer) observer = ObservableDeferred(observer)
@ -327,10 +333,10 @@ class CacheListDescriptor(object):
cached[arg] = res cached[arg] = res
return defer.gatherResults( return preserve_context_over_deferred(defer.gatherResults(
cached.values(), cached.values(),
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)) ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)))
obj.__dict__[self.orig.__name__] = wrapped obj.__dict__[self.orig.__name__] = wrapped

View File

@ -55,7 +55,7 @@ class ExpiringCache(object):
def f(): def f():
self._prune_cache() self._prune_cache()
self._clock.looping_call(f, self._expiry_ms/2) self._clock.looping_call(f, self._expiry_ms / 2)
def __setitem__(self, key, value): def __setitem__(self, key, value):
now = self._clock.time_msec() now = self._clock.time_msec()

View File

@ -87,7 +87,8 @@ class SnapshotCache(object):
# expire from the rotation of that cache. # expire from the rotation of that cache.
self.next_result_cache[key] = result self.next_result_cache[key] = result
self.pending_result_cache.pop(key, None) self.pending_result_cache.pop(key, None)
return r
result.observe().addBoth(shuffle_along) result.addBoth(shuffle_along)
return result.observe() return result.observe()

View File

@ -32,7 +32,7 @@ class StreamChangeCache(object):
entities that may have changed since that position. If position key is too entities that may have changed since that position. If position key is too
old then the cache will simply return all given entities. old then the cache will simply return all given entities.
""" """
def __init__(self, name, current_stream_pos, max_size=10000): def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache={}):
self._max_size = max_size self._max_size = max_size
self._entity_to_key = {} self._entity_to_key = {}
self._cache = sorteddict() self._cache = sorteddict()
@ -40,6 +40,9 @@ class StreamChangeCache(object):
self.name = name self.name = name
caches_by_name[self.name] = self._cache caches_by_name[self.name] = self._cache
for entity, stream_pos in prefilled_cache.items():
self.entity_has_changed(entity, stream_pos)
def has_entity_changed(self, entity, stream_pos): def has_entity_changed(self, entity, stream_pos):
"""Returns True if the entity may have been updated since stream_pos """Returns True if the entity may have been updated since stream_pos
""" """
@ -49,15 +52,10 @@ class StreamChangeCache(object):
cache_counter.inc_misses(self.name) cache_counter.inc_misses(self.name)
return True return True
if stream_pos == self._earliest_known_stream_pos:
# If the same as the earliest key, assume nothing has changed.
cache_counter.inc_hits(self.name)
return False
latest_entity_change_pos = self._entity_to_key.get(entity, None) latest_entity_change_pos = self._entity_to_key.get(entity, None)
if latest_entity_change_pos is None: if latest_entity_change_pos is None:
cache_counter.inc_misses(self.name) cache_counter.inc_hits(self.name)
return True return False
if stream_pos < latest_entity_change_pos: if stream_pos < latest_entity_change_pos:
cache_counter.inc_misses(self.name) cache_counter.inc_misses(self.name)
@ -95,7 +93,7 @@ class StreamChangeCache(object):
if stream_pos > self._earliest_known_stream_pos: if stream_pos > self._earliest_known_stream_pos:
old_pos = self._entity_to_key.get(entity, None) old_pos = self._entity_to_key.get(entity, None)
if old_pos: if old_pos is not None:
stream_pos = max(stream_pos, old_pos) stream_pos = max(stream_pos, old_pos)
self._cache.pop(old_pos, None) self._cache.pop(old_pos, None)
self._cache[stream_pos] = entity self._cache[stream_pos] = entity

View File

@ -58,7 +58,7 @@ class TreeCache(object):
if n: if n:
break break
node_and_keys[i+1][0].pop(k) node_and_keys[i + 1][0].pop(k)
popped, cnt = _strip_and_count_entires(popped) popped, cnt = _strip_and_count_entires(popped)
self.size -= cnt self.size -= cnt

View File

@ -15,9 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.util.logcontext import ( from synapse.util.logcontext import PreserveLoggingContext
PreserveLoggingContext, preserve_context_over_deferred,
)
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
@ -97,6 +95,7 @@ class Signal(object):
Each observer callable may return a Deferred.""" Each observer callable may return a Deferred."""
self.observers.append(observer) self.observers.append(observer)
@defer.inlineCallbacks
def fire(self, *args, **kwargs): def fire(self, *args, **kwargs):
"""Invokes every callable in the observer list, passing in the args and """Invokes every callable in the observer list, passing in the args and
kwargs. Exceptions thrown by observers are logged but ignored. It is kwargs. Exceptions thrown by observers are logged but ignored. It is
@ -116,6 +115,7 @@ class Signal(object):
failure.getTracebackObject())) failure.getTracebackObject()))
if not self.suppress_failures: if not self.suppress_failures:
return failure return failure
return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb) return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
with PreserveLoggingContext(): with PreserveLoggingContext():
@ -124,8 +124,11 @@ class Signal(object):
for observer in self.observers for observer in self.observers
] ]
d = defer.gatherResults(deferreds, consumeErrors=True) res = yield defer.gatherResults(
deferreds, consumeErrors=True
).addErrback(unwrapFirstError)
d.addErrback(unwrapFirstError) defer.returnValue(res)
return preserve_context_over_deferred(d) def __repr__(self):
return "<Signal name=%r>" % (self.name,)

View File

@ -41,13 +41,14 @@ except:
class LoggingContext(object): class LoggingContext(object):
"""Additional context for log formatting. Contexts are scoped within a """Additional context for log formatting. Contexts are scoped within a
"with" block. Contexts inherit the state of their parent contexts. "with" block.
Args: Args:
name (str): Name for the context for debugging. name (str): Name for the context for debugging.
""" """
__slots__ = [ __slots__ = [
"parent_context", "name", "usage_start", "usage_end", "main_thread", "__dict__" "previous_context", "name", "usage_start", "usage_end", "main_thread",
"__dict__", "tag", "alive",
] ]
thread_local = threading.local() thread_local = threading.local()
@ -72,10 +73,13 @@ class LoggingContext(object):
def add_database_transaction(self, duration_ms): def add_database_transaction(self, duration_ms):
pass pass
def __nonzero__(self):
return False
sentinel = Sentinel() sentinel = Sentinel()
def __init__(self, name=None): def __init__(self, name=None):
self.parent_context = None self.previous_context = LoggingContext.current_context()
self.name = name self.name = name
self.ru_stime = 0. self.ru_stime = 0.
self.ru_utime = 0. self.ru_utime = 0.
@ -83,6 +87,8 @@ class LoggingContext(object):
self.db_txn_duration = 0. self.db_txn_duration = 0.
self.usage_start = None self.usage_start = None
self.main_thread = threading.current_thread() self.main_thread = threading.current_thread()
self.tag = ""
self.alive = True
def __str__(self): def __str__(self):
return "%s@%x" % (self.name, id(self)) return "%s@%x" % (self.name, id(self))
@ -101,6 +107,7 @@ class LoggingContext(object):
The context that was previously active The context that was previously active
""" """
current = cls.current_context() current = cls.current_context()
if current is not context: if current is not context:
current.stop() current.stop()
cls.thread_local.current_context = context cls.thread_local.current_context = context
@ -109,9 +116,13 @@ class LoggingContext(object):
def __enter__(self): def __enter__(self):
"""Enters this logging context into thread local storage""" """Enters this logging context into thread local storage"""
if self.parent_context is not None: old_context = self.set_current_context(self)
raise Exception("Attempt to enter logging context multiple times") if self.previous_context != old_context:
self.parent_context = self.set_current_context(self) logger.warn(
"Expected previous context %r, found %r",
self.previous_context, old_context
)
self.alive = True
return self return self
def __exit__(self, type, value, traceback): def __exit__(self, type, value, traceback):
@ -120,7 +131,7 @@ class LoggingContext(object):
Returns: Returns:
None to avoid suppressing any exeptions that were thrown. None to avoid suppressing any exeptions that were thrown.
""" """
current = self.set_current_context(self.parent_context) current = self.set_current_context(self.previous_context)
if current is not self: if current is not self:
if current is self.sentinel: if current is self.sentinel:
logger.debug("Expected logging context %s has been lost", self) logger.debug("Expected logging context %s has been lost", self)
@ -130,16 +141,11 @@ class LoggingContext(object):
current, current,
self self
) )
self.parent_context = None self.previous_context = None
self.alive = False
def __getattr__(self, name):
"""Delegate member lookup to parent context"""
return getattr(self.parent_context, name)
def copy_to(self, record): def copy_to(self, record):
"""Copy fields from this context and its parents to the record""" """Copy fields from this context to the record"""
if self.parent_context is not None:
self.parent_context.copy_to(record)
for key, value in self.__dict__.items(): for key, value in self.__dict__.items():
setattr(record, key, value) setattr(record, key, value)
@ -208,7 +214,7 @@ class PreserveLoggingContext(object):
exited. Used to restore the context after a function using exited. Used to restore the context after a function using
@defer.inlineCallbacks is resumed by a callback from the reactor.""" @defer.inlineCallbacks is resumed by a callback from the reactor."""
__slots__ = ["current_context", "new_context"] __slots__ = ["current_context", "new_context", "has_parent"]
def __init__(self, new_context=LoggingContext.sentinel): def __init__(self, new_context=LoggingContext.sentinel):
self.new_context = new_context self.new_context = new_context
@ -219,12 +225,27 @@ class PreserveLoggingContext(object):
self.new_context self.new_context
) )
if self.current_context:
self.has_parent = self.current_context.previous_context is not None
if not self.current_context.alive:
logger.debug(
"Entering dead context: %s",
self.current_context,
)
def __exit__(self, type, value, traceback): def __exit__(self, type, value, traceback):
"""Restores the current logging context""" """Restores the current logging context"""
LoggingContext.set_current_context(self.current_context) context = LoggingContext.set_current_context(self.current_context)
if context != self.new_context:
logger.debug(
"Unexpected logging context: %s is not %s",
context, self.new_context,
)
if self.current_context is not LoggingContext.sentinel: if self.current_context is not LoggingContext.sentinel:
if self.current_context.parent_context is None: if not self.current_context.alive:
logger.warn( logger.debug(
"Restoring dead context: %s", "Restoring dead context: %s",
self.current_context, self.current_context,
) )
@ -284,3 +305,74 @@ def preserve_context_over_deferred(deferred):
d = _PreservingContextDeferred(current_context) d = _PreservingContextDeferred(current_context)
deferred.chainDeferred(d) deferred.chainDeferred(d)
return d return d
def preserve_fn(f):
"""Ensures that function is called with correct context and that context is
restored after return. Useful for wrapping functions that return a deferred
which you don't yield on.
"""
current = LoggingContext.current_context()
def g(*args, **kwargs):
with PreserveLoggingContext(current):
return f(*args, **kwargs)
return g
# modules to ignore in `logcontext_tracer`
_to_ignore = [
"synapse.util.logcontext",
"synapse.http.server",
"synapse.storage._base",
"synapse.util.async",
]
def logcontext_tracer(frame, event, arg):
"""A tracer that logs whenever a logcontext "unexpectedly" changes within
a function. Probably inaccurate.
Use by calling `sys.settrace(logcontext_tracer)` in the main thread.
"""
if event == 'call':
name = frame.f_globals["__name__"]
if name.startswith("synapse"):
if name == "synapse.util.logcontext":
if frame.f_code.co_name in ["__enter__", "__exit__"]:
tracer = frame.f_back.f_trace
if tracer:
tracer.just_changed = True
tracer = frame.f_trace
if tracer:
return tracer
if not any(name.startswith(ig) for ig in _to_ignore):
return LineTracer()
class LineTracer(object):
__slots__ = ["context", "just_changed"]
def __init__(self):
self.context = LoggingContext.current_context()
self.just_changed = False
def __call__(self, frame, event, arg):
if event in 'line':
if self.just_changed:
self.context = LoggingContext.current_context()
self.just_changed = False
else:
c = LoggingContext.current_context()
if c != self.context:
logger.info(
"Context changed! %s -> %s, %s, %s",
self.context, c,
frame.f_code.co_filename, frame.f_lineno
)
self.context = c
return self

View File

@ -111,7 +111,7 @@ def time_function(f):
_log_debug_as_f( _log_debug_as_f(
f, f,
"[FUNC END] {%s-%d} %f", "[FUNC END] {%s-%d} %f",
(func_name, id, end-start,), (func_name, id, end - start,),
) )
return r return r
@ -168,3 +168,38 @@ def trace_function(f):
wrapped.__name__ = func_name wrapped.__name__ = func_name
return wrapped return wrapped
def get_previous_frames():
s = inspect.currentframe().f_back.f_back
to_return = []
while s:
if s.f_globals["__name__"].startswith("synapse"):
filename, lineno, function, _, _ = inspect.getframeinfo(s)
args_string = inspect.formatargvalues(*inspect.getargvalues(s))
to_return.append("{{ %s:%d %s - Args: %s }}" % (
filename, lineno, function, args_string
))
s = s.f_back
return ", ". join(to_return)
def get_previous_frame(ignore=[]):
s = inspect.currentframe().f_back.f_back
while s:
if s.f_globals["__name__"].startswith("synapse"):
if not any(s.f_globals["__name__"].startswith(ig) for ig in ignore):
filename, lineno, function, _, _ = inspect.getframeinfo(s)
args_string = inspect.formatargvalues(*inspect.getargvalues(s))
return "{{ %s:%d %s - Args: %s }}" % (
filename, lineno, function, args_string
)
s = s.f_back
return None

97
synapse/util/metrics.py Normal file
View File

@ -0,0 +1,97 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util.logcontext import LoggingContext
import synapse.metrics
import logging
logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__)
block_timer = metrics.register_distribution(
"block_timer",
labels=["block_name"]
)
block_ru_utime = metrics.register_distribution(
"block_ru_utime", labels=["block_name"]
)
block_ru_stime = metrics.register_distribution(
"block_ru_stime", labels=["block_name"]
)
block_db_txn_count = metrics.register_distribution(
"block_db_txn_count", labels=["block_name"]
)
block_db_txn_duration = metrics.register_distribution(
"block_db_txn_duration", labels=["block_name"]
)
class Measure(object):
__slots__ = [
"clock", "name", "start_context", "start", "new_context", "ru_utime",
"ru_stime", "db_txn_count", "db_txn_duration"
]
def __init__(self, clock, name):
self.clock = clock
self.name = name
self.start_context = None
self.start = None
def __enter__(self):
self.start = self.clock.time_msec()
self.start_context = LoggingContext.current_context()
if self.start_context:
self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
self.db_txn_count = self.start_context.db_txn_count
self.db_txn_duration = self.start_context.db_txn_duration
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None or not self.start_context:
return
duration = self.clock.time_msec() - self.start
block_timer.inc_by(duration, self.name)
context = LoggingContext.current_context()
if context != self.start_context:
logger.warn(
"Context have unexpectedly changed from '%s' to '%s'. (%r)",
context, self.start_context, self.name
)
return
if not context:
logger.warn("Expected context. (%r)", self.name)
return
ru_utime, ru_stime = context.get_resource_usage()
block_ru_utime.inc_by(ru_utime - self.ru_utime, self.name)
block_ru_stime.inc_by(ru_stime - self.ru_stime, self.name)
block_db_txn_count.inc_by(context.db_txn_count - self.db_txn_count, self.name)
block_db_txn_duration.inc_by(
context.db_txn_duration - self.db_txn_duration, self.name
)

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import LimitExceededError from synapse.api.errors import LimitExceededError
from synapse.util.async import sleep from synapse.util.async import sleep
from synapse.util.logcontext import preserve_fn
import collections import collections
import contextlib import contextlib
@ -163,7 +164,7 @@ class _PerHostRatelimiter(object):
"Ratelimit [%s]: sleeping req", "Ratelimit [%s]: sleeping req",
id(request_id), id(request_id),
) )
ret_defer = sleep(self.sleep_msec/1000.0) ret_defer = preserve_fn(sleep)(self.sleep_msec / 1000.0)
self.sleeping_requests.add(request_id) self.sleeping_requests.add(request_id)

14
tests/config/__init__.py Normal file
View File

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path
import shutil
import tempfile
from synapse.config.homeserver import HomeServerConfig
from tests import unittest
class ConfigGenerationTestCase(unittest.TestCase):
def setUp(self):
self.dir = tempfile.mkdtemp()
print self.dir
self.file = os.path.join(self.dir, "homeserver.yaml")
def tearDown(self):
shutil.rmtree(self.dir)
def test_generate_config_generates_files(self):
HomeServerConfig.load_config("", [
"--generate-config",
"-c", self.file,
"--report-stats=yes",
"-H", "lemurs.win"
])
self.assertSetEqual(
set([
"homeserver.yaml",
"lemurs.win.log.config",
"lemurs.win.signing.key",
"lemurs.win.tls.crt",
"lemurs.win.tls.dh",
"lemurs.win.tls.key",
]),
set(os.listdir(self.dir))
)

78
tests/config/test_load.py Normal file
View File

@ -0,0 +1,78 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path
import shutil
import tempfile
import yaml
from synapse.config.homeserver import HomeServerConfig
from tests import unittest
class ConfigLoadingTestCase(unittest.TestCase):
def setUp(self):
self.dir = tempfile.mkdtemp()
print self.dir
self.file = os.path.join(self.dir, "homeserver.yaml")
def tearDown(self):
shutil.rmtree(self.dir)
def test_load_fails_if_server_name_missing(self):
self.generate_config_and_remove_lines_containing("server_name")
with self.assertRaises(Exception):
HomeServerConfig.load_config("", ["-c", self.file])
def test_generates_and_loads_macaroon_secret_key(self):
self.generate_config()
with open(self.file,
"r") as f:
raw = yaml.load(f)
self.assertIn("macaroon_secret_key", raw)
config = HomeServerConfig.load_config("", ["-c", self.file])
self.assertTrue(
hasattr(config, "macaroon_secret_key"),
"Want config to have attr macaroon_secret_key"
)
if len(config.macaroon_secret_key) < 5:
self.fail(
"Want macaroon secret key to be string of at least length 5,"
"was: %r" % (config.macaroon_secret_key,)
)
def test_load_succeeds_if_macaroon_secret_key_missing(self):
self.generate_config_and_remove_lines_containing("macaroon")
config1 = HomeServerConfig.load_config("", ["-c", self.file])
config2 = HomeServerConfig.load_config("", ["-c", self.file])
self.assertEqual(config1.macaroon_secret_key, config2.macaroon_secret_key)
def generate_config(self):
HomeServerConfig.load_config("", [
"--generate-config",
"-c", self.file,
"--report-stats=yes",
"-H", "lemurs.win"
])
def generate_config_and_remove_lines_containing(self, needle):
self.generate_config()
with open(self.file, "r") as f:
contents = f.readlines()
contents = [l for l in contents if needle not in l]
with open(self.file, "w") as f:
f.write("".join(contents))

View File

@ -122,7 +122,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0) self.ratelimiter.send_message.return_value = (True, 0)
hs.config.enable_registration_captcha = False hs.config.enable_registration_captcha = False
hs.config.disable_registration = False hs.config.enable_registration = True
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()

View File

@ -41,7 +41,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.hs.hostname = "superbig~testing~thing.com" self.hs.hostname = "superbig~testing~thing.com"
self.hs.get_auth = Mock(return_value=self.auth) self.hs.get_auth = Mock(return_value=self.auth)
self.hs.get_handlers = Mock(return_value=self.handlers) self.hs.get_handlers = Mock(return_value=self.handlers)
self.hs.config.disable_registration = False self.hs.config.enable_registration = True
# init the thing we're testing # init the thing we're testing
self.servlet = RegisterRestServlet(self.hs) self.servlet = RegisterRestServlet(self.hs)
@ -120,7 +120,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
})) }))
def test_POST_disabled_registration(self): def test_POST_disabled_registration(self):
self.hs.config.disable_registration = True self.hs.config.enable_registration = False
self.request_data = json.dumps({ self.request_data = json.dumps({
"username": "kermit", "username": "kermit",
"password": "monkey" "password": "monkey"

View File

@ -51,32 +51,6 @@ class RoomStoreTestCase(unittest.TestCase):
(yield self.store.get_room(self.room.to_string())) (yield self.store.get_room(self.room.to_string()))
) )
@defer.inlineCallbacks
def test_get_rooms(self):
# get_rooms does an INNER JOIN on the room_aliases table :(
rooms = yield self.store.get_rooms(is_public=True)
# Should be empty before we add the alias
self.assertEquals([], rooms)
yield self.store.create_room_alias_association(
room_alias=self.alias,
room_id=self.room.to_string(),
servers=["test"]
)
rooms = yield self.store.get_rooms(is_public=True)
self.assertEquals(1, len(rooms))
self.assertEquals({
"name": None,
"room_id": self.room.to_string(),
"topic": None,
"aliases": [self.alias.to_string()],
"world_readable": False,
"guest_can_join": False,
}, rooms[0])
class RoomEventsStoreTestCase(unittest.TestCase): class RoomEventsStoreTestCase(unittest.TestCase):

View File

@ -5,6 +5,7 @@ from .. import unittest
from synapse.util.async import sleep from synapse.util.async import sleep
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
class LoggingContextTestCase(unittest.TestCase): class LoggingContextTestCase(unittest.TestCase):
def _check_test_key(self, value): def _check_test_key(self, value):
@ -17,15 +18,6 @@ class LoggingContextTestCase(unittest.TestCase):
context_one.test_key = "test" context_one.test_key = "test"
self._check_test_key("test") self._check_test_key("test")
def test_chaining(self):
with LoggingContext() as context_one:
context_one.test_key = "one"
with LoggingContext() as context_two:
self._check_test_key("one")
context_two.test_key = "two"
self._check_test_key("two")
self._check_test_key("one")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_sleep(self): def test_sleep(self):
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -46,9 +46,10 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config = Mock() config = Mock()
config.signing_key = [MockKey()] config.signing_key = [MockKey()]
config.event_cache_size = 1 config.event_cache_size = 1
config.disable_registration = False config.enable_registration = True
config.macaroon_secret_key = "not even a little secret" config.macaroon_secret_key = "not even a little secret"
config.server_name = "server.under.test" config.server_name = "server.under.test"
config.trusted_third_party_id_servers = []
if "clock" not in kargs: if "clock" not in kargs:
kargs["clock"] = MockClock() kargs["clock"] = MockClock()

View File

@ -11,7 +11,7 @@ deps =
setenv = setenv =
PYTHONDONTWRITEBYTECODE = no_byte_code PYTHONDONTWRITEBYTECODE = no_byte_code
commands = commands =
/bin/bash -c "coverage run {env:COVERAGE_OPTS:} --source={toxinidir}/synapse \ /bin/bash -c "find {toxinidir} -name '*.pyc' -delete ; coverage run {env:COVERAGE_OPTS:} --source={toxinidir}/synapse \
{envbindir}/trial {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:}" {envbindir}/trial {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:}"
{env:DUMP_COVERAGE_COMMAND:coverage report -m} {env:DUMP_COVERAGE_COMMAND:coverage report -m}