Merge branch 'release-v0.22.0' of github.com:matrix-org/synapse

This commit is contained in:
Erik Johnston 2017-07-06 10:36:25 +01:00
commit 42b50483be
91 changed files with 3929 additions and 488 deletions

View File

@ -1,3 +1,59 @@
Changes in synapse v0.22.0 (2017-07-06)
=======================================
No changes since v0.22.0-rc2
Changes in synapse v0.22.0-rc2 (2017-07-04)
===========================================
Changes:
* Improve performance of storing user IPs (PR #2307, #2308)
* Slightly improve performance of verifying access tokens (PR #2320)
* Slightly improve performance of event persistence (PR #2321)
* Increase default cache factor size from 0.1 to 0.5 (PR #2330)
Bug fixes:
* Fix bug with storing registration sessions that caused frequent CPU churn
(PR #2319)
Changes in synapse v0.22.0-rc1 (2017-06-26)
===========================================
Features:
* Add a user directory API (PR #2252, and many more)
* Add shutdown room API to remove room from local server (PR #2291)
* Add API to quarantine media (PR #2292)
* Add new config option to not send event contents to push servers (PR #2301)
Thanks to @cjdelisle!
Changes:
* Various performance fixes (PR #2177, #2233, #2230, #2238, #2248, #2256,
#2274)
* Deduplicate sync filters (PR #2219) Thanks to @krombel!
* Correct a typo in UPGRADE.rst (PR #2231) Thanks to @aaronraimist!
* Add count of one time keys to sync stream (PR #2237)
* Only store event_auth for state events (PR #2247)
* Store URL cache preview downloads separately (PR #2299)
Bug fixes:
* Fix users not getting notifications when AS listened to that user_id (PR
#2216) Thanks to @slipeer!
* Fix users without push set up not getting notifications after joining rooms
(PR #2236)
* Fix preview url API to trim long descriptions (PR #2243)
* Fix bug where we used cached but unpersisted state group as prev group,
resulting in broken state of restart (PR #2263)
* Fix removing of pushers when using workers (PR #2267)
* Fix CORS headers to allow Authorization header (PR #2285) Thanks to @krombel!
Changes in synapse v0.21.1 (2017-06-15) Changes in synapse v0.21.1 (2017-06-15)
======================================= =======================================

View File

@ -528,6 +528,30 @@ fix try re-installing from PyPI or directly from
# Install from github # Install from github
pip install --user https://github.com/pyca/pynacl/tarball/master pip install --user https://github.com/pyca/pynacl/tarball/master
Running out of File Handles
~~~~~~~~~~~~~~~~~~~~~~~~~~~
If synapse runs out of filehandles, it typically fails badly - live-locking
at 100% CPU, and/or failing to accept new TCP connections (blocking the
connecting client). Matrix currently can legitimately use a lot of file handles,
thanks to busy rooms like #matrix:matrix.org containing hundreds of participating
servers. The first time a server talks in a room it will try to connect
simultaneously to all participating servers, which could exhaust the available
file descriptors between DNS queries & HTTPS sockets, especially if DNS is slow
to respond. (We need to improve the routing algorithm used to be better than
full mesh, but as of June 2017 this hasn't happened yet).
If you hit this failure mode, we recommend increasing the maximum number of
open file handles to be at least 4096 (assuming a default of 1024 or 256).
This is typically done by editing ``/etc/security/limits.conf``
Separately, Synapse may leak file handles if inbound HTTP requests get stuck
during processing - e.g. blocked behind a lock or talking to a remote server etc.
This is best diagnosed by matching up the 'Received request' and 'Processed request'
log lines and looking for any 'Processed request' lines which take more than
a few seconds to execute. Please let us know at #matrix-dev:matrix.org if
you see this failure mode so we can help debug it, however.
ArchLinux ArchLinux
~~~~~~~~~ ~~~~~~~~~
@ -875,12 +899,9 @@ cache a lot of recent room data and metadata in RAM in order to speed up
common requests. We'll improve this in future, but for now the easiest common requests. We'll improve this in future, but for now the easiest
way to either reduce the RAM usage (at the risk of slowing things down) way to either reduce the RAM usage (at the risk of slowing things down)
is to set the almost-undocumented ``SYNAPSE_CACHE_FACTOR`` environment is to set the almost-undocumented ``SYNAPSE_CACHE_FACTOR`` environment
variable. Roughly speaking, a SYNAPSE_CACHE_FACTOR of 1.0 will max out variable. The default is 0.5, which can be decreased to reduce RAM usage
at around 3-4GB of resident memory - this is what we currently run the in memory constrained enviroments, or increased if performance starts to
matrix.org on. The default setting is currently 0.1, which is probably degrade.
around a ~700MB footprint. You can dial it down further to 0.02 if
desired, which targets roughly ~512MB. Conversely you can dial it up if
you need performance for lots of users and have a box with a lot of RAM.
.. _`key_management`: https://matrix.org/docs/spec/server_server/unstable.html#retrieving-server-keys .. _`key_management`: https://matrix.org/docs/spec/server_server/unstable.html#retrieving-server-keys

View File

@ -33,7 +33,7 @@ To check whether your update was sucessfull, run:
.. code:: bash .. code:: bash
# replace your.server.domain with ther domain of your synaspe homeserver # replace your.server.domain with ther domain of your synapse homeserver
curl https://<your.server.domain>/_matrix/federation/v1/version curl https://<your.server.domain>/_matrix/federation/v1/version
So for the Matrix.org HS server the URL would be: https://matrix.org/_matrix/federation/v1/version. So for the Matrix.org HS server the URL would be: https://matrix.org/_matrix/federation/v1/version.

View File

@ -41,6 +41,7 @@ BOOLEAN_COLUMNS = {
"presence_stream": ["currently_active"], "presence_stream": ["currently_active"],
"public_room_list_stream": ["visibility"], "public_room_list_stream": ["visibility"],
"device_lists_outbound_pokes": ["sent"], "device_lists_outbound_pokes": ["sent"],
"users_who_share_rooms": ["share_private"],
} }
@ -121,7 +122,7 @@ class Store(object):
try: try:
txn = conn.cursor() txn = conn.cursor()
return func( return func(
LoggingTransaction(txn, desc, self.database_engine, []), LoggingTransaction(txn, desc, self.database_engine, [], []),
*args, **kwargs *args, **kwargs
) )
except self.database_engine.module.DatabaseError as e: except self.database_engine.module.DatabaseError as e:

View File

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.21.1" __version__ = "0.22.0"

View File

@ -23,7 +23,8 @@ from synapse import event_auth
from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes from synapse.api.errors import AuthError, Codes
from synapse.types import UserID from synapse.types import UserID
from synapse.util import logcontext from synapse.util.caches import register_cache, CACHE_SIZE_FACTOR
from synapse.util.caches.lrucache import LruCache
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,6 +40,10 @@ AuthEventTypes = (
GUEST_DEVICE_ID = "guest_device" GUEST_DEVICE_ID = "guest_device"
class _InvalidMacaroonException(Exception):
pass
class Auth(object): class Auth(object):
""" """
FIXME: This class contains a mix of functions for authenticating users FIXME: This class contains a mix of functions for authenticating users
@ -51,6 +56,9 @@ class Auth(object):
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000)
register_cache("token_cache", self.token_cache)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_from_context(self, event, context, do_sig_check=True): def check_from_context(self, event, context, do_sig_check=True):
auth_events_ids = yield self.compute_auth_events( auth_events_ids = yield self.compute_auth_events(
@ -144,17 +152,8 @@ class Auth(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def check_host_in_room(self, room_id, host): def check_host_in_room(self, room_id, host):
with Measure(self.clock, "check_host_in_room"): with Measure(self.clock, "check_host_in_room"):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.is_host_joined(room_id, host)
defer.returnValue(latest_event_ids)
logger.debug("calling resolve_state_groups from check_host_in_room")
entry = yield self.state.resolve_state_groups(
room_id, latest_event_ids
)
ret = yield self.store.is_host_joined(
room_id, host, entry.state_group, entry.state
)
defer.returnValue(ret)
def _check_joined_room(self, member, user_id, room_id): def _check_joined_room(self, member, user_id, room_id):
if not member or member.membership != Membership.JOIN: if not member or member.membership != Membership.JOIN:
@ -209,7 +208,7 @@ class Auth(object):
default=[""] default=[""]
)[0] )[0]
if user and access_token and ip_addr: if user and access_token and ip_addr:
logcontext.preserve_fn(self.store.insert_client_ip)( self.store.insert_client_ip(
user=user, user=user,
access_token=access_token, access_token=access_token,
ip=ip_addr, ip=ip_addr,
@ -276,8 +275,8 @@ class Auth(object):
AuthError if no user by that token exists or the token is invalid. AuthError if no user by that token exists or the token is invalid.
""" """
try: try:
macaroon = pymacaroons.Macaroon.deserialize(token) user_id, guest = self._parse_and_validate_macaroon(token, rights)
except Exception: # deserialize can throw more-or-less anything except _InvalidMacaroonException:
# doesn't look like a macaroon: treat it as an opaque token which # doesn't look like a macaroon: treat it as an opaque token which
# must be in the database. # must be in the database.
# TODO: it would be nice to get rid of this, but apparently some # TODO: it would be nice to get rid of this, but apparently some
@ -286,19 +285,8 @@ class Auth(object):
defer.returnValue(r) defer.returnValue(r)
try: try:
user_id = self.get_user_id_from_macaroon(macaroon)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
self.validate_macaroon(
macaroon, rights, self.hs.config.expire_access_token,
user_id=user_id,
)
guest = False
for caveat in macaroon.caveats:
if caveat.caveat_id == "guest = true":
guest = True
if guest: if guest:
# Guest access tokens are not stored in the database (there can # Guest access tokens are not stored in the database (there can
# only be one access token per guest, anyway). # only be one access token per guest, anyway).
@ -370,6 +358,55 @@ class Auth(object):
errcode=Codes.UNKNOWN_TOKEN errcode=Codes.UNKNOWN_TOKEN
) )
def _parse_and_validate_macaroon(self, token, rights="access"):
"""Takes a macaroon and tries to parse and validate it. This is cached
if and only if rights == access and there isn't an expiry.
On invalid macaroon raises _InvalidMacaroonException
Returns:
(user_id, is_guest)
"""
if rights == "access":
cached = self.token_cache.get(token, None)
if cached:
return cached
try:
macaroon = pymacaroons.Macaroon.deserialize(token)
except Exception: # deserialize can throw more-or-less anything
# doesn't look like a macaroon: treat it as an opaque token which
# must be in the database.
# TODO: it would be nice to get rid of this, but apparently some
# people use access tokens which aren't macaroons
raise _InvalidMacaroonException()
try:
user_id = self.get_user_id_from_macaroon(macaroon)
has_expiry = False
guest = False
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith("time "):
has_expiry = True
elif caveat.caveat_id == "guest = true":
guest = True
self.validate_macaroon(
macaroon, rights, self.hs.config.expire_access_token,
user_id=user_id,
)
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
errcode=Codes.UNKNOWN_TOKEN
)
if not has_expiry and rights == "access":
self.token_cache[token] = (user_id, guest)
return user_id, guest
def get_user_id_from_macaroon(self, macaroon): def get_user_id_from_macaroon(self, macaroon):
"""Retrieve the user_id given by the caveats on the macaroon. """Retrieve the user_id given by the caveats on the macaroon.

View File

@ -24,6 +24,7 @@ from synapse.http.server import JsonResource
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.keys import SlavedKeyStore from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.room import RoomStore
@ -33,7 +34,6 @@ from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.client.v1.room import PublicRoomListRestServlet from synapse.rest.client.v1.room import PublicRoomListRestServlet
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
@ -65,8 +65,8 @@ class ClientReaderSlavedStore(
SlavedApplicationServiceStore, SlavedApplicationServiceStore,
SlavedRegistrationStore, SlavedRegistrationStore,
TransactionStore, TransactionStore,
SlavedClientIpStore,
BaseSlavedStore, BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different
): ):
pass pass

View File

@ -51,7 +51,7 @@ import sys
import logging import logging
import gc import gc
logger = logging.getLogger("synapse.app.appservice") logger = logging.getLogger("synapse.app.federation_sender")
class FederationSenderSlaveStore( class FederationSenderSlaveStore(

View File

@ -23,13 +23,13 @@ from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage.media_repository import MediaRepositoryStore from synapse.storage.media_repository import MediaRepositoryStore
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
@ -60,10 +60,10 @@ logger = logging.getLogger("synapse.app.media_repository")
class MediaRepositorySlavedStore( class MediaRepositorySlavedStore(
SlavedApplicationServiceStore, SlavedApplicationServiceStore,
SlavedRegistrationStore, SlavedRegistrationStore,
SlavedClientIpStore,
TransactionStore, TransactionStore,
BaseSlavedStore, BaseSlavedStore,
MediaRepositoryStore, MediaRepositoryStore,
ClientIpStore,
): ):
pass pass

View File

@ -29,6 +29,7 @@ from synapse.rest.client.v1 import events
from synapse.rest.client.v1.room import RoomInitialSyncRestServlet from synapse.rest.client.v1.room import RoomInitialSyncRestServlet
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
@ -42,7 +43,6 @@ from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage.presence import UserPresenceState from synapse.storage.presence import UserPresenceState
from synapse.storage.roommember import RoomMemberStore from synapse.storage.roommember import RoomMemberStore
@ -77,9 +77,9 @@ class SynchrotronSlavedStore(
SlavedPresenceStore, SlavedPresenceStore,
SlavedDeviceInboxStore, SlavedDeviceInboxStore,
SlavedDeviceStore, SlavedDeviceStore,
SlavedClientIpStore,
RoomStore, RoomStore,
BaseSlavedStore, BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different
): ):
who_forgot_in_room = ( who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"] RoomMemberStore.__dict__["who_forgot_in_room"]

270
synapse/app/user_dir.py Normal file
View File

@ -0,0 +1,270 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.server import HomeServer
from synapse.config._base import ConfigError
from synapse.config.logger import setup_logging
from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
from synapse.http.site import SynapseSite
from synapse.http.server import JsonResource
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.client.v2_alpha import user_directory
from synapse.storage.engines import create_engine
from synapse.storage.user_directory import UserDirectoryStore
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext, preserve_fn
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse import events
from twisted.internet import reactor
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import gc
logger = logging.getLogger("synapse.app.user_dir")
class UserDirectorySlaveStore(
SlavedEventStore,
SlavedApplicationServiceStore,
SlavedRegistrationStore,
SlavedClientIpStore,
UserDirectoryStore,
BaseSlavedStore,
):
def __init__(self, db_conn, hs):
super(UserDirectorySlaveStore, self).__init__(db_conn, hs)
events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
db_conn, "current_state_delta_stream",
entity_column="room_id",
stream_column="stream_id",
max_value=events_max, # As we share the stream id with events token
limit=1000,
)
self._curr_state_delta_stream_cache = StreamChangeCache(
"_curr_state_delta_stream_cache", min_curr_state_delta_id,
prefilled_cache=curr_state_delta_prefill,
)
self._current_state_delta_pos = events_max
def stream_positions(self):
result = super(UserDirectorySlaveStore, self).stream_positions()
result["current_state_deltas"] = self._current_state_delta_pos
return result
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "current_state_deltas":
self._current_state_delta_pos = token
for row in rows:
self._curr_state_delta_stream_cache.entity_has_changed(
row.room_id, token
)
return super(UserDirectorySlaveStore, self).process_replication_rows(
stream_name, token, rows
)
class UserDirectoryServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_addresses = listener_config["bind_addresses"]
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
elif name == "client":
resource = JsonResource(self, canonical_json=False)
user_directory.register_servlets(self, resource)
resources.update({
"/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource,
"/_matrix/client/v2_alpha": resource,
"/_matrix/client/api/v1": resource,
})
root_resource = create_resource_tree(resources, Resource())
for address in bind_addresses:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=address
)
logger.info("Synapse user_dir now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=address
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self)
def build_tcp_replication(self):
return UserDirectoryReplicationHandler(self)
class UserDirectoryReplicationHandler(ReplicationClientHandler):
def __init__(self, hs):
super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore())
self.user_directory = hs.get_user_directory_handler()
def on_rdata(self, stream_name, token, rows):
super(UserDirectoryReplicationHandler, self).on_rdata(
stream_name, token, rows
)
if stream_name == "current_state_deltas":
preserve_fn(self.user_directory.notify_new_event)()
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse user directory", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.user_dir"
setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config)
if config.update_user_directory:
sys.stderr.write(
"\nThe update_user_directory must be disabled in the main synapse process"
"\nbefore they can be run in a separate worker."
"\nPlease add ``update_user_directory: false`` to the main config"
"\n"
)
sys.exit(1)
# Force the pushers to start since they will be disabled in the main config
config.update_user_directory = True
tls_server_context_factory = context_factory.ServerContextFactory(config)
ps = UserDirectoryServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ps.setup()
ps.start_listening(config.worker_listeners)
def run():
# make sure that we run the reactor with the sentinel log context,
# otherwise other PreserveLoggingContext instances will get confused
# and complain when they see the logcontext arbitrarily swapping
# between the sentinel and `run` logcontexts.
with PreserveLoggingContext():
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ps.get_datastore().start_profiling()
ps.get_state_handler().start_caching()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-user-dir",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View File

@ -241,6 +241,16 @@ class ApplicationService(object):
def is_exclusive_room(self, room_id): def is_exclusive_room(self, room_id):
return self._is_exclusive(ApplicationService.NS_ROOMS, room_id) return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
def get_exlusive_user_regexes(self):
"""Get the list of regexes used to determine if a user is exclusively
registered by the AS
"""
return [
regex_obj["regex"]
for regex_obj in self.namespaces[ApplicationService.NS_USERS]
if regex_obj["exclusive"]
]
def is_rate_limited(self): def is_rate_limited(self):
return self.rate_limited return self.rate_limited

View File

@ -33,6 +33,7 @@ from .jwt import JWTConfig
from .password_auth_providers import PasswordAuthProviderConfig from .password_auth_providers import PasswordAuthProviderConfig
from .emailconfig import EmailConfig from .emailconfig import EmailConfig
from .workers import WorkerConfig from .workers import WorkerConfig
from .push import PushConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
@ -40,7 +41,7 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig, VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
JWTConfig, PasswordConfig, EmailConfig, JWTConfig, PasswordConfig, EmailConfig,
WorkerConfig, PasswordAuthProviderConfig,): WorkerConfig, PasswordAuthProviderConfig, PushConfig,):
pass pass

45
synapse/config/push.py Normal file
View File

@ -0,0 +1,45 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class PushConfig(Config):
def read_config(self, config):
self.push_redact_content = False
push_config = config.get("email", {})
self.push_redact_content = push_config.get("redact_content", False)
def default_config(self, config_dir_path, server_name, **kwargs):
return """
# Control how push messages are sent to google/apple to notifications.
# Normally every message said in a room with one or more people using
# mobile devices will be posted to a push server hosted by matrix.org
# which is registered with google and apple in order to allow push
# notifications to be sent to these mobile devices.
#
# Setting redact_content to true will make the push messages contain no
# message content which will provide increased privacy. This is a
# temporary solution pending improvements to Android and iPhone apps
# to get content from the app rather than the notification.
#
# For modern android devices the notification content will still appear
# because it is loaded by the app. iPhone, however will send a
# notification saying only that a message arrived and who it came from.
#
#push:
# redact_content: false
"""

View File

@ -35,6 +35,10 @@ class ServerConfig(Config):
# "disable" federation # "disable" federation
self.send_federation = config.get("send_federation", True) self.send_federation = config.get("send_federation", True)
# Whether to update the user directory or not. This should be set to
# false only if we are updating the user directory in a worker
self.update_user_directory = config.get("update_user_directory", True)
self.filter_timeline_limit = config.get("filter_timeline_limit", -1) self.filter_timeline_limit = config.get("filter_timeline_limit", -1)
if self.public_baseurl is not None: if self.public_baseurl is not None:

View File

@ -187,6 +187,7 @@ class TransactionQueue(object):
prev_id for prev_id, _ in event.prev_events prev_id for prev_id, _ in event.prev_events
], ],
) )
destinations = set(destinations)
if send_on_behalf_of is not None: if send_on_behalf_of is not None:
# If we are sending the event on behalf of another server # If we are sending the event on behalf of another server

View File

@ -21,6 +21,7 @@ from synapse.api.constants import LoginType
from synapse.types import UserID from synapse.types import UserID
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.caches.expiringcache import ExpiringCache
from twisted.web.client import PartialDownloadError from twisted.web.client import PartialDownloadError
@ -52,7 +53,15 @@ class AuthHandler(BaseHandler):
LoginType.DUMMY: self._check_dummy_auth, LoginType.DUMMY: self._check_dummy_auth,
} }
self.bcrypt_rounds = hs.config.bcrypt_rounds self.bcrypt_rounds = hs.config.bcrypt_rounds
self.sessions = {}
# This is not a cache per se, but a store of all current sessions that
# expire after N hours
self.sessions = ExpiringCache(
cache_name="register_sessions",
clock=hs.get_clock(),
expiry_ms=self.SESSION_EXPIRE_MS,
reset_expiry_on_get=True,
)
account_handler = _AccountHandler( account_handler = _AccountHandler(
hs, check_user_exists=self.check_user_exists hs, check_user_exists=self.check_user_exists
@ -617,16 +626,6 @@ class AuthHandler(BaseHandler):
logger.debug("Saving session %s", session) logger.debug("Saving session %s", session)
session["last_used"] = self.hs.get_clock().time_msec() session["last_used"] = self.hs.get_clock().time_msec()
self.sessions[session["id"]] = session self.sessions[session["id"]] = session
self._prune_sessions()
def _prune_sessions(self):
for sid, sess in self.sessions.items():
last_used = 0
if 'last_used' in sess:
last_used = sess['last_used']
now = self.hs.get_clock().time_msec()
if last_used < now - AuthHandler.SESSION_EXPIRE_MS:
del self.sessions[sid]
def hash(self, password): def hash(self, password):
"""Computes a secure hash of password. """Computes a secure hash of password.

View File

@ -106,7 +106,7 @@ class DeviceHandler(BaseHandler):
device_map = yield self.store.get_devices_by_user(user_id) device_map = yield self.store.get_devices_by_user(user_id)
ips = yield self.store.get_last_client_ip_by_device( ips = yield self.store.get_last_client_ip_by_device(
devices=((user_id, device_id) for device_id in device_map.keys()) user_id, device_id=None
) )
devices = device_map.values() devices = device_map.values()
@ -133,7 +133,7 @@ class DeviceHandler(BaseHandler):
except errors.StoreError: except errors.StoreError:
raise errors.NotFoundError raise errors.NotFoundError
ips = yield self.store.get_last_client_ip_by_device( ips = yield self.store.get_last_client_ip_by_device(
devices=((user_id, device_id),) user_id, device_id,
) )
_update_device_from_client_ips(device, ips) _update_device_from_client_ips(device, ips)
defer.returnValue(device) defer.returnValue(device)

View File

@ -43,7 +43,6 @@ from synapse.events.utils import prune_event
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
from synapse.push.action_generator import ActionGenerator
from synapse.util.distributor import user_joined_room from synapse.util.distributor import user_joined_room
from twisted.internet import defer from twisted.internet import defer
@ -75,6 +74,8 @@ class FederationHandler(BaseHandler):
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
self.server_name = hs.hostname self.server_name = hs.hostname
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
self.action_generator = hs.get_action_generator()
self.is_mine_id = hs.is_mine_id
self.replication_layer.set_handler(self) self.replication_layer.set_handler(self)
@ -832,7 +833,11 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_event_auth(self, event_id): def on_event_auth(self, event_id):
auth = yield self.store.get_auth_chain([event_id]) event = yield self.store.get_event(event_id)
auth = yield self.store.get_auth_chain(
[auth_id for auth_id, _ in event.auth_events],
include_given=True
)
for event in auth: for event in auth:
event.signatures.update( event.signatures.update(
@ -1047,9 +1052,7 @@ class FederationHandler(BaseHandler):
yield user_joined_room(self.distributor, user, event.room_id) yield user_joined_room(self.distributor, user, event.room_id)
state_ids = context.prev_state_ids.values() state_ids = context.prev_state_ids.values()
auth_chain = yield self.store.get_auth_chain(set( auth_chain = yield self.store.get_auth_chain(state_ids)
[event.event_id] + state_ids
))
state = yield self.store.get_events(context.prev_state_ids.values()) state = yield self.store.get_events(context.prev_state_ids.values())
@ -1066,6 +1069,24 @@ class FederationHandler(BaseHandler):
""" """
event = pdu event = pdu
is_blocked = yield self.store.is_room_blocked(event.room_id)
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
membership = event.content.get("membership")
if event.type != EventTypes.Member or membership != Membership.INVITE:
raise SynapseError(400, "The event was not an m.room.member invite event")
sender_domain = get_domain_from_id(event.sender)
if sender_domain != origin:
raise SynapseError(400, "The invite event was not from the server sending it")
if event.state_key is None:
raise SynapseError(400, "The invite event did not have a state key")
if not self.is_mine_id(event.state_key):
raise SynapseError(400, "The invite event must be for this server")
event.internal_metadata.outlier = True event.internal_metadata.outlier = True
event.internal_metadata.invite_from_remote = True event.internal_metadata.invite_from_remote = True
@ -1100,6 +1121,9 @@ class FederationHandler(BaseHandler):
user_id, user_id,
"leave" "leave"
) )
# Mark as outlier as we don't have any state for this event; we're not
# even in the room.
event.internal_metadata.outlier = True
event = self._sign_event(event) event = self._sign_event(event)
# Try the host that we succesfully called /make_leave/ on first for # Try the host that we succesfully called /make_leave/ on first for
@ -1271,7 +1295,7 @@ class FederationHandler(BaseHandler):
for event in res: for event in res:
# We sign these again because there was a bug where we # We sign these again because there was a bug where we
# incorrectly signed things the first time round # incorrectly signed things the first time round
if self.hs.is_mine_id(event.event_id): if self.is_mine_id(event.event_id):
event.signatures.update( event.signatures.update(
compute_event_signature( compute_event_signature(
event, event,
@ -1344,7 +1368,7 @@ class FederationHandler(BaseHandler):
) )
if event: if event:
if self.hs.is_mine_id(event.event_id): if self.is_mine_id(event.event_id):
# FIXME: This is a temporary work around where we occasionally # FIXME: This is a temporary work around where we occasionally
# return events slightly differently than when they were # return events slightly differently than when they were
# originally signed # originally signed
@ -1389,8 +1413,7 @@ class FederationHandler(BaseHandler):
) )
if not event.internal_metadata.is_outlier(): if not event.internal_metadata.is_outlier():
action_generator = ActionGenerator(self.hs) yield self.action_generator.handle_push_actions_for_event(
yield action_generator.handle_push_actions_for_event(
event, context event, context
) )
@ -1599,7 +1622,11 @@ class FederationHandler(BaseHandler):
pass pass
# Now get the current auth_chain for the event. # Now get the current auth_chain for the event.
local_auth_chain = yield self.store.get_auth_chain([event_id]) event = yield self.store.get_event(event_id)
local_auth_chain = yield self.store.get_auth_chain(
[auth_id for auth_id, _ in event.auth_events],
include_given=True
)
# TODO: Check if we would now reject event_id. If so we need to tell # TODO: Check if we would now reject event_id. If so we need to tell
# everyone. # everyone.
@ -1792,7 +1819,9 @@ class FederationHandler(BaseHandler):
auth_ids = yield self.auth.compute_auth_events( auth_ids = yield self.auth.compute_auth_events(
event, context.prev_state_ids event, context.prev_state_ids
) )
local_auth_chain = yield self.store.get_auth_chain(auth_ids) local_auth_chain = yield self.store.get_auth_chain(
auth_ids, include_given=True
)
try: try:
# 2. Get remote difference. # 2. Get remote difference.

View File

@ -20,7 +20,6 @@ from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.push.action_generator import ActionGenerator
from synapse.types import ( from synapse.types import (
UserID, RoomAlias, RoomStreamToken, UserID, RoomAlias, RoomStreamToken,
) )
@ -35,6 +34,7 @@ from canonicaljson import encode_canonical_json
import logging import logging
import random import random
import ujson
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -54,6 +54,8 @@ class MessageHandler(BaseHandler):
# This is to stop us from diverging history *too* much. # This is to stop us from diverging history *too* much.
self.limiter = Limiter(max_count=5) self.limiter = Limiter(max_count=5)
self.action_generator = hs.get_action_generator()
@defer.inlineCallbacks @defer.inlineCallbacks
def purge_history(self, room_id, event_id): def purge_history(self, room_id, event_id):
event = yield self.store.get_event(event_id) event = yield self.store.get_event(event_id)
@ -497,6 +499,14 @@ class MessageHandler(BaseHandler):
logger.warn("Denying new event %r because %s", event, err) logger.warn("Denying new event %r because %s", event, err)
raise err raise err
# Ensure that we can round trip before trying to persist in db
try:
dump = ujson.dumps(event.content)
ujson.loads(dump)
except:
logger.exception("Failed to encode content: %r", event.content)
raise
yield self.maybe_kick_guest_users(event, context) yield self.maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias: if event.type == EventTypes.CanonicalAlias:
@ -590,8 +600,7 @@ class MessageHandler(BaseHandler):
"Changing the room create event is forbidden", "Changing the room create event is forbidden",
) )
action_generator = ActionGenerator(self.hs) yield self.action_generator.handle_push_actions_for_event(
yield action_generator.handle_push_actions_for_event(
event, context event, context
) )

View File

@ -61,7 +61,7 @@ class RoomCreationHandler(BaseHandler):
} }
@defer.inlineCallbacks @defer.inlineCallbacks
def create_room(self, requester, config): def create_room(self, requester, config, ratelimit=True):
""" Creates a new room. """ Creates a new room.
Args: Args:
@ -75,6 +75,7 @@ class RoomCreationHandler(BaseHandler):
""" """
user_id = requester.user.to_string() user_id = requester.user.to_string()
if ratelimit:
yield self.ratelimit(requester) yield self.ratelimit(requester)
if "room_alias_name" in config: if "room_alias_name" in config:
@ -167,6 +168,7 @@ class RoomCreationHandler(BaseHandler):
initial_state=initial_state, initial_state=initial_state,
creation_content=creation_content, creation_content=creation_content,
room_alias=room_alias, room_alias=room_alias,
power_level_content_override=config.get("power_level_content_override", {})
) )
if "name" in config: if "name" in config:
@ -245,7 +247,8 @@ class RoomCreationHandler(BaseHandler):
invite_list, invite_list,
initial_state, initial_state,
creation_content, creation_content,
room_alias room_alias,
power_level_content_override,
): ):
def create(etype, content, **kwargs): def create(etype, content, **kwargs):
e = { e = {
@ -291,7 +294,15 @@ class RoomCreationHandler(BaseHandler):
ratelimit=False, ratelimit=False,
) )
if (EventTypes.PowerLevels, '') not in initial_state: # We treat the power levels override specially as this needs to be one
# of the first events that get sent into a room.
pl_content = initial_state.pop((EventTypes.PowerLevels, ''), None)
if pl_content is not None:
yield send(
etype=EventTypes.PowerLevels,
content=pl_content,
)
else:
power_level_content = { power_level_content = {
"users": { "users": {
creator_id: 100, creator_id: 100,
@ -316,6 +327,8 @@ class RoomCreationHandler(BaseHandler):
for invitee in invite_list: for invitee in invite_list:
power_level_content["users"][invitee] = 100 power_level_content["users"][invitee] = 100
power_level_content.update(power_level_content_override)
yield send( yield send(
etype=EventTypes.PowerLevels, etype=EventTypes.PowerLevels,
content=power_level_content, content=power_level_content,

View File

@ -203,6 +203,11 @@ class RoomMemberHandler(BaseHandler):
if not remote_room_hosts: if not remote_room_hosts:
remote_room_hosts = [] remote_room_hosts = []
if effective_membership_state not in ("leave", "ban",):
is_blocked = yield self.store.is_room_blocked(room_id)
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
current_state_ids = yield self.state_handler.get_current_state_ids( current_state_ids = yield self.state_handler.get_current_state_ids(
room_id, latest_event_ids=latest_event_ids, room_id, latest_event_ids=latest_event_ids,
@ -369,6 +374,11 @@ class RoomMemberHandler(BaseHandler):
# so don't really fit into the general auth process. # so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed") raise AuthError(403, "Guest access not allowed")
if event.membership not in (Membership.LEAVE, Membership.BAN):
is_blocked = yield self.store.is_room_blocked(room_id)
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
yield message_handler.handle_new_client_event( yield message_handler.handle_new_client_event(
requester, requester,
event, event,

View File

@ -117,6 +117,8 @@ class SyncResult(collections.namedtuple("SyncResult", [
"archived", # ArchivedSyncResult for each archived room. "archived", # ArchivedSyncResult for each archived room.
"to_device", # List of direct messages for the device. "to_device", # List of direct messages for the device.
"device_lists", # List of user_ids whose devices have chanegd "device_lists", # List of user_ids whose devices have chanegd
"device_one_time_keys_count", # Dict of algorithm to count for one time keys
# for this device
])): ])):
__slots__ = [] __slots__ = []
@ -550,6 +552,14 @@ class SyncHandler(object):
sync_result_builder sync_result_builder
) )
device_id = sync_config.device_id
one_time_key_counts = {}
if device_id:
user_id = sync_config.user.to_string()
one_time_key_counts = yield self.store.count_e2e_one_time_keys(
user_id, device_id
)
defer.returnValue(SyncResult( defer.returnValue(SyncResult(
presence=sync_result_builder.presence, presence=sync_result_builder.presence,
account_data=sync_result_builder.account_data, account_data=sync_result_builder.account_data,
@ -558,6 +568,7 @@ class SyncHandler(object):
archived=sync_result_builder.archived, archived=sync_result_builder.archived,
to_device=sync_result_builder.to_device, to_device=sync_result_builder.to_device,
device_lists=device_lists, device_lists=device_lists,
device_one_time_keys_count=one_time_key_counts,
next_batch=sync_result_builder.now_token, next_batch=sync_result_builder.now_token,
)) ))

View File

@ -89,7 +89,7 @@ class TypingHandler(object):
until = self._member_typing_until.get(member, None) until = self._member_typing_until.get(member, None)
if not until or until <= now: if not until or until <= now:
logger.info("Timing out typing for: %s", member.user_id) logger.info("Timing out typing for: %s", member.user_id)
preserve_fn(self._stopped_typing)(member) self._stopped_typing(member)
continue continue
# Check if we need to resend a keep alive over federation for this # Check if we need to resend a keep alive over federation for this
@ -147,7 +147,7 @@ class TypingHandler(object):
# No point sending another notification # No point sending another notification
defer.returnValue(None) defer.returnValue(None)
yield self._push_update( self._push_update(
member=member, member=member,
typing=True, typing=True,
) )
@ -171,7 +171,7 @@ class TypingHandler(object):
member = RoomMember(room_id=room_id, user_id=target_user_id) member = RoomMember(room_id=room_id, user_id=target_user_id)
yield self._stopped_typing(member) self._stopped_typing(member)
@defer.inlineCallbacks @defer.inlineCallbacks
def user_left_room(self, user, room_id): def user_left_room(self, user, room_id):
@ -180,7 +180,6 @@ class TypingHandler(object):
member = RoomMember(room_id=room_id, user_id=user_id) member = RoomMember(room_id=room_id, user_id=user_id)
yield self._stopped_typing(member) yield self._stopped_typing(member)
@defer.inlineCallbacks
def _stopped_typing(self, member): def _stopped_typing(self, member):
if member.user_id not in self._room_typing.get(member.room_id, set()): if member.user_id not in self._room_typing.get(member.room_id, set()):
# No point # No point
@ -189,16 +188,15 @@ class TypingHandler(object):
self._member_typing_until.pop(member, None) self._member_typing_until.pop(member, None)
self._member_last_federation_poke.pop(member, None) self._member_last_federation_poke.pop(member, None)
yield self._push_update( self._push_update(
member=member, member=member,
typing=False, typing=False,
) )
@defer.inlineCallbacks
def _push_update(self, member, typing): def _push_update(self, member, typing):
if self.hs.is_mine_id(member.user_id): if self.hs.is_mine_id(member.user_id):
# Only send updates for changes to our own users. # Only send updates for changes to our own users.
yield self._push_remote(member, typing) preserve_fn(self._push_remote)(member, typing)
self._push_update_local( self._push_update_local(
member=member, member=member,

View File

@ -0,0 +1,641 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.storage.roommember import ProfileInfo
from synapse.util.metrics import Measure
from synapse.util.async import sleep
logger = logging.getLogger(__name__)
class UserDirectoyHandler(object):
"""Handles querying of and keeping updated the user_directory.
N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY
The user directory is filled with users who this server can see are joined to a
world_readable or publically joinable room. We keep a database table up to date
by streaming changes of the current state and recalculating whether users should
be in the directory or not when necessary.
For each user in the directory we also store a room_id which is public and that the
user is joined to. This allows us to ignore history_visibility and join_rules changes
for that user in all other public rooms, as we know they'll still be in at least
one public room.
"""
INITIAL_SLEEP_MS = 50
INITIAL_SLEEP_COUNT = 100
INITIAL_BATCH_SIZE = 100
def __init__(self, hs):
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self.server_name = hs.hostname
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
self.update_user_directory = hs.config.update_user_directory
# When start up for the first time we need to populate the user_directory.
# This is a set of user_id's we've inserted already
self.initially_handled_users = set()
self.initially_handled_users_in_public = set()
self.initially_handled_users_share = set()
self.initially_handled_users_share_private_room = set()
# The current position in the current_state_delta stream
self.pos = None
# Guard to ensure we only process deltas one at a time
self._is_processing = False
if self.update_user_directory:
self.notifier.add_replication_callback(self.notify_new_event)
# We kick this off so that we don't have to wait for a change before
# we start populating the user directory
self.clock.call_later(0, self.notify_new_event)
def search_users(self, user_id, search_term, limit):
"""Searches for users in directory
Returns:
dict of the form::
{
"limited": <bool>, # whether there were more results or not
"results": [ # Ordered by best match first
{
"user_id": <user_id>,
"display_name": <display_name>,
"avatar_url": <avatar_url>
}
]
}
"""
return self.store.search_user_dir(user_id, search_term, limit)
@defer.inlineCallbacks
def notify_new_event(self):
"""Called when there may be more deltas to process
"""
if not self.update_user_directory:
return
if self._is_processing:
return
self._is_processing = True
try:
yield self._unsafe_process()
finally:
self._is_processing = False
@defer.inlineCallbacks
def _unsafe_process(self):
# If self.pos is None then means we haven't fetched it from DB
if self.pos is None:
self.pos = yield self.store.get_user_directory_stream_pos()
# If still None then we need to do the initial fill of directory
if self.pos is None:
yield self._do_initial_spam()
self.pos = yield self.store.get_user_directory_stream_pos()
# Loop round handling deltas until we're up to date
while True:
with Measure(self.clock, "user_dir_delta"):
deltas = yield self.store.get_current_state_deltas(self.pos)
if not deltas:
return
logger.info("Handling %d state deltas", len(deltas))
yield self._handle_deltas(deltas)
self.pos = deltas[-1]["stream_id"]
yield self.store.update_user_directory_stream_pos(self.pos)
@defer.inlineCallbacks
def _do_initial_spam(self):
"""Populates the user_directory from the current state of the DB, used
when synapse first starts with user_directory support
"""
new_pos = yield self.store.get_max_stream_id_in_current_state_deltas()
# Delete any existing entries just in case there are any
yield self.store.delete_all_from_user_dir()
# We process by going through each existing room at a time.
room_ids = yield self.store.get_all_rooms()
logger.info("Doing initial update of user directory. %d rooms", len(room_ids))
num_processed_rooms = 1
for room_id in room_ids:
logger.info("Handling room %d/%d", num_processed_rooms, len(room_ids))
yield self._handle_intial_room(room_id)
num_processed_rooms += 1
yield sleep(self.INITIAL_SLEEP_MS / 1000.)
logger.info("Processed all rooms.")
self.initially_handled_users = None
self.initially_handled_users_in_public = None
self.initially_handled_users_share = None
self.initially_handled_users_share_private_room = None
yield self.store.update_user_directory_stream_pos(new_pos)
@defer.inlineCallbacks
def _handle_intial_room(self, room_id):
"""Called when we initially fill out user_directory one room at a time
"""
is_in_room = yield self.store.is_host_joined(room_id, self.server_name)
if not is_in_room:
return
is_public = yield self.store.is_room_world_readable_or_publicly_joinable(room_id)
users_with_profile = yield self.state.get_current_user_in_room(room_id)
user_ids = set(users_with_profile)
unhandled_users = user_ids - self.initially_handled_users
yield self.store.add_profiles_to_user_dir(
room_id, {
user_id: users_with_profile[user_id] for user_id in unhandled_users
}
)
self.initially_handled_users |= unhandled_users
if is_public:
yield self.store.add_users_to_public_room(
room_id,
user_ids=user_ids - self.initially_handled_users_in_public
)
self.initially_handled_users_in_public |= user_ids
# We now go and figure out the new users who share rooms with user entries
# We sleep aggressively here as otherwise it can starve resources.
# We also batch up inserts/updates, but try to avoid too many at once.
to_insert = set()
to_update = set()
count = 0
for user_id in user_ids:
if count % self.INITIAL_SLEEP_COUNT == 0:
yield sleep(self.INITIAL_SLEEP_MS / 1000.)
if not self.is_mine_id(user_id):
count += 1
continue
if self.store.get_if_app_services_interested_in_user(user_id):
count += 1
continue
for other_user_id in user_ids:
if user_id == other_user_id:
continue
if count % self.INITIAL_SLEEP_COUNT == 0:
yield sleep(self.INITIAL_SLEEP_MS / 1000.)
count += 1
user_set = (user_id, other_user_id)
if user_set in self.initially_handled_users_share_private_room:
continue
if user_set in self.initially_handled_users_share:
if is_public:
continue
to_update.add(user_set)
else:
to_insert.add(user_set)
if is_public:
self.initially_handled_users_share.add(user_set)
else:
self.initially_handled_users_share_private_room.add(user_set)
if len(to_insert) > self.INITIAL_BATCH_SIZE:
yield self.store.add_users_who_share_room(
room_id, not is_public, to_insert,
)
to_insert.clear()
if len(to_update) > self.INITIAL_BATCH_SIZE:
yield self.store.update_users_who_share_room(
room_id, not is_public, to_update,
)
to_update.clear()
if to_insert:
yield self.store.add_users_who_share_room(
room_id, not is_public, to_insert,
)
to_insert.clear()
if to_update:
yield self.store.update_users_who_share_room(
room_id, not is_public, to_update,
)
to_update.clear()
@defer.inlineCallbacks
def _handle_deltas(self, deltas):
"""Called with the state deltas to process
"""
for delta in deltas:
typ = delta["type"]
state_key = delta["state_key"]
room_id = delta["room_id"]
event_id = delta["event_id"]
prev_event_id = delta["prev_event_id"]
logger.debug("Handling: %r %r, %s", typ, state_key, event_id)
# For join rule and visibility changes we need to check if the room
# may have become public or not and add/remove the users in said room
if typ in (EventTypes.RoomHistoryVisibility, EventTypes.JoinRules):
yield self._handle_room_publicity_change(
room_id, prev_event_id, event_id, typ,
)
elif typ == EventTypes.Member:
change = yield self._get_key_change(
prev_event_id, event_id,
key_name="membership",
public_value=Membership.JOIN,
)
if change is None:
# Handle any profile changes
yield self._handle_profile_change(
state_key, room_id, prev_event_id, event_id,
)
continue
if not change:
# Need to check if the server left the room entirely, if so
# we might need to remove all the users in that room
is_in_room = yield self.store.is_host_joined(
room_id, self.server_name,
)
if not is_in_room:
logger.info("Server left room: %r", room_id)
# Fetch all the users that we marked as being in user
# directory due to being in the room and then check if
# need to remove those users or not
user_ids = yield self.store.get_users_in_dir_due_to_room(room_id)
for user_id in user_ids:
yield self._handle_remove_user(room_id, user_id)
return
else:
logger.debug("Server is still in room: %r", room_id)
if change: # The user joined
event = yield self.store.get_event(event_id, allow_none=True)
profile = ProfileInfo(
avatar_url=event.content.get("avatar_url"),
display_name=event.content.get("displayname"),
)
yield self._handle_new_user(room_id, state_key, profile)
else: # The user left
yield self._handle_remove_user(room_id, state_key)
else:
logger.debug("Ignoring irrelevant type: %r", typ)
@defer.inlineCallbacks
def _handle_room_publicity_change(self, room_id, prev_event_id, event_id, typ):
"""Handle a room having potentially changed from/to world_readable/publically
joinable.
Args:
room_id (str)
prev_event_id (str|None): The previous event before the state change
event_id (str|None): The new event after the state change
typ (str): Type of the event
"""
logger.debug("Handling change for %s: %s", typ, room_id)
if typ == EventTypes.RoomHistoryVisibility:
change = yield self._get_key_change(
prev_event_id, event_id,
key_name="history_visibility",
public_value="world_readable",
)
elif typ == EventTypes.JoinRules:
change = yield self._get_key_change(
prev_event_id, event_id,
key_name="join_rule",
public_value=JoinRules.PUBLIC,
)
else:
raise Exception("Invalid event type")
# If change is None, no change. True => become world_readable/public,
# False => was world_readable/public
if change is None:
logger.debug("No change")
return
# There's been a change to or from being world readable.
is_public = yield self.store.is_room_world_readable_or_publicly_joinable(
room_id
)
logger.debug("Change: %r, is_public: %r", change, is_public)
if change and not is_public:
# If we became world readable but room isn't currently public then
# we ignore the change
return
elif not change and is_public:
# If we stopped being world readable but are still public,
# ignore the change
return
if change:
users_with_profile = yield self.state.get_current_user_in_room(room_id)
for user_id, profile in users_with_profile.iteritems():
yield self._handle_new_user(room_id, user_id, profile)
else:
users = yield self.store.get_users_in_public_due_to_room(room_id)
for user_id in users:
yield self._handle_remove_user(room_id, user_id)
@defer.inlineCallbacks
def _handle_new_user(self, room_id, user_id, profile):
"""Called when we might need to add user to directory
Args:
room_id (str): room_id that user joined or started being public that
user_id (str)
"""
logger.debug("Adding user to dir, %r", user_id)
row = yield self.store.get_user_in_directory(user_id)
if not row:
yield self.store.add_profiles_to_user_dir(room_id, {user_id: profile})
is_public = yield self.store.is_room_world_readable_or_publicly_joinable(
room_id
)
if is_public:
row = yield self.store.get_user_in_public_room(user_id)
if not row:
yield self.store.add_users_to_public_room(room_id, [user_id])
else:
logger.debug("Not adding user to public dir, %r", user_id)
# Now we update users who share rooms with users. We do this by getting
# all the current users in the room and seeing which aren't already
# marked in the database as sharing with `user_id`
users_with_profile = yield self.state.get_current_user_in_room(room_id)
to_insert = set()
to_update = set()
is_appservice = self.store.get_if_app_services_interested_in_user(user_id)
# First, if they're our user then we need to update for every user
if self.is_mine_id(user_id) and not is_appservice:
# Returns a map of other_user_id -> shared_private. We only need
# to update mappings if for users that either don't share a room
# already (aren't in the map) or, if the room is private, those that
# only share a public room.
user_ids_shared = yield self.store.get_users_who_share_room_from_dir(
user_id
)
for other_user_id in users_with_profile:
if user_id == other_user_id:
continue
shared_is_private = user_ids_shared.get(other_user_id)
if shared_is_private is True:
# We've already marked in the database they share a private room
continue
elif shared_is_private is False:
# They already share a public room, so only update if this is
# a private room
if not is_public:
to_update.add((user_id, other_user_id))
elif shared_is_private is None:
# This is the first time they both share a room
to_insert.add((user_id, other_user_id))
# Next we need to update for every local user in the room
for other_user_id in users_with_profile:
if user_id == other_user_id:
continue
is_appservice = self.store.get_if_app_services_interested_in_user(
other_user_id
)
if self.is_mine_id(other_user_id) and not is_appservice:
shared_is_private = yield self.store.get_if_users_share_a_room(
other_user_id, user_id,
)
if shared_is_private is True:
# We've already marked in the database they share a private room
continue
elif shared_is_private is False:
# They already share a public room, so only update if this is
# a private room
if not is_public:
to_update.add((other_user_id, user_id))
elif shared_is_private is None:
# This is the first time they both share a room
to_insert.add((other_user_id, user_id))
if to_insert:
yield self.store.add_users_who_share_room(
room_id, not is_public, to_insert,
)
if to_update:
yield self.store.update_users_who_share_room(
room_id, not is_public, to_update,
)
@defer.inlineCallbacks
def _handle_remove_user(self, room_id, user_id):
"""Called when we might need to remove user to directory
Args:
room_id (str): room_id that user left or stopped being public that
user_id (str)
"""
logger.debug("Maybe removing user %r", user_id)
row = yield self.store.get_user_in_directory(user_id)
update_user_dir = row and row["room_id"] == room_id
row = yield self.store.get_user_in_public_room(user_id)
update_user_in_public = row and row["room_id"] == room_id
if (update_user_in_public or update_user_dir):
# XXX: Make this faster?
rooms = yield self.store.get_rooms_for_user(user_id)
for j_room_id in rooms:
if (not update_user_in_public and not update_user_dir):
break
is_in_room = yield self.store.is_host_joined(
j_room_id, self.server_name,
)
if not is_in_room:
continue
if update_user_dir:
update_user_dir = False
yield self.store.update_user_in_user_dir(user_id, j_room_id)
is_public = yield self.store.is_room_world_readable_or_publicly_joinable(
j_room_id
)
if update_user_in_public and is_public:
yield self.store.update_user_in_public_user_list(user_id, j_room_id)
update_user_in_public = False
if update_user_dir:
yield self.store.remove_from_user_dir(user_id)
elif update_user_in_public:
yield self.store.remove_from_user_in_public_room(user_id)
# Now handle users_who_share_rooms.
# Get a list of user tuples that were in the DB due to this room and
# users (this includes tuples where the other user matches `user_id`)
user_tuples = yield self.store.get_users_in_share_dir_with_room_id(
user_id, room_id,
)
for user_id, other_user_id in user_tuples:
# For each user tuple get a list of rooms that they still share,
# trying to find a private room, and update the entry in the DB
rooms = yield self.store.get_rooms_in_common_for_users(user_id, other_user_id)
# If they dont share a room anymore, remove the mapping
if not rooms:
yield self.store.remove_user_who_share_room(
user_id, other_user_id,
)
continue
found_public_share = None
for j_room_id in rooms:
is_public = yield self.store.is_room_world_readable_or_publicly_joinable(
j_room_id
)
if is_public:
found_public_share = j_room_id
else:
found_public_share = None
yield self.store.update_users_who_share_room(
room_id, not is_public, [(user_id, other_user_id)],
)
break
if found_public_share:
yield self.store.update_users_who_share_room(
room_id, not is_public, [(user_id, other_user_id)],
)
@defer.inlineCallbacks
def _handle_profile_change(self, user_id, room_id, prev_event_id, event_id):
"""Check member event changes for any profile changes and update the
database if there are.
"""
if not prev_event_id or not event_id:
return
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
event = yield self.store.get_event(event_id, allow_none=True)
if not prev_event or not event:
return
if event.membership != Membership.JOIN:
return
prev_name = prev_event.content.get("displayname")
new_name = event.content.get("displayname")
prev_avatar = prev_event.content.get("avatar_url")
new_avatar = event.content.get("avatar_url")
if prev_name != new_name or prev_avatar != new_avatar:
yield self.store.update_profile_in_user_dir(
user_id, new_name, new_avatar, room_id,
)
@defer.inlineCallbacks
def _get_key_change(self, prev_event_id, event_id, key_name, public_value):
"""Given two events check if the `key_name` field in content changed
from not matching `public_value` to doing so.
For example, check if `history_visibility` (`key_name`) changed from
`shared` to `world_readable` (`public_value`).
Returns:
None if the field in the events either both match `public_value`
or if neither do, i.e. there has been no change.
True if it didnt match `public_value` but now does
False if it did match `public_value` but now doesn't
"""
prev_event = None
event = None
if prev_event_id:
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if event_id:
event = yield self.store.get_event(event_id, allow_none=True)
if not event and not prev_event:
logger.debug("Neither event exists: %r %r", prev_event_id, event_id)
defer.returnValue(None)
prev_value = None
value = None
if prev_event:
prev_value = prev_event.content.get(key_name)
if event:
value = event.content.get(key_name)
logger.debug("prev_value: %r -> value: %r", prev_value, value)
if value == public_value and prev_value != public_value:
defer.returnValue(True)
elif value != public_value and prev_value == public_value:
defer.returnValue(False)
else:
defer.returnValue(None)

View File

@ -412,7 +412,7 @@ def set_cors_headers(request):
) )
request.setHeader( request.setHeader(
"Access-Control-Allow-Headers", "Access-Control-Allow-Headers",
"Origin, X-Requested-With, Content-Type, Accept" "Origin, X-Requested-With, Content-Type, Accept, Authorization"
) )

View File

@ -251,7 +251,8 @@ class Notifier(object):
"""Notify any user streams that are interested in this room event""" """Notify any user streams that are interested in this room event"""
# poke any interested application service. # poke any interested application service.
preserve_fn(self.appservice_handler.notify_interested_services)( preserve_fn(self.appservice_handler.notify_interested_services)(
room_stream_id) room_stream_id
)
if self.federation_sender: if self.federation_sender:
preserve_fn(self.federation_sender.notify_new_events)( preserve_fn(self.federation_sender.notify_new_events)(

View File

@ -15,7 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
from .bulk_push_rule_evaluator import evaluator_for_event from .bulk_push_rule_evaluator import BulkPushRuleEvaluator
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -24,11 +24,12 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ActionGenerator: class ActionGenerator(object):
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.bulk_evaluator = BulkPushRuleEvaluator(hs)
# really we want to get all user ids and all profile tags too, # really we want to get all user ids and all profile tags too,
# since we want the actions for each profile tag for every user and # since we want the actions for each profile tag for every user and
# also actions for a client with no profile tag for each user. # also actions for a client with no profile tag for each user.
@ -38,16 +39,11 @@ class ActionGenerator:
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_push_actions_for_event(self, event, context): def handle_push_actions_for_event(self, event, context):
with Measure(self.clock, "evaluator_for_event"):
bulk_evaluator = yield evaluator_for_event(
event, self.hs, self.store, context
)
with Measure(self.clock, "action_for_event_by_user"): with Measure(self.clock, "action_for_event_by_user"):
actions_by_user = yield bulk_evaluator.action_for_event_by_user( actions_by_user = yield self.bulk_evaluator.action_for_event_by_user(
event, context event, context
) )
context.push_actions = [ context.push_actions = [
(uid, actions) for uid, actions in actions_by_user.items() (uid, actions) for uid, actions in actions_by_user.iteritems()
] ]

View File

@ -19,60 +19,83 @@ from twisted.internet import defer
from .push_rule_evaluator import PushRuleEvaluatorForEvent from .push_rule_evaluator import PushRuleEvaluatorForEvent
from synapse.api.constants import EventTypes
from synapse.visibility import filter_events_for_clients_context from synapse.visibility import filter_events_for_clients_context
from synapse.api.constants import EventTypes, Membership
from synapse.util.caches.descriptors import cached
from synapse.util.async import Linearizer
from collections import namedtuple
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
rules_by_room = {}
class BulkPushRuleEvaluator(object):
"""Calculates the outcome of push rules for an event for all users in the
room at once.
"""
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def evaluator_for_event(event, hs, store, context): def _get_rules_for_event(self, event, context):
rules_by_user = yield store.bulk_get_push_rules_for_room( """This gets the rules for all users in the room at the time of the event,
event, context as well as the push rules for the invitee if the event is an invite.
)
Returns:
dict of user_id -> push_rules
"""
room_id = event.room_id
rules_for_room = self._get_rules_for_room(room_id)
rules_by_user = yield rules_for_room.get_rules(event, context)
# if this event is an invite event, we may need to run rules for the user # if this event is an invite event, we may need to run rules for the user
# who's been invited, otherwise they won't get told they've been invited # who's been invited, otherwise they won't get told they've been invited
if event.type == 'm.room.member' and event.content['membership'] == 'invite': if event.type == 'm.room.member' and event.content['membership'] == 'invite':
invited_user = event.state_key invited = event.state_key
if invited_user and hs.is_mine_id(invited_user): if invited and self.hs.is_mine_id(invited):
has_pusher = yield store.user_has_pusher(invited_user) has_pusher = yield self.store.user_has_pusher(invited)
if has_pusher: if has_pusher:
rules_by_user = dict(rules_by_user) rules_by_user = dict(rules_by_user)
rules_by_user[invited_user] = yield store.get_push_rules_for_user( rules_by_user[invited] = yield self.store.get_push_rules_for_user(
invited_user invited
) )
defer.returnValue(BulkPushRuleEvaluator( defer.returnValue(rules_by_user)
event.room_id, rules_by_user, store
))
@cached()
def _get_rules_for_room(self, room_id):
"""Get the current RulesForRoom object for the given room id
class BulkPushRuleEvaluator: Returns:
RulesForRoom
""" """
Runs push rules for all users in a room. # It's important that RulesForRoom gets added to self._get_rules_for_room.cache
This is faster than running PushRuleEvaluator for each user because it # before any lookup methods get called on it as otherwise there may be
fetches all the rules for all the users in one (batched) db query # a race if invalidate_all gets called (which assumes its in the cache)
rather than doing multiple queries per-user. It currently uses return RulesForRoom(self.hs, room_id, self._get_rules_for_room.cache)
the same logic to run the actual rules, but could be optimised further
(see https://matrix.org/jira/browse/SYN-562)
"""
def __init__(self, room_id, rules_by_user, store):
self.room_id = room_id
self.rules_by_user = rules_by_user
self.store = store
@defer.inlineCallbacks @defer.inlineCallbacks
def action_for_event_by_user(self, event, context): def action_for_event_by_user(self, event, context):
"""Given an event and context, evaluate the push rules and return
the results
Returns:
dict of user_id -> action
"""
rules_by_user = yield self._get_rules_for_event(event, context)
actions_by_user = {} actions_by_user = {}
# None of these users can be peeking since this list of users comes # None of these users can be peeking since this list of users comes
# from the set of users in the room, so we know for sure they're all # from the set of users in the room, so we know for sure they're all
# actually in the room. # actually in the room.
user_tuples = [ user_tuples = [(u, False) for u in rules_by_user]
(u, False) for u in self.rules_by_user.keys()
]
filtered_by_user = yield filter_events_for_clients_context( filtered_by_user = yield filter_events_for_clients_context(
self.store, user_tuples, [event], {event.event_id: context} self.store, user_tuples, [event], {event.event_id: context}
@ -86,7 +109,7 @@ class BulkPushRuleEvaluator:
condition_cache = {} condition_cache = {}
for uid, rules in self.rules_by_user.items(): for uid, rules in rules_by_user.iteritems():
display_name = None display_name = None
profile_info = room_members.get(uid) profile_info = room_members.get(uid)
if profile_info: if profile_info:
@ -138,3 +161,240 @@ def _condition_checker(evaluator, conditions, uid, display_name, cache):
return False return False
return True return True
class RulesForRoom(object):
"""Caches push rules for users in a room.
This efficiently handles users joining/leaving the room by not invalidating
the entire cache for the room.
"""
def __init__(self, hs, room_id, rules_for_room_cache):
"""
Args:
hs (HomeServer)
room_id (str)
rules_for_room_cache(Cache): The cache object that caches these
RoomsForUser objects.
"""
self.room_id = room_id
self.is_mine_id = hs.is_mine_id
self.store = hs.get_datastore()
self.linearizer = Linearizer(name="rules_for_room")
self.member_map = {} # event_id -> (user_id, state)
self.rules_by_user = {} # user_id -> rules
# The last state group we updated the caches for. If the state_group of
# a new event comes along, we know that we can just return the cached
# result.
# On invalidation of the rules themselves (if the user changes them),
# we invalidate everything and set state_group to `object()`
self.state_group = object()
# A sequence number to keep track of when we're allowed to update the
# cache. We bump the sequence number when we invalidate the cache. If
# the sequence number changes while we're calculating stuff we should
# not update the cache with it.
self.sequence = 0
# A cache of user_ids that we *know* aren't interesting, e.g. user_ids
# owned by AS's, or remote users, etc. (I.e. users we will never need to
# calculate push for)
# These never need to be invalidated as we will never set up push for
# them.
self.uninteresting_user_set = set()
# We need to be clever on the invalidating caches callbacks, as
# otherwise the invalidation callback holds a reference to the object,
# potentially causing it to leak.
# To get around this we pass a function that on invalidations looks ups
# the RoomsForUser entry in the cache, rather than keeping a reference
# to self around in the callback.
self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
@defer.inlineCallbacks
def get_rules(self, event, context):
"""Given an event context return the rules for all users who are
currently in the room.
"""
state_group = context.state_group
with (yield self.linearizer.queue(())):
if state_group and self.state_group == state_group:
logger.debug("Using cached rules for %r", self.room_id)
defer.returnValue(self.rules_by_user)
ret_rules_by_user = {}
missing_member_event_ids = {}
if state_group and self.state_group == context.prev_group:
# If we have a simple delta then we can reuse most of the previous
# results.
ret_rules_by_user = self.rules_by_user
current_state_ids = context.delta_ids
else:
current_state_ids = context.current_state_ids
logger.debug(
"Looking for member changes in %r %r", state_group, current_state_ids
)
# Loop through to see which member events we've seen and have rules
# for and which we need to fetch
for key in current_state_ids:
typ, user_id = key
if typ != EventTypes.Member:
continue
if user_id in self.uninteresting_user_set:
continue
if not self.is_mine_id(user_id):
self.uninteresting_user_set.add(user_id)
continue
if self.store.get_if_app_services_interested_in_user(user_id):
self.uninteresting_user_set.add(user_id)
continue
event_id = current_state_ids[key]
res = self.member_map.get(event_id, None)
if res:
user_id, state = res
if state == Membership.JOIN:
rules = self.rules_by_user.get(user_id, None)
if rules:
ret_rules_by_user[user_id] = rules
continue
# If a user has left a room we remove their push rule. If they
# joined then we readd it later in _update_rules_with_member_event_ids
ret_rules_by_user.pop(user_id, None)
missing_member_event_ids[user_id] = event_id
if missing_member_event_ids:
# If we have some memebr events we haven't seen, look them up
# and fetch push rules for them if appropriate.
logger.debug("Found new member events %r", missing_member_event_ids)
yield self._update_rules_with_member_event_ids(
ret_rules_by_user, missing_member_event_ids, state_group, event
)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"Returning push rules for %r %r",
self.room_id, ret_rules_by_user.keys(),
)
defer.returnValue(ret_rules_by_user)
@defer.inlineCallbacks
def _update_rules_with_member_event_ids(self, ret_rules_by_user, member_event_ids,
state_group, event):
"""Update the partially filled rules_by_user dict by fetching rules for
any newly joined users in the `member_event_ids` list.
Args:
ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets
updated with any new rules.
member_event_ids (list): List of event ids for membership events that
have happened since the last time we filled rules_by_user
state_group: The state group we are currently computing push rules
for. Used when updating the cache.
"""
sequence = self.sequence
rows = yield self.store._simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids.values(),
retcols=('user_id', 'membership', 'event_id'),
keyvalues={},
batch_size=500,
desc="_get_rules_for_member_event_ids",
)
members = {
row["event_id"]: (row["user_id"], row["membership"])
for row in rows
}
# If the event is a join event then it will be in current state evnts
# map but not in the DB, so we have to explicitly insert it.
if event.type == EventTypes.Member:
for event_id in member_event_ids.itervalues():
if event_id == event.event_id:
members[event_id] = (event.state_key, event.membership)
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Found members %r: %r", self.room_id, members.values())
interested_in_user_ids = set(
user_id for user_id, membership in members.itervalues()
if membership == Membership.JOIN
)
logger.debug("Joined: %r", interested_in_user_ids)
if_users_with_pushers = yield self.store.get_if_users_have_pushers(
interested_in_user_ids,
on_invalidate=self.invalidate_all_cb,
)
user_ids = set(
uid for uid, have_pusher in if_users_with_pushers.iteritems() if have_pusher
)
logger.debug("With pushers: %r", user_ids)
users_with_receipts = yield self.store.get_users_with_read_receipts_in_room(
self.room_id, on_invalidate=self.invalidate_all_cb,
)
logger.debug("With receipts: %r", users_with_receipts)
# any users with pushers must be ours: they have pushers
for uid in users_with_receipts:
if uid in interested_in_user_ids:
user_ids.add(uid)
rules_by_user = yield self.store.bulk_get_push_rules(
user_ids, on_invalidate=self.invalidate_all_cb,
)
ret_rules_by_user.update(
item for item in rules_by_user.iteritems() if item[0] is not None
)
self.update_cache(sequence, members, ret_rules_by_user, state_group)
def invalidate_all(self):
# Note: Don't hand this function directly to an invalidation callback
# as it keeps a reference to self and will stop this instance from being
# GC'd if it gets dropped from the rules_to_user cache. Instead use
# `self.invalidate_all_cb`
logger.debug("Invalidating RulesForRoom for %r", self.room_id)
self.sequence += 1
self.state_group = object()
self.member_map = {}
self.rules_by_user = {}
def update_cache(self, sequence, members, rules_by_user, state_group):
if sequence == self.sequence:
self.member_map.update(members)
self.rules_by_user = rules_by_user
self.state_group = state_group
class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))):
# We rely on _CacheContext implementing __eq__ and __hash__ sensibly,
# which namedtuple does for us (i.e. two _CacheContext are the same if
# their caches and keys match). This is important in particular to
# dedupe when we add callbacks to lru cache nodes, otherwise the number
# of callbacks would grow.
def __call__(self):
rules = self.cache.get(self.room_id, None, update_metrics=False)
if rules:
rules.invalidate_all()

View File

@ -21,7 +21,6 @@ import logging
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from mailer import Mailer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,8 +55,10 @@ class EmailPusher(object):
This shares quite a bit of code with httpusher: it would be good to This shares quite a bit of code with httpusher: it would be good to
factor out the common parts factor out the common parts
""" """
def __init__(self, hs, pusherdict): def __init__(self, hs, pusherdict, mailer):
self.hs = hs self.hs = hs
self.mailer = mailer
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
self.pusher_id = pusherdict['id'] self.pusher_id = pusherdict['id']
@ -73,16 +74,6 @@ class EmailPusher(object):
self.processing = False self.processing = False
if self.hs.config.email_enable_notifs:
if 'data' in pusherdict and 'brand' in pusherdict['data']:
app_name = pusherdict['data']['brand']
else:
app_name = self.hs.config.email_app_name
self.mailer = Mailer(self.hs, app_name)
else:
self.mailer = None
@defer.inlineCallbacks @defer.inlineCallbacks
def on_started(self): def on_started(self):
if self.mailer is not None: if self.mailer is not None:

View File

@ -275,7 +275,7 @@ class HttpPusher(object):
if event.type == 'm.room.member': if event.type == 'm.room.member':
d['notification']['membership'] = event.content['membership'] d['notification']['membership'] = event.content['membership']
d['notification']['user_is_target'] = event.state_key == self.user_id d['notification']['user_is_target'] = event.state_key == self.user_id
if 'content' in event: if not self.hs.config.push_redact_content and 'content' in event:
d['notification']['content'] = event.content d['notification']['content'] = event.content
# We no longer send aliases separately, instead, we send the human # We no longer send aliases separately, instead, we send the human

View File

@ -78,23 +78,17 @@ ALLOWED_ATTRS = {
class Mailer(object): class Mailer(object):
def __init__(self, hs, app_name): def __init__(self, hs, app_name, notif_template_html, notif_template_text):
self.hs = hs self.hs = hs
self.notif_template_html = notif_template_html
self.notif_template_text = notif_template_text
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.macaroon_gen = self.hs.get_macaroon_generator() self.macaroon_gen = self.hs.get_macaroon_generator()
self.state_handler = self.hs.get_state_handler() self.state_handler = self.hs.get_state_handler()
loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir)
self.app_name = app_name self.app_name = app_name
logger.info("Created Mailer for app_name %s" % app_name) logger.info("Created Mailer for app_name %s" % app_name)
env = jinja2.Environment(loader=loader)
env.filters["format_ts"] = format_ts_filter
env.filters["mxc_to_http"] = self.mxc_to_http_filter
self.notif_template_html = env.get_template(
self.hs.config.email_notif_template_html
)
self.notif_template_text = env.get_template(
self.hs.config.email_notif_template_text
)
@defer.inlineCallbacks @defer.inlineCallbacks
def send_notification_mail(self, app_id, user_id, email_address, def send_notification_mail(self, app_id, user_id, email_address,
@ -481,28 +475,6 @@ class Mailer(object):
urllib.urlencode(params), urllib.urlencode(params),
) )
def mxc_to_http_filter(self, value, width, height, resize_method="crop"):
if value[0:6] != "mxc://":
return ""
serverAndMediaId = value[6:]
fragment = None
if '#' in serverAndMediaId:
(serverAndMediaId, fragment) = serverAndMediaId.split('#', 1)
fragment = "#" + fragment
params = {
"width": width,
"height": height,
"method": resize_method,
}
return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
self.hs.config.public_baseurl,
serverAndMediaId,
urllib.urlencode(params),
fragment or "",
)
def safe_markup(raw_html): def safe_markup(raw_html):
return jinja2.Markup(bleach.linkify(bleach.clean( return jinja2.Markup(bleach.linkify(bleach.clean(
@ -543,3 +515,52 @@ def string_ordinal_total(s):
def format_ts_filter(value, format): def format_ts_filter(value, format):
return time.strftime(format, time.localtime(value / 1000)) return time.strftime(format, time.localtime(value / 1000))
def load_jinja2_templates(config):
"""Load the jinja2 email templates from disk
Returns:
(notif_template_html, notif_template_text)
"""
logger.info("loading jinja2")
loader = jinja2.FileSystemLoader(config.email_template_dir)
env = jinja2.Environment(loader=loader)
env.filters["format_ts"] = format_ts_filter
env.filters["mxc_to_http"] = _create_mxc_to_http_filter(config)
notif_template_html = env.get_template(
config.email_notif_template_html
)
notif_template_text = env.get_template(
config.email_notif_template_text
)
return notif_template_html, notif_template_text
def _create_mxc_to_http_filter(config):
def mxc_to_http_filter(value, width, height, resize_method="crop"):
if value[0:6] != "mxc://":
return ""
serverAndMediaId = value[6:]
fragment = None
if '#' in serverAndMediaId:
(serverAndMediaId, fragment) = serverAndMediaId.split('#', 1)
fragment = "#" + fragment
params = {
"width": width,
"height": height,
"method": resize_method,
}
return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
config.public_baseurl,
serverAndMediaId,
urllib.urlencode(params),
fragment or "",
)
return mxc_to_http_filter

View File

@ -26,22 +26,54 @@ logger = logging.getLogger(__name__)
# process works fine) # process works fine)
try: try:
from synapse.push.emailpusher import EmailPusher from synapse.push.emailpusher import EmailPusher
from synapse.push.mailer import Mailer, load_jinja2_templates
except: except:
pass pass
def create_pusher(hs, pusherdict): class PusherFactory(object):
logger.info("trying to create_pusher for %r", pusherdict) def __init__(self, hs):
self.hs = hs
PUSHER_TYPES = { self.pusher_types = {
"http": HttpPusher, "http": HttpPusher,
} }
logger.info("email enable notifs: %r", hs.config.email_enable_notifs) logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
if hs.config.email_enable_notifs: if hs.config.email_enable_notifs:
PUSHER_TYPES["email"] = EmailPusher self.mailers = {} # app_name -> Mailer
templates = load_jinja2_templates(hs.config)
self.notif_template_html, self.notif_template_text = templates
self.pusher_types["email"] = self._create_email_pusher
logger.info("defined email pusher type") logger.info("defined email pusher type")
if pusherdict['kind'] in PUSHER_TYPES: def create_pusher(self, pusherdict):
logger.info("trying to create_pusher for %r", pusherdict)
if pusherdict['kind'] in self.pusher_types:
logger.info("found pusher") logger.info("found pusher")
return PUSHER_TYPES[pusherdict['kind']](hs, pusherdict) return self.pusher_types[pusherdict['kind']](self.hs, pusherdict)
def _create_email_pusher(self, _hs, pusherdict):
app_name = self._app_name_from_pusherdict(pusherdict)
mailer = self.mailers.get(app_name)
if not mailer:
mailer = Mailer(
hs=self.hs,
app_name=app_name,
notif_template_html=self.notif_template_html,
notif_template_text=self.notif_template_text,
)
self.mailers[app_name] = mailer
return EmailPusher(self.hs, pusherdict, mailer)
def _app_name_from_pusherdict(self, pusherdict):
if 'data' in pusherdict and 'brand' in pusherdict['data']:
app_name = pusherdict['data']['brand']
else:
app_name = self.hs.config.email_app_name
return app_name

View File

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
import pusher from .pusher import PusherFactory
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
@ -28,6 +28,7 @@ logger = logging.getLogger(__name__)
class PusherPool: class PusherPool:
def __init__(self, _hs): def __init__(self, _hs):
self.hs = _hs self.hs = _hs
self.pusher_factory = PusherFactory(_hs)
self.start_pushers = _hs.config.start_pushers self.start_pushers = _hs.config.start_pushers
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
@ -48,7 +49,7 @@ class PusherPool:
# will then get pulled out of the database, # will then get pulled out of the database,
# recreated, added and started: this means we have only one # recreated, added and started: this means we have only one
# code path adding pushers. # code path adding pushers.
pusher.create_pusher(self.hs, { self.pusher_factory.create_pusher({
"id": None, "id": None,
"user_name": user_id, "user_name": user_id,
"kind": kind, "kind": kind,
@ -186,7 +187,7 @@ class PusherPool:
logger.info("Starting %d pushers", len(pushers)) logger.info("Starting %d pushers", len(pushers))
for pusherdict in pushers: for pusherdict in pushers:
try: try:
p = pusher.create_pusher(self.hs, pusherdict) p = self.pusher_factory.create_pusher(pusherdict)
except: except:
logger.exception("Couldn't start a pusher: caught Exception") logger.exception("Couldn't start a pusher: caught Exception")
continue continue

View File

@ -16,6 +16,7 @@
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.config.appservice import load_appservices from synapse.config.appservice import load_appservices
from synapse.storage.appservice import _make_exclusive_regex
class SlavedApplicationServiceStore(BaseSlavedStore): class SlavedApplicationServiceStore(BaseSlavedStore):
@ -25,6 +26,7 @@ class SlavedApplicationServiceStore(BaseSlavedStore):
hs.config.server_name, hs.config.server_name,
hs.config.app_service_config_files hs.config.app_service_config_files
) )
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
get_app_service_by_token = DataStore.get_app_service_by_token.__func__ get_app_service_by_token = DataStore.get_app_service_by_token.__func__
get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__ get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__
@ -38,3 +40,6 @@ class SlavedApplicationServiceStore(BaseSlavedStore):
get_appservice_state = DataStore.get_appservice_state.__func__ get_appservice_state = DataStore.get_appservice_state.__func__
set_appservice_last_pos = DataStore.set_appservice_last_pos.__func__ set_appservice_last_pos = DataStore.set_appservice_last_pos.__func__
set_appservice_state = DataStore.set_appservice_state.__func__ set_appservice_state = DataStore.set_appservice_state.__func__
get_if_app_services_interested_in_user = (
DataStore.get_if_app_services_interested_in_user.__func__
)

View File

@ -0,0 +1,48 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage.client_ips import LAST_SEEN_GRANULARITY
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.caches.descriptors import Cache
class SlavedClientIpStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedClientIpStore, self).__init__(db_conn, hs)
self.client_ip_last_seen = Cache(
name="client_ip_last_seen",
keylen=4,
max_entries=50000 * CACHE_SIZE_FACTOR,
)
def insert_client_ip(self, user, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec())
user_id = user.to_string()
key = (user_id, access_token, ip)
try:
last_seen = self.client_ip_last_seen.get(key)
except KeyError:
last_seen = None
# Rate-limited inserts
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
return
self.hs.get_tcp_replication().send_user_ip(
user_id, access_token, ip, user_agent, device_id, now
)

View File

@ -16,6 +16,7 @@
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.storage.end_to_end_keys import EndToEndKeyStore
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -45,6 +46,7 @@ class SlavedDeviceStore(BaseSlavedStore):
_mark_as_sent_devices_by_remote_txn = ( _mark_as_sent_devices_by_remote_txn = (
DataStore._mark_as_sent_devices_by_remote_txn.__func__ DataStore._mark_as_sent_devices_by_remote_txn.__func__
) )
count_e2e_one_time_keys = EndToEndKeyStore.__dict__["count_e2e_one_time_keys"]
def stream_positions(self): def stream_positions(self):
result = super(SlavedDeviceStore, self).stream_positions() result = super(SlavedDeviceStore, self).stream_positions()

View File

@ -108,6 +108,8 @@ class SlavedEventStore(BaseSlavedStore):
get_current_state_ids = ( get_current_state_ids = (
StateStore.__dict__["get_current_state_ids"] StateStore.__dict__["get_current_state_ids"]
) )
get_state_group_delta = StateStore.__dict__["get_state_group_delta"]
_get_joined_hosts_cache = RoomMemberStore.__dict__["_get_joined_hosts_cache"]
has_room_changed_since = DataStore.has_room_changed_since.__func__ has_room_changed_since = DataStore.has_room_changed_since.__func__
get_unread_push_actions_for_user_in_range_for_http = ( get_unread_push_actions_for_user_in_range_for_http = (
@ -151,8 +153,7 @@ class SlavedEventStore(BaseSlavedStore):
get_room_events_stream_for_rooms = ( get_room_events_stream_for_rooms = (
DataStore.get_room_events_stream_for_rooms.__func__ DataStore.get_room_events_stream_for_rooms.__func__
) )
is_host_joined = DataStore.is_host_joined.__func__ is_host_joined = RoomMemberStore.__dict__["is_host_joined"]
_is_host_joined = RoomMemberStore.__dict__["_is_host_joined"]
get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__ get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__
_set_before_and_after = staticmethod(DataStore._set_before_and_after) _set_before_and_after = staticmethod(DataStore._set_before_and_after)

View File

@ -20,6 +20,7 @@ from twisted.internet.protocol import ReconnectingClientFactory
from .commands import ( from .commands import (
FederationAckCommand, UserSyncCommand, RemovePusherCommand, InvalidateCacheCommand, FederationAckCommand, UserSyncCommand, RemovePusherCommand, InvalidateCacheCommand,
UserIpCommand,
) )
from .protocol import ClientReplicationStreamProtocol from .protocol import ClientReplicationStreamProtocol
@ -178,6 +179,12 @@ class ReplicationClientHandler(object):
cmd = InvalidateCacheCommand(cache_func.__name__, keys) cmd = InvalidateCacheCommand(cache_func.__name__, keys)
self.send_command(cmd) self.send_command(cmd)
def send_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
"""Tell the master that the user made a request.
"""
cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
self.send_command(cmd)
def await_sync(self, data): def await_sync(self, data):
"""Returns a deferred that is resolved when we receive a SYNC command """Returns a deferred that is resolved when we receive a SYNC command
with given data. with given data.

View File

@ -304,6 +304,36 @@ class InvalidateCacheCommand(Command):
return " ".join((self.cache_func, json.dumps(self.keys))) return " ".join((self.cache_func, json.dumps(self.keys)))
class UserIpCommand(Command):
"""Sent periodically when a worker sees activity from a client.
Format::
USER_IP <user_id>, <access_token>, <ip>, <device_id>, <last_seen>, <user_agent>
"""
NAME = "USER_IP"
def __init__(self, user_id, access_token, ip, user_agent, device_id, last_seen):
self.user_id = user_id
self.access_token = access_token
self.ip = ip
self.user_agent = user_agent
self.device_id = device_id
self.last_seen = last_seen
@classmethod
def from_line(cls, line):
user_id, access_token, ip, device_id, last_seen, user_agent = line.split(" ", 5)
return cls(user_id, access_token, ip, user_agent, device_id, int(last_seen))
def to_line(self):
return " ".join((
self.user_id, self.access_token, self.ip, self.device_id,
str(self.last_seen), self.user_agent,
))
# Map of command name to command type. # Map of command name to command type.
COMMAND_MAP = { COMMAND_MAP = {
cmd.NAME: cmd cmd.NAME: cmd
@ -320,6 +350,7 @@ COMMAND_MAP = {
SyncCommand, SyncCommand,
RemovePusherCommand, RemovePusherCommand,
InvalidateCacheCommand, InvalidateCacheCommand,
UserIpCommand,
) )
} }
@ -342,5 +373,6 @@ VALID_CLIENT_COMMANDS = (
FederationAckCommand.NAME, FederationAckCommand.NAME,
RemovePusherCommand.NAME, RemovePusherCommand.NAME,
InvalidateCacheCommand.NAME, InvalidateCacheCommand.NAME,
UserIpCommand.NAME,
ErrorCommand.NAME, ErrorCommand.NAME,
) )

View File

@ -406,6 +406,12 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
def on_INVALIDATE_CACHE(self, cmd): def on_INVALIDATE_CACHE(self, cmd):
self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
def on_USER_IP(self, cmd):
self.streamer.on_user_ip(
cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id,
cmd.last_seen,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def subscribe_to_stream(self, stream_name, token): def subscribe_to_stream(self, stream_name, token):
"""Subscribe the remote to a streams. """Subscribe the remote to a streams.

View File

@ -35,6 +35,7 @@ user_sync_counter = metrics.register_counter("user_sync")
federation_ack_counter = metrics.register_counter("federation_ack") federation_ack_counter = metrics.register_counter("federation_ack")
remove_pusher_counter = metrics.register_counter("remove_pusher") remove_pusher_counter = metrics.register_counter("remove_pusher")
invalidate_cache_counter = metrics.register_counter("invalidate_cache") invalidate_cache_counter = metrics.register_counter("invalidate_cache")
user_ip_cache_counter = metrics.register_counter("user_ip_cache")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -67,6 +68,7 @@ class ReplicationStreamer(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
# Current connections. # Current connections.
self.connections = [] self.connections = []
@ -99,7 +101,7 @@ class ReplicationStreamer(object):
if not hs.config.send_federation: if not hs.config.send_federation:
self.federation_sender = hs.get_federation_sender() self.federation_sender = hs.get_federation_sender()
hs.get_notifier().add_replication_callback(self.on_notifier_poke) self.notifier.add_replication_callback(self.on_notifier_poke)
# Keeps track of whether we are currently checking for updates # Keeps track of whether we are currently checking for updates
self.is_looping = False self.is_looping = False
@ -237,6 +239,15 @@ class ReplicationStreamer(object):
invalidate_cache_counter.inc() invalidate_cache_counter.inc()
getattr(self.store, cache_func).invalidate(tuple(keys)) getattr(self.store, cache_func).invalidate(tuple(keys))
@measure_func("repl.on_user_ip")
def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
"""The client saw a user request
"""
user_ip_cache_counter.inc()
self.store.insert_client_ip(
user_id, access_token, ip, user_agent, device_id, last_seen,
)
def send_sync_to_all_connections(self, data): def send_sync_to_all_connections(self, data):
"""Sends a SYNC command to all clients. """Sends a SYNC command to all clients.

View File

@ -112,6 +112,12 @@ AccountDataStreamRow = namedtuple("AccountDataStream", (
"data_type", # str "data_type", # str
"data", # dict "data", # dict
)) ))
CurrentStateDeltaStreamRow = namedtuple("CurrentStateDeltaStream", (
"room_id", # str
"type", # str
"state_key", # str
"event_id", # str, optional
))
class Stream(object): class Stream(object):
@ -443,6 +449,21 @@ class AccountDataStream(Stream):
defer.returnValue(results) defer.returnValue(results)
class CurrentStateDeltaStream(Stream):
"""Current state for a room was changed
"""
NAME = "current_state_deltas"
ROW_TYPE = CurrentStateDeltaStreamRow
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_max_current_state_delta_stream_id
self.update_function = store.get_all_updated_current_state_deltas
super(CurrentStateDeltaStream, self).__init__(hs)
STREAMS_MAP = { STREAMS_MAP = {
stream.NAME: stream stream.NAME: stream
for stream in ( for stream in (
@ -460,5 +481,6 @@ STREAMS_MAP = {
FederationStream, FederationStream,
TagAccountDataStream, TagAccountDataStream,
AccountDataStream, AccountDataStream,
CurrentStateDeltaStream,
) )
} }

View File

@ -51,6 +51,7 @@ from synapse.rest.client.v2_alpha import (
devices, devices,
thirdparty, thirdparty,
sendtodevice, sendtodevice,
user_directory,
) )
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
@ -100,3 +101,4 @@ class ClientRestResource(JsonResource):
devices.register_servlets(hs, client_resource) devices.register_servlets(hs, client_resource)
thirdparty.register_servlets(hs, client_resource) thirdparty.register_servlets(hs, client_resource)
sendtodevice.register_servlets(hs, client_resource) sendtodevice.register_servlets(hs, client_resource)
user_directory.register_servlets(hs, client_resource)

View File

@ -15,8 +15,9 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError
from synapse.types import UserID from synapse.types import UserID, create_requester
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
@ -157,6 +158,142 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
class ShutdownRoomRestServlet(ClientV1RestServlet):
"""Shuts down a room by removing all local users from the room and blocking
all future invites and joins to the room. Any local aliases will be repointed
to a new room created by `new_room_user_id` and kicked users will be auto
joined to the new room.
"""
PATTERNS = client_path_patterns("/admin/shutdown_room/(?P<room_id>[^/]+)")
DEFAULT_MESSAGE = (
"Sharing illegal content on this server is not permitted and rooms in"
" violatation will be blocked."
)
def __init__(self, hs):
super(ShutdownRoomRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
self.state = hs.get_state_handler()
@defer.inlineCallbacks
def on_POST(self, request, room_id):
requester = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(requester.user)
if not is_admin:
raise AuthError(403, "You are not a server admin")
content = parse_json_object_from_request(request)
new_room_user_id = content.get("new_room_user_id")
if not new_room_user_id:
raise SynapseError(400, "Please provide field `new_room_user_id`")
room_creator_requester = create_requester(new_room_user_id)
message = content.get("message", self.DEFAULT_MESSAGE)
room_name = content.get("room_name", "Content Violation Notification")
info = yield self.handlers.room_creation_handler.create_room(
room_creator_requester,
config={
"preset": "public_chat",
"name": room_name,
"power_level_content_override": {
"users_default": -10,
},
},
ratelimit=False,
)
new_room_id = info["room_id"]
msg_handler = self.handlers.message_handler
yield msg_handler.create_and_send_nonmember_event(
room_creator_requester,
{
"type": "m.room.message",
"content": {"body": message, "msgtype": "m.text"},
"room_id": new_room_id,
"sender": new_room_user_id,
},
ratelimit=False,
)
requester_user_id = requester.user.to_string()
logger.info("Shutting down room %r", room_id)
yield self.store.block_room(room_id, requester_user_id)
users = yield self.state.get_current_user_in_room(room_id)
kicked_users = []
for user_id in users:
if not self.hs.is_mine_id(user_id):
continue
logger.info("Kicking %r from %r...", user_id, room_id)
target_requester = create_requester(user_id)
yield self.handlers.room_member_handler.update_membership(
requester=target_requester,
target=target_requester.user,
room_id=room_id,
action=Membership.LEAVE,
content={},
ratelimit=False
)
yield self.handlers.room_member_handler.forget(target_requester.user, room_id)
yield self.handlers.room_member_handler.update_membership(
requester=target_requester,
target=target_requester.user,
room_id=new_room_id,
action=Membership.JOIN,
content={},
ratelimit=False
)
kicked_users.append(user_id)
aliases_for_room = yield self.store.get_aliases_for_room(room_id)
yield self.store.update_aliases_for_room(
room_id, new_room_id, requester_user_id
)
defer.returnValue((200, {
"kicked_users": kicked_users,
"local_aliases": aliases_for_room,
"new_room_id": new_room_id,
}))
class QuarantineMediaInRoom(ClientV1RestServlet):
"""Quarantines all media in a room so that no one can download it via
this server.
"""
PATTERNS = client_path_patterns("/admin/quarantine_media/(?P<room_id>[^/]+)")
def __init__(self, hs):
super(QuarantineMediaInRoom, self).__init__(hs)
self.store = hs.get_datastore()
@defer.inlineCallbacks
def on_POST(self, request, room_id):
requester = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(requester.user)
if not is_admin:
raise AuthError(403, "You are not a server admin")
num_quarantined = yield self.store.quarantine_media_ids_in_room(
room_id, requester.user.to_string(),
)
defer.returnValue((200, {"num_quarantined": num_quarantined}))
class ResetPasswordRestServlet(ClientV1RestServlet): class ResetPasswordRestServlet(ClientV1RestServlet):
"""Post request to allow an administrator reset password for a user. """Post request to allow an administrator reset password for a user.
This need a user have a administrator access in Synapse. This need a user have a administrator access in Synapse.
@ -353,3 +490,5 @@ def register_servlets(hs, http_server):
ResetPasswordRestServlet(hs).register(http_server) ResetPasswordRestServlet(hs).register(http_server)
GetUsersPaginatedRestServlet(hs).register(http_server) GetUsersPaginatedRestServlet(hs).register(http_server)
SearchUsersRestServlet(hs).register(http_server) SearchUsersRestServlet(hs).register(http_server)
ShutdownRoomRestServlet(hs).register(http_server)
QuarantineMediaInRoom(hs).register(http_server)

View File

@ -192,6 +192,7 @@ class SyncRestServlet(RestServlet):
"invite": invited, "invite": invited,
"leave": archived, "leave": archived,
}, },
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
"next_batch": sync_result.next_batch.to_string(), "next_batch": sync_result.next_batch.to_string(),
} }

View File

@ -0,0 +1,79 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
class UserDirectorySearchRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/user_directory/search$")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(UserDirectorySearchRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.user_directory_handler = hs.get_user_directory_handler()
@defer.inlineCallbacks
def on_POST(self, request):
"""Searches for users in directory
Returns:
dict of the form::
{
"limited": <bool>, # whether there were more results or not
"results": [ # Ordered by best match first
{
"user_id": <user_id>,
"display_name": <display_name>,
"avatar_url": <avatar_url>
}
]
}
"""
requester = yield self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
limit = body.get("limit", 10)
limit = min(limit, 50)
try:
search_term = body["search_term"]
except:
raise SynapseError(400, "`search_term` is required field")
results = yield self.user_directory_handler.search_users(
user_id, search_term, limit,
)
defer.returnValue((200, results))
def register_servlets(hs, http_server):
UserDirectorySearchRestServlet(hs).register(http_server)

View File

@ -66,13 +66,18 @@ class DownloadResource(Resource):
@defer.inlineCallbacks @defer.inlineCallbacks
def _respond_local_file(self, request, media_id, name): def _respond_local_file(self, request, media_id, name):
media_info = yield self.store.get_local_media(media_id) media_info = yield self.store.get_local_media(media_id)
if not media_info: if not media_info or media_info["quarantined_by"]:
respond_404(request) respond_404(request)
return return
media_type = media_info["media_type"] media_type = media_info["media_type"]
media_length = media_info["media_length"] media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"] upload_name = name if name else media_info["upload_name"]
if media_info["url_cache"]:
# TODO: Check the file still exists, if it doesn't we can redownload
# it from the url `media_info["url_cache"]`
file_path = self.filepaths.url_cache_filepath(media_id)
else:
file_path = self.filepaths.local_media_filepath(media_id) file_path = self.filepaths.local_media_filepath(media_id)
yield respond_with_file( yield respond_with_file(

View File

@ -71,3 +71,21 @@ class MediaFilePaths(object):
self.base_path, "remote_thumbnail", server_name, self.base_path, "remote_thumbnail", server_name,
file_id[0:2], file_id[2:4], file_id[4:], file_id[0:2], file_id[2:4], file_id[4:],
) )
def url_cache_filepath(self, media_id):
return os.path.join(
self.base_path, "url_cache",
media_id[0:2], media_id[2:4], media_id[4:]
)
def url_cache_thumbnail(self, media_id, width, height, content_type,
method):
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (
width, height, top_level_type, sub_type, method
)
return os.path.join(
self.base_path, "url_cache_thumbnails",
media_id[0:2], media_id[2:4], media_id[4:],
file_name
)

View File

@ -135,6 +135,8 @@ class MediaRepository(object):
media_info = yield self._download_remote_file( media_info = yield self._download_remote_file(
server_name, media_id server_name, media_id
) )
elif media_info["quarantined_by"]:
raise NotFoundError()
else: else:
self.recently_accessed_remotes.add((server_name, media_id)) self.recently_accessed_remotes.add((server_name, media_id))
yield self.store.update_cached_last_access_time( yield self.store.update_cached_last_access_time(
@ -184,6 +186,7 @@ class MediaRepository(object):
raise raise
except NotRetryingDestination: except NotRetryingDestination:
logger.warn("Not retrying destination %r", server_name) logger.warn("Not retrying destination %r", server_name)
raise SynapseError(502, "Failed to fetch remote media")
except Exception: except Exception:
logger.exception("Failed to fetch remote media %s/%s", logger.exception("Failed to fetch remote media %s/%s",
server_name, media_id) server_name, media_id)
@ -323,13 +326,17 @@ class MediaRepository(object):
defer.returnValue(t_path) defer.returnValue(t_path)
@defer.inlineCallbacks @defer.inlineCallbacks
def _generate_local_thumbnails(self, media_id, media_info): def _generate_local_thumbnails(self, media_id, media_info, url_cache=False):
media_type = media_info["media_type"] media_type = media_info["media_type"]
requirements = self._get_thumbnail_requirements(media_type) requirements = self._get_thumbnail_requirements(media_type)
if not requirements: if not requirements:
return return
if url_cache:
input_path = self.filepaths.url_cache_filepath(media_id)
else:
input_path = self.filepaths.local_media_filepath(media_id) input_path = self.filepaths.local_media_filepath(media_id)
thumbnailer = Thumbnailer(input_path) thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width m_width = thumbnailer.width
m_height = thumbnailer.height m_height = thumbnailer.height
@ -357,6 +364,11 @@ class MediaRepository(object):
for t_width, t_height, t_type in scales: for t_width, t_height, t_type in scales:
t_method = "scale" t_method = "scale"
if url_cache:
t_path = self.filepaths.url_cache_thumbnail(
media_id, t_width, t_height, t_type, t_method
)
else:
t_path = self.filepaths.local_media_thumbnail( t_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method media_id, t_width, t_height, t_type, t_method
) )
@ -374,6 +386,11 @@ class MediaRepository(object):
# thumbnail. # thumbnail.
continue continue
t_method = "crop" t_method = "crop"
if url_cache:
t_path = self.filepaths.url_cache_thumbnail(
media_id, t_width, t_height, t_type, t_method
)
else:
t_path = self.filepaths.local_media_thumbnail( t_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method media_id, t_width, t_height, t_type, t_method
) )

View File

@ -164,7 +164,7 @@ class PreviewUrlResource(Resource):
if _is_media(media_info['media_type']): if _is_media(media_info['media_type']):
dims = yield self.media_repo._generate_local_thumbnails( dims = yield self.media_repo._generate_local_thumbnails(
media_info['filesystem_id'], media_info media_info['filesystem_id'], media_info, url_cache=True,
) )
og = { og = {
@ -210,7 +210,7 @@ class PreviewUrlResource(Resource):
if _is_media(image_info['media_type']): if _is_media(image_info['media_type']):
# TODO: make sure we don't choke on white-on-transparent images # TODO: make sure we don't choke on white-on-transparent images
dims = yield self.media_repo._generate_local_thumbnails( dims = yield self.media_repo._generate_local_thumbnails(
image_info['filesystem_id'], image_info image_info['filesystem_id'], image_info, url_cache=True,
) )
if dims: if dims:
og["og:image:width"] = dims['width'] og["og:image:width"] = dims['width']
@ -256,7 +256,7 @@ class PreviewUrlResource(Resource):
# XXX: horrible duplication with base_resource's _download_remote_file() # XXX: horrible duplication with base_resource's _download_remote_file()
file_id = random_string(24) file_id = random_string(24)
fname = self.filepaths.local_media_filepath(file_id) fname = self.filepaths.url_cache_filepath(file_id)
self.media_repo._makedirs(fname) self.media_repo._makedirs(fname)
try: try:
@ -303,6 +303,7 @@ class PreviewUrlResource(Resource):
upload_name=download_name, upload_name=download_name,
media_length=length, media_length=length,
user_id=user, user_id=user,
url_cache=url,
) )
except Exception as e: except Exception as e:
@ -434,6 +435,8 @@ def _calc_og(tree, media_uri):
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE) for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
) )
og['og:description'] = summarize_paragraphs(text_nodes) og['og:description'] = summarize_paragraphs(text_nodes)
else:
og['og:description'] = summarize_paragraphs([og['og:description']])
# TODO: delete the url downloads to stop diskfilling, # TODO: delete the url downloads to stop diskfilling,
# as we only ever cared about its OG # as we only ever cared about its OG

View File

@ -81,7 +81,7 @@ class ThumbnailResource(Resource):
method, m_type): method, m_type):
media_info = yield self.store.get_local_media(media_id) media_info = yield self.store.get_local_media(media_id)
if not media_info: if not media_info or media_info["quarantined_by"]:
respond_404(request) respond_404(request)
return return
@ -101,6 +101,13 @@ class ThumbnailResource(Resource):
t_type = thumbnail_info["thumbnail_type"] t_type = thumbnail_info["thumbnail_type"]
t_method = thumbnail_info["thumbnail_method"] t_method = thumbnail_info["thumbnail_method"]
if media_info["url_cache"]:
# TODO: Check the file still exists, if it doesn't we can redownload
# it from the url `media_info["url_cache"]`
file_path = self.filepaths.url_cache_thumbnail(
media_id, t_width, t_height, t_type, t_method,
)
else:
file_path = self.filepaths.local_media_thumbnail( file_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method, media_id, t_width, t_height, t_type, t_method,
) )
@ -117,7 +124,7 @@ class ThumbnailResource(Resource):
desired_type): desired_type):
media_info = yield self.store.get_local_media(media_id) media_info = yield self.store.get_local_media(media_id)
if not media_info: if not media_info or media_info["quarantined_by"]:
respond_404(request) respond_404(request)
return return
@ -134,8 +141,17 @@ class ThumbnailResource(Resource):
t_type = info["thumbnail_type"] == desired_type t_type = info["thumbnail_type"] == desired_type
if t_w and t_h and t_method and t_type: if t_w and t_h and t_method and t_type:
if media_info["url_cache"]:
# TODO: Check the file still exists, if it doesn't we can redownload
# it from the url `media_info["url_cache"]`
file_path = self.filepaths.url_cache_thumbnail(
media_id, desired_width, desired_height, desired_type,
desired_method,
)
else:
file_path = self.filepaths.local_media_thumbnail( file_path = self.filepaths.local_media_thumbnail(
media_id, desired_width, desired_height, desired_type, desired_method, media_id, desired_width, desired_height, desired_type,
desired_method,
) )
yield respond_with_file(request, desired_type, file_path) yield respond_with_file(request, desired_type, file_path)
return return

View File

@ -49,9 +49,11 @@ from synapse.handlers.events import EventHandler, EventStreamHandler
from synapse.handlers.initial_sync import InitialSyncHandler from synapse.handlers.initial_sync import InitialSyncHandler
from synapse.handlers.receipts import ReceiptsHandler from synapse.handlers.receipts import ReceiptsHandler
from synapse.handlers.read_marker import ReadMarkerHandler from synapse.handlers.read_marker import ReadMarkerHandler
from synapse.handlers.user_directory import UserDirectoyHandler
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.notifier import Notifier from synapse.notifier import Notifier
from synapse.push.action_generator import ActionGenerator
from synapse.push.pusherpool import PusherPool from synapse.push.pusherpool import PusherPool
from synapse.rest.media.v1.media_repository import MediaRepository from synapse.rest.media.v1.media_repository import MediaRepository
from synapse.state import StateHandler from synapse.state import StateHandler
@ -135,6 +137,8 @@ class HomeServer(object):
'macaroon_generator', 'macaroon_generator',
'tcp_replication', 'tcp_replication',
'read_marker_handler', 'read_marker_handler',
'action_generator',
'user_directory_handler',
] ]
def __init__(self, hostname, **kwargs): def __init__(self, hostname, **kwargs):
@ -299,6 +303,12 @@ class HomeServer(object):
def build_tcp_replication(self): def build_tcp_replication(self):
raise NotImplementedError() raise NotImplementedError()
def build_action_generator(self):
return ActionGenerator(self)
def build_user_directory_handler(self):
return UserDirectoyHandler(self)
def remove_pusher(self, app_id, push_key, user_id): def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)

