Merge branch 'develop' of github.com:matrix-org/synapse into release-v0.12.1

This commit is contained in:
Erik Johnston 2016-01-29 13:52:12 +00:00
commit fd142c29d9
40 changed files with 939 additions and 1034 deletions

View File

@ -26,7 +26,7 @@ TOX_BIN=$WORKSPACE/.tox/py27/bin
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch)
(cd .sytest-base; git fetch -p)
fi
rm -rf sytest
@ -52,7 +52,7 @@ RUN_POSTGRES=""
for port in $(($PORT_BASE + 1)) $(($PORT_BASE + 2)); do
if psql synapse_jenkins_$port <<< ""; then
RUN_POSTGRES=$RUN_POSTGRES:$port
RUN_POSTGRES="$RUN_POSTGRES:$port"
cat > localhost-$port/database.yaml << EOF
name: psycopg2
args:
@ -62,7 +62,7 @@ EOF
done
# Run if both postgresql databases exist
if test $RUN_POSTGRES = ":$(($PORT_BASE + 1)):$(($PORT_BASE + 2))"; then
if test "$RUN_POSTGRES" = ":$(($PORT_BASE + 1)):$(($PORT_BASE + 2))"; then
echo >&2 "Running sytest with PostgreSQL";
$TOX_BIN/pip install psycopg2
./run-tests.pl --coverage -O tap --synapse-directory $WORKSPACE \

View File

@ -15,6 +15,8 @@
from synapse.api.errors import SynapseError
from synapse.types import UserID, RoomID
import ujson as json
class Filtering(object):
@ -149,6 +151,9 @@ class FilterCollection(object):
"include_leave", False
)
def __repr__(self):
return "<FilterCollection %s>" % (json.dumps(self._filter_json),)
def get_filter_json(self):
return self._filter_json

View File

@ -50,16 +50,14 @@ from twisted.cred import checkers, portal
from twisted.internet import reactor, task, defer
from twisted.application import service
from twisted.enterprise import adbapi
from twisted.web.resource import Resource, EncodingResourceWrapper
from twisted.web.static import File
from twisted.web.server import Site, GzipEncoderFactory, Request
from synapse.http.server import JsonResource, RootRedirect
from synapse.http.server import RootRedirect
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.rest.key.v1.server_key_resource import LocalKey
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.api.urls import (
FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
SERVER_KEY_PREFIX, MEDIA_PREFIX, STATIC_PREFIX,
@ -69,6 +67,7 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
from synapse.util.logcontext import LoggingContext
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.federation.transport.server import TransportLayerServer
from synapse import events
@ -95,80 +94,37 @@ def gz_wrap(r):
return EncodingResourceWrapper(r, [GzipEncoderFactory()])
def build_resource_for_web_client(hs):
webclient_path = hs.get_config().web_client_location
if not webclient_path:
try:
import syweb
except ImportError:
quit_with_error(
"Could not find a webclient.\n\n"
"Please either install the matrix-angular-sdk or configure\n"
"the location of the source to serve via the configuration\n"
"option `web_client_location`\n\n"
"To install the `matrix-angular-sdk` via pip, run:\n\n"
" pip install '%(dep)s'\n"
"\n"
"You can also disable hosting of the webclient via the\n"
"configuration option `web_client`\n"
% {"dep": DEPENDENCY_LINKS["matrix-angular-sdk"]}
)
syweb_path = os.path.dirname(syweb.__file__)
webclient_path = os.path.join(syweb_path, "webclient")
# GZip is disabled here due to
# https://twistedmatrix.com/trac/ticket/7678
# (It can stay enabled for the API resources: they call
# write() with the whole body and then finish() straight
# after and so do not trigger the bug.
# GzipFile was removed in commit 184ba09
# return GzipFile(webclient_path) # TODO configurable?
return File(webclient_path) # TODO configurable?
class SynapseHomeServer(HomeServer):
def build_http_client(self):
return MatrixFederationHttpClient(self)
def build_client_resource(self):
return ClientRestResource(self)
def build_resource_for_federation(self):
return JsonResource(self)
def build_resource_for_web_client(self):
webclient_path = self.get_config().web_client_location
if not webclient_path:
try:
import syweb
except ImportError:
quit_with_error(
"Could not find a webclient.\n\n"
"Please either install the matrix-angular-sdk or configure\n"
"the location of the source to serve via the configuration\n"
"option `web_client_location`\n\n"
"To install the `matrix-angular-sdk` via pip, run:\n\n"
" pip install '%(dep)s'\n"
"\n"
"You can also disable hosting of the webclient via the\n"
"configuration option `web_client`\n"
% {"dep": DEPENDENCY_LINKS["matrix-angular-sdk"]}
)
syweb_path = os.path.dirname(syweb.__file__)
webclient_path = os.path.join(syweb_path, "webclient")
# GZip is disabled here due to
# https://twistedmatrix.com/trac/ticket/7678
# (It can stay enabled for the API resources: they call
# write() with the whole body and then finish() straight
# after and so do not trigger the bug.
# GzipFile was removed in commit 184ba09
# return GzipFile(webclient_path) # TODO configurable?
return File(webclient_path) # TODO configurable?
def build_resource_for_static_content(self):
# This is old and should go away: not going to bother adding gzip
return File(
os.path.join(os.path.dirname(synapse.__file__), "static")
)
def build_resource_for_content_repo(self):
return ContentRepoResource(
self, self.config.uploads_path, self.auth, self.content_addr
)
def build_resource_for_media_repository(self):
return MediaRepositoryResource(self)
def build_resource_for_server_key(self):
return LocalKey(self)
def build_resource_for_server_key_v2(self):
return KeyApiV2Resource(self)
def build_resource_for_metrics(self):
if self.get_config().enable_metrics:
return MetricsResource(self)
else:
return None
def build_db_pool(self):
name = self.db_config["name"]
return adbapi.ConnectionPool(
name,
**self.db_config.get("args", {})
)
def _listener_http(self, config, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
@ -178,13 +134,11 @@ class SynapseHomeServer(HomeServer):
if tls and config.no_tls:
return
metrics_resource = self.get_resource_for_metrics()
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "client":
client_resource = self.get_client_resource()
client_resource = ClientRestResource(self)
if res["compress"]:
client_resource = gz_wrap(client_resource)
@ -198,31 +152,35 @@ class SynapseHomeServer(HomeServer):
if name == "federation":
resources.update({
FEDERATION_PREFIX: self.get_resource_for_federation(),
FEDERATION_PREFIX: TransportLayerServer(self),
})
if name in ["static", "client"]:
resources.update({
STATIC_PREFIX: self.get_resource_for_static_content(),
STATIC_PREFIX: File(
os.path.join(os.path.dirname(synapse.__file__), "static")
),
})
if name in ["media", "federation", "client"]:
resources.update({
MEDIA_PREFIX: self.get_resource_for_media_repository(),
CONTENT_REPO_PREFIX: self.get_resource_for_content_repo(),
MEDIA_PREFIX: MediaRepositoryResource(self),
CONTENT_REPO_PREFIX: ContentRepoResource(
self, self.config.uploads_path, self.auth, self.content_addr
),
})
if name in ["keys", "federation"]:
resources.update({
SERVER_KEY_PREFIX: self.get_resource_for_server_key(),
SERVER_KEY_V2_PREFIX: self.get_resource_for_server_key_v2(),
SERVER_KEY_PREFIX: LocalKey(self),
SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self),
})
if name == "webclient":
resources[WEB_CLIENT_PREFIX] = self.get_resource_for_web_client()
resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self)
if name == "metrics" and metrics_resource:
resources[METRICS_PREFIX] = metrics_resource
if name == "metrics" and self.get_config().enable_metrics:
resources[METRICS_PREFIX] = MetricsResource(self)
root_resource = create_resource_tree(resources)
if tls:
@ -296,6 +254,18 @@ class SynapseHomeServer(HomeServer):
except IncorrectDatabaseSetup as e:
quit_with_error(e.message)
def get_db_conn(self):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
self.database_engine.on_new_connection(db_conn)
return db_conn
def quit_with_error(error_string):
message_lines = error_string.split("\n")
@ -432,13 +402,7 @@ def setup(config_options):
logger.info("Preparing database: %s...", config.database_config['name'])
try:
db_conn = database_engine.module.connect(
**{
k: v for k, v in config.database_config.get("args", {}).items()
if not k.startswith("cp_")
}
)
db_conn = hs.get_db_conn()
database_engine.prepare_database(db_conn)
hs.run_startup_checks(db_conn, database_engine)
@ -453,13 +417,17 @@ def setup(config_options):
logger.info("Database prepared in %s.", config.database_config['name'])
hs.setup()
hs.start_listening()
hs.get_pusherpool().start()
hs.get_state_handler().start_caching()
hs.get_datastore().start_profiling()
hs.get_datastore().start_doing_background_updates()
hs.get_replication_layer().start_get_pdu_cache()
def start():
hs.get_pusherpool().start()
hs.get_state_handler().start_caching()
hs.get_datastore().start_profiling()
hs.get_datastore().start_doing_background_updates()
hs.get_replication_layer().start_get_pdu_cache()
reactor.callWhenRunning(start)
return hs
@ -675,7 +643,7 @@ def _resource_id(resource, path_seg):
the mapping should looks like _resource_id(A,C) = B.
Args:
resource (Resource): The *parent* Resource
resource (Resource): The *parent* Resourceb
path_seg (str): The name of the child Resource to be attached.
Returns:
str: A unique string which can be a key to the child Resource.

View File