View File

@ -24,13 +24,13 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from synapse.util.caches import CACHE_SIZE_FACTOR
from collections import namedtuple from collections import namedtuple
from frozendict import frozendict from frozendict import frozendict
import logging import logging
import hashlib import hashlib
import os
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -38,9 +38,6 @@ logger = logging.getLogger(__name__)
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
SIZE_OF_CACHE = int(100000 * CACHE_SIZE_FACTOR) SIZE_OF_CACHE = int(100000 * CACHE_SIZE_FACTOR)
EVICTION_TIMEOUT_SECONDS = 60 * 60 EVICTION_TIMEOUT_SECONDS = 60 * 60
@ -170,9 +167,7 @@ class StateHandler(object):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_user_in_room") logger.debug("calling resolve_state_groups from get_current_user_in_room")
entry = yield self.resolve_state_groups(room_id, latest_event_ids) entry = yield self.resolve_state_groups(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state( joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
room_id, entry.state_id, entry.state
)
defer.returnValue(joined_users) defer.returnValue(joined_users)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -181,9 +176,7 @@ class StateHandler(object):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_hosts_in_room") logger.debug("calling resolve_state_groups from get_current_hosts_in_room")
entry = yield self.resolve_state_groups(room_id, latest_event_ids) entry = yield self.resolve_state_groups(room_id, latest_event_ids)
joined_hosts = yield self.store.get_joined_hosts( joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
room_id, entry.state_id, entry.state
)
defer.returnValue(joined_hosts) defer.returnValue(joined_hosts)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -195,12 +188,12 @@ class StateHandler(object):
Returns: Returns:
synapse.events.snapshot.EventContext: synapse.events.snapshot.EventContext:
""" """
context = EventContext()
if event.internal_metadata.is_outlier(): if event.internal_metadata.is_outlier():
# If this is an outlier, then we know it shouldn't have any current # If this is an outlier, then we know it shouldn't have any current
# state. Certainly store.get_current_state won't return any, and # state. Certainly store.get_current_state won't return any, and
# persisting the event won't store the state group. # persisting the event won't store the state group.
context = EventContext()
if old_state: if old_state:
context.prev_state_ids = { context.prev_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state (s.type, s.state_key): s.event_id for s in old_state
@ -219,6 +212,7 @@ class StateHandler(object):
defer.returnValue(context) defer.returnValue(context)
if old_state: if old_state:
context = EventContext()
context.prev_state_ids = { context.prev_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state (s.type, s.state_key): s.event_id for s in old_state
} }
@ -239,19 +233,13 @@ class StateHandler(object):
defer.returnValue(context) defer.returnValue(context)
logger.debug("calling resolve_state_groups from compute_event_context") logger.debug("calling resolve_state_groups from compute_event_context")
if event.is_state():
entry = yield self.resolve_state_groups(
event.room_id, [e for e, _ in event.prev_events],
event_type=event.type,
state_key=event.state_key,
)
else:
entry = yield self.resolve_state_groups( entry = yield self.resolve_state_groups(
event.room_id, [e for e, _ in event.prev_events], event.room_id, [e for e, _ in event.prev_events],
) )
curr_state = entry.state curr_state = entry.state
context = EventContext()
context.prev_state_ids = curr_state context.prev_state_ids = curr_state
if event.is_state(): if event.is_state():
context.state_group = self.store.get_next_state_group() context.state_group = self.store.get_next_state_group()
@ -264,10 +252,14 @@ class StateHandler(object):
context.current_state_ids = dict(context.prev_state_ids) context.current_state_ids = dict(context.prev_state_ids)
context.current_state_ids[key] = event.event_id context.current_state_ids[key] = event.event_id
if entry.state_group:
context.prev_group = entry.state_group
context.delta_ids = {
key: event.event_id
}
elif entry.prev_group:
context.prev_group = entry.prev_group context.prev_group = entry.prev_group
context.delta_ids = entry.delta_ids context.delta_ids = dict(entry.delta_ids)
if context.delta_ids is not None:
context.delta_ids = dict(context.delta_ids)
context.delta_ids[key] = event.event_id context.delta_ids[key] = event.event_id
else: else:
if entry.state_group is None: if entry.state_group is None:
@ -284,7 +276,7 @@ class StateHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def resolve_state_groups(self, room_id, event_ids, event_type=None, state_key=""): def resolve_state_groups(self, room_id, event_ids):
""" Given a list of event_ids this method fetches the state at each """ Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them. event, resolves conflicts between them and returns them.
@ -309,11 +301,13 @@ class StateHandler(object):
if len(group_names) == 1: if len(group_names) == 1:
name, state_list = state_groups_ids.items().pop() name, state_list = state_groups_ids.items().pop()
prev_group, delta_ids = yield self.store.get_state_group_delta(name)
defer.returnValue(_StateCacheEntry( defer.returnValue(_StateCacheEntry(
state=state_list, state=state_list,
state_group=name, state_group=name,
prev_group=name, prev_group=prev_group,
delta_ids={}, delta_ids=delta_ids,
)) ))
with (yield self.resolve_linearizer.queue(group_names)): with (yield self.resolve_linearizer.queue(group_names)):
@ -357,20 +351,18 @@ class StateHandler(object):
if new_state_event_ids == frozenset(e_id for e_id in events): if new_state_event_ids == frozenset(e_id for e_id in events):
state_group = sg state_group = sg
break break
if state_group is None:
# Worker instances don't have access to this method, but we want # TODO: We want to create a state group for this set of events, to
# to set the state_group on the main instance to increase cache # increase cache hits, but we need to make sure that it doesn't
# hits. # end up as a prev_group without being added to the database
if hasattr(self.store, "get_next_state_group"):
state_group = self.store.get_next_state_group()
prev_group = None prev_group = None
delta_ids = None delta_ids = None
for old_group, old_ids in state_groups_ids.items(): for old_group, old_ids in state_groups_ids.iteritems():
if not set(new_state.iterkeys()) - set(old_ids.iterkeys()): if not set(new_state) - set(old_ids):
n_delta_ids = { n_delta_ids = {
k: v k: v
for k, v in new_state.items() for k, v in new_state.iteritems()
if old_ids.get(k) != v if old_ids.get(k) != v
} }
if not delta_ids or len(n_delta_ids) < len(delta_ids): if not delta_ids or len(n_delta_ids) < len(delta_ids):

View File

@ -49,6 +49,7 @@ from .tags import TagsStore
from .account_data import AccountDataStore from .account_data import AccountDataStore
from .openid import OpenIdStore from .openid import OpenIdStore
from .client_ips import ClientIpStore from .client_ips import ClientIpStore
from .user_directory import UserDirectoryStore
from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
from .engines import PostgresEngine from .engines import PostgresEngine
@ -86,6 +87,7 @@ class DataStore(RoomMemberStore, RoomStore,
ClientIpStore, ClientIpStore,
DeviceStore, DeviceStore,
DeviceInboxStore, DeviceInboxStore,
UserDirectoryStore,
): ):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
@ -221,11 +223,24 @@ class DataStore(RoomMemberStore, RoomStore,
"DeviceListFederationStreamChangeCache", device_list_max, "DeviceListFederationStreamChangeCache", device_list_max,
) )
curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
db_conn, "current_state_delta_stream",
entity_column="room_id",
stream_column="stream_id",
max_value=events_max, # As we share the stream id with events token
limit=1000,
)
self._curr_state_delta_stream_cache = StreamChangeCache(
"_curr_state_delta_stream_cache", min_curr_state_delta_id,
prefilled_cache=curr_state_delta_prefill,
)
cur = LoggingTransaction( cur = LoggingTransaction(
db_conn.cursor(), db_conn.cursor(),
name="_find_stream_orderings_for_times_txn", name="_find_stream_orderings_for_times_txn",
database_engine=self.database_engine, database_engine=self.database_engine,
after_callbacks=[] after_callbacks=[],
final_callbacks=[],
) )
self._find_stream_orderings_for_times_txn(cur) self._find_stream_orderings_for_times_txn(cur)
cur.close() cur.close()
@ -289,16 +304,6 @@ class DataStore(RoomMemberStore, RoomStore,
ret = yield self.runInteraction("count_users", _count_users) ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret) defer.returnValue(ret)
def get_user_ip_and_agents(self, user):
return self._simple_select_list(
table="user_ips",
keyvalues={"user_id": user.to_string()},
retcols=[
"access_token", "ip", "user_agent", "last_seen"
],
desc="get_user_ip_and_agents",
)
def get_users(self): def get_users(self):
"""Function to reterive a list of users in users table. """Function to reterive a list of users in users table.

View File

@ -16,6 +16,7 @@ import logging
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache from synapse.util.caches.descriptors import Cache
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
@ -27,10 +28,6 @@ from twisted.internet import defer
import sys import sys
import time import time
import threading import threading
import os
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -52,13 +49,17 @@ class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object """An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute() passed to the constructor. Adds logging and metrics to the .execute()
method.""" method."""
__slots__ = ["txn", "name", "database_engine", "after_callbacks"] __slots__ = [
"txn", "name", "database_engine", "after_callbacks", "final_callbacks",
]
def __init__(self, txn, name, database_engine, after_callbacks): def __init__(self, txn, name, database_engine, after_callbacks,
final_callbacks):
object.__setattr__(self, "txn", txn) object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name) object.__setattr__(self, "name", name)
object.__setattr__(self, "database_engine", database_engine) object.__setattr__(self, "database_engine", database_engine)
object.__setattr__(self, "after_callbacks", after_callbacks) object.__setattr__(self, "after_callbacks", after_callbacks)
object.__setattr__(self, "final_callbacks", final_callbacks)
def call_after(self, callback, *args, **kwargs): def call_after(self, callback, *args, **kwargs):
"""Call the given callback on the main twisted thread after the """Call the given callback on the main twisted thread after the
@ -67,6 +68,9 @@ class LoggingTransaction(object):
""" """
self.after_callbacks.append((callback, args, kwargs)) self.after_callbacks.append((callback, args, kwargs))
def call_finally(self, callback, *args, **kwargs):
self.final_callbacks.append((callback, args, kwargs))
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.txn, name) return getattr(self.txn, name)
@ -217,8 +221,8 @@ class SQLBaseStore(object):
self._clock.looping_call(loop, 10000) self._clock.looping_call(loop, 10000)
def _new_transaction(self, conn, desc, after_callbacks, logging_context, def _new_transaction(self, conn, desc, after_callbacks, final_callbacks,
func, *args, **kwargs): logging_context, func, *args, **kwargs):
start = time.time() * 1000 start = time.time() * 1000
txn_id = self._TXN_ID txn_id = self._TXN_ID
@ -237,7 +241,8 @@ class SQLBaseStore(object):
try: try:
txn = conn.cursor() txn = conn.cursor()
txn = LoggingTransaction( txn = LoggingTransaction(
txn, name, self.database_engine, after_callbacks txn, name, self.database_engine, after_callbacks,
final_callbacks,
) )
r = func(txn, *args, **kwargs) r = func(txn, *args, **kwargs)
conn.commit() conn.commit()
@ -298,6 +303,7 @@ class SQLBaseStore(object):
start_time = time.time() * 1000 start_time = time.time() * 1000
after_callbacks = [] after_callbacks = []
final_callbacks = []
def inner_func(conn, *args, **kwargs): def inner_func(conn, *args, **kwargs):
with LoggingContext("runInteraction") as context: with LoggingContext("runInteraction") as context:
@ -309,7 +315,7 @@ class SQLBaseStore(object):
current_context.copy_to(context) current_context.copy_to(context)
return self._new_transaction( return self._new_transaction(
conn, desc, after_callbacks, current_context, conn, desc, after_callbacks, final_callbacks, current_context,
func, *args, **kwargs func, *args, **kwargs
) )
@ -318,9 +324,13 @@ class SQLBaseStore(object):
result = yield self._db_pool.runWithConnection( result = yield self._db_pool.runWithConnection(
inner_func, *args, **kwargs inner_func, *args, **kwargs
) )
finally:
for after_callback, after_args, after_kwargs in after_callbacks: for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs) after_callback(*after_args, **after_kwargs)
finally:
for after_callback, after_args, after_kwargs in final_callbacks:
after_callback(*after_args, **after_kwargs)
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -425,6 +435,11 @@ class SQLBaseStore(object):
txn.execute(sql, vals) txn.execute(sql, vals)
def _simple_insert_many(self, table, values, desc):
return self.runInteraction(
desc, self._simple_insert_many_txn, table, values
)
@staticmethod @staticmethod
def _simple_insert_many_txn(txn, table, values): def _simple_insert_many_txn(txn, table, values):
if not values: if not values:
@ -936,7 +951,7 @@ class SQLBaseStore(object):
# __exit__ called after the transaction finishes. # __exit__ called after the transaction finishes.
ctx = self._cache_id_gen.get_next() ctx = self._cache_id_gen.get_next()
stream_id = ctx.__enter__() stream_id = ctx.__enter__()
txn.call_after(ctx.__exit__, None, None, None) txn.call_finally(ctx.__exit__, None, None, None)
txn.call_after(self.hs.get_notifier().on_new_replication_data) txn.call_after(self.hs.get_notifier().on_new_replication_data)
self._simple_insert_txn( self._simple_insert_txn(

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import re
import simplejson as json import simplejson as json
from twisted.internet import defer from twisted.internet import defer
@ -26,6 +27,25 @@ from ._base import SQLBaseStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _make_exclusive_regex(services_cache):
# We precompie a regex constructed from all the regexes that the AS's
# have registered for exclusive users.
exclusive_user_regexes = [
regex.pattern
for service in services_cache
for regex in service.get_exlusive_user_regexes()
]
if exclusive_user_regexes:
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
exclusive_user_regex = re.compile(exclusive_user_regex)
else:
# We handle this case specially otherwise the constructed regex
# will always match
exclusive_user_regex = None
return exclusive_user_regex
class ApplicationServiceStore(SQLBaseStore): class ApplicationServiceStore(SQLBaseStore):
def __init__(self, hs): def __init__(self, hs):
@ -35,16 +55,17 @@ class ApplicationServiceStore(SQLBaseStore):
hs.hostname, hs.hostname,
hs.config.app_service_config_files hs.config.app_service_config_files
) )
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
def get_app_services(self): def get_app_services(self):
return self.services_cache return self.services_cache
def get_if_app_services_interested_in_user(self, user_id): def get_if_app_services_interested_in_user(self, user_id):
"""Check if the user is one associated with an app service """Check if the user is one associated with an app service (exclusively)
""" """
for service in self.services_cache: if self.exclusive_user_regex:
if service.is_interested_in_user(user_id): return bool(self.exclusive_user_regex.match(user_id))
return True else:
return False return False
def get_app_service_by_user_id(self, user_id): def get_app_service_by_user_id(self, user_id):

View File

@ -15,11 +15,14 @@
import logging import logging
from twisted.internet import defer from twisted.internet import defer, reactor
from ._base import Cache from ._base import Cache
from . import background_updates from . import background_updates
from synapse.util.caches import CACHE_SIZE_FACTOR
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Number of msec of granularity to store the user IP 'last seen' time. Smaller # Number of msec of granularity to store the user IP 'last seen' time. Smaller
@ -33,7 +36,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
self.client_ip_last_seen = Cache( self.client_ip_last_seen = Cache(
name="client_ip_last_seen", name="client_ip_last_seen",
keylen=4, keylen=4,
max_entries=5000, max_entries=50000 * CACHE_SIZE_FACTOR,
) )
super(ClientIpStore, self).__init__(hs) super(ClientIpStore, self).__init__(hs)
@ -45,7 +48,14 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
columns=["user_id", "device_id", "last_seen"], columns=["user_id", "device_id", "last_seen"],
) )
@defer.inlineCallbacks # (user_id, access_token, ip) -> (user_agent, device_id, last_seen)
self._batch_row_update = {}
self._client_ip_looper = self._clock.looping_call(
self._update_client_ips_batch, 5 * 1000
)
reactor.addSystemEventTrigger("before", "shutdown", self._update_client_ips_batch)
def insert_client_ip(self, user, access_token, ip, user_agent, device_id): def insert_client_ip(self, user, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec()) now = int(self._clock.time_msec())
key = (user.to_string(), access_token, ip) key = (user.to_string(), access_token, ip)
@ -57,34 +67,48 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
# Rate-limited inserts # Rate-limited inserts
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
defer.returnValue(None) return
self.client_ip_last_seen.prefill(key, now) self.client_ip_last_seen.prefill(key, now)
# It's safe not to lock here: a) no unique constraint, self._batch_row_update[key] = (user_agent, device_id, now)
# b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
yield self._simple_upsert( def _update_client_ips_batch(self):
"user_ips", to_update = self._batch_row_update
self._batch_row_update = {}
return self.runInteraction(
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
)
def _update_client_ips_batch_txn(self, txn, to_update):
self.database_engine.lock_table(txn, "user_ips")
for entry in to_update.iteritems():
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
self._simple_upsert_txn(
txn,
table="user_ips",
keyvalues={ keyvalues={
"user_id": user.to_string(), "user_id": user_id,
"access_token": access_token, "access_token": access_token,
"ip": ip, "ip": ip,
"user_agent": user_agent, "user_agent": user_agent,
"device_id": device_id, "device_id": device_id,
}, },
values={ values={
"last_seen": now, "last_seen": last_seen,
}, },
desc="insert_client_ip",
lock=False, lock=False,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def get_last_client_ip_by_device(self, devices): def get_last_client_ip_by_device(self, user_id, device_id):
"""For each device_id listed, give the user_ip it was last seen on """For each device_id listed, give the user_ip it was last seen on
Args: Args:
devices (iterable[(str, str)]): list of (user_id, device_id) pairs user_id (str)
device_id (str): If None fetches all devices for the user
Returns: Returns:
defer.Deferred: resolves to a dict, where the keys defer.Deferred: resolves to a dict, where the keys
@ -95,6 +119,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
res = yield self.runInteraction( res = yield self.runInteraction(
"get_last_client_ip_by_device", "get_last_client_ip_by_device",
self._get_last_client_ip_by_device_txn, self._get_last_client_ip_by_device_txn,
user_id, device_id,
retcols=( retcols=(
"user_id", "user_id",
"access_token", "access_token",
@ -103,19 +128,30 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"device_id", "device_id",
"last_seen", "last_seen",
), ),
devices=devices
) )
ret = {(d["user_id"], d["device_id"]): d for d in res} ret = {(d["user_id"], d["device_id"]): d for d in res}
for key in self._batch_row_update:
uid, access_token, ip = key
if uid == user_id:
user_agent, did, last_seen = self._batch_row_update[key]
if not device_id or did == device_id:
ret[(user_id, device_id)] = {
"user_id": user_id,
"access_token": access_token,
"ip": ip,
"user_agent": user_agent,
"device_id": did,
"last_seen": last_seen,
}
defer.returnValue(ret) defer.returnValue(ret)
@classmethod @classmethod
def _get_last_client_ip_by_device_txn(cls, txn, devices, retcols): def _get_last_client_ip_by_device_txn(cls, txn, user_id, device_id, retcols):
where_clauses = [] where_clauses = []
bindings = [] bindings = []
for (user_id, device_id) in devices:
if device_id is None: if device_id is None:
where_clauses.append("(user_id = ? AND device_id IS NULL)") where_clauses.append("user_id = ?")
bindings.extend((user_id, )) bindings.extend((user_id, ))
else: else:
where_clauses.append("(user_id = ? AND device_id = ?)") where_clauses.append("(user_id = ? AND device_id = ?)")
@ -147,3 +183,37 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
txn.execute(sql, bindings) txn.execute(sql, bindings)
return cls.cursor_to_dict(txn) return cls.cursor_to_dict(txn)
@defer.inlineCallbacks
def get_user_ip_and_agents(self, user):
user_id = user.to_string()
results = {}
for key in self._batch_row_update:
uid, access_token, ip = key
if uid == user_id:
user_agent, _, last_seen = self._batch_row_update[key]
results[(access_token, ip)] = (user_agent, last_seen)
rows = yield self._simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=[
"access_token", "ip", "user_agent", "last_seen"
],
desc="get_user_ip_and_agents",
)
results.update(
((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"]))
for row in rows
)
defer.returnValue(list(
{
"access_token": access_token,
"ip": ip,
"user_agent": user_agent,
"last_seen": last_seen,
}
for (access_token, ip), (user_agent, last_seen) in results.iteritems()
))

View File

@ -368,7 +368,7 @@ class DeviceStore(SQLBaseStore):
prev_sent_id_sql = """ prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id SELECT coalesce(max(stream_id), 0) as stream_id
FROM device_lists_outbound_pokes FROM device_lists_outbound_last_success
WHERE destination = ? AND user_id = ? AND stream_id <= ? WHERE destination = ? AND user_id = ? AND stream_id <= ?
""" """
@ -510,32 +510,43 @@ class DeviceStore(SQLBaseStore):
) )
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
# First we DELETE all rows such that only the latest row for each # We update the device_lists_outbound_last_success with the successfully
# (destination, user_id is left. We do this by selecting first and # poked users. We do the join to see which users need to be inserted and
# deleting. # which updated.
sql = """ sql = """
SELECT user_id, coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
WHERE destination = ? AND stream_id <= ? FROM device_lists_outbound_pokes as o
LEFT JOIN device_lists_outbound_last_success as s
USING (destination, user_id)
WHERE destination = ? AND o.stream_id <= ?
GROUP BY user_id GROUP BY user_id
HAVING count(*) > 1
""" """
txn.execute(sql, (destination, stream_id,)) txn.execute(sql, (destination, stream_id,))
rows = txn.fetchall() rows = txn.fetchall()
sql = """ sql = """
DELETE FROM device_lists_outbound_pokes UPDATE device_lists_outbound_last_success
WHERE destination = ? AND user_id = ? AND stream_id < ? SET stream_id = ?
WHERE destination = ? AND user_id = ?
""" """
txn.executemany( txn.executemany(
sql, ((destination, row[0], row[1],) for row in rows) sql, ((row[1], destination, row[0],) for row in rows if row[2])
) )
# Mark everything that is left as sent
sql = """ sql = """
UPDATE device_lists_outbound_pokes SET sent = ? INSERT INTO device_lists_outbound_last_success
(destination, user_id, stream_id) VALUES (?, ?, ?)
"""
txn.executemany(
sql, ((destination, row[0], row[1],) for row in rows if not row[2])
)
# Delete all sent outbound pokes
sql = """
DELETE FROM device_lists_outbound_pokes
WHERE destination = ? AND stream_id <= ? WHERE destination = ? AND stream_id <= ?
""" """
txn.execute(sql, (True, destination, stream_id,)) txn.execute(sql, (destination, stream_id,))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_whose_devices_changed(self, from_key): def get_user_whose_devices_changed(self, from_key):
@ -670,6 +681,14 @@ class DeviceStore(SQLBaseStore):
) )
) )
# Since we've deleted unsent deltas, we need to remove the entry
# of last successful sent so that the prev_ids are correctly set.
sql = """
DELETE FROM device_lists_outbound_last_success
WHERE destination = ? AND user_id = ?
"""
txn.executemany(sql, ((row[0], row[1]) for row in rows))
logger.info("Pruned %d device list outbound pokes", txn.rowcount) logger.info("Pruned %d device list outbound pokes", txn.rowcount)
return self.runInteraction( return self.runInteraction(

View File

@ -170,3 +170,17 @@ class DirectoryStore(SQLBaseStore):
"room_alias", "room_alias",
desc="get_aliases_for_room", desc="get_aliases_for_room",
) )
def update_aliases_for_room(self, old_room_id, new_room_id, creator):
def _update_aliases_for_room_txn(txn):
sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?"
txn.execute(sql, (new_room_id, creator, old_room_id,))
self._invalidate_cache_and_stream(
txn, self.get_aliases_for_room, (old_room_id,)
)
self._invalidate_cache_and_stream(
txn, self.get_aliases_for_room, (new_room_id,)
)
return self.runInteraction(
"_update_aliases_for_room_txn", _update_aliases_for_room_txn
)