@ -17,15 +17,10 @@
"""
from .replication import ReplicationLayer
from .transport import TransportLayer
from .transport.client import TransportLayerClient
def initialize_http_replication(homeserver):
transport = TransportLayer(
homeserver,
homeserver.hostname,
server=homeserver.get_resource_for_federation(),
client=homeserver.get_http_client()
)
transport = TransportLayerClient(homeserver)
return ReplicationLayer(homeserver, transport)

View File

@ -54,8 +54,6 @@ class ReplicationLayer(FederationClient, FederationServer):
self.keyring = hs.get_keyring()
self.transport_layer = transport_layer
self.transport_layer.register_received_handler(self)
self.transport_layer.register_request_handler(self)
self.federation_client = self

View File

@ -20,55 +20,3 @@ By default this is done over HTTPS (and all home servers are required to
support HTTPS), however individual pairings of servers may decide to
communicate over a different (albeit still reliable) protocol.
"""
from .server import TransportLayerServer
from .client import TransportLayerClient
from synapse.util.ratelimitutils import FederationRateLimiter
class TransportLayer(TransportLayerServer, TransportLayerClient):
"""This is a basic implementation of the transport layer that translates
transactions and other requests to/from HTTP.
Attributes:
server_name (str): Local home server host
server (synapse.http.server.HttpServer): the http server to
register listeners on
client (synapse.http.client.HttpClient): the http client used to
send requests
request_handler (TransportRequestHandler): The handler to fire when we
receive requests for data.
received_handler (TransportReceivedHandler): The handler to fire when
we receive data.
"""
def __init__(self, homeserver, server_name, server, client):
"""
Args:
server_name (str): Local home server host
server (synapse.protocol.http.HttpServer): the http server to
register listeners on
client (synapse.protocol.http.HttpClient): the http client used to
send requests
"""
self.keyring = homeserver.get_keyring()
self.clock = homeserver.get_clock()
self.server_name = server_name
self.server = server
self.client = client
self.request_handler = None
self.received_handler = None
self.ratelimiter = FederationRateLimiter(
self.clock,
window_size=homeserver.config.federation_rc_window_size,
sleep_limit=homeserver.config.federation_rc_sleep_limit,
sleep_msec=homeserver.config.federation_rc_sleep_delay,
reject_limit=homeserver.config.federation_rc_reject_limit,
concurrent_requests=homeserver.config.federation_rc_concurrent,
)

View File

@ -28,6 +28,10 @@ logger = logging.getLogger(__name__)
class TransportLayerClient(object):
"""Sends federation HTTP requests to other servers"""
def __init__(self, hs):
self.server_name = hs.hostname
self.client = hs.get_http_client()
@log_function
def get_room_state(self, destination, room_id, event_id):
""" Requests all state for a given room from the given server at the

View File

@ -17,7 +17,8 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError
from synapse.util.logutils import log_function
from synapse.http.server import JsonResource
from synapse.util.ratelimitutils import FederationRateLimiter
import functools
import logging
@ -28,9 +29,41 @@ import re
logger = logging.getLogger(__name__)
class TransportLayerServer(object):
class TransportLayerServer(JsonResource):
"""Handles incoming federation HTTP requests"""
def __init__(self, hs):
self.hs = hs
self.clock = hs.get_clock()
super(TransportLayerServer, self).__init__(hs)
self.authenticator = Authenticator(hs)
self.ratelimiter = FederationRateLimiter(
self.clock,
window_size=hs.config.federation_rc_window_size,
sleep_limit=hs.config.federation_rc_sleep_limit,
sleep_msec=hs.config.federation_rc_sleep_delay,
reject_limit=hs.config.federation_rc_reject_limit,
concurrent_requests=hs.config.federation_rc_concurrent,
)
self.register_servlets()
def register_servlets(self):
register_servlets(
self.hs,
resource=self,
ratelimiter=self.ratelimiter,
authenticator=self.authenticator,
)
class Authenticator(object):
def __init__(self, hs):
self.keyring = hs.get_keyring()
self.server_name = hs.hostname
# A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks
def authenticate_request(self, request):
@ -98,37 +131,9 @@ class TransportLayerServer(object):
defer.returnValue((origin, content))
@log_function
def register_received_handler(self, handler):
""" Register a handler that will be fired when we receive data.
Args:
handler (TransportReceivedHandler)
"""
FederationSendServlet(
handler,
authenticator=self,
ratelimiter=self.ratelimiter,
server_name=self.server_name,
).register(self.server)
@log_function
def register_request_handler(self, handler):
""" Register a handler that will be fired when we get asked for data.
Args:
handler (TransportRequestHandler)
"""
for servletclass in SERVLET_CLASSES:
servletclass(
handler,
authenticator=self,
ratelimiter=self.ratelimiter,
).register(self.server)
class BaseFederationServlet(object):
def __init__(self, handler, authenticator, ratelimiter):
def __init__(self, handler, authenticator, ratelimiter, server_name):
self.handler = handler
self.authenticator = authenticator
self.ratelimiter = ratelimiter
@ -172,7 +177,9 @@ class FederationSendServlet(BaseFederationServlet):
PATH = "/send/([^/]*)/"
def __init__(self, handler, server_name, **kwargs):
super(FederationSendServlet, self).__init__(handler, **kwargs)
super(FederationSendServlet, self).__init__(
handler, server_name=server_name, **kwargs
)
self.server_name = server_name
# This is when someone is trying to send us a bunch of data.
@ -432,6 +439,7 @@ class On3pidBindServlet(BaseFederationServlet):
SERVLET_CLASSES = (
FederationSendServlet,
FederationPullServlet,
FederationEventServlet,
FederationStateServlet,
@ -451,3 +459,13 @@ SERVLET_CLASSES = (
FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet,
)
def register_servlets(hs, resource, authenticator, ratelimiter):
for servletclass in SERVLET_CLASSES:
servletclass(
handler=hs.get_replication_layer(),
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
).register(resource)

View File

@ -1186,7 +1186,13 @@ class FederationHandler(BaseHandler):
try:
self.auth.check(e, auth_events=auth_for_e)
except AuthError as err:
except SynapseError as err:
# we may get SynapseErrors here as well as AuthErrors. For
# instance, there are a couple of (ancient) events in some
# rooms whose senders do not have the correct sigil; these
# cause SynapseErrors in auth.check. We don't want to give up
# the attempt to federate altogether in such cases.
logger.warn(
"Rejecting %s because %s",
e.event_id, err.msg

View File

@ -16,7 +16,7 @@
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError, AuthError, Codes
from synapse.api.errors import AuthError, Codes
from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
@ -105,8 +105,6 @@ class MessageHandler(BaseHandler):
room_token = pagin_config.from_token.room_key
room_token = RoomStreamToken.parse(room_token)
if room_token.topological is None:
raise SynapseError(400, "Invalid token")
pagin_config.from_token = pagin_config.from_token.copy_and_replace(
"room_key", str(room_token)
@ -117,27 +115,31 @@ class MessageHandler(BaseHandler):
membership, member_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id
)
if membership == Membership.LEAVE:
# If they have left the room then clamp the token to be before
# they left the room.
leave_token = yield self.store.get_topological_token_for_event(
member_event_id
if source_config.direction == 'b':
# if we're going backwards, we might need to backfill. This
# requires that we have a topo token.
if room_token.topological:
max_topo = room_token.topological
else:
max_topo = yield self.store.get_max_topological_token_for_stream_and_room(
room_id, room_token.stream
)
if membership == Membership.LEAVE:
# If they have left the room then clamp the token to be before
# they left the room, to save the effort of loading from the
# database.
leave_token = yield self.store.get_topological_token_for_event(
member_event_id
)
leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < max_topo:
source_config.from_key = str(leave_token)
yield self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, max_topo
)
leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < room_token.topological:
source_config.from_key = str(leave_token)
if source_config.direction == "f":
if source_config.to_key is None:
source_config.to_key = str(leave_token)
else:
to_token = RoomStreamToken.parse(source_config.to_key)
if leave_token.topological < to_token.topological:
source_config.to_key = str(leave_token)
yield self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, room_token.topological
)
events, next_key = yield data_source.get_pagination_rows(
requester.user, source_config, room_id

View File

@ -72,7 +72,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
)
class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [
class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [
"room_id", # str
"timeline", # TimelineBatch
"state", # dict[(str, str), FrozenEvent]
@ -298,46 +298,19 @@ class SyncHandler(BaseHandler):
room_id, sync_config, now_token, since_token=timeline_since_token
)
notifs = yield self.unread_notifs_for_room_id(
room_id, sync_config, ephemeral_by_room
room_sync = yield self.incremental_sync_with_gap_for_room(
room_id, sync_config,
now_token=now_token,
since_token=timeline_since_token,
ephemeral_by_room=ephemeral_by_room,
tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
all_ephemeral_by_room=ephemeral_by_room,
batch=batch,
full_state=True,
)
unread_notifications = {}
if notifs is not None:
unread_notifications["notification_count"] = len(notifs)
unread_notifications["highlight_count"] = len([
1 for notif in notifs if _action_has_highlight(notif["actions"])
])
current_state = yield self.get_state_at(room_id, now_token)
current_state = {
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(
current_state.values()
)
}
account_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
)
account_data = sync_config.filter_collection.filter_room_account_data(
account_data
)
ephemeral = sync_config.filter_collection.filter_room_ephemeral(
ephemeral_by_room.get(room_id, [])
)
defer.returnValue(JoinedSyncResult(
room_id=room_id,
timeline=batch,
state=current_state,
ephemeral=ephemeral,
account_data=account_data,
unread_notifications=unread_notifications,
))
defer.returnValue(room_sync)
def account_data_for_user(self, account_data):
account_data_events = []
@ -429,44 +402,20 @@ class SyncHandler(BaseHandler):
defer.returnValue((now_token, ephemeral_by_room))
@defer.inlineCallbacks
def full_state_sync_for_archived_room(self, room_id, sync_config,
leave_event_id, leave_token,
timeline_since_token, tags_by_room,
account_data_by_room):
"""Sync a room for a client which is starting without any state
Returns:
A Deferred JoinedSyncResult.
A Deferred ArchivedSyncResult.
"""
batch = yield self.load_filtered_recents(
room_id, sync_config, leave_token, since_token=timeline_since_token
return self.incremental_sync_for_archived_room(
sync_config, room_id, leave_event_id, timeline_since_token, tags_by_room,
account_data_by_room, full_state=True, leave_token=leave_token,
)
leave_state = yield self.store.get_state_for_event(leave_event_id)
leave_state = {
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(
leave_state.values()
)
}
account_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
)
account_data = sync_config.filter_collection.filter_room_account_data(
account_data
)
defer.returnValue(ArchivedSyncResult(
room_id=room_id,
timeline=batch,
state=leave_state,
account_data=account_data,
))
@defer.inlineCallbacks
def incremental_sync_with_gap(self, sync_config, since_token):
""" Get the incremental delta needed to bring the client up to
@ -512,154 +461,127 @@ class SyncHandler(BaseHandler):
sync_config.user
)
user_id = sync_config.user.to_string()
timeline_limit = sync_config.filter_collection.timeline_limit()
room_events, _ = yield self.store.get_room_events_stream(
sync_config.user.to_string(),
from_key=since_token.room_key,
to_key=now_token.room_key,
limit=timeline_limit + 1,
)
tags_by_room = yield self.store.get_updated_tags(
sync_config.user.to_string(),
user_id,
since_token.account_data_key,
)
account_data, account_data_by_room = (
yield self.store.get_updated_account_data_for_user(
sync_config.user.to_string(),
user_id,
since_token.account_data_key,
)
)
joined = []
# Get a list of membership change events that have happened.
rooms_changed = yield self.store.get_room_changes_for_user(
user_id, since_token.room_key, now_token.room_key
)
mem_change_events_by_room_id = {}
for event in rooms_changed:
mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
newly_joined_rooms = []
archived = []
if len(room_events) <= timeline_limit:
# There is no gap in any of the rooms. Therefore we can just
# partition the new events by room and return them.
logger.debug("Got %i events for incremental sync - not limited",
len(room_events))
invited = []
for room_id, events in mem_change_events_by_room_id.items():
non_joins = [e for e in events if e.membership != Membership.JOIN]
has_join = len(non_joins) != len(events)
invite_events = []
leave_events = []
events_by_room_id = {}
for event in room_events:
events_by_room_id.setdefault(event.room_id, []).append(event)
if event.room_id not in joined_room_ids:
if (event.type == EventTypes.Member
and event.state_key == sync_config.user.to_string()):
if event.membership == Membership.INVITE:
invite_events.append(event)
elif event.membership in (Membership.LEAVE, Membership.BAN):
leave_events.append(event)
# We want to figure out if we joined the room at some point since
# the last sync (even if we have since left). This is to make sure
# we do send down the room, and with full state, where necessary
if room_id in joined_room_ids or has_join:
old_state = yield self.get_state_at(room_id, since_token)
old_mem_ev = old_state.get((EventTypes.Member, user_id), None)
if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
newly_joined_rooms.append(room_id)
for room_id in joined_room_ids:
recents = events_by_room_id.get(room_id, [])
logger.debug("Events for room %s: %r", room_id, recents)
state = {
(event.type, event.state_key): event
for event in recents if event.is_state()}
limited = False
if room_id in joined_room_ids:
continue
if recents:
prev_batch = now_token.copy_and_replace(
"room_key", recents[0].internal_metadata.before
)
else:
prev_batch = now_token
just_joined = yield self.check_joined_room(sync_config, state)
if just_joined:
logger.debug("User has just joined %s: needs full state",
room_id)
state = yield self.get_state_at(room_id, now_token)
# the timeline is inherently limited if we've just joined
limited = True
recents = sync_config.filter_collection.filter_room_timeline(recents)
state = {
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(
state.values()
)
}
acc_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
)
acc_data = sync_config.filter_collection.filter_room_account_data(
acc_data
)
ephemeral = sync_config.filter_collection.filter_room_ephemeral(
ephemeral_by_room.get(room_id, [])
)
room_sync = JoinedSyncResult(
room_id=room_id,
timeline=TimelineBatch(
events=recents,
prev_batch=prev_batch,
limited=limited,
),
state=state,
ephemeral=ephemeral,
account_data=acc_data,
unread_notifications={},
)
logger.debug("Result for room %s: %r", room_id, room_sync)
if not non_joins:
continue
# Only bother if we're still currently invited
should_invite = non_joins[-1].membership == Membership.INVITE
if should_invite:
room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
if room_sync:
notifs = yield self.unread_notifs_for_room_id(
room_id, sync_config, all_ephemeral_by_room
)
invited.append(room_sync)
if notifs is not None:
notif_dict = room_sync.unread_notifications
notif_dict["notification_count"] = len(notifs)
notif_dict["highlight_count"] = len([
1 for notif in notifs
if _action_has_highlight(notif["actions"])
])
# Always include leave/ban events. Just take the last one.
# TODO: How do we handle ban -> leave in same batch?
leave_events = [
e for e in non_joins
if e.membership in (Membership.LEAVE, Membership.BAN)
]
joined.append(room_sync)
else:
logger.debug("Got %i events for incremental sync - hit limit",
len(room_events))
invite_events = yield self.store.get_invites_for_user(
sync_config.user.to_string()
)
leave_events = yield self.store.get_leave_and_ban_events_for_user(
sync_config.user.to_string()
)
for room_id in joined_room_ids:
room_sync = yield self.incremental_sync_with_gap_for_room(
room_id, sync_config, since_token, now_token,
ephemeral_by_room, tags_by_room, account_data_by_room,
all_ephemeral_by_room=all_ephemeral_by_room,
if leave_events:
leave_event = leave_events[-1]
room_sync = yield self.incremental_sync_for_archived_room(
sync_config, room_id, leave_event.event_id, since_token,
tags_by_room, account_data_by_room,
full_state=room_id in newly_joined_rooms
)
if room_sync:
joined.append(room_sync)
archived.append(room_sync)
for leave_event in leave_events:
room_sync = yield self.incremental_sync_for_archived_room(
sync_config, leave_event, since_token, tags_by_room,
account_data_by_room
# Get all events for rooms we're currently joined to.
room_to_events = yield self.store.get_room_events_stream_for_rooms(
room_ids=joined_room_ids,
from_key=since_token.room_key,
to_key=now_token.room_key,
limit=timeline_limit + 1,
)
joined = []
# We loop through all room ids, even if there are no new events, in case
# there are non room events taht we need to notify about.
for room_id in joined_room_ids:
room_entry = room_to_events.get(room_id, None)
if room_entry:
events, start_key = room_entry
prev_batch_token = now_token.copy_and_replace("room_key", start_key)
newly_joined_room = room_id in newly_joined_rooms
full_state = newly_joined_room
batch = yield self.load_filtered_recents(
room_id, sync_config, prev_batch_token,
since_token=since_token,
recents=events,
newly_joined_room=newly_joined_room,
)
else:
batch = TimelineBatch(
events=[],
prev_batch=since_token,
limited=False,
)
full_state = False
room_sync = yield self.incremental_sync_with_gap_for_room(
room_id=room_id,
sync_config=sync_config,
since_token=since_token,
now_token=now_token,
ephemeral_by_room=ephemeral_by_room,
tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
all_ephemeral_by_room=all_ephemeral_by_room,
batch=batch,
full_state=full_state,
)
if room_sync:
archived.append(room_sync)
invited = [
InvitedSyncResult(room_id=event.room_id, invite=event)
for event in invite_events
]
joined.append(room_sync)
account_data_for_user = sync_config.filter_collection.filter_account_data(
self.account_data_for_user(account_data)
@ -680,28 +602,40 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks
def load_filtered_recents(self, room_id, sync_config, now_token,
since_token=None):
since_token=None, recents=None, newly_joined_room=False):
"""
:returns a Deferred TimelineBatch
"""
limited = True
recents = []
filtering_factor = 2
timeline_limit = sync_config.filter_collection.timeline_limit()
load_limit = max(timeline_limit * filtering_factor, 100)
max_repeat = 3 # Only try a few times per room, otherwise
load_limit = max(timeline_limit * filtering_factor, 10)
max_repeat = 5 # Only try a few times per room, otherwise
room_key = now_token.room_key
end_key = room_key
limited = recents is None or newly_joined_room or timeline_limit < len(recents)
if recents is not None:
recents = sync_config.filter_collection.filter_room_timeline(recents)
recents = yield self._filter_events_for_client(
sync_config.user.to_string(),
recents,
is_peeking=sync_config.is_guest,
)
else:
recents = []
since_key = None
if since_token and not newly_joined_room:
since_key = since_token.room_key
while limited and len(recents) < timeline_limit and max_repeat:
events, keys = yield self.store.get_recent_events_for_room(
events, end_key = yield self.store.get_room_events_stream_for_room(
room_id,
limit=load_limit + 1,
from_token=since_token.room_key if since_token else None,
end_token=end_key,
from_key=since_key,
to_key=end_key,
)
room_key, _ = keys
end_key = "s" + room_key.split('-')[-1]
loaded_recents = sync_config.filter_collection.filter_room_timeline(events)
loaded_recents = yield self._filter_events_for_client(
sync_config.user.to_string(),
@ -710,8 +644,10 @@ class SyncHandler(BaseHandler):
)
loaded_recents.extend(recents)
recents = loaded_recents
if len(events) <= load_limit:
limited = False
break
max_repeat -= 1
if len(recents) > timeline_limit:
@ -724,7 +660,9 @@ class SyncHandler(BaseHandler):
)
defer.returnValue(TimelineBatch(
events=recents, prev_batch=prev_batch_token, limited=limited
events=recents,
prev_batch=prev_batch_token,
limited=limited or newly_joined_room
))
@defer.inlineCallbacks
@ -732,25 +670,12 @@ class SyncHandler(BaseHandler):
since_token, now_token,
ephemeral_by_room, tags_by_room,
account_data_by_room,
all_ephemeral_by_room):
""" Get the incremental delta needed to bring the client up to date for
the room. Gives the client the most recent events and the changes to
state.
Returns:
A Deferred JoinedSyncResult
"""
logger.debug("Doing incremental sync for room %s between %s and %s",
room_id, since_token, now_token)
all_ephemeral_by_room,
batch, full_state=False):
if full_state:
state = yield self.get_state_at(room_id, now_token)
# TODO(mjark): Check for redactions we might have missed.
batch = yield self.load_filtered_recents(
room_id, sync_config, now_token, since_token,
)
logger.debug("Recents %r", batch)
if batch.limited:
elif batch.limited:
current_state = yield self.get_state_at(room_id, now_token)
state_at_previous_sync = yield self.get_state_at(
@ -772,17 +697,6 @@ class SyncHandler(BaseHandler):
if just_joined:
state = yield self.get_state_at(room_id, now_token)
notifs = yield self.unread_notifs_for_room_id(
room_id, sync_config, all_ephemeral_by_room
)
unread_notifications = {}
if notifs is not None:
unread_notifications["notification_count"] = len(notifs)
unread_notifications["highlight_count"] = len([
1 for notif in notifs if _action_has_highlight(notif["actions"])
])
state = {
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(state.values())
@ -800,6 +714,7 @@ class SyncHandler(BaseHandler):
ephemeral_by_room.get(room_id, [])
)
unread_notifications = {}
room_sync = JoinedSyncResult(
room_id=room_id,
timeline=batch,
@ -809,48 +724,64 @@ class SyncHandler(BaseHandler):
unread_notifications=unread_notifications,
)
if room_sync:
notifs = yield self.unread_notifs_for_room_id(
room_id, sync_config, all_ephemeral_by_room
)
if notifs is not None:
unread_notifications["notification_count"] = len(notifs)
unread_notifications["highlight_count"] = len([
1 for notif in notifs if _action_has_highlight(notif["actions"])
])
logger.debug("Room sync: %r", room_sync)
defer.returnValue(room_sync)
@defer.inlineCallbacks
def incremental_sync_for_archived_room(self, sync_config, leave_event,
def incremental_sync_for_archived_room(self, sync_config, room_id, leave_event_id,
since_token, tags_by_room,
account_data_by_room):
account_data_by_room, full_state,
leave_token=None):
""" Get the incremental delta needed to bring the client up to date for
the archived room.
Returns:
A Deferred ArchivedSyncResult
"""
stream_token = yield self.store.get_stream_token_for_event(
leave_event.event_id
)
if not leave_token:
stream_token = yield self.store.get_stream_token_for_event(
leave_event_id
)
leave_token = since_token.copy_and_replace("room_key", stream_token)
leave_token = since_token.copy_and_replace("room_key", stream_token)
if since_token.is_after(leave_token):
if since_token and since_token.is_after(leave_token):
defer.returnValue(None)
batch = yield self.load_filtered_recents(
leave_event.room_id, sync_config, leave_token, since_token,
room_id, sync_config, leave_token, since_token,
)
logger.debug("Recents %r", batch)
state_events_at_leave = yield self.store.get_state_for_event(
leave_event.event_id
leave_event_id
)
state_at_previous_sync = yield self.get_state_at(
leave_event.room_id, stream_position=since_token
)
if not full_state:
state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token
)
state_events_delta = yield self.compute_state_delta(
since_token=since_token,
previous_state=state_at_previous_sync,
current_state=state_events_at_leave,
)
state_events_delta = yield self.compute_state_delta(
since_token=since_token,
previous_state=state_at_previous_sync,
current_state=state_events_at_leave,
)
else:
state_events_delta = state_events_at_leave
state_events_delta = {
(e.type, e.state_key): e
@ -860,7 +791,7 @@ class SyncHandler(BaseHandler):
}
account_data = self.account_data_for_room(
leave_event.room_id, tags_by_room, account_data_by_room
room_id, tags_by_room, account_data_by_room
)
account_data = sync_config.filter_collection.filter_room_account_data(
@ -868,7 +799,7 @@ class SyncHandler(BaseHandler):
)
room_sync = ArchivedSyncResult(
room_id=leave_event.room_id,
room_id=room_id,
timeline=batch,
state=state_events_delta,
account_data=account_data,

View File

@ -22,7 +22,7 @@ REQUIREMENTS = {
"unpaddedbase64>=1.0.1": ["unpaddedbase64>=1.0.1"],
"canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],
"signedjson>=1.0.0": ["signedjson>=1.0.0"],
"pynacl>=0.3.0": ["nacl>=0.3.0", "nacl.bindings"],
"pynacl==0.3.0": ["nacl==0.3.0", "nacl.bindings"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
"Twisted>=15.1.0": ["twisted>=15.1.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"],

View File

@ -66,11 +66,12 @@ class PushRuleRestServlet(ClientV1RestServlet):
raise SynapseError(400, e.message)
before = request.args.get("before", None)
if before and len(before):
before = before[0]
if before:
before = _namespaced_rule_id(spec, before[0])
after = request.args.get("after", None)
if after and len(after):
after = after[0]
if after:
after = _namespaced_rule_id(spec, after[0])
try:
yield self.hs.get_datastore().add_push_rule(
@ -452,11 +453,15 @@ def _strip_device_condition(rule):
def _namespaced_rule_id_from_spec(spec):
return _namespaced_rule_id(spec, spec['rule_id'])
def _namespaced_rule_id(spec, rule_id):
if spec['scope'] == 'global':
scope = 'global'
else:
scope = 'device/%s' % (spec['profile_tag'])
return "%s/%s/%s" % (scope, spec['template'], spec['rule_id'])
return "%s/%s/%s" % (scope, spec['template'], rule_id)
def _rule_id_from_namespaced(in_rule_id):

View File

@ -20,6 +20,8 @@
# Imports required for the default HomeServer() implementation
from twisted.web.client import BrowserLikePolicyForHTTPS
from twisted.enterprise import adbapi
from synapse.federation import initialize_http_replication
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.notifier import Notifier
@ -36,8 +38,15 @@ from synapse.push.pusherpool import PusherPool
from synapse.events.builder import EventBuilderFactory
from synapse.api.filtering import Filtering
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
class BaseHomeServer(object):
import logging
logger = logging.getLogger(__name__)
class HomeServer(object):
"""A basic homeserver object without lazy component builders.
This will need all of the components it requires to either be passed as
@ -98,39 +107,18 @@ class BaseHomeServer(object):
self.hostname = hostname
self._building = {}
self.clock = Clock()
self.distributor = Distributor()
self.ratelimiter = Ratelimiter()
# Other kwargs are explicit dependencies
for depname in kwargs:
setattr(self, depname, kwargs[depname])
@classmethod
def _make_dependency_method(cls, depname):
def _get(self):
if hasattr(self, depname):
return getattr(self, depname)
if hasattr(self, "build_%s" % (depname)):
# Prevent cyclic dependencies from deadlocking
if depname in self._building:
raise ValueError("Cyclic dependency while building %s" % (
depname,
))
self._building[depname] = 1
builder = getattr(self, "build_%s" % (depname))
dep = builder()
setattr(self, depname, dep)
del self._building[depname]
return dep
raise NotImplementedError(
"%s has no %s nor a builder for it" % (
type(self).__name__, depname,
)
)
setattr(BaseHomeServer, "get_%s" % (depname), _get)
def setup(self):
logger.info("Setting up.")
self.datastore = DataStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def get_ip_from_request(self, request):
# X-Forwarded-For is handled by our custom request type.
@ -142,33 +130,9 @@ class BaseHomeServer(object):
def is_mine_id(self, string):
return string.split(":", 1)[1] == self.hostname
# Build magic accessors for every dependency
for depname in BaseHomeServer.DEPENDENCIES:
BaseHomeServer._make_dependency_method(depname)
class HomeServer(BaseHomeServer):
"""A homeserver object that will construct most of its dependencies as
required.
It still requires the following to be specified by the caller:
resource_for_client
resource_for_web_client
resource_for_federation
resource_for_content_repo
http_client
db_pool
"""
def build_clock(self):
return Clock()
def build_replication_layer(self):
return initialize_http_replication(self)
def build_datastore(self):
return DataStore(self)
def build_handlers(self):
return Handlers(self)
@ -179,10 +143,9 @@ class HomeServer(BaseHomeServer):
return Auth(self)
def build_http_client_context_factory(self):
config = self.get_config()
return (
InsecureInterceptableContextFactory()
if config.use_insecure_ssl_client_just_for_testing_do_not_use
if self.config.use_insecure_ssl_client_just_for_testing_do_not_use
else BrowserLikePolicyForHTTPS()
)
@ -201,15 +164,9 @@ class HomeServer(BaseHomeServer):
def build_state_handler(self):
return StateHandler(self)
def build_distributor(self):
return Distributor()
def build_event_sources(self):
return EventSources(self)
def build_ratelimiter(self):
return Ratelimiter()
def build_keyring(self):
return Keyring(self)
@ -224,3 +181,55 @@ class HomeServer(BaseHomeServer):
def build_pusherpool(self):
return PusherPool(self)
def build_http_client(self):
return MatrixFederationHttpClient(self)
def build_db_pool(self):
name = self.db_config["name"]
return adbapi.ConnectionPool(
name,
**self.db_config.get("args", {})
)
def _make_dependency_method(depname):
def _get(hs):
try:
return getattr(hs, depname)
except AttributeError:
pass
try:
builder = getattr(hs, "build_%s" % (depname))
except AttributeError:
builder = None
if builder:
# Prevent cyclic dependencies from deadlocking
if depname in hs._building:
raise ValueError("Cyclic dependency while building %s" % (
depname,
))
hs._building[depname] = 1
dep = builder()
setattr(hs, depname, dep)
del hs._building[depname]
return dep
raise NotImplementedError(
"%s has no %s nor a builder for it" % (
type(hs).__name__, depname,
)
)
setattr(HomeServer, "get_%s" % (depname), _get)
# Build magic accessors for every dependency
for depname in HomeServer.DEPENDENCIES:
_make_dependency_method(depname)

View File

@ -46,6 +46,9 @@ from .tags import TagsStore
from .account_data import AccountDataStore
from util.id_generators import IdGenerator, StreamIdGenerator
import logging
@ -79,18 +82,43 @@ class DataStore(RoomMemberStore, RoomStore,
EventPushActionsStore
):
def __init__(self, hs):
super(DataStore, self).__init__(hs)
def __init__(self, db_conn, hs):
self.hs = hs
self.min_token_deferred = self._get_min_token()
self.min_token = None
cur = db_conn.cursor()
try:
cur.execute("SELECT MIN(stream_ordering) FROM events",)
rows = cur.fetchall()
self.min_stream_token = rows[0][0] if rows and rows[0] and rows[0][0] else -1
self.min_stream_token = min(self.min_stream_token, -1)
finally:
cur.close()
self.client_ip_last_seen = Cache(
name="client_ip_last_seen",
keylen=4,
)
self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering"
)
self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id"
)
self._account_data_id_gen = StreamIdGenerator(
db_conn, "account_data_max_stream_id", "stream_id"
)
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
self._pushers_id_gen = IdGenerator("pushers", "id", self)
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
super(DataStore, self).__init__(hs)
@defer.inlineCallbacks
def insert_client_ip(self, user, access_token, ip, user_agent):
now = int(self._clock.time_msec())

View File

@ -15,13 +15,11 @@
import logging
from synapse.api.errors import StoreError
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache
import synapse.metrics
from util.id_generators import IdGenerator, StreamIdGenerator
from twisted.internet import defer
@ -175,16 +173,6 @@ class SQLBaseStore(object):
self.database_engine = hs.database_engine
self._stream_id_gen = StreamIdGenerator("events", "stream_ordering")
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
self._pushers_id_gen = IdGenerator("pushers", "id", self)
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id")
def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec()
@ -345,7 +333,8 @@ class SQLBaseStore(object):
defer.returnValue(result)
def cursor_to_dict(self, cursor):
@staticmethod
def cursor_to_dict(cursor):
"""Converts a SQL cursor into an list of dicts.
Args:
@ -402,8 +391,8 @@ class SQLBaseStore(object):
if not or_ignore:
raise
@log_function
def _simple_insert_txn(self, txn, table, values):
@staticmethod
def _simple_insert_txn(txn, table, values):
keys, vals = zip(*values.items())
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
@ -414,7 +403,8 @@ class SQLBaseStore(object):
txn.execute(sql, vals)
def _simple_insert_many_txn(self, txn, table, values):
@staticmethod
def _simple_insert_many_txn(txn, table, values):
if not values:
return
@ -537,9 +527,10 @@ class SQLBaseStore(object):
table, keyvalues, retcol, allow_none=allow_none,
)
def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol,
@classmethod
def _simple_select_one_onecol_txn(cls, txn, table, keyvalues, retcol,
allow_none=False):
ret = self._simple_select_onecol_txn(
ret = cls._simple_select_onecol_txn(
txn,
table=table,
keyvalues=keyvalues,
@ -554,7 +545,8 @@ class SQLBaseStore(object):
else:
raise StoreError(404, "No row found")
def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
@staticmethod
def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
sql = (
"SELECT %(retcol)s FROM %(table)s WHERE %(where)s"
) % {
@ -603,7 +595,8 @@ class SQLBaseStore(object):
table, keyvalues, retcols
)
def _simple_select_list_txn(self, txn, table, keyvalues, retcols):
@classmethod
def _simple_select_list_txn(cls, txn, table, keyvalues, retcols):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@ -627,7 +620,7 @@ class SQLBaseStore(object):
)
txn.execute(sql)
return self.cursor_to_dict(txn)
return cls.cursor_to_dict(txn)
@defer.inlineCallbacks
def _simple_select_many_batch(self, table, column, iterable, retcols,
@ -662,7 +655,8 @@ class SQLBaseStore(object):
defer.returnValue(results)
def _simple_select_many_txn(self, txn, table, column, iterable, keyvalues, retcols):
@classmethod
def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@ -699,7 +693,7 @@ class SQLBaseStore(object):
)
txn.execute(sql, values)
return self.cursor_to_dict(txn)
return cls.cursor_to_dict(txn)
def _simple_update_one(self, table, keyvalues, updatevalues,
desc="_simple_update_one"):
@ -726,7 +720,8 @@ class SQLBaseStore(object):
table, keyvalues, updatevalues,
)
def _simple_update_one_txn(self, txn, table, keyvalues, updatevalues):
@staticmethod
def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
update_sql = "UPDATE %s SET %s WHERE %s" % (
table,
", ".join("%s = ?" % (k,) for k in updatevalues),
@ -743,7 +738,8 @@ class SQLBaseStore(object):
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched")
def _simple_select_one_txn(self, txn, table, keyvalues, retcols,
@staticmethod
def _simple_select_one_txn(txn, table, keyvalues, retcols,
allow_none=False):
select_sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
@ -784,7 +780,8 @@ class SQLBaseStore(object):
raise StoreError(500, "more than one row matched")
return self.runInteraction(desc, func)
def _simple_delete_txn(self, txn, table, keyvalues):
@staticmethod
def _simple_delete_txn(txn, table, keyvalues):
sql = "DELETE FROM %s WHERE %s" % (
table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues)

View File

@ -14,6 +14,7 @@
# limitations under the License.
from ._base import SQLBaseStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
from twisted.internet import defer
import ujson as json
@ -23,6 +24,14 @@ logger = logging.getLogger(__name__)
class AccountDataStore(SQLBaseStore):
def __init__(self, hs):
super(AccountDataStore, self).__init__(hs)
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache",
self._account_data_id_gen.get_max_token(None),
max_size=10000,
)
def get_account_data_for_user(self, user_id):
"""Get all the client account_data for a user.
@ -83,7 +92,7 @@ class AccountDataStore(SQLBaseStore):
"get_account_data_for_room", get_account_data_for_room_txn
)
def get_updated_account_data_for_user(self, user_id, stream_id):
def get_updated_account_data_for_user(self, user_id, stream_id, room_ids=None):
"""Get all the client account_data for a that's changed.
Args:
@ -120,6 +129,12 @@ class AccountDataStore(SQLBaseStore):
return (global_account_data, account_data_by_room)
changed = self._account_data_stream_cache.has_entity_changed(
user_id, int(stream_id)
)
if not changed:
return ({}, {})
return self.runInteraction(
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
)
@ -186,6 +201,10 @@ class AccountDataStore(SQLBaseStore):
"content": content_json,
}
)
txn.call_after(
self._account_data_stream_cache.entity_has_changed,
user_id, next_id,
)
self._update_max_stream_id(txn, next_id)
with (yield self._account_data_id_gen.get_next(self)) as next_id:

View File

@ -66,11 +66,9 @@ class EventsStore(SQLBaseStore):
return
if backfilled:
if not self.min_token_deferred.called:
yield self.min_token_deferred
start = self.min_token - 1
self.min_token -= len(events_and_contexts) + 1
stream_orderings = range(start, self.min_token, -1)
start = self.min_stream_token - 1
self.min_stream_token -= len(events_and_contexts) + 1
stream_orderings = range(start, self.min_stream_token, -1)
@contextmanager
def stream_ordering_manager():
@ -107,10 +105,8 @@ class EventsStore(SQLBaseStore):
is_new_state=True, current_state=None):
stream_ordering = None
if backfilled:
if not self.min_token_deferred.called:
yield self.min_token_deferred
self.min_token -= 1
stream_ordering = self.min_token
self.min_stream_token -= 1
stream_ordering = self.min_stream_token
if stream_ordering is None:
stream_ordering_manager = yield self._stream_id_gen.get_next(self)
@ -214,6 +210,12 @@ class EventsStore(SQLBaseStore):
for event, _ in events_and_contexts:
txn.call_after(self._invalidate_get_event_cache, event.event_id)
if not backfilled:
txn.call_after(
self._events_stream_cache.entity_has_changed,
event.room_id, event.internal_metadata.stream_ordering,
)
depth_updates = {}
for event, _ in events_and_contexts:
if event.internal_metadata.is_outlier():

View File

@ -16,12 +16,13 @@
from twisted.internet import defer
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks
import simplejson as json
class FilteringStore(SQLBaseStore):
@defer.inlineCallbacks
@cachedInlineCallbacks(num_args=2)
def get_user_filter(self, user_localpart, filter_id):
def_json = yield self._simple_select_one_onecol(
table="user_filters",

View File

@ -130,7 +130,8 @@ class PushRuleStore(SQLBaseStore):
def _add_push_rule_relative_txn(self, txn, user_id, **kwargs):
after = kwargs.pop("after", None)
relative_to_rule = kwargs.pop("before", after)
before = kwargs.pop("before", None)
relative_to_rule = before or after
res = self._simple_select_one_txn(
txn,

View File

@ -15,11 +15,10 @@
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
from synapse.util.caches import cache_counter, caches_by_name
from synapse.util.caches.stream_change_cache import StreamChangeCache
from twisted.internet import defer
from blist import sorteddict
import logging
import ujson as json
@ -31,7 +30,9 @@ class ReceiptsStore(SQLBaseStore):
def __init__(self, hs):
super(ReceiptsStore, self).__init__(hs)
self._receipts_stream_cache = _RoomStreamChangeCache()
self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token(None)
)
@cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type):
@ -76,8 +77,8 @@ class ReceiptsStore(SQLBaseStore):
room_ids = set(room_ids)
if from_key:
room_ids = yield self._receipts_stream_cache.get_rooms_changed(
self, room_ids, from_key
room_ids = yield self._receipts_stream_cache.get_entities_changed(
room_ids, from_key
)
results = yield self._get_linearized_receipts_for_rooms(
@ -220,6 +221,11 @@ class ReceiptsStore(SQLBaseStore):
# FIXME: This shouldn't invalidate the whole cache
txn.call_after(self.get_linearized_receipts_for_room.invalidate_all)
txn.call_after(
self._receipts_stream_cache.entity_has_changed,
room_id, stream_id
)
# We don't want to clobber receipts for more recent events, so we
# have to compare orderings of existing receipts
sql = (
@ -307,9 +313,6 @@ class ReceiptsStore(SQLBaseStore):
stream_id_manager = yield self._receipts_id_gen.get_next(self)
with stream_id_manager as stream_id:
yield self._receipts_stream_cache.room_has_changed(
self, room_id, stream_id
)
have_persisted = yield self.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
@ -368,63 +371,3 @@ class ReceiptsStore(SQLBaseStore):
"data": json.dumps(data),
}
)
class _RoomStreamChangeCache(object):
"""Keeps track of the stream_id of the latest change in rooms.
Given a list of rooms and stream key, it will give a subset of rooms that
may have changed since that key. If the key is too old then the cache
will simply return all rooms.
"""
def __init__(self, size_of_cache=10000):
self._size_of_cache = size_of_cache
self._room_to_key = {}
self._cache = sorteddict()
self._earliest_key = None
self.name = "ReceiptsRoomChangeCache"
caches_by_name[self.name] = self._cache
@defer.inlineCallbacks
def get_rooms_changed(self, store, room_ids, key):
"""Returns subset of room ids that have had new receipts since the
given key. If the key is too old it will just return the given list.
"""
if key > (yield self._get_earliest_key(store)):
keys = self._cache.keys()
i = keys.bisect_right(key)
result = set(
self._cache[k] for k in keys[i:]
).intersection(room_ids)
cache_counter.inc_hits(self.name)
else:
result = room_ids
cache_counter.inc_misses(self.name)
defer.returnValue(result)
@defer.inlineCallbacks
def room_has_changed(self, store, room_id, key):
"""Informs the cache that the room has been changed at the given key.
"""
if key > (yield self._get_earliest_key(store)):
old_key = self._room_to_key.get(room_id, None)
if old_key:
key = max(key, old_key)
self._cache.pop(old_key, None)
self._cache[key] = room_id
while len(self._cache) > self._size_of_cache:
k, r = self._cache.popitem()
self._earliest_key = max(k, self._earliest_key)
self._room_to_key.pop(r, None)
@defer.inlineCallbacks
def _get_earliest_key(self, store):
if self._earliest_key is None:
self._earliest_key = yield store.get_max_receipt_stream_id()
self._earliest_key = int(self._earliest_key)
defer.returnValue(self._earliest_key)

View File

@ -241,7 +241,7 @@ class RoomMemberStore(SQLBaseStore):
return rows
@cached()
@cached(max_entries=5000)
def get_rooms_for_user(self, user_id):
return self.get_rooms_for_user_where_membership_is(
user_id, membership_list=[Membership.JOIN],

View File

@ -0,0 +1,16 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE INDEX events_room_stream on events(room_id, stream_ordering);

View File

@ -37,6 +37,7 @@ from twisted.internet import defer
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.api.constants import EventTypes
from synapse.types import RoomStreamToken
from synapse.util.logutils import log_function
@ -77,6 +78,12 @@ def upper_bound(token):
class StreamStore(SQLBaseStore):
def __init__(self, hs):
super(StreamStore, self).__init__(hs)
self._events_stream_cache = StreamChangeCache(
"EventsRoomStreamChangeCache", self._stream_id_gen.get_max_token(None)
)
@defer.inlineCallbacks
def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
@ -157,6 +164,135 @@ class StreamStore(SQLBaseStore):
results = yield self.runInteraction("get_appservice_room_stream", f)
defer.returnValue(results)
@defer.inlineCallbacks
def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0):
from_id = RoomStreamToken.parse_stream_token(from_key).stream
room_ids = yield self._events_stream_cache.get_entities_changed(
room_ids, from_id
)
if not room_ids:
defer.returnValue({})
results = {}
room_ids = list(room_ids)
for rm_ids in (room_ids[i:i+20] for i in xrange(0, len(room_ids), 20)):
res = yield defer.gatherResults([
self.get_room_events_stream_for_room(
room_id, from_key, to_key, limit
).addCallback(lambda r, rm: (rm, r), room_id)
for room_id in room_ids
])
results.update(dict(res))
defer.returnValue(results)
@defer.inlineCallbacks
def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0):
if from_key is not None:
from_id = RoomStreamToken.parse_stream_token(from_key).stream
else:
from_id = None
to_id = RoomStreamToken.parse_stream_token(to_key).stream
if from_key == to_key:
defer.returnValue(([], from_key))
if from_id:
has_changed = yield self._events_stream_cache.has_entity_changed(
room_id, from_id
)
if not has_changed:
defer.returnValue(([], from_key))
def f(txn):
if from_id is not None:
sql = (
"SELECT event_id, stream_ordering FROM events WHERE"
" room_id = ?"
" AND not outlier"
" AND stream_ordering > ? AND stream_ordering <= ?"
" ORDER BY stream_ordering DESC LIMIT ?"
)
txn.execute(sql, (room_id, from_id, to_id, limit))
else:
sql = (
"SELECT event_id, stream_ordering FROM events WHERE"
" room_id = ?"
" AND not outlier"
" AND stream_ordering <= ?"
" ORDER BY stream_ordering DESC LIMIT ?"
)
txn.execute(sql, (room_id, to_id, limit))
rows = self.cursor_to_dict(txn)
ret = self._get_events_txn(
txn,
[r["event_id"] for r in rows],
get_prev_content=True
)
self._set_before_and_after(ret, rows, topo_order=False)
ret.reverse()
if rows:
key = "s%d" % min(r["stream_ordering"] for r in rows)
else:
# Assume we didn't get anything because there was nothing to
# get.
key = from_key
return ret, key
res = yield self.runInteraction("get_room_events_stream_for_room", f)
defer.returnValue(res)
def get_room_changes_for_user(self, user_id, from_key, to_key):
if from_key is not None:
from_id = RoomStreamToken.parse_stream_token(from_key).stream
else:
from_id = None
to_id = RoomStreamToken.parse_stream_token(to_key).stream
if from_key == to_key:
return defer.succeed([])
def f(txn):
if from_id is not None:
sql = (
"SELECT m.event_id, stream_ordering FROM events AS e,"
" room_memberships AS m"
" WHERE e.event_id = m.event_id"
" AND m.user_id = ?"
" AND e.stream_ordering > ? AND e.stream_ordering <= ?"
" ORDER BY e.stream_ordering ASC"
)
txn.execute(sql, (user_id, from_id, to_id,))
else:
sql = (
"SELECT m.event_id, stream_ordering FROM events AS e,"
" room_memberships AS m"
" WHERE e.event_id = m.event_id"
" AND m.user_id = ?"
" AND stream_ordering <= ?"
" ORDER BY stream_ordering ASC"
)
txn.execute(sql, (user_id, to_id,))
rows = self.cursor_to_dict(txn)
ret = self._get_events_txn(
txn,
[r["event_id"] for r in rows],
get_prev_content=True
)
return ret
return self.runInteraction("get_room_changes_for_user", f)
@log_function
def get_room_events_stream(
self,
@ -174,7 +310,8 @@ class StreamStore(SQLBaseStore):
"SELECT c.room_id FROM history_visibility AS h"
" INNER JOIN current_state_events AS c"
" ON h.event_id = c.event_id"
" WHERE c.room_id IN (%s) AND h.history_visibility = 'world_readable'" % (
" WHERE c.room_id IN (%s)"
" AND h.history_visibility = 'world_readable'" % (
",".join(map(lambda _: "?", room_ids))
)
)
@ -434,6 +571,18 @@ class StreamStore(SQLBaseStore):
row["topological_ordering"], row["stream_ordering"],)
)
def get_max_topological_token_for_stream_and_room(self, room_id, stream_key):
sql = (
"SELECT max(topological_ordering) FROM events"
" WHERE room_id = ? AND stream_ordering < ?"
)
return self._execute(
"get_max_topological_token_for_stream_and_room", None,
sql, room_id, stream_key,
).addCallback(
lambda r: r[0][0] if r else 0
)
def _get_max_topological_txn(self, txn):
txn.execute(
"SELECT MAX(topological_ordering) FROM events"
@ -444,24 +593,14 @@ class StreamStore(SQLBaseStore):
rows = txn.fetchall()
return rows[0][0] if rows else 0
@defer.inlineCallbacks
def _get_min_token(self):
row = yield self._execute(
"_get_min_token", None, "SELECT MIN(stream_ordering) FROM events"
)
self.min_token = row[0][0] if row and row[0] and row[0][0] else -1
self.min_token = min(self.min_token, -1)
logger.debug("min_token is: %s", self.min_token)
defer.returnValue(self.min_token)
@staticmethod
def _set_before_and_after(events, rows):
def _set_before_and_after(events, rows, topo_order=True):
for event, row in zip(events, rows):
stream = row["stream_ordering"]
topo = event.depth
if topo_order:
topo = event.depth
else:
topo = None
internal = event.internal_metadata
internal.before = str(RoomStreamToken(topo, stream - 1))
internal.after = str(RoomStreamToken(topo, stream))

View File

@ -16,7 +16,6 @@
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
from twisted.internet import defer
from .util.id_generators import StreamIdGenerator
import ujson as json
import logging
@ -25,13 +24,6 @@ logger = logging.getLogger(__name__)
class TagsStore(SQLBaseStore):
def __init__(self, hs):
super(TagsStore, self).__init__(hs)
self._account_data_id_gen = StreamIdGenerator(
"account_data_max_stream_id", "stream_id"
)
def get_max_account_data_stream_id(self):
"""Get the current max stream id for the private user data stream
@ -87,6 +79,12 @@ class TagsStore(SQLBaseStore):
room_ids = [row[0] for row in txn.fetchall()]
return room_ids
changed = self._account_data_stream_cache.has_entity_changed(
user_id, int(stream_id)
)
if not changed:
defer.returnValue({})
room_ids = yield self.runInteraction(
"get_updated_tags", get_updated_tags_txn
)
@ -184,6 +182,11 @@ class TagsStore(SQLBaseStore):
next_id(int): The the revision to advance to.
"""
txn.call_after(
self._account_data_stream_cache.entity_has_changed,
user_id, next_id
)
update_max_id_sql = (
"UPDATE account_data_max_stream_id"
" SET stream_id = ?"

View File

@ -72,28 +72,24 @@ class StreamIdGenerator(object):
with stream_id_gen.get_next_txn(txn) as stream_id:
# ... persist event ...
"""
def __init__(self, table, column):
def __init__(self, db_conn, table, column):
self.table = table
self.column = column
self._lock = threading.Lock()
self._current_max = None
cur = db_conn.cursor()
self._current_max = self._get_or_compute_current_max(cur)
cur.close()
self._unfinished_ids = deque()
@defer.inlineCallbacks
def get_next(self, store):
"""
Usage:
with yield stream_id_gen.get_next as stream_id:
# ... persist event ...
"""
if not self._current_max:
yield store.runInteraction(
"_compute_current_max",
self._get_or_compute_current_max,
)
with self._lock:
self._current_max += 1
next_id = self._current_max
@ -108,21 +104,14 @@ class StreamIdGenerator(object):
with self._lock:
self._unfinished_ids.remove(next_id)
defer.returnValue(manager())
return manager()
@defer.inlineCallbacks
def get_next_mult(self, store, n):
"""
Usage:
with yield stream_id_gen.get_next(store, n) as stream_ids:
# ... persist events ...
"""
if not self._current_max:
yield store.runInteraction(
"_compute_current_max",
self._get_or_compute_current_max,
)
with self._lock:
next_ids = range(self._current_max + 1, self._current_max + n + 1)
self._current_max += n
@ -139,24 +128,17 @@ class StreamIdGenerator(object):
for next_id in next_ids:
self._unfinished_ids.remove(next_id)
defer.returnValue(manager())
return manager()
@defer.inlineCallbacks
def get_max_token(self, store):
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
"""
if not self._current_max:
yield store.runInteraction(
"_compute_current_max",
self._get_or_compute_current_max,
)
with self._lock:
if self._unfinished_ids:
defer.returnValue(self._unfinished_ids[0] - 1)
return self._unfinished_ids[0] - 1
defer.returnValue(self._current_max)
return self._current_max
def _get_or_compute_current_max(self, txn):
with self._lock:

View File

@ -37,7 +37,7 @@ class LruCache(object):
"""
def __init__(self, max_size, keylen=1, cache_type=dict):
cache = cache_type()
self.size = 0
self.cache = cache # Used for introspection.
list_root = []
list_root[:] = [list_root, list_root, None, None]
@ -60,7 +60,6 @@ class LruCache(object):
prev_node[NEXT] = node
next_node[PREV] = node
cache[key] = node
self.size += 1
def move_node_to_front(node):
prev_node = node[PREV]
@ -79,7 +78,6 @@ class LruCache(object):
next_node = node[NEXT]
prev_node[NEXT] = next_node
next_node[PREV] = prev_node
self.size -= 1
@synchronized
def cache_get(key, default=None):
@ -98,7 +96,7 @@ class LruCache(object):
node[VALUE] = value
else:
add_node(key, value)
if self.size > max_size:
if len(cache) > max_size:
todelete = list_root[PREV]
delete_node(todelete)
cache.pop(todelete[KEY], None)
@ -110,7 +108,7 @@ class LruCache(object):
return node[VALUE]
else:
add_node(key, value)
if self.size > max_size:
if len(cache) > max_size:
todelete = list_root[PREV]
delete_node(todelete)
cache.pop(todelete[KEY], None)
@ -145,7 +143,7 @@ class LruCache(object):
@synchronized
def cache_len():
return self.size
return len(cache)
@synchronized
def cache_contains(key):

View File

@ -0,0 +1,107 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util.caches import cache_counter, caches_by_name
from blist import sorteddict
import logging
logger = logging.getLogger(__name__)
class StreamChangeCache(object):
"""Keeps track of the stream positions of the latest change in a set of entities.
Typically the entity will be a room or user id.
Given a list of entities and a stream position, it will give a subset of
entities that may have changed since that position. If position key is too
old then the cache will simply return all given entities.
"""
def __init__(self, name, current_stream_pos, max_size=10000):
self._max_size = max_size
self._entity_to_key = {}
self._cache = sorteddict()
self._earliest_known_stream_pos = current_stream_pos
self.name = name
caches_by_name[self.name] = self._cache
def has_entity_changed(self, entity, stream_pos):
"""Returns True if the entity may have been updated since stream_pos
"""
assert type(stream_pos) is int
if stream_pos < self._earliest_known_stream_pos:
cache_counter.inc_misses(self.name)
return True
if stream_pos == self._earliest_known_stream_pos:
# If the same as the earliest key, assume nothing has changed.
cache_counter.inc_hits(self.name)
return False
latest_entity_change_pos = self._entity_to_key.get(entity, None)
if latest_entity_change_pos is None:
cache_counter.inc_misses(self.name)
return True
if stream_pos < latest_entity_change_pos:
cache_counter.inc_misses(self.name)
return True
cache_counter.inc_hits(self.name)
return False
def get_entities_changed(self, entities, stream_pos):
"""Returns subset of entities that have had new things since the
given position. If the position is too old it will just return the given list.
"""
assert type(stream_pos) is int
if stream_pos >= self._earliest_known_stream_pos:
keys = self._cache.keys()
i = keys.bisect_right(stream_pos)
result = set(
self._cache[k] for k in keys[i:]
).intersection(entities)
cache_counter.inc_hits(self.name)
else:
result = entities
cache_counter.inc_misses(self.name)
return result
def entity_has_changed(self, entity, stream_pos):
"""Informs the cache that the entity has been changed at the given
position.
"""
assert type(stream_pos) is int
if stream_pos > self._earliest_known_stream_pos:
old_pos = self._entity_to_key.get(entity, None)
if old_pos:
stream_pos = max(stream_pos, old_pos)
self._cache.pop(old_pos, None)
self._cache[stream_pos] = entity
self._entity_to_key[entity] = stream_pos
while len(self._cache) > self._max_size:
k, r = self._cache.popitem()
self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)
self._entity_to_key.pop(r, None)

View File

@ -8,6 +8,7 @@ class TreeCache(object):
Keys must be tuples.
"""
def __init__(self):
self.size = 0
self.root = {}
def __setitem__(self, key, value):
@ -20,7 +21,8 @@ class TreeCache(object):
node = self.root
for k in key[:-1]:
node = node.setdefault(k, {})
node[key[-1]] = value
node[key[-1]] = _Entry(value)
self.size += 1
def get(self, key, default=None):
node = self.root
@ -28,9 +30,10 @@ class TreeCache(object):
node = node.get(k, None)
if node is None:
return default
return node.get(key[-1], default)
return node.get(key[-1], _Entry(default)).value
def clear(self):
self.size = 0
self.root = {}
def pop(self, key, default=None):
@ -57,4 +60,33 @@ class TreeCache(object):
break
node_and_keys[i+1][0].pop(k)
popped, cnt = _strip_and_count_entires(popped)
self.size -= cnt
return popped
def __len__(self):
return self.size
class _Entry(object):
__slots__ = ["value"]
def __init__(self, value):
self.value = value
def _strip_and_count_entires(d):
"""Takes an _Entry or dict with leaves of _Entry's, and either returns the
value or a dictionary with _Entry's replaced by their values.
Also returns the count of _Entry's
"""
if isinstance(d, dict):
cnt = 0
for key, value in d.items():
v, n = _strip_and_count_entires(value)
d[key] = v
cnt += n
return d, cnt
else:
return d.value, 1

View File

@ -382,19 +382,20 @@ class FilteringTestCase(unittest.TestCase):
"types": ["m.*"]
}
}
user = UserID.from_string("@" + user_localpart + ":test")
filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart,
user_localpart=user_localpart + "2",
user_filter=user_filter_json,
)
event = MockEvent(
event_id="$asdasd:localhost",
sender="@foo:bar",
type="custom.avatar.3d.crazy",
)
events = [event]
user_filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart,
user_localpart=user_localpart + "2",
filter_id=filter_id,
)