View File

@ -185,8 +185,8 @@ class EndToEndKeyStore(SQLBaseStore):
for algorithm, key_id, json_bytes in new_keys for algorithm, key_id, json_bytes in new_keys
], ],
) )
txn.call_after( self._invalidate_cache_and_stream(
self.count_e2e_one_time_keys.invalidate, (user_id, device_id,) txn, self.count_e2e_one_time_keys, (user_id, device_id,)
) )
yield self.runInteraction( yield self.runInteraction(
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
@ -237,24 +237,29 @@ class EndToEndKeyStore(SQLBaseStore):
) )
for user_id, device_id, algorithm, key_id in delete: for user_id, device_id, algorithm, key_id in delete:
txn.execute(sql, (user_id, device_id, algorithm, key_id)) txn.execute(sql, (user_id, device_id, algorithm, key_id))
txn.call_after( self._invalidate_cache_and_stream(
self.count_e2e_one_time_keys.invalidate, (user_id, device_id,) txn, self.count_e2e_one_time_keys, (user_id, device_id,)
) )
return result return result
return self.runInteraction( return self.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
) )
@defer.inlineCallbacks
def delete_e2e_keys_by_device(self, user_id, device_id): def delete_e2e_keys_by_device(self, user_id, device_id):
yield self._simple_delete( def delete_e2e_keys_by_device_txn(txn):
self._simple_delete_txn(
txn,
table="e2e_device_keys_json", table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id}, keyvalues={"user_id": user_id, "device_id": device_id},
desc="delete_e2e_device_keys_by_device"
) )
yield self._simple_delete( self._simple_delete_txn(
txn,
table="e2e_one_time_keys_json", table="e2e_one_time_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id}, keyvalues={"user_id": user_id, "device_id": device_id},
desc="delete_e2e_one_time_keys_by_device"
) )
self.count_e2e_one_time_keys.invalidate((user_id, device_id,)) self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id,)
)
return self.runInteraction(
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
)