View File

@ -1,303 +0,0 @@
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# trial imports
from twisted.internet import defer
from tests import unittest
# python imports
from mock import Mock, ANY
from ..utils import MockHttpResource, MockClock, setup_test_homeserver
from synapse.federation import initialize_http_replication
from synapse.events import FrozenEvent
def make_pdu(prev_pdus=[], **kwargs):
"""Provide some default fields for making a PduTuple."""
pdu_fields = {
"state_key": None,
"prev_events": prev_pdus,
}
pdu_fields.update(kwargs)
return FrozenEvent(pdu_fields)
class FederationTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.mock_resource = MockHttpResource()
self.mock_http_client = Mock(spec=[
"get_json",
"put_json",
])
self.mock_persistence = Mock(spec=[
"prep_send_transaction",
"delivered_txn",
"get_received_txn_response",
"set_received_txn_response",
"get_destination_retry_timings",
"get_auth_chain",
])
self.mock_persistence.get_received_txn_response.return_value = (
defer.succeed(None)
)
retry_timings_res = {
"destination": "",
"retry_last_ts": 0,
"retry_interval": 0,
}
self.mock_persistence.get_destination_retry_timings.return_value = (
defer.succeed(retry_timings_res)
)
self.mock_persistence.get_auth_chain.return_value = []
self.clock = MockClock()
hs = yield setup_test_homeserver(
resource_for_federation=self.mock_resource,
http_client=self.mock_http_client,
datastore=self.mock_persistence,
clock=self.clock,
keyring=Mock(),
)
self.federation = initialize_http_replication(hs)
self.distributor = hs.get_distributor()
@defer.inlineCallbacks
def test_get_state(self):
mock_handler = Mock(spec=[
"get_state_for_pdu",
])
self.federation.set_handler(mock_handler)
mock_handler.get_state_for_pdu.return_value = defer.succeed([])
# Empty context initially
(code, response) = yield self.mock_resource.trigger(
"GET",
"/_matrix/federation/v1/state/my-context/",
None
)
self.assertEquals(200, code)
self.assertFalse(response["pdus"])
# Now lets give the context some state
mock_handler.get_state_for_pdu.return_value = (
defer.succeed([
make_pdu(
event_id="the-pdu-id",
origin="red",
user_id="@a:red",
room_id="my-context",
type="m.topic",
origin_server_ts=123456789000,
depth=1,
content={"topic": "The topic"},
state_key="",
power_level=1000,
prev_state="last-pdu-id",
),
])
)
(code, response) = yield self.mock_resource.trigger(
"GET",
"/_matrix/federation/v1/state/my-context/",
None
)
self.assertEquals(200, code)
self.assertEquals(1, len(response["pdus"]))
@defer.inlineCallbacks
def test_get_pdu(self):
mock_handler = Mock(spec=[
"get_persisted_pdu",
])
self.federation.set_handler(mock_handler)
mock_handler.get_persisted_pdu.return_value = (
defer.succeed(None)
)
(code, response) = yield self.mock_resource.trigger(
"GET",
"/_matrix/federation/v1/event/abc123def456/",
None
)
self.assertEquals(404, code)
# Now insert such a PDU
mock_handler.get_persisted_pdu.return_value = (
defer.succeed(
make_pdu(
event_id="abc123def456",
origin="red",
user_id="@a:red",
room_id="my-context",
type="m.text",
origin_server_ts=123456789001,
depth=1,
content={"text": "Here is the message"},
)
)
)
(code, response) = yield self.mock_resource.trigger(
"GET",
"/_matrix/federation/v1/event/abc123def456/",
None
)
self.assertEquals(200, code)
self.assertEquals(1, len(response["pdus"]))
self.assertEquals("m.text", response["pdus"][0]["type"])
@defer.inlineCallbacks
def test_send_pdu(self):
self.mock_http_client.put_json.return_value = defer.succeed(
(200, "OK")
)
pdu = make_pdu(
event_id="abc123def456",
origin="red",
user_id="@a:red",
room_id="my-context",
type="m.text",
origin_server_ts=123456789001,
depth=1,
content={"text": "Here is the message"},
)
yield self.federation.send_pdu(pdu, ["remote"])
self.mock_http_client.put_json.assert_called_with(
"remote",
path="/_matrix/federation/v1/send/1000000/",
data={
"origin_server_ts": 1000000,
"origin": "test",
"pdus": [
pdu.get_pdu_json(),
],
'pdu_failures': [],
},
json_data_callback=ANY,
long_retries=True,
)
@defer.inlineCallbacks
def test_send_edu(self):
self.mock_http_client.put_json.return_value = defer.succeed(
(200, "OK")
)
yield self.federation.send_edu(
destination="remote",
edu_type="m.test",
content={"testing": "content here"},
)
# MockClock ensures we can guess these timestamps
self.mock_http_client.put_json.assert_called_with(
"remote",
path="/_matrix/federation/v1/send/1000000/",
data={
"origin": "test",
"origin_server_ts": 1000000,
"pdus": [],
"edus": [
{
"edu_type": "m.test",
"content": {"testing": "content here"},
}
],
'pdu_failures': [],
},
json_data_callback=ANY,
long_retries=True,
)
@defer.inlineCallbacks
def test_recv_edu(self):
recv_observer = Mock()
recv_observer.return_value = defer.succeed(())
self.federation.register_edu_handler("m.test", recv_observer)
yield self.mock_resource.trigger(
"PUT",
"/_matrix/federation/v1/send/1001000/",
"""{
"origin": "remote",
"origin_server_ts": 1001000,
"pdus": [],
"edus": [
{
"origin": "remote",
"destination": "test",
"edu_type": "m.test",
"content": {"testing": "reply here"}
}
]
}"""
)
recv_observer.assert_called_with(
"remote", {"testing": "reply here"}
)
@defer.inlineCallbacks
def test_send_query(self):
self.mock_http_client.get_json.return_value = defer.succeed(
{"your": "response"}
)
response = yield self.federation.make_query(
destination="remote",
query_type="a-question",
args={"one": "1", "two": "2"},
)
self.assertEquals({"your": "response"}, response)
self.mock_http_client.get_json.assert_called_with(
destination="remote",
path="/_matrix/federation/v1/query/a-question",
args={"one": "1", "two": "2"},
retry_on_dns_fail=True,
)
@defer.inlineCallbacks
def test_recv_query(self):
recv_handler = Mock()
recv_handler.return_value = defer.succeed({"another": "response"})
self.federation.register_query_handler("a-question", recv_handler)
code, response = yield self.mock_resource.trigger(
"GET",
"/_matrix/federation/v1/query/a-question?three=3&four=4",
None
)
self.assertEquals(200, code)
self.assertEquals({"another": "response"}, response)
recv_handler.assert_called_with(
{"three": "3", "four": "4"}
)

View File

@ -280,6 +280,15 @@ class PresenceEventStreamTestCase(unittest.TestCase):
}
EventSources.SOURCE_TYPES["presence"] = PresenceEventSource
clock = Mock(spec=[
"call_later",
"cancel_call_later",
"time_msec",
"looping_call",
])
clock.time_msec.return_value = 1000000
hs = yield setup_test_homeserver(
http_client=None,
resource_for_client=self.mock_resource,
@ -289,16 +298,9 @@ class PresenceEventStreamTestCase(unittest.TestCase):
"get_presence_list",
"get_rooms_for_user",
]),
clock=Mock(spec=[
"call_later",
"cancel_call_later",
"time_msec",
"looping_call",
]),
clock=clock,
)
hs.get_clock().time_msec.return_value = 1000000
def _get_user_by_req(req=None, allow_guest=False):
return Requester(UserID.from_string(myid), "", False)

View File

@ -1045,8 +1045,13 @@ class RoomMessageListTestCase(RestTestCase):
self.assertTrue("end" in response)
@defer.inlineCallbacks
def test_stream_token_is_rejected(self):
def test_stream_token_is_accepted_for_fwd_pagianation(self):
token = "s0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=s0_0_0_0" %
self.room_id)
self.assertEquals(400, code)
"/rooms/%s/messages?access_token=x&from=%s" %
(self.room_id, token))
self.assertEquals(200, code)
self.assertTrue("start" in response)
self.assertEquals(token, response['start'])
self.assertTrue("chunk" in response)
self.assertTrue("end" in response)