View File

@ -37,24 +37,54 @@ class EventFederationStore(SQLBaseStore):
and backfilling from another server respectively. and backfilling from another server respectively.
""" """
EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
def __init__(self, hs): def __init__(self, hs):
super(EventFederationStore, self).__init__(hs) super(EventFederationStore, self).__init__(hs)
self.register_background_update_handler(
self.EVENT_AUTH_STATE_ONLY,
self._background_delete_non_state_event_auth,
)
hs.get_clock().looping_call( hs.get_clock().looping_call(
self._delete_old_forward_extrem_cache, 60 * 60 * 1000 self._delete_old_forward_extrem_cache, 60 * 60 * 1000
) )
def get_auth_chain(self, event_ids): def get_auth_chain(self, event_ids, include_given=False):
return self.get_auth_chain_ids(event_ids).addCallback(self._get_events) """Get auth events for given event_ids. The events *must* be state events.
def get_auth_chain_ids(self, event_ids): Args:
event_ids (list): state events
include_given (bool): include the given events in result
Returns:
list of events
"""
return self.get_auth_chain_ids(
event_ids, include_given=include_given,
).addCallback(self._get_events)
def get_auth_chain_ids(self, event_ids, include_given=False):
"""Get auth events for given event_ids. The events *must* be state events.
Args:
event_ids (list): state events
include_given (bool): include the given events in result
Returns:
list of event_ids
"""
return self.runInteraction( return self.runInteraction(
"get_auth_chain_ids", "get_auth_chain_ids",
self._get_auth_chain_ids_txn, self._get_auth_chain_ids_txn,
event_ids event_ids, include_given
) )
def _get_auth_chain_ids_txn(self, txn, event_ids): def _get_auth_chain_ids_txn(self, txn, event_ids, include_given):
if include_given:
results = set(event_ids)
else:
results = set() results = set()
base_sql = ( base_sql = (
@ -504,3 +534,52 @@ class EventFederationStore(SQLBaseStore):
txn.execute(query, (room_id,)) txn.execute(query, (room_id,))
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
@defer.inlineCallbacks
def _background_delete_non_state_event_auth(self, progress, batch_size):
def delete_event_auth(txn):
target_min_stream_id = progress.get("target_min_stream_id_inclusive")
max_stream_id = progress.get("max_stream_id_exclusive")
if not target_min_stream_id or not max_stream_id:
txn.execute("SELECT COALESCE(MIN(stream_ordering), 0) FROM events")
rows = txn.fetchall()
target_min_stream_id = rows[0][0]
txn.execute("SELECT COALESCE(MAX(stream_ordering), 0) FROM events")
rows = txn.fetchall()
max_stream_id = rows[0][0]
min_stream_id = max_stream_id - batch_size
sql = """
DELETE FROM event_auth
WHERE event_id IN (
SELECT event_id FROM events
LEFT JOIN state_events USING (room_id, event_id)
WHERE ? <= stream_ordering AND stream_ordering < ?
AND state_key IS null
)
"""
txn.execute(sql, (min_stream_id, max_stream_id,))
new_progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
"max_stream_id_exclusive": min_stream_id,
}
self._background_update_progress_txn(
txn, self.EVENT_AUTH_STATE_ONLY, new_progress
)
return min_stream_id >= target_min_stream_id
result = yield self.runInteraction(
self.EVENT_AUTH_STATE_ONLY, delete_event_auth
)
if not result:
yield self._end_background_update(self.EVENT_AUTH_STATE_ONLY)
defer.returnValue(batch_size)

View File

@ -403,6 +403,11 @@ class EventsStore(SQLBaseStore):
(room_id, ), new_state (room_id, ), new_state
) )
for room_id, latest_event_ids in new_forward_extremeties.iteritems():
self.get_latest_event_ids_in_room.prefill(
(room_id,), list(latest_event_ids)
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids): def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids):
"""Calculates the new forward extremeties for a room given events to """Calculates the new forward extremeties for a room given events to
@ -647,9 +652,10 @@ class EventsStore(SQLBaseStore):
list of the event ids which are the forward extremities. list of the event ids which are the forward extremities.
""" """
self._update_current_state_txn(txn, current_state_for_room)
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
self._update_current_state_txn(txn, current_state_for_room, max_stream_order)
self._update_forward_extremities_txn( self._update_forward_extremities_txn(
txn, txn,
new_forward_extremities=new_forward_extremeties, new_forward_extremities=new_forward_extremeties,
@ -712,7 +718,7 @@ class EventsStore(SQLBaseStore):
backfilled=backfilled, backfilled=backfilled,
) )
def _update_current_state_txn(self, txn, state_delta_by_room): def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
for room_id, current_state_tuple in state_delta_by_room.iteritems(): for room_id, current_state_tuple in state_delta_by_room.iteritems():
to_delete, to_insert, _ = current_state_tuple to_delete, to_insert, _ = current_state_tuple
txn.executemany( txn.executemany(
@ -734,6 +740,29 @@ class EventsStore(SQLBaseStore):
], ],
) )
state_deltas = {key: None for key in to_delete}
state_deltas.update(to_insert)
self._simple_insert_many_txn(
txn,
table="current_state_delta_stream",
values=[
{
"stream_id": max_stream_order,
"room_id": room_id,
"type": key[0],
"state_key": key[1],
"event_id": ev_id,
"prev_event_id": to_delete.get(key, None),
}
for key, ev_id in state_deltas.iteritems()
]
)
self._curr_state_delta_stream_cache.entity_has_changed(
room_id, max_stream_order,
)
# Invalidate the various caches # Invalidate the various caches
# Figure out the changes of membership to invalidate the # Figure out the changes of membership to invalidate the
@ -742,11 +771,7 @@ class EventsStore(SQLBaseStore):
# and which we have added, then we invlidate the caches for all # and which we have added, then we invlidate the caches for all
# those users. # those users.
members_changed = set( members_changed = set(
state_key for ev_type, state_key in to_delete.iterkeys() state_key for ev_type, state_key in state_deltas
if ev_type == EventTypes.Member
)
members_changed.update(
state_key for ev_type, state_key in to_insert.iterkeys()
if ev_type == EventTypes.Member if ev_type == EventTypes.Member
) )
@ -755,6 +780,11 @@ class EventsStore(SQLBaseStore):
txn, self.get_rooms_for_user, (member,) txn, self.get_rooms_for_user, (member,)
) )
for host in set(get_domain_from_id(u) for u in members_changed):
self._invalidate_cache_and_stream(
txn, self.is_host_joined, (room_id, host)
)
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(
txn, self.get_users_in_room, (room_id,) txn, self.get_users_in_room, (room_id,)
) )
@ -1119,6 +1149,7 @@ class EventsStore(SQLBaseStore):
} }
for event, _ in events_and_contexts for event, _ in events_and_contexts
for auth_id, _ in event.auth_events for auth_id, _ in event.auth_events
if event.is_state()
], ],
) )
@ -1418,7 +1449,7 @@ class EventsStore(SQLBaseStore):
] ]
rows = self._new_transaction( rows = self._new_transaction(
conn, "do_fetch", [], None, self._fetch_event_rows, event_ids conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids
) )
row_dict = { row_dict = {
@ -2243,6 +2274,24 @@ class EventsStore(SQLBaseStore):
defer.returnValue((int(res["topological_ordering"]), int(res["stream_ordering"]))) defer.returnValue((int(res["topological_ordering"]), int(res["stream_ordering"])))
def get_max_current_state_delta_stream_id(self):
return self._stream_id_gen.get_current_token()
def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
def get_all_updated_current_state_deltas_txn(txn):
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC LIMIT ?
"""
txn.execute(sql, (from_token, to_token, limit))
return txn.fetchall()
return self.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
)
AllNewEventsResult = namedtuple("AllNewEventsResult", [ AllNewEventsResult = namedtuple("AllNewEventsResult", [
"new_forward_events", "new_backfill_events", "new_forward_events", "new_backfill_events",

View File

@ -19,6 +19,7 @@ from ._base import SQLBaseStore
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks
from canonicaljson import encode_canonical_json
import simplejson as json import simplejson as json
@ -46,11 +47,20 @@ class FilteringStore(SQLBaseStore):
defer.returnValue(json.loads(str(def_json).decode("utf-8"))) defer.returnValue(json.loads(str(def_json).decode("utf-8")))
def add_user_filter(self, user_localpart, user_filter): def add_user_filter(self, user_localpart, user_filter):
def_json = json.dumps(user_filter).encode("utf-8") def_json = encode_canonical_json(user_filter)
# Need an atomic transaction to SELECT the maximal ID so far then # Need an atomic transaction to SELECT the maximal ID so far then
# INSERT a new one # INSERT a new one
def _do_txn(txn): def _do_txn(txn):
sql = (
"SELECT filter_id FROM user_filters "
"WHERE user_id = ? AND filter_json = ?"
)
txn.execute(sql, (user_localpart, def_json))
filter_id_response = txn.fetchone()
if filter_id_response is not None:
return filter_id_response[0]
sql = ( sql = (
"SELECT MAX(filter_id) FROM user_filters " "SELECT MAX(filter_id) FROM user_filters "
"WHERE user_id = ?" "WHERE user_id = ?"

View File

@ -30,13 +30,16 @@ class MediaRepositoryStore(SQLBaseStore):
return self._simple_select_one( return self._simple_select_one(
"local_media_repository", "local_media_repository",
{"media_id": media_id}, {"media_id": media_id},
("media_type", "media_length", "upload_name", "created_ts"), (
"media_type", "media_length", "upload_name", "created_ts",
"quarantined_by", "url_cache",
),
allow_none=True, allow_none=True,
desc="get_local_media", desc="get_local_media",
) )
def store_local_media(self, media_id, media_type, time_now_ms, upload_name, def store_local_media(self, media_id, media_type, time_now_ms, upload_name,
media_length, user_id): media_length, user_id, url_cache=None):
return self._simple_insert( return self._simple_insert(
"local_media_repository", "local_media_repository",
{ {
@ -46,6 +49,7 @@ class MediaRepositoryStore(SQLBaseStore):
"upload_name": upload_name, "upload_name": upload_name,
"media_length": media_length, "media_length": media_length,
"user_id": user_id.to_string(), "user_id": user_id.to_string(),
"url_cache": url_cache,
}, },
desc="store_local_media", desc="store_local_media",
) )
@ -138,7 +142,7 @@ class MediaRepositoryStore(SQLBaseStore):
{"media_origin": origin, "media_id": media_id}, {"media_origin": origin, "media_id": media_id},
( (
"media_type", "media_length", "upload_name", "created_ts", "media_type", "media_length", "upload_name", "created_ts",
"filesystem_id", "filesystem_id", "quarantined_by",
), ),
allow_none=True, allow_none=True,
desc="get_cached_remote_media", desc="get_cached_remote_media",

View File

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 41 SCHEMA_VERSION = 43
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View File

@ -49,7 +49,7 @@ def _load_rules(rawrules, enabled_map):
class PushRuleStore(SQLBaseStore): class PushRuleStore(SQLBaseStore):
@cachedInlineCallbacks() @cachedInlineCallbacks(max_entries=5000)
def get_push_rules_for_user(self, user_id): def get_push_rules_for_user(self, user_id):
rows = yield self._simple_select_list( rows = yield self._simple_select_list(
table="push_rules", table="push_rules",
@ -73,7 +73,7 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(rules) defer.returnValue(rules)
@cachedInlineCallbacks() @cachedInlineCallbacks(max_entries=5000)
def get_push_rules_enabled_for_user(self, user_id): def get_push_rules_enabled_for_user(self, user_id):
results = yield self._simple_select_list( results = yield self._simple_select_list(
table="push_rules_enable", table="push_rules_enable",

View File

@ -45,7 +45,9 @@ class ReceiptsStore(SQLBaseStore):
return return
# Returns an ObservableDeferred # Returns an ObservableDeferred
res = self.get_users_with_read_receipts_in_room.cache.get((room_id,), None) res = self.get_users_with_read_receipts_in_room.cache.get(
room_id, None, update_metrics=False,
)
if res: if res:
if isinstance(res, defer.Deferred) and res.called: if isinstance(res, defer.Deferred) and res.called:

View File

@ -24,6 +24,7 @@ from .engines import PostgresEngine, Sqlite3Engine
import collections import collections
import logging import logging
import ujson as json import ujson as json
import re
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -507,3 +508,98 @@ class RoomStore(SQLBaseStore):
)) ))
else: else:
defer.returnValue(None) defer.returnValue(None)
@cached(max_entries=10000)
def is_room_blocked(self, room_id):
return self._simple_select_one_onecol(
table="blocked_rooms",
keyvalues={
"room_id": room_id,
},
retcol="1",
allow_none=True,
desc="is_room_blocked",
)
@defer.inlineCallbacks
def block_room(self, room_id, user_id):
yield self._simple_insert(
table="blocked_rooms",
values={
"room_id": room_id,
"user_id": user_id,
},
desc="block_room",
)
self.is_room_blocked.invalidate((room_id,))
def quarantine_media_ids_in_room(self, room_id, quarantined_by):
"""For a room loops through all events with media and quarantines
the associated media
"""
def _get_media_ids_in_room(txn):
mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
next_token = self.get_current_events_token() + 1
total_media_quarantined = 0
while next_token:
sql = """
SELECT stream_ordering, content FROM events
WHERE room_id = ?
AND stream_ordering < ?
AND contains_url = ? AND outlier = ?
ORDER BY stream_ordering DESC
LIMIT ?
"""
txn.execute(sql, (room_id, next_token, True, False, 100))
next_token = None
local_media_mxcs = []
remote_media_mxcs = []
for stream_ordering, content_json in txn:
next_token = stream_ordering
content = json.loads(content_json)
content_url = content.get("url")
thumbnail_url = content.get("info", {}).get("thumbnail_url")
for url in (content_url, thumbnail_url):
if not url:
continue
matches = mxc_re.match(url)
if matches:
hostname = matches.group(1)
media_id = matches.group(2)
if hostname == self.hostname:
local_media_mxcs.append(media_id)
else:
remote_media_mxcs.append((hostname, media_id))
# Now update all the tables to set the quarantined_by flag
txn.executemany("""
UPDATE local_media_repository
SET quarantined_by = ?
WHERE media_id = ?
""", ((quarantined_by, media_id) for media_id in local_media_mxcs))
txn.executemany(
"""
UPDATE remote_media_cache
SET quarantined_by = ?
WHERE media_origin AND media_id = ?
""",
(
(quarantined_by, origin, media_id)
for origin, media_id in remote_media_mxcs
)
)
total_media_quarantined += len(local_media_mxcs)
total_media_quarantined += len(remote_media_mxcs)
return total_media_quarantined
return self.runInteraction("get_media_ids_in_room", _get_media_ids_in_room)

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from collections import namedtuple from collections import namedtuple
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.async import Linearizer
from synapse.util.caches import intern_string from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.stringutils import to_ascii from synapse.util.stringutils import to_ascii
@ -392,7 +393,8 @@ class RoomMemberStore(SQLBaseStore):
context=context, context=context,
) )
def get_joined_users_from_state(self, room_id, state_group, state_ids): def get_joined_users_from_state(self, room_id, state_entry):
state_group = state_entry.state_group
if not state_group: if not state_group:
# If state_group is None it means it has yet to be assigned a # If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group # state group, i.e. we need to make sure that calls with a state_group
@ -401,7 +403,7 @@ class RoomMemberStore(SQLBaseStore):
state_group = object() state_group = object()
return self._get_joined_users_from_context( return self._get_joined_users_from_context(
room_id, state_group, state_ids, room_id, state_group, state_entry.state, context=state_entry,
) )
@cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True, @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True,
@ -499,42 +501,40 @@ class RoomMemberStore(SQLBaseStore):
defer.returnValue(users_in_room) defer.returnValue(users_in_room)
def is_host_joined(self, room_id, host, state_group, state_ids): @cachedInlineCallbacks(max_entries=10000)
if not state_group: def is_host_joined(self, room_id, host):
# If state_group is None it means it has yet to be assigned a if '%' in host or '_' in host:
# state group, i.e. we need to make sure that calls with a state_group raise Exception("Invalid host name")
# of None don't hit previous cached calls with a None state_group.
# To do this we set the state_group to a new object as object() != object()
state_group = object()
return self._is_host_joined( sql = """
room_id, host, state_group, state_ids SELECT state_key FROM current_state_events AS c
) INNER JOIN room_memberships USING (event_id)
WHERE membership = 'join'
AND type = 'm.room.member'
AND c.room_id = ?
AND state_key LIKE ?
LIMIT 1
"""
@cachedInlineCallbacks(num_args=3) # We do need to be careful to ensure that host doesn't have any wild cards
def _is_host_joined(self, room_id, host, state_group, current_state_ids): # in it, but we checked above for known ones and we'll check below that
# We don't use `state_group`, its there so that we can cache based # the returned user actually has the correct domain.
# on it. However, its important that its never None, since two current_state's like_clause = "%:" + host
# with a state_group of None are likely to be different.
# See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None
for (etype, state_key), event_id in current_state_ids.items(): rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause)
if etype == EventTypes.Member:
try:
if get_domain_from_id(state_key) != host:
continue
except:
logger.warn("state_key not user_id: %s", state_key)
continue
event = yield self.get_event(event_id, allow_none=True)
if event and event.content["membership"] == Membership.JOIN:
defer.returnValue(True)
if not rows:
defer.returnValue(False) defer.returnValue(False)
def get_joined_hosts(self, room_id, state_group, state_ids): user_id = rows[0][0]
if get_domain_from_id(user_id) != host:
# This can only happen if the host name has something funky in it
raise Exception("Invalid host name")
defer.returnValue(True)
def get_joined_hosts(self, room_id, state_entry):
state_group = state_entry.state_group
if not state_group: if not state_group:
# If state_group is None it means it has yet to be assigned a # If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group # state group, i.e. we need to make sure that calls with a state_group
@ -543,33 +543,20 @@ class RoomMemberStore(SQLBaseStore):
state_group = object() state_group = object()
return self._get_joined_hosts( return self._get_joined_hosts(
room_id, state_group, state_ids room_id, state_group, state_entry.state, state_entry=state_entry,
) )
@cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True) @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
def _get_joined_hosts(self, room_id, state_group, current_state_ids): # @defer.inlineCallbacks
def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry):
# We don't use `state_group`, its there so that we can cache based # We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's # on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different. # with a state_group of None are likely to be different.
# See bulk_get_push_rules_for_room for how we work around this. # See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None assert state_group is not None
joined_hosts = set() cache = self._get_joined_hosts_cache(room_id)
for etype, state_key in current_state_ids: joined_hosts = yield cache.get_destinations(state_entry)
if etype == EventTypes.Member:
try:
host = get_domain_from_id(state_key)
except:
logger.warn("state_key not user_id: %s", state_key)
continue
if host in joined_hosts:
continue
event_id = current_state_ids[(etype, state_key)]
event = yield self.get_event(event_id, allow_none=True)
if event and event.content["membership"] == Membership.JOIN:
joined_hosts.add(intern_string(host))
defer.returnValue(joined_hosts) defer.returnValue(joined_hosts)
@ -647,3 +634,75 @@ class RoomMemberStore(SQLBaseStore):
yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME) yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME)
defer.returnValue(result) defer.returnValue(result)
@cached(max_entries=10000, iterable=True)
def _get_joined_hosts_cache(self, room_id):
return _JoinedHostsCache(self, room_id)
class _JoinedHostsCache(object):
"""Cache for joined hosts in a room that is optimised to handle updates
via state deltas.
"""
def __init__(self, store, room_id):
self.store = store
self.room_id = room_id
self.hosts_to_joined_users = {}
self.state_group = object()
self.linearizer = Linearizer("_JoinedHostsCache")
self._len = 0
@defer.inlineCallbacks
def get_destinations(self, state_entry):
"""Get set of destinations for a state entry
Args:
state_entry(synapse.state._StateCacheEntry)
"""
if state_entry.state_group == self.state_group:
defer.returnValue(frozenset(self.hosts_to_joined_users))
with (yield self.linearizer.queue(())):
if state_entry.state_group == self.state_group:
pass
elif state_entry.prev_group == self.state_group:
for (typ, state_key), event_id in state_entry.delta_ids.iteritems():
if typ != EventTypes.Member:
continue
host = intern_string(get_domain_from_id(state_key))
user_id = state_key
known_joins = self.hosts_to_joined_users.setdefault(host, set())
event = yield self.store.get_event(event_id)
if event.membership == Membership.JOIN:
known_joins.add(user_id)
else:
known_joins.discard(user_id)
if not known_joins:
self.hosts_to_joined_users.pop(host, None)
else:
joined_users = yield self.store.get_joined_users_from_state(
self.room_id, state_entry,
)
self.hosts_to_joined_users = {}
for user_id in joined_users:
host = intern_string(get_domain_from_id(user_id))
self.hosts_to_joined_users.setdefault(host, set()).add(user_id)
if state_entry.state_group:
self.state_group = state_entry.state_group
else:
self.state_group = object()
self._len = sum(len(v) for v in self.hosts_to_joined_users.itervalues())
defer.returnValue(frozenset(self.hosts_to_joined_users))
def __len__(self):
return self._len

View File

@ -0,0 +1,26 @@
/* Copyright 2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE current_state_delta_stream (
stream_id BIGINT NOT NULL,
room_id TEXT NOT NULL,
type TEXT NOT NULL,
state_key TEXT NOT NULL,
event_id TEXT, -- Is null if the key was removed
prev_event_id TEXT -- Is null if the key was added
);
CREATE INDEX current_state_delta_stream_idx ON current_state_delta_stream(stream_id);

View File

@ -0,0 +1,33 @@
/* Copyright 2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Table of last stream_id that we sent to destination for user_id. This is
-- used to fill out the `prev_id` fields of outbound device list updates.
CREATE TABLE device_lists_outbound_last_success (
destination TEXT NOT NULL,
user_id TEXT NOT NULL,
stream_id BIGINT NOT NULL
);
INSERT INTO device_lists_outbound_last_success
SELECT destination, user_id, coalesce(max(stream_id), 0) as stream_id
FROM device_lists_outbound_pokes
WHERE sent = (1 = 1) -- sqlite doesn't have inbuilt boolean values
GROUP BY destination, user_id;
CREATE INDEX device_lists_outbound_last_success_idx ON device_lists_outbound_last_success(
destination, user_id, stream_id
);

View File

@ -0,0 +1,17 @@
/* Copyright 2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT INTO background_updates (update_name, progress_json) VALUES
('event_auth_state_only', '{}');

View File

@ -0,0 +1,84 @@
# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from synapse.storage.prepare_database import get_statements
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
logger = logging.getLogger(__name__)
BOTH_TABLES = """
CREATE TABLE user_directory_stream_pos (
Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
stream_id BIGINT,
CHECK (Lock='X')
);
INSERT INTO user_directory_stream_pos (stream_id) VALUES (null);
CREATE TABLE user_directory (
user_id TEXT NOT NULL,
room_id TEXT NOT NULL, -- A room_id that we know the user is joined to
display_name TEXT,
avatar_url TEXT
);
CREATE INDEX user_directory_room_idx ON user_directory(room_id);
CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id);
CREATE TABLE users_in_pubic_room (
user_id TEXT NOT NULL,
room_id TEXT NOT NULL -- A room_id that we know is public
);
CREATE INDEX users_in_pubic_room_room_idx ON users_in_pubic_room(room_id);
CREATE UNIQUE INDEX users_in_pubic_room_user_idx ON users_in_pubic_room(user_id);
"""
POSTGRES_TABLE = """
CREATE TABLE user_directory_search (
user_id TEXT NOT NULL,
vector tsvector
);
CREATE INDEX user_directory_search_fts_idx ON user_directory_search USING gin(vector);
CREATE UNIQUE INDEX user_directory_search_user_idx ON user_directory_search(user_id);
"""
SQLITE_TABLE = """
CREATE VIRTUAL TABLE user_directory_search
USING fts4 ( user_id, value );
"""
def run_create(cur, database_engine, *args, **kwargs):
for statement in get_statements(BOTH_TABLES.splitlines()):
cur.execute(statement)
if isinstance(database_engine, PostgresEngine):
for statement in get_statements(POSTGRES_TABLE.splitlines()):
cur.execute(statement)
elif isinstance(database_engine, Sqlite3Engine):
for statement in get_statements(SQLITE_TABLE.splitlines()):
cur.execute(statement)
else:
raise Exception("Unrecognized database engine")
def run_upgrade(*args, **kwargs):
pass

View File

@ -0,0 +1,21 @@
/* Copyright 2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE blocked_rooms (
room_id TEXT NOT NULL,
user_id TEXT NOT NULL -- Admin who blocked the room
);
CREATE UNIQUE INDEX blocked_rooms_idx ON blocked_rooms(room_id);

View File

@ -0,0 +1,17 @@
/* Copyright 2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE local_media_repository ADD COLUMN quarantined_by TEXT;
ALTER TABLE remote_media_cache ADD COLUMN quarantined_by TEXT;

View File

@ -0,0 +1,16 @@
/* Copyright 2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE local_media_repository ADD COLUMN url_cache TEXT;

View File

@ -0,0 +1,33 @@
/* Copyright 2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Table keeping track of who shares a room with who. We only keep track
-- of this for local users, so `user_id` is local users only (but we do keep track
-- of which remote users share a room)
CREATE TABLE users_who_share_rooms (
user_id TEXT NOT NULL,
other_user_id TEXT NOT NULL,
room_id TEXT NOT NULL,
share_private BOOLEAN NOT NULL -- is the shared room private? i.e. they share a private room
);
CREATE UNIQUE INDEX users_who_share_rooms_u_idx ON users_who_share_rooms(user_id, other_user_id);
CREATE INDEX users_who_share_rooms_r_idx ON users_who_share_rooms(room_id);
CREATE INDEX users_who_share_rooms_o_idx ON users_who_share_rooms(other_user_id);
-- Make sure that we popualte the table initially
UPDATE user_directory_stream_pos SET stream_id = NULL;

View File

@ -20,6 +20,7 @@ from synapse.util.stringutils import to_ascii
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from twisted.internet import defer from twisted.internet import defer
from collections import namedtuple
import logging import logging
@ -29,6 +30,16 @@ logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100 MAX_STATE_DELTA_HOPS = 100
class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))):
"""Return type of get_state_group_delta that implements __len__, which lets
us use the itrable flag when caching
"""
__slots__ = []
def __len__(self):
return len(self.delta_ids) if self.delta_ids else 0
class StateStore(SQLBaseStore): class StateStore(SQLBaseStore):
""" Keeps track of the state at a given event. """ Keeps track of the state at a given event.
@ -98,6 +109,46 @@ class StateStore(SQLBaseStore):
_get_current_state_ids_txn, _get_current_state_ids_txn,
) )
@cached(max_entries=10000, iterable=True)
def get_state_group_delta(self, state_group):
"""Given a state group try to return a previous group and a delta between
the old and the new.
Returns:
(prev_group, delta_ids), where both may be None.
"""
def _get_state_group_delta_txn(txn):
prev_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={
"state_group": state_group,
},
retcol="prev_state_group",
allow_none=True,
)
if not prev_group:
return _GetStateGroupDelta(None, None)
delta_ids = self._simple_select_list_txn(
txn,
table="state_groups_state",
keyvalues={
"state_group": state_group,
},
retcols=("type", "state_key", "event_id",)
)
return _GetStateGroupDelta(prev_group, {
(row["type"], row["state_key"]): row["event_id"]
for row in delta_ids
})
return self.runInteraction(
"get_state_group_delta",
_get_state_group_delta_txn,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_groups_ids(self, room_id, event_ids): def get_state_groups_ids(self, room_id, event_ids):
if not event_ids: if not event_ids:
@ -184,6 +235,19 @@ class StateStore(SQLBaseStore):
# We persist as a delta if we can, while also ensuring the chain # We persist as a delta if we can, while also ensuring the chain
# of deltas isn't tooo long, as otherwise read performance degrades. # of deltas isn't tooo long, as otherwise read performance degrades.
if context.prev_group: if context.prev_group:
is_in_db = self._simple_select_one_onecol_txn(
txn,
table="state_groups",
keyvalues={"id": context.prev_group},
retcol="id",
allow_none=True,
)
if not is_in_db:
raise Exception(
"Trying to persist state with unpersisted prev_group: %r"
% (context.prev_group,)
)
potential_hops = self._count_state_group_hops_txn( potential_hops = self._count_state_group_hops_txn(
txn, context.prev_group txn, context.prev_group
) )
@ -251,6 +315,12 @@ class StateStore(SQLBaseStore):
], ],
) )
for event_id, state_group_id in state_groups.iteritems():
txn.call_after(
self._get_state_group_for_event.prefill,
(event_id,), state_group_id
)
def _count_state_group_hops_txn(self, txn, state_group): def _count_state_group_hops_txn(self, txn, state_group):
"""Given a state group, count how many hops there are in the tree. """Given a state group, count how many hops there are in the tree.
@ -520,8 +590,8 @@ class StateStore(SQLBaseStore):
state_map = yield self.get_state_ids_for_events([event_id], types) state_map = yield self.get_state_ids_for_events([event_id], types)
defer.returnValue(state_map[event_id]) defer.returnValue(state_map[event_id])
@cached(num_args=2, max_entries=50000) @cached(max_entries=50000)
def _get_state_group_for_event(self, room_id, event_id): def _get_state_group_for_event(self, event_id):
return self._simple_select_one_onecol( return self._simple_select_one_onecol(
table="event_to_state_groups", table="event_to_state_groups",
keyvalues={ keyvalues={
@ -563,20 +633,22 @@ class StateStore(SQLBaseStore):
where a `state_key` of `None` matches all state_keys for the where a `state_key` of `None` matches all state_keys for the
`type`. `type`.
""" """
is_all, state_dict_ids = self._state_group_cache.get(group) is_all, known_absent, state_dict_ids = self._state_group_cache.get(group)
type_to_key = {} type_to_key = {}
missing_types = set() missing_types = set()
for typ, state_key in types: for typ, state_key in types:
key = (typ, state_key)
if state_key is None: if state_key is None:
type_to_key[typ] = None type_to_key[typ] = None
missing_types.add((typ, state_key)) missing_types.add(key)
else: else:
if type_to_key.get(typ, object()) is not None: if type_to_key.get(typ, object()) is not None:
type_to_key.setdefault(typ, set()).add(state_key) type_to_key.setdefault(typ, set()).add(state_key)
if (typ, state_key) not in state_dict_ids: if key not in state_dict_ids and key not in known_absent:
missing_types.add((typ, state_key)) missing_types.add(key)
sentinel = object() sentinel = object()
@ -590,7 +662,7 @@ class StateStore(SQLBaseStore):
return True return True
return False return False
got_all = not (missing_types or types is None) got_all = is_all or not missing_types
return { return {
k: v for k, v in state_dict_ids.iteritems() k: v for k, v in state_dict_ids.iteritems()
@ -607,7 +679,7 @@ class StateStore(SQLBaseStore):
Args: Args:
group: The state group to lookup group: The state group to lookup
""" """
is_all, state_dict_ids = self._state_group_cache.get(group) is_all, _, state_dict_ids = self._state_group_cache.get(group)
return state_dict_ids, is_all return state_dict_ids, is_all
@ -624,7 +696,7 @@ class StateStore(SQLBaseStore):
missing_groups = [] missing_groups = []
if types is not None: if types is not None:
for group in set(groups): for group in set(groups):
state_dict_ids, missing_types, got_all = self._get_some_state_from_cache( state_dict_ids, _, got_all = self._get_some_state_from_cache(
group, types group, types
) )
results[group] = state_dict_ids results[group] = state_dict_ids
@ -653,18 +725,6 @@ class StateStore(SQLBaseStore):
# Now we want to update the cache with all the things we fetched # Now we want to update the cache with all the things we fetched
# from the database. # from the database.
for group, group_state_dict in group_to_state_dict.iteritems(): for group, group_state_dict in group_to_state_dict.iteritems():
if types:
# We delibrately put key -> None mappings into the cache to
# cache absence of the key, on the assumption that if we've
# explicitly asked for some types then we will probably ask
# for them again.
state_dict = {
(intern_string(etype), intern_string(state_key)): None
for (etype, state_key) in types
}
state_dict.update(results[group])
results[group] = state_dict
else:
state_dict = results[group] state_dict = results[group]
state_dict.update( state_dict.update(
@ -677,17 +737,9 @@ class StateStore(SQLBaseStore):
key=group, key=group,
value=state_dict, value=state_dict,
full=(types is None), full=(types is None),
known_absent=types,
) )
# Remove all the entries with None values. The None values were just
# used for bookkeeping in the cache.
for group, state_dict in results.iteritems():
results[group] = {
key: event_id
for key, event_id in state_dict.iteritems()
if event_id
}
defer.returnValue(results) defer.returnValue(results)
def get_next_state_group(self): def get_next_state_group(self):

View File

@ -0,0 +1,743 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.api.constants import EventTypes, JoinRules
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import get_domain_from_id, get_localpart_from_id
import re
import logging
logger = logging.getLogger(__name__)
class UserDirectoryStore(SQLBaseStore):
@cachedInlineCallbacks(cache_context=True)
def is_room_world_readable_or_publicly_joinable(self, room_id, cache_context):
"""Check if the room is either world_readable or publically joinable
"""
current_state_ids = yield self.get_current_state_ids(
room_id, on_invalidate=cache_context.invalidate
)
join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
if join_rules_id:
join_rule_ev = yield self.get_event(join_rules_id, allow_none=True)
if join_rule_ev:
if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
defer.returnValue(True)
hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
if hist_vis_id:
hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True)
if hist_vis_ev:
if hist_vis_ev.content.get("history_visibility") == "world_readable":
defer.returnValue(True)
defer.returnValue(False)
@defer.inlineCallbacks
def add_users_to_public_room(self, room_id, user_ids):
"""Add user to the list of users in public rooms
Args:
room_id (str): A room_id that all users are in that is world_readable
or publically joinable
user_ids (list(str)): Users to add
"""
yield self._simple_insert_many(
table="users_in_pubic_room",
values=[
{
"user_id": user_id,
"room_id": room_id,
}
for user_id in user_ids
],
desc="add_users_to_public_room"
)
for user_id in user_ids:
self.get_user_in_public_room.invalidate((user_id,))
def add_profiles_to_user_dir(self, room_id, users_with_profile):
"""Add profiles to the user directory
Args:
room_id (str): A room_id that all users are joined to
users_with_profile (dict): Users to add to directory in the form of
mapping of user_id -> ProfileInfo
"""
if isinstance(self.database_engine, PostgresEngine):
# We weight the loclpart most highly, then display name and finally
# server name
sql = """
INSERT INTO user_directory_search(user_id, vector)
VALUES (?,
setweight(to_tsvector('english', ?), 'A')
|| setweight(to_tsvector('english', ?), 'D')
|| setweight(to_tsvector('english', COALESCE(?, '')), 'B')
)
"""
args = (
(
user_id, get_localpart_from_id(user_id), get_domain_from_id(user_id),
profile.display_name,
)
for user_id, profile in users_with_profile.iteritems()
)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = """
INSERT INTO user_directory_search(user_id, value)
VALUES (?,?)
"""
args = (
(
user_id,
"%s %s" % (user_id, p.display_name,) if p.display_name else user_id
)
for user_id, p in users_with_profile.iteritems()
)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
def _add_profiles_to_user_dir_txn(txn):
txn.executemany(sql, args)
self._simple_insert_many_txn(
txn,
table="user_directory",
values=[
{
"user_id": user_id,
"room_id": room_id,
"display_name": profile.display_name,
"avatar_url": profile.avatar_url,
}
for user_id, profile in users_with_profile.iteritems()
]
)
for user_id in users_with_profile:
txn.call_after(
self.get_user_in_directory.invalidate, (user_id,)
)
return self.runInteraction(
"add_profiles_to_user_dir", _add_profiles_to_user_dir_txn
)
@defer.inlineCallbacks
def update_user_in_user_dir(self, user_id, room_id):
yield self._simple_update_one(
table="user_directory",
keyvalues={"user_id": user_id},
updatevalues={"room_id": room_id},
desc="update_user_in_user_dir",
)
self.get_user_in_directory.invalidate((user_id,))
def update_profile_in_user_dir(self, user_id, display_name, avatar_url, room_id):
def _update_profile_in_user_dir_txn(txn):
new_entry = self._simple_upsert_txn(
txn,
table="user_directory",
keyvalues={"user_id": user_id},
insertion_values={"room_id": room_id},
values={"display_name": display_name, "avatar_url": avatar_url},
lock=False, # We're only inserter
)
if isinstance(self.database_engine, PostgresEngine):
# We weight the loclpart most highly, then display name and finally
# server name
if new_entry:
sql = """
INSERT INTO user_directory_search(user_id, vector)
VALUES (?,
setweight(to_tsvector('english', ?), 'A')
|| setweight(to_tsvector('english', ?), 'D')
|| setweight(to_tsvector('english', COALESCE(?, '')), 'B')
)
"""
txn.execute(
sql,
(
user_id, get_localpart_from_id(user_id),
get_domain_from_id(user_id), display_name,
)
)
else:
sql = """
UPDATE user_directory_search
SET vector = setweight(to_tsvector('english', ?), 'A')
|| setweight(to_tsvector('english', ?), 'D')
|| setweight(to_tsvector('english', COALESCE(?, '')), 'B')
WHERE user_id = ?
"""
txn.execute(
sql,
(
get_localpart_from_id(user_id), get_domain_from_id(user_id),
display_name, user_id,
)
)
elif isinstance(self.database_engine, Sqlite3Engine):
value = "%s %s" % (user_id, display_name,) if display_name else user_id
self._simple_upsert_txn(
txn,
table="user_directory_search",
keyvalues={"user_id": user_id},
values={"value": value},
lock=False, # We're only inserter
)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
return self.runInteraction(
"update_profile_in_user_dir", _update_profile_in_user_dir_txn
)
@defer.inlineCallbacks
def update_user_in_public_user_list(self, user_id, room_id):
yield self._simple_update_one(
table="users_in_pubic_room",
keyvalues={"user_id": user_id},
updatevalues={"room_id": room_id},
desc="update_user_in_public_user_list",
)
self.get_user_in_public_room.invalidate((user_id,))
def remove_from_user_dir(self, user_id):
def _remove_from_user_dir_txn(txn):
self._simple_delete_txn(
txn,
table="user_directory",
keyvalues={"user_id": user_id},
)
self._simple_delete_txn(
txn,
table="user_directory_search",
keyvalues={"user_id": user_id},
)
self._simple_delete_txn(
txn,
table="users_in_pubic_room",
keyvalues={"user_id": user_id},
)
txn.call_after(
self.get_user_in_directory.invalidate, (user_id,)
)
txn.call_after(
self.get_user_in_public_room.invalidate, (user_id,)
)
return self.runInteraction(
"remove_from_user_dir", _remove_from_user_dir_txn,
)
@defer.inlineCallbacks
def remove_from_user_in_public_room(self, user_id):
yield self._simple_delete(
table="users_in_pubic_room",
keyvalues={"user_id": user_id},
desc="remove_from_user_in_public_room",
)
self.get_user_in_public_room.invalidate((user_id,))
def get_users_in_public_due_to_room(self, room_id):
"""Get all user_ids that are in the room directory becuase they're
in the given room_id
"""
return self._simple_select_onecol(
table="users_in_pubic_room",
keyvalues={"room_id": room_id},
retcol="user_id",
desc="get_users_in_public_due_to_room",
)
@defer.inlineCallbacks
def get_users_in_dir_due_to_room(self, room_id):
"""Get all user_ids that are in the room directory becuase they're
in the given room_id
"""
user_ids_dir = yield self._simple_select_onecol(
table="user_directory",
keyvalues={"room_id": room_id},
retcol="user_id",
desc="get_users_in_dir_due_to_room",
)
user_ids_pub = yield self._simple_select_onecol(
table="users_in_pubic_room",
keyvalues={"room_id": room_id},
retcol="user_id",
desc="get_users_in_dir_due_to_room",
)
user_ids_share = yield self._simple_select_onecol(
table="users_who_share_rooms",
keyvalues={"room_id": room_id},
retcol="user_id",
desc="get_users_in_dir_due_to_room",
)
user_ids = set(user_ids_dir)
user_ids.update(user_ids_pub)
user_ids.update(user_ids_share)
defer.returnValue(user_ids)
@defer.inlineCallbacks
def get_all_rooms(self):
"""Get all room_ids we've ever known about, in ascending order of "size"
"""
sql = """
SELECT room_id FROM current_state_events
GROUP BY room_id
ORDER BY count(*) ASC
"""
rows = yield self._execute("get_all_rooms", None, sql)
defer.returnValue([room_id for room_id, in rows])
def add_users_who_share_room(self, room_id, share_private, user_id_tuples):
"""Insert entries into the users_who_share_rooms table. The first
user should be a local user.
Args:
room_id (str)
share_private (bool): Is the room private
user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
"""
def _add_users_who_share_room_txn(txn):
self._simple_insert_many_txn(
txn,
table="users_who_share_rooms",
values=[
{
"user_id": user_id,
"other_user_id": other_user_id,
"room_id": room_id,
"share_private": share_private,
}
for user_id, other_user_id in user_id_tuples
],
)
for user_id, other_user_id in user_id_tuples:
txn.call_after(
self.get_users_who_share_room_from_dir.invalidate,
(user_id,),
)
txn.call_after(
self.get_if_users_share_a_room.invalidate,
(user_id, other_user_id),
)
return self.runInteraction(
"add_users_who_share_room", _add_users_who_share_room_txn
)
def update_users_who_share_room(self, room_id, share_private, user_id_sets):
"""Updates entries in the users_who_share_rooms table. The first
user should be a local user.
Args:
room_id (str)
share_private (bool): Is the room private
user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
"""
def _update_users_who_share_room_txn(txn):
sql = """
UPDATE users_who_share_rooms
SET room_id = ?, share_private = ?
WHERE user_id = ? AND other_user_id = ?
"""
txn.executemany(
sql,
(
(room_id, share_private, uid, oid)
for uid, oid in user_id_sets
)
)
for user_id, other_user_id in user_id_sets:
txn.call_after(
self.get_users_who_share_room_from_dir.invalidate,
(user_id,),
)
txn.call_after(
self.get_if_users_share_a_room.invalidate,
(user_id, other_user_id),
)
return self.runInteraction(
"update_users_who_share_room", _update_users_who_share_room_txn
)
def remove_user_who_share_room(self, user_id, other_user_id):
"""Deletes entries in the users_who_share_rooms table. The first
user should be a local user.
Args:
room_id (str)
share_private (bool): Is the room private
user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
"""
def _remove_user_who_share_room_txn(txn):
self._simple_delete_txn(
txn,
table="users_who_share_rooms",
keyvalues={
"user_id": user_id,
"other_user_id": other_user_id,
},
)
txn.call_after(
self.get_users_who_share_room_from_dir.invalidate,
(user_id,),
)
txn.call_after(
self.get_if_users_share_a_room.invalidate,
(user_id, other_user_id),
)
return self.runInteraction(
"remove_user_who_share_room", _remove_user_who_share_room_txn
)
@cached(max_entries=500000)
def get_if_users_share_a_room(self, user_id, other_user_id):
"""Gets if users share a room.
Args:
user_id (str): Must be a local user_id
other_user_id (str)
Returns:
bool|None: None if they don't share a room, otherwise whether they
share a private room or not.
"""
return self._simple_select_one_onecol(
table="users_who_share_rooms",
keyvalues={
"user_id": user_id,
"other_user_id": other_user_id,
},
retcol="share_private",
allow_none=True,
desc="get_if_users_share_a_room",
)
@cachedInlineCallbacks(max_entries=500000, iterable=True)
def get_users_who_share_room_from_dir(self, user_id):
"""Returns the set of users who share a room with `user_id`
Args:
user_id(str): Must be a local user
Returns:
dict: user_id -> share_private mapping
"""
rows = yield self._simple_select_list(
table="users_who_share_rooms",
keyvalues={
"user_id": user_id,
},
retcols=("other_user_id", "share_private",),
desc="get_users_who_share_room_with_user",
)
defer.returnValue({
row["other_user_id"]: row["share_private"]
for row in rows
})
def get_users_in_share_dir_with_room_id(self, user_id, room_id):
"""Get all user tuples that are in the users_who_share_rooms due to the
given room_id.
Returns:
[(user_id, other_user_id)]: where one of the two will match the given
user_id.
"""
sql = """
SELECT user_id, other_user_id FROM users_who_share_rooms
WHERE room_id = ? AND (user_id = ? OR other_user_id = ?)
"""
return self._execute(
"get_users_in_share_dir_with_room_id", None, sql, room_id, user_id, user_id
)
@defer.inlineCallbacks
def get_rooms_in_common_for_users(self, user_id, other_user_id):
"""Given two user_ids find out the list of rooms they share.
"""
sql = """
SELECT room_id FROM (
SELECT c.room_id FROM current_state_events AS c
INNER JOIN room_memberships USING (event_id)
WHERE type = 'm.room.member'
AND membership = 'join'
AND state_key = ?
) AS f1 INNER JOIN (
SELECT c.room_id FROM current_state_events AS c
INNER JOIN room_memberships USING (event_id)
WHERE type = 'm.room.member'
AND membership = 'join'
AND state_key = ?
) f2 USING (room_id)
"""
rows = yield self._execute(
"get_rooms_in_common_for_users", None, sql, user_id, other_user_id
)
defer.returnValue([room_id for room_id, in rows])
def delete_all_from_user_dir(self):
"""Delete the entire user directory
"""
def _delete_all_from_user_dir_txn(txn):
txn.execute("DELETE FROM user_directory")
txn.execute("DELETE FROM user_directory_search")
txn.execute("DELETE FROM users_in_pubic_room")
txn.execute("DELETE FROM users_who_share_rooms")
txn.call_after(self.get_user_in_directory.invalidate_all)
txn.call_after(self.get_user_in_public_room.invalidate_all)
txn.call_after(self.get_users_who_share_room_from_dir.invalidate_all)
txn.call_after(self.get_if_users_share_a_room.invalidate_all)
return self.runInteraction(
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
)
@cached()
def get_user_in_directory(self, user_id):
return self._simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
retcols=("room_id", "display_name", "avatar_url",),
allow_none=True,
desc="get_user_in_directory",
)
@cached()
def get_user_in_public_room(self, user_id):
return self._simple_select_one(
table="users_in_pubic_room",
keyvalues={"user_id": user_id},
retcols=("room_id",),
allow_none=True,
desc="get_user_in_public_room",
)
def get_user_directory_stream_pos(self):
return self._simple_select_one_onecol(
table="user_directory_stream_pos",
keyvalues={},
retcol="stream_id",
desc="get_user_directory_stream_pos",
)
def update_user_directory_stream_pos(self, stream_id):
return self._simple_update_one(
table="user_directory_stream_pos",
keyvalues={},
updatevalues={"stream_id": stream_id},
desc="update_user_directory_stream_pos",
)
def get_current_state_deltas(self, prev_stream_id):
prev_stream_id = int(prev_stream_id)
if not self._curr_state_delta_stream_cache.has_any_entity_changed(prev_stream_id):
return []
def get_current_state_deltas_txn(txn):
# First we calculate the max stream id that will give us less than
# N results.
# We arbitarily limit to 100 stream_id entries to ensure we don't
# select toooo many.
sql = """
SELECT stream_id, count(*)
FROM current_state_delta_stream
WHERE stream_id > ?
GROUP BY stream_id
ORDER BY stream_id ASC
LIMIT 100
"""
txn.execute(sql, (prev_stream_id,))
total = 0
max_stream_id = prev_stream_id
for max_stream_id, count in txn:
total += count
if total > 100:
# We arbitarily limit to 100 entries to ensure we don't
# select toooo many.
break
# Now actually get the deltas
sql = """
SELECT stream_id, room_id, type, state_key, event_id, prev_event_id
FROM current_state_delta_stream
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
"""
txn.execute(sql, (prev_stream_id, max_stream_id,))
return self.cursor_to_dict(txn)
return self.runInteraction(
"get_current_state_deltas", get_current_state_deltas_txn
)
def get_max_stream_id_in_current_state_deltas(self):
return self._simple_select_one_onecol(
table="current_state_delta_stream",
keyvalues={},
retcol="COALESCE(MAX(stream_id), -1)",
desc="get_max_stream_id_in_current_state_deltas",
)
@defer.inlineCallbacks
def search_user_dir(self, user_id, search_term, limit):
"""Searches for users in directory
Returns:
dict of the form::
{
"limited": <bool>, # whether there were more results or not
"results": [ # Ordered by best match first
{
"user_id": <user_id>,
"display_name": <display_name>,
"avatar_url": <avatar_url>
}
]
}
"""
if isinstance(self.database_engine, PostgresEngine):
full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
# We order by rank and then if they have profile info
# The ranking algorithm is hand tweaked for "best" results. Broadly
# the idea is we give a higher weight to exact matches.
# The array of numbers are the weights for the various part of the
# search: (domain, _, display name, localpart)
sql = """
SELECT d.user_id, display_name, avatar_url
FROM user_directory_search
INNER JOIN user_directory AS d USING (user_id)
LEFT JOIN users_in_pubic_room AS p USING (user_id)
LEFT JOIN (
SELECT other_user_id AS user_id FROM users_who_share_rooms
WHERE user_id = ? AND share_private
) AS s USING (user_id)
WHERE
(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)
AND vector @@ to_tsquery('english', ?)
ORDER BY
(CASE WHEN s.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
* (CASE WHEN display_name IS NOT NULL THEN 1.2 ELSE 1.0 END)
* (CASE WHEN avatar_url IS NOT NULL THEN 1.2 ELSE 1.0 END)
* (
3 * ts_rank_cd(
'{0.1, 0.1, 0.9, 1.0}',
vector,
to_tsquery('english', ?),
8
)
+ ts_rank_cd(
'{0.1, 0.1, 0.9, 1.0}',
vector,
to_tsquery('english', ?),
8
)
)
DESC,
display_name IS NULL,
avatar_url IS NULL
LIMIT ?
"""
args = (user_id, full_query, exact_query, prefix_query, limit + 1,)
elif isinstance(self.database_engine, Sqlite3Engine):
search_query = _parse_query_sqlite(search_term)
sql = """
SELECT d.user_id, display_name, avatar_url
FROM user_directory_search
INNER JOIN user_directory AS d USING (user_id)
LEFT JOIN users_in_pubic_room AS p USING (user_id)
LEFT JOIN (
SELECT other_user_id AS user_id FROM users_who_share_rooms
WHERE user_id = ? AND share_private
) AS s USING (user_id)
WHERE
(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)
AND value MATCH ?
ORDER BY
rank(matchinfo(user_directory_search)) DESC,
display_name IS NULL,
avatar_url IS NULL
LIMIT ?
"""
args = (user_id, search_query, limit + 1)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
results = yield self._execute(
"search_user_dir", self.cursor_to_dict, sql, *args
)
limited = len(results) > limit
defer.returnValue({
"limited": limited,
"results": results,
})
def _parse_query_sqlite(search_term):
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to database.
We use this so that we can add prefix matching, which isn't something
that is supported by default.
We specifically add both a prefix and non prefix matching term so that
exact matches get ranked higher.
"""
# Pull out the individual words, discarding any non-word characters.
results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
return " & ".join("(%s* | %s)" % (result, result,) for result in results)
def _parse_query_postgres(search_term):
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to database.
We use this so that we can add prefix matching, which isn't something
that is supported by default.
"""
# Pull out the individual words, discarding any non-word characters.
results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
both = " & ".join("(%s:* | %s)" % (result, result,) for result in results)
exact = " & ".join("%s" % (result,) for result in results)
prefix = " & ".join("%s:*" % (result,) for result in results)
return both, exact, prefix

View File

@ -62,6 +62,13 @@ def get_domain_from_id(string):
return string[idx + 1:] return string[idx + 1:]
def get_localpart_from_id(string):
idx = string.find(":")
if idx == -1:
raise SynapseError(400, "Invalid ID: %r" % (string,))
return string[1:idx]
class DomainSpecificString( class DomainSpecificString(
namedtuple("DomainSpecificString", ("localpart", "domain")) namedtuple("DomainSpecificString", ("localpart", "domain"))
): ):

View File

@ -16,7 +16,7 @@
import synapse.metrics import synapse.metrics
import os import os
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5))
metrics = synapse.metrics.get_metrics_for("synapse.util.caches") metrics = synapse.metrics.get_metrics_for("synapse.util.caches")

View File

@ -16,6 +16,7 @@ import logging
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError, logcontext from synapse.util import unwrapFirstError, logcontext
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
from synapse.util.stringutils import to_ascii from synapse.util.stringutils import to_ascii
@ -25,7 +26,6 @@ from . import register_cache
from twisted.internet import defer from twisted.internet import defer
from collections import namedtuple from collections import namedtuple
import os
import functools import functools
import inspect import inspect
import threading import threading
@ -37,9 +37,6 @@ logger = logging.getLogger(__name__)
_CacheSentinel = object() _CacheSentinel = object()
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
class CacheEntry(object): class CacheEntry(object):
__slots__ = [ __slots__ = [
"deferred", "sequence", "callbacks", "invalidated" "deferred", "sequence", "callbacks", "invalidated"
@ -404,6 +401,7 @@ class CacheDescriptor(_CacheDescriptorBase):
wrapped.invalidate_all = cache.invalidate_all wrapped.invalidate_all = cache.invalidate_all
wrapped.cache = cache wrapped.cache = cache
wrapped.num_args = self.num_args
obj.__dict__[self.orig.__name__] = wrapped obj.__dict__[self.orig.__name__] = wrapped
@ -451,8 +449,9 @@ class CacheListDescriptor(_CacheDescriptorBase):
) )
def __get__(self, obj, objtype=None): def __get__(self, obj, objtype=None):
cached_method = getattr(obj, self.cached_method_name)
cache = getattr(obj, self.cached_method_name).cache cache = cached_method.cache
num_args = cached_method.num_args
@functools.wraps(self.orig) @functools.wraps(self.orig)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
@ -469,12 +468,23 @@ class CacheListDescriptor(_CacheDescriptorBase):
results = {} results = {}
cached_defers = {} cached_defers = {}
missing = [] missing = []
for arg in list_args:
key = list(keyargs)
key[self.list_pos] = arg
# If the cache takes a single arg then that is used as the key,
# otherwise a tuple is used.
if num_args == 1:
def cache_get(arg):
return cache.get(arg, callback=invalidate_callback)
else:
key = list(keyargs)
def cache_get(arg):
key[self.list_pos] = arg
return cache.get(tuple(key), callback=invalidate_callback)
for arg in list_args:
try: try:
res = cache.get(tuple(key), callback=invalidate_callback) res = cache_get(arg)
if not isinstance(res, ObservableDeferred): if not isinstance(res, ObservableDeferred):
results[arg] = res results[arg] = res
elif not res.has_succeeded(): elif not res.has_succeeded():
@ -505,6 +515,17 @@ class CacheListDescriptor(_CacheDescriptorBase):
observer = ObservableDeferred(observer) observer = ObservableDeferred(observer)
if num_args == 1:
cache.set(
arg, observer,
callback=invalidate_callback
)
def invalidate(f, key):
cache.invalidate(key)
return f
observer.addErrback(invalidate, arg)
else:
key = list(keyargs) key = list(keyargs)
key[self.list_pos] = arg key[self.list_pos] = arg
cache.set( cache.set(

View File

@ -23,7 +23,17 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "value"))): class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "value"))):
"""Returned when getting an entry from the cache
Attributes:
full (bool): Whether the cache has the full or dict or just some keys.
If not full then not all requested keys will necessarily be present
in `value`
known_absent (set): Keys that were looked up in the dict and were not
there.
value (dict): The full or partial dict value
"""
def __len__(self): def __len__(self):
return len(self.value) return len(self.value)
@ -58,21 +68,31 @@ class DictionaryCache(object):
) )
def get(self, key, dict_keys=None): def get(self, key, dict_keys=None):
"""Fetch an entry out of the cache
Args:
key
dict_key(list): If given a set of keys then return only those keys
that exist in the cache.
Returns:
DictionaryEntry
"""
entry = self.cache.get(key, self.sentinel) entry = self.cache.get(key, self.sentinel)
if entry is not self.sentinel: if entry is not self.sentinel:
self.metrics.inc_hits() self.metrics.inc_hits()
if dict_keys is None: if dict_keys is None:
return DictionaryEntry(entry.full, dict(entry.value)) return DictionaryEntry(entry.full, entry.known_absent, dict(entry.value))
else: else:
return DictionaryEntry(entry.full, { return DictionaryEntry(entry.full, entry.known_absent, {
k: entry.value[k] k: entry.value[k]
for k in dict_keys for k in dict_keys
if k in entry.value if k in entry.value
}) })
self.metrics.inc_misses() self.metrics.inc_misses()
return DictionaryEntry(False, {}) return DictionaryEntry(False, set(), {})
def invalidate(self, key): def invalidate(self, key):
self.check_thread() self.check_thread()
@ -87,19 +107,34 @@ class DictionaryCache(object):
self.sequence += 1 self.sequence += 1
self.cache.clear() self.cache.clear()
def update(self, sequence, key, value, full=False): def update(self, sequence, key, value, full=False, known_absent=None):
"""Updates the entry in the cache
Args:
sequence
key
value (dict): The value to update the cache with.
full (bool): Whether the given value is the full dict, or just a
partial subset there of. If not full then any existing entries
for the key will be updated.
known_absent (set): Set of keys that we know don't exist in the full
dict.
"""
self.check_thread() self.check_thread()
if self.sequence == sequence: if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the # Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369) # number that the cache had before the SELECT was started (SYN-369)
if known_absent is None:
known_absent = set()
if full: if full:
self._insert(key, value) self._insert(key, value, known_absent)
else: else:
self._update_or_insert(key, value) self._update_or_insert(key, value, known_absent)
def _update_or_insert(self, key, value): def _update_or_insert(self, key, value, known_absent):
entry = self.cache.setdefault(key, DictionaryEntry(False, {})) entry = self.cache.setdefault(key, DictionaryEntry(False, set(), {}))
entry.value.update(value) entry.value.update(value)
entry.known_absent.update(known_absent)
def _insert(self, key, value): def _insert(self, key, value, known_absent):
self.cache[key] = DictionaryEntry(True, value) self.cache[key] = DictionaryEntry(True, known_absent, value)

View File

@ -94,6 +94,9 @@ class ExpiringCache(object):
return entry.value return entry.value
def __contains__(self, key):
return key in self._cache
def get(self, key, default=None): def get(self, key, default=None):
try: try:
return self[key] return self[key]

View File

@ -13,20 +13,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.util.caches import register_cache from synapse.util.caches import register_cache, CACHE_SIZE_FACTOR
from blist import sorteddict from blist import sorteddict
import logging import logging
import os
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
class StreamChangeCache(object): class StreamChangeCache(object):
"""Keeps track of the stream positions of the latest change in a set of entities. """Keeps track of the stream positions of the latest change in a set of entities.
@ -89,6 +85,21 @@ class StreamChangeCache(object):
return result return result
def has_any_entity_changed(self, stream_pos):
"""Returns if any entity has changed
"""
assert type(stream_pos) is int
if stream_pos >= self._earliest_known_stream_pos:
self.metrics.inc_hits()
keys = self._cache.keys()
i = keys.bisect_right(stream_pos)
return i < len(keys)
else:
self.metrics.inc_misses()
return True
def get_all_entities_changed(self, stream_pos): def get_all_entities_changed(self, stream_pos):
"""Returns all entites that have had new things since the given """Returns all entites that have had new things since the given
position. If the position is too old it will return None. position. If the position is too old it will return None.

View File

@ -241,7 +241,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
callcount2 = [0] callcount2 = [0]
class A(object): class A(object):
@cached(max_entries=20) # HACK: This makes it 2 due to cache factor @cached(max_entries=4) # HACK: This makes it 2 due to cache factor
def func(self, key): def func(self, key):
callcount[0] += 1 callcount[0] += 1
return key return key

View File

@ -43,10 +43,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
"access_token", "ip", "user_agent", "device_id", "access_token", "ip", "user_agent", "device_id",
) )
# deliberately use an iterable here to make sure that the lookup result = yield self.store.get_last_client_ip_by_device(user_id, "device_id")
# method doesn't iterate it twice
device_list = iter(((user_id, "device_id"),))
result = yield self.store.get_last_client_ip_by_device(device_list)
r = result[(user_id, "device_id")] r = result[(user_id, "device_id")]
self.assertDictContainsSubset( self.assertDictContainsSubset(

View File

@ -143,6 +143,7 @@ class StateTestCase(unittest.TestCase):
"add_event_hashes", "add_event_hashes",
"get_events", "get_events",
"get_next_state_group", "get_next_state_group",
"get_state_group_delta",
] ]
) )
hs = Mock(spec_set=[ hs = Mock(spec_set=[
@ -154,6 +155,7 @@ class StateTestCase(unittest.TestCase):
hs.get_auth.return_value = Auth(hs) hs.get_auth.return_value = Auth(hs)
self.store.get_next_state_group.side_effect = Mock self.store.get_next_state_group.side_effect = Mock
self.store.get_state_group_delta.return_value = (None, None)
self.state = StateHandler(hs) self.state = StateHandler(hs)
self.event_id = 0 self.event_id = 0

View File

@ -28,7 +28,7 @@ class DictCacheTestCase(unittest.TestCase):
key = "test_simple_cache_hit_full" key = "test_simple_cache_hit_full"
v = self.cache.get(key) v = self.cache.get(key)
self.assertEqual((False, {}), v) self.assertEqual((False, set(), {}), v)
seq = self.cache.sequence seq = self.cache.sequence
test_value = {"test": "test_simple_cache_hit_full"} test_value = {"test": "test_simple_cache_hit_full"}

View File

@ -55,6 +55,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config.password_providers = [] config.password_providers = []
config.worker_replication_url = "" config.worker_replication_url = ""
config.worker_app = None config.worker_app = None
config.email_enable_notifs = False
config.use_frozen_dicts = True config.use_frozen_dicts = True
config.database_config = {"name": "sqlite3"} config.database_config = {"name": "sqlite3"}