View File

@ -439,7 +439,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f2 = self._write_config(suffix="2")
config = Mock(app_service_config_files=[f1, f2])
hs = yield setup_test_homeserver(config=config)
hs = yield setup_test_homeserver(config=config, datastore=Mock())
ApplicationServiceStore(hs)
@ -449,7 +449,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f2 = self._write_config(id="id", suffix="2")
config = Mock(app_service_config_files=[f1, f2])
hs = yield setup_test_homeserver(config=config)
hs = yield setup_test_homeserver(config=config, datastore=Mock())
with self.assertRaises(ConfigError) as cm:
ApplicationServiceStore(hs)
@ -465,7 +465,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f2 = self._write_config(as_token="as_token", suffix="2")
config = Mock(app_service_config_files=[f1, f2])
hs = yield setup_test_homeserver(config=config)
hs = yield setup_test_homeserver(config=config, datastore=Mock())
with self.assertRaises(ConfigError) as cm:
ApplicationServiceStore(hs)

View File

@ -18,7 +18,6 @@ from tests import unittest
from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.storage.registration import RegistrationStore
from synapse.util import stringutils
from tests.utils import setup_test_homeserver
@ -31,7 +30,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
hs = yield setup_test_homeserver()
self.db_pool = hs.get_db_pool()
self.store = RegistrationStore(hs)
self.store = hs.get_datastore()
self.user_id = "@my-user:test"
self.tokens = ["AbCdEfGhIjKlMnOpQrStUvWxYz",

View File

@ -16,10 +16,10 @@
from tests import unittest
from synapse.api.errors import SynapseError
from synapse.server import BaseHomeServer
from synapse.server import HomeServer
from synapse.types import UserID, RoomAlias
mock_homeserver = BaseHomeServer(hostname="my.domain")
mock_homeserver = HomeServer(hostname="my.domain")
class UserIDTestCase(unittest.TestCase):
@ -34,7 +34,6 @@ class UserIDTestCase(unittest.TestCase):
with self.assertRaises(SynapseError):
UserID.from_string("")
def test_build(self):
user = UserID("5678efgh", "my.domain")

View File

@ -19,6 +19,7 @@ from .. import unittest
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache
class LruCacheTestCase(unittest.TestCase):
def test_get_set(self):
@ -72,3 +73,9 @@ class LruCacheTestCase(unittest.TestCase):
self.assertEquals(cache.get(("vehicles", "car")), "vroom")
self.assertEquals(cache.get(("vehicles", "train")), "chuff")
# Man from del_multi say "Yes".
def test_clear(self):
cache = LruCache(1)
cache["key"] = 1
cache.clear()
self.assertEquals(len(cache), 0)

View File

@ -25,6 +25,7 @@ class TreeCacheTestCase(unittest.TestCase):
cache[("b",)] = "B"
self.assertEquals(cache.get(("a",)), "A")
self.assertEquals(cache.get(("b",)), "B")
self.assertEquals(len(cache), 2)
def test_pop_onelevel(self):
cache = TreeCache()
@ -33,6 +34,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEquals(cache.pop(("a",)), "A")
self.assertEquals(cache.pop(("a",)), None)
self.assertEquals(cache.get(("b",)), "B")
self.assertEquals(len(cache), 1)
def test_get_set_twolevel(self):
cache = TreeCache()
@ -42,6 +44,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEquals(cache.get(("a", "a")), "AA")
self.assertEquals(cache.get(("a", "b")), "AB")
self.assertEquals(cache.get(("b", "a")), "BA")
self.assertEquals(len(cache), 3)
def test_pop_twolevel(self):
cache = TreeCache()
@ -53,6 +56,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEquals(cache.get(("a", "b")), "AB")
self.assertEquals(cache.pop(("b", "a")), "BA")
self.assertEquals(cache.pop(("b", "a")), None)
self.assertEquals(len(cache), 1)
def test_pop_mixedlevel(self):
cache = TreeCache()
@ -64,3 +68,11 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEquals(cache.get(("a", "a")), None)
self.assertEquals(cache.get(("a", "b")), None)
self.assertEquals(cache.get(("b", "a")), "BA")
self.assertEquals(len(cache), 1)
def test_clear(self):
cache = TreeCache()
cache[("a",)] = "A"
cache[("b",)] = "B"
cache.clear()
self.assertEquals(len(cache), 0)

View File

@ -19,6 +19,8 @@ from synapse.api.constants import EventTypes
from synapse.storage.prepare_database import prepare_database
from synapse.storage.engines import create_engine
from synapse.server import HomeServer
from synapse.federation.transport import server
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.logcontext import LoggingContext
@ -58,8 +60,10 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
name, db_pool=db_pool, config=config,
version_string="Synapse/tests",
database_engine=create_engine("sqlite3"),
get_db_conn=db_pool.get_db_conn,
**kargs
)
hs.setup()
else:
hs = HomeServer(
name, db_pool=None, datastore=datastore, config=config,
@ -80,6 +84,22 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
hs.build_handlers = swap_out_hash_for_testing(hs.build_handlers)
fed = kargs.get("resource_for_federation", None)
if fed:
server.register_servlets(
hs,
resource=fed,
authenticator=server.Authenticator(hs),
ratelimiter=FederationRateLimiter(
hs.get_clock(),
window_size=hs.config.federation_rc_window_size,
sleep_limit=hs.config.federation_rc_sleep_limit,
sleep_msec=hs.config.federation_rc_sleep_delay,
reject_limit=hs.config.federation_rc_reject_limit,
concurrent_requests=hs.config.federation_rc_concurrent
),
)
defer.returnValue(hs)
@ -262,6 +282,12 @@ class SQLiteMemoryDbPool(ConnectionPool, object):
lambda conn: prepare_database(conn, engine)
)
def get_db_conn(self):
conn = self.connect()
engine = create_engine("sqlite3")
prepare_database(conn, engine)
return conn
class MemoryDataStore(object):