mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Merge branch 'release-v0.22.0' of github.com:matrix-org/synapse
This commit is contained in:
commit
42b50483be
56
CHANGES.rst
56
CHANGES.rst
@ -1,3 +1,59 @@
|
|||||||
|
Changes in synapse v0.22.0 (2017-07-06)
|
||||||
|
=======================================
|
||||||
|
|
||||||
|
No changes since v0.22.0-rc2
|
||||||
|
|
||||||
|
|
||||||
|
Changes in synapse v0.22.0-rc2 (2017-07-04)
|
||||||
|
===========================================
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
|
||||||
|
* Improve performance of storing user IPs (PR #2307, #2308)
|
||||||
|
* Slightly improve performance of verifying access tokens (PR #2320)
|
||||||
|
* Slightly improve performance of event persistence (PR #2321)
|
||||||
|
* Increase default cache factor size from 0.1 to 0.5 (PR #2330)
|
||||||
|
|
||||||
|
Bug fixes:
|
||||||
|
|
||||||
|
* Fix bug with storing registration sessions that caused frequent CPU churn
|
||||||
|
(PR #2319)
|
||||||
|
|
||||||
|
|
||||||
|
Changes in synapse v0.22.0-rc1 (2017-06-26)
|
||||||
|
===========================================
|
||||||
|
|
||||||
|
Features:
|
||||||
|
|
||||||
|
* Add a user directory API (PR #2252, and many more)
|
||||||
|
* Add shutdown room API to remove room from local server (PR #2291)
|
||||||
|
* Add API to quarantine media (PR #2292)
|
||||||
|
* Add new config option to not send event contents to push servers (PR #2301)
|
||||||
|
Thanks to @cjdelisle!
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
|
||||||
|
* Various performance fixes (PR #2177, #2233, #2230, #2238, #2248, #2256,
|
||||||
|
#2274)
|
||||||
|
* Deduplicate sync filters (PR #2219) Thanks to @krombel!
|
||||||
|
* Correct a typo in UPGRADE.rst (PR #2231) Thanks to @aaronraimist!
|
||||||
|
* Add count of one time keys to sync stream (PR #2237)
|
||||||
|
* Only store event_auth for state events (PR #2247)
|
||||||
|
* Store URL cache preview downloads separately (PR #2299)
|
||||||
|
|
||||||
|
Bug fixes:
|
||||||
|
|
||||||
|
* Fix users not getting notifications when AS listened to that user_id (PR
|
||||||
|
#2216) Thanks to @slipeer!
|
||||||
|
* Fix users without push set up not getting notifications after joining rooms
|
||||||
|
(PR #2236)
|
||||||
|
* Fix preview url API to trim long descriptions (PR #2243)
|
||||||
|
* Fix bug where we used cached but unpersisted state group as prev group,
|
||||||
|
resulting in broken state of restart (PR #2263)
|
||||||
|
* Fix removing of pushers when using workers (PR #2267)
|
||||||
|
* Fix CORS headers to allow Authorization header (PR #2285) Thanks to @krombel!
|
||||||
|
|
||||||
|
|
||||||
Changes in synapse v0.21.1 (2017-06-15)
|
Changes in synapse v0.21.1 (2017-06-15)
|
||||||
=======================================
|
=======================================
|
||||||
|
|
||||||
|
33
README.rst
33
README.rst
@ -528,6 +528,30 @@ fix try re-installing from PyPI or directly from
|
|||||||
# Install from github
|
# Install from github
|
||||||
pip install --user https://github.com/pyca/pynacl/tarball/master
|
pip install --user https://github.com/pyca/pynacl/tarball/master
|
||||||
|
|
||||||
|
Running out of File Handles
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
If synapse runs out of filehandles, it typically fails badly - live-locking
|
||||||
|
at 100% CPU, and/or failing to accept new TCP connections (blocking the
|
||||||
|
connecting client). Matrix currently can legitimately use a lot of file handles,
|
||||||
|
thanks to busy rooms like #matrix:matrix.org containing hundreds of participating
|
||||||
|
servers. The first time a server talks in a room it will try to connect
|
||||||
|
simultaneously to all participating servers, which could exhaust the available
|
||||||
|
file descriptors between DNS queries & HTTPS sockets, especially if DNS is slow
|
||||||
|
to respond. (We need to improve the routing algorithm used to be better than
|
||||||
|
full mesh, but as of June 2017 this hasn't happened yet).
|
||||||
|
|
||||||
|
If you hit this failure mode, we recommend increasing the maximum number of
|
||||||
|
open file handles to be at least 4096 (assuming a default of 1024 or 256).
|
||||||
|
This is typically done by editing ``/etc/security/limits.conf``
|
||||||
|
|
||||||
|
Separately, Synapse may leak file handles if inbound HTTP requests get stuck
|
||||||
|
during processing - e.g. blocked behind a lock or talking to a remote server etc.
|
||||||
|
This is best diagnosed by matching up the 'Received request' and 'Processed request'
|
||||||
|
log lines and looking for any 'Processed request' lines which take more than
|
||||||
|
a few seconds to execute. Please let us know at #matrix-dev:matrix.org if
|
||||||
|
you see this failure mode so we can help debug it, however.
|
||||||
|
|
||||||
ArchLinux
|
ArchLinux
|
||||||
~~~~~~~~~
|
~~~~~~~~~
|
||||||
|
|
||||||
@ -875,12 +899,9 @@ cache a lot of recent room data and metadata in RAM in order to speed up
|
|||||||
common requests. We'll improve this in future, but for now the easiest
|
common requests. We'll improve this in future, but for now the easiest
|
||||||
way to either reduce the RAM usage (at the risk of slowing things down)
|
way to either reduce the RAM usage (at the risk of slowing things down)
|
||||||
is to set the almost-undocumented ``SYNAPSE_CACHE_FACTOR`` environment
|
is to set the almost-undocumented ``SYNAPSE_CACHE_FACTOR`` environment
|
||||||
variable. Roughly speaking, a SYNAPSE_CACHE_FACTOR of 1.0 will max out
|
variable. The default is 0.5, which can be decreased to reduce RAM usage
|
||||||
at around 3-4GB of resident memory - this is what we currently run the
|
in memory constrained enviroments, or increased if performance starts to
|
||||||
matrix.org on. The default setting is currently 0.1, which is probably
|
degrade.
|
||||||
around a ~700MB footprint. You can dial it down further to 0.02 if
|
|
||||||
desired, which targets roughly ~512MB. Conversely you can dial it up if
|
|
||||||
you need performance for lots of users and have a box with a lot of RAM.
|
|
||||||
|
|
||||||
|
|
||||||
.. _`key_management`: https://matrix.org/docs/spec/server_server/unstable.html#retrieving-server-keys
|
.. _`key_management`: https://matrix.org/docs/spec/server_server/unstable.html#retrieving-server-keys
|
||||||
|
@ -33,7 +33,7 @@ To check whether your update was sucessfull, run:
|
|||||||
|
|
||||||
.. code:: bash
|
.. code:: bash
|
||||||
|
|
||||||
# replace your.server.domain with ther domain of your synaspe homeserver
|
# replace your.server.domain with ther domain of your synapse homeserver
|
||||||
curl https://<your.server.domain>/_matrix/federation/v1/version
|
curl https://<your.server.domain>/_matrix/federation/v1/version
|
||||||
|
|
||||||
So for the Matrix.org HS server the URL would be: https://matrix.org/_matrix/federation/v1/version.
|
So for the Matrix.org HS server the URL would be: https://matrix.org/_matrix/federation/v1/version.
|
||||||
|
@ -41,6 +41,7 @@ BOOLEAN_COLUMNS = {
|
|||||||
"presence_stream": ["currently_active"],
|
"presence_stream": ["currently_active"],
|
||||||
"public_room_list_stream": ["visibility"],
|
"public_room_list_stream": ["visibility"],
|
||||||
"device_lists_outbound_pokes": ["sent"],
|
"device_lists_outbound_pokes": ["sent"],
|
||||||
|
"users_who_share_rooms": ["share_private"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -121,7 +122,7 @@ class Store(object):
|
|||||||
try:
|
try:
|
||||||
txn = conn.cursor()
|
txn = conn.cursor()
|
||||||
return func(
|
return func(
|
||||||
LoggingTransaction(txn, desc, self.database_engine, []),
|
LoggingTransaction(txn, desc, self.database_engine, [], []),
|
||||||
*args, **kwargs
|
*args, **kwargs
|
||||||
)
|
)
|
||||||
except self.database_engine.module.DatabaseError as e:
|
except self.database_engine.module.DatabaseError as e:
|
||||||
|
@ -16,4 +16,4 @@
|
|||||||
""" This is a reference implementation of a Matrix home server.
|
""" This is a reference implementation of a Matrix home server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.21.1"
|
__version__ = "0.22.0"
|
||||||
|
@ -23,7 +23,8 @@ from synapse import event_auth
|
|||||||
from synapse.api.constants import EventTypes, Membership, JoinRules
|
from synapse.api.constants import EventTypes, Membership, JoinRules
|
||||||
from synapse.api.errors import AuthError, Codes
|
from synapse.api.errors import AuthError, Codes
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
from synapse.util import logcontext
|
from synapse.util.caches import register_cache, CACHE_SIZE_FACTOR
|
||||||
|
from synapse.util.caches.lrucache import LruCache
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -39,6 +40,10 @@ AuthEventTypes = (
|
|||||||
GUEST_DEVICE_ID = "guest_device"
|
GUEST_DEVICE_ID = "guest_device"
|
||||||
|
|
||||||
|
|
||||||
|
class _InvalidMacaroonException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Auth(object):
|
class Auth(object):
|
||||||
"""
|
"""
|
||||||
FIXME: This class contains a mix of functions for authenticating users
|
FIXME: This class contains a mix of functions for authenticating users
|
||||||
@ -51,6 +56,9 @@ class Auth(object):
|
|||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
|
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
|
||||||
|
|
||||||
|
self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000)
|
||||||
|
register_cache("token_cache", self.token_cache)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_from_context(self, event, context, do_sig_check=True):
|
def check_from_context(self, event, context, do_sig_check=True):
|
||||||
auth_events_ids = yield self.compute_auth_events(
|
auth_events_ids = yield self.compute_auth_events(
|
||||||
@ -144,17 +152,8 @@ class Auth(object):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_host_in_room(self, room_id, host):
|
def check_host_in_room(self, room_id, host):
|
||||||
with Measure(self.clock, "check_host_in_room"):
|
with Measure(self.clock, "check_host_in_room"):
|
||||||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
latest_event_ids = yield self.store.is_host_joined(room_id, host)
|
||||||
|
defer.returnValue(latest_event_ids)
|
||||||
logger.debug("calling resolve_state_groups from check_host_in_room")
|
|
||||||
entry = yield self.state.resolve_state_groups(
|
|
||||||
room_id, latest_event_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
ret = yield self.store.is_host_joined(
|
|
||||||
room_id, host, entry.state_group, entry.state
|
|
||||||
)
|
|
||||||
defer.returnValue(ret)
|
|
||||||
|
|
||||||
def _check_joined_room(self, member, user_id, room_id):
|
def _check_joined_room(self, member, user_id, room_id):
|
||||||
if not member or member.membership != Membership.JOIN:
|
if not member or member.membership != Membership.JOIN:
|
||||||
@ -209,7 +208,7 @@ class Auth(object):
|
|||||||
default=[""]
|
default=[""]
|
||||||
)[0]
|
)[0]
|
||||||
if user and access_token and ip_addr:
|
if user and access_token and ip_addr:
|
||||||
logcontext.preserve_fn(self.store.insert_client_ip)(
|
self.store.insert_client_ip(
|
||||||
user=user,
|
user=user,
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
ip=ip_addr,
|
ip=ip_addr,
|
||||||
@ -276,8 +275,8 @@ class Auth(object):
|
|||||||
AuthError if no user by that token exists or the token is invalid.
|
AuthError if no user by that token exists or the token is invalid.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(token)
|
user_id, guest = self._parse_and_validate_macaroon(token, rights)
|
||||||
except Exception: # deserialize can throw more-or-less anything
|
except _InvalidMacaroonException:
|
||||||
# doesn't look like a macaroon: treat it as an opaque token which
|
# doesn't look like a macaroon: treat it as an opaque token which
|
||||||
# must be in the database.
|
# must be in the database.
|
||||||
# TODO: it would be nice to get rid of this, but apparently some
|
# TODO: it would be nice to get rid of this, but apparently some
|
||||||
@ -286,19 +285,8 @@ class Auth(object):
|
|||||||
defer.returnValue(r)
|
defer.returnValue(r)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user_id = self.get_user_id_from_macaroon(macaroon)
|
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
|
|
||||||
self.validate_macaroon(
|
|
||||||
macaroon, rights, self.hs.config.expire_access_token,
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
guest = False
|
|
||||||
for caveat in macaroon.caveats:
|
|
||||||
if caveat.caveat_id == "guest = true":
|
|
||||||
guest = True
|
|
||||||
|
|
||||||
if guest:
|
if guest:
|
||||||
# Guest access tokens are not stored in the database (there can
|
# Guest access tokens are not stored in the database (there can
|
||||||
# only be one access token per guest, anyway).
|
# only be one access token per guest, anyway).
|
||||||
@ -370,6 +358,55 @@ class Auth(object):
|
|||||||
errcode=Codes.UNKNOWN_TOKEN
|
errcode=Codes.UNKNOWN_TOKEN
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _parse_and_validate_macaroon(self, token, rights="access"):
|
||||||
|
"""Takes a macaroon and tries to parse and validate it. This is cached
|
||||||
|
if and only if rights == access and there isn't an expiry.
|
||||||
|
|
||||||
|
On invalid macaroon raises _InvalidMacaroonException
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(user_id, is_guest)
|
||||||
|
"""
|
||||||
|
if rights == "access":
|
||||||
|
cached = self.token_cache.get(token, None)
|
||||||
|
if cached:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
try:
|
||||||
|
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||||
|
except Exception: # deserialize can throw more-or-less anything
|
||||||
|
# doesn't look like a macaroon: treat it as an opaque token which
|
||||||
|
# must be in the database.
|
||||||
|
# TODO: it would be nice to get rid of this, but apparently some
|
||||||
|
# people use access tokens which aren't macaroons
|
||||||
|
raise _InvalidMacaroonException()
|
||||||
|
|
||||||
|
try:
|
||||||
|
user_id = self.get_user_id_from_macaroon(macaroon)
|
||||||
|
|
||||||
|
has_expiry = False
|
||||||
|
guest = False
|
||||||
|
for caveat in macaroon.caveats:
|
||||||
|
if caveat.caveat_id.startswith("time "):
|
||||||
|
has_expiry = True
|
||||||
|
elif caveat.caveat_id == "guest = true":
|
||||||
|
guest = True
|
||||||
|
|
||||||
|
self.validate_macaroon(
|
||||||
|
macaroon, rights, self.hs.config.expire_access_token,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
|
||||||
|
raise AuthError(
|
||||||
|
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
|
||||||
|
errcode=Codes.UNKNOWN_TOKEN
|
||||||
|
)
|
||||||
|
|
||||||
|
if not has_expiry and rights == "access":
|
||||||
|
self.token_cache[token] = (user_id, guest)
|
||||||
|
|
||||||
|
return user_id, guest
|
||||||
|
|
||||||
def get_user_id_from_macaroon(self, macaroon):
|
def get_user_id_from_macaroon(self, macaroon):
|
||||||
"""Retrieve the user_id given by the caveats on the macaroon.
|
"""Retrieve the user_id given by the caveats on the macaroon.
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ from synapse.http.server import JsonResource
|
|||||||
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||||
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
||||||
|
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
|
||||||
from synapse.replication.slave.storage.events import SlavedEventStore
|
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||||
from synapse.replication.slave.storage.keys import SlavedKeyStore
|
from synapse.replication.slave.storage.keys import SlavedKeyStore
|
||||||
from synapse.replication.slave.storage.room import RoomStore
|
from synapse.replication.slave.storage.room import RoomStore
|
||||||
@ -33,7 +34,6 @@ from synapse.replication.slave.storage.transactions import TransactionStore
|
|||||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||||
from synapse.rest.client.v1.room import PublicRoomListRestServlet
|
from synapse.rest.client.v1.room import PublicRoomListRestServlet
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.client_ips import ClientIpStore
|
|
||||||
from synapse.storage.engines import create_engine
|
from synapse.storage.engines import create_engine
|
||||||
from synapse.util.httpresourcetree import create_resource_tree
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
||||||
@ -65,8 +65,8 @@ class ClientReaderSlavedStore(
|
|||||||
SlavedApplicationServiceStore,
|
SlavedApplicationServiceStore,
|
||||||
SlavedRegistrationStore,
|
SlavedRegistrationStore,
|
||||||
TransactionStore,
|
TransactionStore,
|
||||||
|
SlavedClientIpStore,
|
||||||
BaseSlavedStore,
|
BaseSlavedStore,
|
||||||
ClientIpStore, # After BaseSlavedStore because the constructor is different
|
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -51,7 +51,7 @@ import sys
|
|||||||
import logging
|
import logging
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
logger = logging.getLogger("synapse.app.appservice")
|
logger = logging.getLogger("synapse.app.federation_sender")
|
||||||
|
|
||||||
|
|
||||||
class FederationSenderSlaveStore(
|
class FederationSenderSlaveStore(
|
||||||
|
@ -23,13 +23,13 @@ from synapse.http.site import SynapseSite
|
|||||||
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||||
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
||||||
|
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
|
||||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||||
from synapse.replication.slave.storage.transactions import TransactionStore
|
from synapse.replication.slave.storage.transactions import TransactionStore
|
||||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||||
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
||||||
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
|
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.client_ips import ClientIpStore
|
|
||||||
from synapse.storage.engines import create_engine
|
from synapse.storage.engines import create_engine
|
||||||
from synapse.storage.media_repository import MediaRepositoryStore
|
from synapse.storage.media_repository import MediaRepositoryStore
|
||||||
from synapse.util.httpresourcetree import create_resource_tree
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
@ -60,10 +60,10 @@ logger = logging.getLogger("synapse.app.media_repository")
|
|||||||
class MediaRepositorySlavedStore(
|
class MediaRepositorySlavedStore(
|
||||||
SlavedApplicationServiceStore,
|
SlavedApplicationServiceStore,
|
||||||
SlavedRegistrationStore,
|
SlavedRegistrationStore,
|
||||||
|
SlavedClientIpStore,
|
||||||
TransactionStore,
|
TransactionStore,
|
||||||
BaseSlavedStore,
|
BaseSlavedStore,
|
||||||
MediaRepositoryStore,
|
MediaRepositoryStore,
|
||||||
ClientIpStore,
|
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -29,6 +29,7 @@ from synapse.rest.client.v1 import events
|
|||||||
from synapse.rest.client.v1.room import RoomInitialSyncRestServlet
|
from synapse.rest.client.v1.room import RoomInitialSyncRestServlet
|
||||||
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
|
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
|
||||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||||
|
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
|
||||||
from synapse.replication.slave.storage.events import SlavedEventStore
|
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||||
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||||
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
|
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
|
||||||
@ -42,7 +43,6 @@ from synapse.replication.slave.storage.devices import SlavedDeviceStore
|
|||||||
from synapse.replication.slave.storage.room import RoomStore
|
from synapse.replication.slave.storage.room import RoomStore
|
||||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.client_ips import ClientIpStore
|
|
||||||
from synapse.storage.engines import create_engine
|
from synapse.storage.engines import create_engine
|
||||||
from synapse.storage.presence import UserPresenceState
|
from synapse.storage.presence import UserPresenceState
|
||||||
from synapse.storage.roommember import RoomMemberStore
|
from synapse.storage.roommember import RoomMemberStore
|
||||||
@ -77,9 +77,9 @@ class SynchrotronSlavedStore(
|
|||||||
SlavedPresenceStore,
|
SlavedPresenceStore,
|
||||||
SlavedDeviceInboxStore,
|
SlavedDeviceInboxStore,
|
||||||
SlavedDeviceStore,
|
SlavedDeviceStore,
|
||||||
|
SlavedClientIpStore,
|
||||||
RoomStore,
|
RoomStore,
|
||||||
BaseSlavedStore,
|
BaseSlavedStore,
|
||||||
ClientIpStore, # After BaseSlavedStore because the constructor is different
|
|
||||||
):
|
):
|
||||||
who_forgot_in_room = (
|
who_forgot_in_room = (
|
||||||
RoomMemberStore.__dict__["who_forgot_in_room"]
|
RoomMemberStore.__dict__["who_forgot_in_room"]
|
||||||
|
270
synapse/app/user_dir.py
Normal file
270
synapse/app/user_dir.py
Normal file
@ -0,0 +1,270 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2017 Vector Creations Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import synapse
|
||||||
|
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.config._base import ConfigError
|
||||||
|
from synapse.config.logger import setup_logging
|
||||||
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
|
from synapse.crypto import context_factory
|
||||||
|
from synapse.http.site import SynapseSite
|
||||||
|
from synapse.http.server import JsonResource
|
||||||
|
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||||
|
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||||
|
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
||||||
|
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
|
||||||
|
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||||
|
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||||
|
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||||
|
from synapse.rest.client.v2_alpha import user_directory
|
||||||
|
from synapse.storage.engines import create_engine
|
||||||
|
from synapse.storage.user_directory import UserDirectoryStore
|
||||||
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
|
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext, preserve_fn
|
||||||
|
from synapse.util.manhole import manhole
|
||||||
|
from synapse.util.rlimit import change_resource_limit
|
||||||
|
from synapse.util.versionstring import get_version_string
|
||||||
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
|
|
||||||
|
from synapse import events
|
||||||
|
|
||||||
|
from twisted.internet import reactor
|
||||||
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
|
from daemonize import Daemonize
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import gc
|
||||||
|
|
||||||
|
logger = logging.getLogger("synapse.app.user_dir")
|
||||||
|
|
||||||
|
|
||||||
|
class UserDirectorySlaveStore(
|
||||||
|
SlavedEventStore,
|
||||||
|
SlavedApplicationServiceStore,
|
||||||
|
SlavedRegistrationStore,
|
||||||
|
SlavedClientIpStore,
|
||||||
|
UserDirectoryStore,
|
||||||
|
BaseSlavedStore,
|
||||||
|
):
|
||||||
|
def __init__(self, db_conn, hs):
|
||||||
|
super(UserDirectorySlaveStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
|
events_max = self._stream_id_gen.get_current_token()
|
||||||
|
curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
|
||||||
|
db_conn, "current_state_delta_stream",
|
||||||
|
entity_column="room_id",
|
||||||
|
stream_column="stream_id",
|
||||||
|
max_value=events_max, # As we share the stream id with events token
|
||||||
|
limit=1000,
|
||||||
|
)
|
||||||
|
self._curr_state_delta_stream_cache = StreamChangeCache(
|
||||||
|
"_curr_state_delta_stream_cache", min_curr_state_delta_id,
|
||||||
|
prefilled_cache=curr_state_delta_prefill,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._current_state_delta_pos = events_max
|
||||||
|
|
||||||
|
def stream_positions(self):
|
||||||
|
result = super(UserDirectorySlaveStore, self).stream_positions()
|
||||||
|
result["current_state_deltas"] = self._current_state_delta_pos
|
||||||
|
return result
|
||||||
|
|
||||||
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
|
if stream_name == "current_state_deltas":
|
||||||
|
self._current_state_delta_pos = token
|
||||||
|
for row in rows:
|
||||||
|
self._curr_state_delta_stream_cache.entity_has_changed(
|
||||||
|
row.room_id, token
|
||||||
|
)
|
||||||
|
return super(UserDirectorySlaveStore, self).process_replication_rows(
|
||||||
|
stream_name, token, rows
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UserDirectoryServer(HomeServer):
|
||||||
|
def get_db_conn(self, run_new_connection=True):
|
||||||
|
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||||
|
# not be passed to the database engine.
|
||||||
|
db_params = {
|
||||||
|
k: v for k, v in self.db_config.get("args", {}).items()
|
||||||
|
if not k.startswith("cp_")
|
||||||
|
}
|
||||||
|
db_conn = self.database_engine.module.connect(**db_params)
|
||||||
|
|
||||||
|
if run_new_connection:
|
||||||
|
self.database_engine.on_new_connection(db_conn)
|
||||||
|
return db_conn
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
logger.info("Setting up.")
|
||||||
|
self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self)
|
||||||
|
logger.info("Finished setting up.")
|
||||||
|
|
||||||
|
def _listen_http(self, listener_config):
|
||||||
|
port = listener_config["port"]
|
||||||
|
bind_addresses = listener_config["bind_addresses"]
|
||||||
|
site_tag = listener_config.get("tag", port)
|
||||||
|
resources = {}
|
||||||
|
for res in listener_config["resources"]:
|
||||||
|
for name in res["names"]:
|
||||||
|
if name == "metrics":
|
||||||
|
resources[METRICS_PREFIX] = MetricsResource(self)
|
||||||
|
elif name == "client":
|
||||||
|
resource = JsonResource(self, canonical_json=False)
|
||||||
|
user_directory.register_servlets(self, resource)
|
||||||
|
resources.update({
|
||||||
|
"/_matrix/client/r0": resource,
|
||||||
|
"/_matrix/client/unstable": resource,
|
||||||
|
"/_matrix/client/v2_alpha": resource,
|
||||||
|
"/_matrix/client/api/v1": resource,
|
||||||
|
})
|
||||||
|
|
||||||
|
root_resource = create_resource_tree(resources, Resource())
|
||||||
|
|
||||||
|
for address in bind_addresses:
|
||||||
|
reactor.listenTCP(
|
||||||
|
port,
|
||||||
|
SynapseSite(
|
||||||
|
"synapse.access.http.%s" % (site_tag,),
|
||||||
|
site_tag,
|
||||||
|
listener_config,
|
||||||
|
root_resource,
|
||||||
|
),
|
||||||
|
interface=address
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Synapse user_dir now listening on port %d", port)
|
||||||
|
|
||||||
|
def start_listening(self, listeners):
|
||||||
|
for listener in listeners:
|
||||||
|
if listener["type"] == "http":
|
||||||
|
self._listen_http(listener)
|
||||||
|
elif listener["type"] == "manhole":
|
||||||
|
bind_addresses = listener["bind_addresses"]
|
||||||
|
|
||||||
|
for address in bind_addresses:
|
||||||
|
reactor.listenTCP(
|
||||||
|
listener["port"],
|
||||||
|
manhole(
|
||||||
|
username="matrix",
|
||||||
|
password="rabbithole",
|
||||||
|
globals={"hs": self},
|
||||||
|
),
|
||||||
|
interface=address
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||||
|
|
||||||
|
self.get_tcp_replication().start_replication(self)
|
||||||
|
|
||||||
|
def build_tcp_replication(self):
|
||||||
|
return UserDirectoryReplicationHandler(self)
|
||||||
|
|
||||||
|
|
||||||
|
class UserDirectoryReplicationHandler(ReplicationClientHandler):
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore())
|
||||||
|
self.user_directory = hs.get_user_directory_handler()
|
||||||
|
|
||||||
|
def on_rdata(self, stream_name, token, rows):
|
||||||
|
super(UserDirectoryReplicationHandler, self).on_rdata(
|
||||||
|
stream_name, token, rows
|
||||||
|
)
|
||||||
|
if stream_name == "current_state_deltas":
|
||||||
|
preserve_fn(self.user_directory.notify_new_event)()
|
||||||
|
|
||||||
|
|
||||||
|
def start(config_options):
|
||||||
|
try:
|
||||||
|
config = HomeServerConfig.load_config(
|
||||||
|
"Synapse user directory", config_options
|
||||||
|
)
|
||||||
|
except ConfigError as e:
|
||||||
|
sys.stderr.write("\n" + e.message + "\n")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
assert config.worker_app == "synapse.app.user_dir"
|
||||||
|
|
||||||
|
setup_logging(config, use_worker_options=True)
|
||||||
|
|
||||||
|
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||||
|
|
||||||
|
database_engine = create_engine(config.database_config)
|
||||||
|
|
||||||
|
if config.update_user_directory:
|
||||||
|
sys.stderr.write(
|
||||||
|
"\nThe update_user_directory must be disabled in the main synapse process"
|
||||||
|
"\nbefore they can be run in a separate worker."
|
||||||
|
"\nPlease add ``update_user_directory: false`` to the main config"
|
||||||
|
"\n"
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Force the pushers to start since they will be disabled in the main config
|
||||||
|
config.update_user_directory = True
|
||||||
|
|
||||||
|
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||||
|
|
||||||
|
ps = UserDirectoryServer(
|
||||||
|
config.server_name,
|
||||||
|
db_config=config.database_config,
|
||||||
|
tls_server_context_factory=tls_server_context_factory,
|
||||||
|
config=config,
|
||||||
|
version_string="Synapse/" + get_version_string(synapse),
|
||||||
|
database_engine=database_engine,
|
||||||
|
)
|
||||||
|
|
||||||
|
ps.setup()
|
||||||
|
ps.start_listening(config.worker_listeners)
|
||||||
|
|
||||||
|
def run():
|
||||||
|
# make sure that we run the reactor with the sentinel log context,
|
||||||
|
# otherwise other PreserveLoggingContext instances will get confused
|
||||||
|
# and complain when they see the logcontext arbitrarily swapping
|
||||||
|
# between the sentinel and `run` logcontexts.
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
logger.info("Running")
|
||||||
|
change_resource_limit(config.soft_file_limit)
|
||||||
|
if config.gc_thresholds:
|
||||||
|
gc.set_threshold(*config.gc_thresholds)
|
||||||
|
reactor.run()
|
||||||
|
|
||||||
|
def start():
|
||||||
|
ps.get_datastore().start_profiling()
|
||||||
|
ps.get_state_handler().start_caching()
|
||||||
|
|
||||||
|
reactor.callWhenRunning(start)
|
||||||
|
|
||||||
|
if config.worker_daemonize:
|
||||||
|
daemon = Daemonize(
|
||||||
|
app="synapse-user-dir",
|
||||||
|
pid=config.worker_pid_file,
|
||||||
|
action=run,
|
||||||
|
auto_close_fds=False,
|
||||||
|
verbose=True,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
daemon.start()
|
||||||
|
else:
|
||||||
|
run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
with LoggingContext("main"):
|
||||||
|
start(sys.argv[1:])
|
@ -241,6 +241,16 @@ class ApplicationService(object):
|
|||||||
def is_exclusive_room(self, room_id):
|
def is_exclusive_room(self, room_id):
|
||||||
return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
|
return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
|
||||||
|
|
||||||
|
def get_exlusive_user_regexes(self):
|
||||||
|
"""Get the list of regexes used to determine if a user is exclusively
|
||||||
|
registered by the AS
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
regex_obj["regex"]
|
||||||
|
for regex_obj in self.namespaces[ApplicationService.NS_USERS]
|
||||||
|
if regex_obj["exclusive"]
|
||||||
|
]
|
||||||
|
|
||||||
def is_rate_limited(self):
|
def is_rate_limited(self):
|
||||||
return self.rate_limited
|
return self.rate_limited
|
||||||
|
|
||||||
|
@ -33,6 +33,7 @@ from .jwt import JWTConfig
|
|||||||
from .password_auth_providers import PasswordAuthProviderConfig
|
from .password_auth_providers import PasswordAuthProviderConfig
|
||||||
from .emailconfig import EmailConfig
|
from .emailconfig import EmailConfig
|
||||||
from .workers import WorkerConfig
|
from .workers import WorkerConfig
|
||||||
|
from .push import PushConfig
|
||||||
|
|
||||||
|
|
||||||
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
|
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
|
||||||
@ -40,7 +41,7 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
|
|||||||
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
|
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
|
||||||
AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
|
AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
|
||||||
JWTConfig, PasswordConfig, EmailConfig,
|
JWTConfig, PasswordConfig, EmailConfig,
|
||||||
WorkerConfig, PasswordAuthProviderConfig,):
|
WorkerConfig, PasswordAuthProviderConfig, PushConfig,):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
45
synapse/config/push.py
Normal file
45
synapse/config/push.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2015, 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ._base import Config
|
||||||
|
|
||||||
|
|
||||||
|
class PushConfig(Config):
|
||||||
|
def read_config(self, config):
|
||||||
|
self.push_redact_content = False
|
||||||
|
|
||||||
|
push_config = config.get("email", {})
|
||||||
|
self.push_redact_content = push_config.get("redact_content", False)
|
||||||
|
|
||||||
|
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||||
|
return """
|
||||||
|
# Control how push messages are sent to google/apple to notifications.
|
||||||
|
# Normally every message said in a room with one or more people using
|
||||||
|
# mobile devices will be posted to a push server hosted by matrix.org
|
||||||
|
# which is registered with google and apple in order to allow push
|
||||||
|
# notifications to be sent to these mobile devices.
|
||||||
|
#
|
||||||
|
# Setting redact_content to true will make the push messages contain no
|
||||||
|
# message content which will provide increased privacy. This is a
|
||||||
|
# temporary solution pending improvements to Android and iPhone apps
|
||||||
|
# to get content from the app rather than the notification.
|
||||||
|
#
|
||||||
|
# For modern android devices the notification content will still appear
|
||||||
|
# because it is loaded by the app. iPhone, however will send a
|
||||||
|
# notification saying only that a message arrived and who it came from.
|
||||||
|
#
|
||||||
|
#push:
|
||||||
|
# redact_content: false
|
||||||
|
"""
|
@ -35,6 +35,10 @@ class ServerConfig(Config):
|
|||||||
# "disable" federation
|
# "disable" federation
|
||||||
self.send_federation = config.get("send_federation", True)
|
self.send_federation = config.get("send_federation", True)
|
||||||
|
|
||||||
|
# Whether to update the user directory or not. This should be set to
|
||||||
|
# false only if we are updating the user directory in a worker
|
||||||
|
self.update_user_directory = config.get("update_user_directory", True)
|
||||||
|
|
||||||
self.filter_timeline_limit = config.get("filter_timeline_limit", -1)
|
self.filter_timeline_limit = config.get("filter_timeline_limit", -1)
|
||||||
|
|
||||||
if self.public_baseurl is not None:
|
if self.public_baseurl is not None:
|
||||||
|
@ -187,6 +187,7 @@ class TransactionQueue(object):
|
|||||||
prev_id for prev_id, _ in event.prev_events
|
prev_id for prev_id, _ in event.prev_events
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
destinations = set(destinations)
|
||||||
|
|
||||||
if send_on_behalf_of is not None:
|
if send_on_behalf_of is not None:
|
||||||
# If we are sending the event on behalf of another server
|
# If we are sending the event on behalf of another server
|
||||||
|
@ -21,6 +21,7 @@ from synapse.api.constants import LoginType
|
|||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
|
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
|
|
||||||
from twisted.web.client import PartialDownloadError
|
from twisted.web.client import PartialDownloadError
|
||||||
|
|
||||||
@ -52,7 +53,15 @@ class AuthHandler(BaseHandler):
|
|||||||
LoginType.DUMMY: self._check_dummy_auth,
|
LoginType.DUMMY: self._check_dummy_auth,
|
||||||
}
|
}
|
||||||
self.bcrypt_rounds = hs.config.bcrypt_rounds
|
self.bcrypt_rounds = hs.config.bcrypt_rounds
|
||||||
self.sessions = {}
|
|
||||||
|
# This is not a cache per se, but a store of all current sessions that
|
||||||
|
# expire after N hours
|
||||||
|
self.sessions = ExpiringCache(
|
||||||
|
cache_name="register_sessions",
|
||||||
|
clock=hs.get_clock(),
|
||||||
|
expiry_ms=self.SESSION_EXPIRE_MS,
|
||||||
|
reset_expiry_on_get=True,
|
||||||
|
)
|
||||||
|
|
||||||
account_handler = _AccountHandler(
|
account_handler = _AccountHandler(
|
||||||
hs, check_user_exists=self.check_user_exists
|
hs, check_user_exists=self.check_user_exists
|
||||||
@ -617,16 +626,6 @@ class AuthHandler(BaseHandler):
|
|||||||
logger.debug("Saving session %s", session)
|
logger.debug("Saving session %s", session)
|
||||||
session["last_used"] = self.hs.get_clock().time_msec()
|
session["last_used"] = self.hs.get_clock().time_msec()
|
||||||
self.sessions[session["id"]] = session
|
self.sessions[session["id"]] = session
|
||||||
self._prune_sessions()
|
|
||||||
|
|
||||||
def _prune_sessions(self):
|
|
||||||
for sid, sess in self.sessions.items():
|
|
||||||
last_used = 0
|
|
||||||
if 'last_used' in sess:
|
|
||||||
last_used = sess['last_used']
|
|
||||||
now = self.hs.get_clock().time_msec()
|
|
||||||
if last_used < now - AuthHandler.SESSION_EXPIRE_MS:
|
|
||||||
del self.sessions[sid]
|
|
||||||
|
|
||||||
def hash(self, password):
|
def hash(self, password):
|
||||||
"""Computes a secure hash of password.
|
"""Computes a secure hash of password.
|
||||||
|
@ -106,7 +106,7 @@ class DeviceHandler(BaseHandler):
|
|||||||
device_map = yield self.store.get_devices_by_user(user_id)
|
device_map = yield self.store.get_devices_by_user(user_id)
|
||||||
|
|
||||||
ips = yield self.store.get_last_client_ip_by_device(
|
ips = yield self.store.get_last_client_ip_by_device(
|
||||||
devices=((user_id, device_id) for device_id in device_map.keys())
|
user_id, device_id=None
|
||||||
)
|
)
|
||||||
|
|
||||||
devices = device_map.values()
|
devices = device_map.values()
|
||||||
@ -133,7 +133,7 @@ class DeviceHandler(BaseHandler):
|
|||||||
except errors.StoreError:
|
except errors.StoreError:
|
||||||
raise errors.NotFoundError
|
raise errors.NotFoundError
|
||||||
ips = yield self.store.get_last_client_ip_by_device(
|
ips = yield self.store.get_last_client_ip_by_device(
|
||||||
devices=((user_id, device_id),)
|
user_id, device_id,
|
||||||
)
|
)
|
||||||
_update_device_from_client_ips(device, ips)
|
_update_device_from_client_ips(device, ips)
|
||||||
defer.returnValue(device)
|
defer.returnValue(device)
|
||||||
|
@ -43,7 +43,6 @@ from synapse.events.utils import prune_event
|
|||||||
|
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
|
||||||
from synapse.push.action_generator import ActionGenerator
|
|
||||||
from synapse.util.distributor import user_joined_room
|
from synapse.util.distributor import user_joined_room
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
@ -75,6 +74,8 @@ class FederationHandler(BaseHandler):
|
|||||||
self.state_handler = hs.get_state_handler()
|
self.state_handler = hs.get_state_handler()
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
self.keyring = hs.get_keyring()
|
self.keyring = hs.get_keyring()
|
||||||
|
self.action_generator = hs.get_action_generator()
|
||||||
|
self.is_mine_id = hs.is_mine_id
|
||||||
|
|
||||||
self.replication_layer.set_handler(self)
|
self.replication_layer.set_handler(self)
|
||||||
|
|
||||||
@ -832,7 +833,11 @@ class FederationHandler(BaseHandler):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_event_auth(self, event_id):
|
def on_event_auth(self, event_id):
|
||||||
auth = yield self.store.get_auth_chain([event_id])
|
event = yield self.store.get_event(event_id)
|
||||||
|
auth = yield self.store.get_auth_chain(
|
||||||
|
[auth_id for auth_id, _ in event.auth_events],
|
||||||
|
include_given=True
|
||||||
|
)
|
||||||
|
|
||||||
for event in auth:
|
for event in auth:
|
||||||
event.signatures.update(
|
event.signatures.update(
|
||||||
@ -1047,9 +1052,7 @@ class FederationHandler(BaseHandler):
|
|||||||
yield user_joined_room(self.distributor, user, event.room_id)
|
yield user_joined_room(self.distributor, user, event.room_id)
|
||||||
|
|
||||||
state_ids = context.prev_state_ids.values()
|
state_ids = context.prev_state_ids.values()
|
||||||
auth_chain = yield self.store.get_auth_chain(set(
|
auth_chain = yield self.store.get_auth_chain(state_ids)
|
||||||
[event.event_id] + state_ids
|
|
||||||
))
|
|
||||||
|
|
||||||
state = yield self.store.get_events(context.prev_state_ids.values())
|
state = yield self.store.get_events(context.prev_state_ids.values())
|
||||||
|
|
||||||
@ -1066,6 +1069,24 @@ class FederationHandler(BaseHandler):
|
|||||||
"""
|
"""
|
||||||
event = pdu
|
event = pdu
|
||||||
|
|
||||||
|
is_blocked = yield self.store.is_room_blocked(event.room_id)
|
||||||
|
if is_blocked:
|
||||||
|
raise SynapseError(403, "This room has been blocked on this server")
|
||||||
|
|
||||||
|
membership = event.content.get("membership")
|
||||||
|
if event.type != EventTypes.Member or membership != Membership.INVITE:
|
||||||
|
raise SynapseError(400, "The event was not an m.room.member invite event")
|
||||||
|
|
||||||
|
sender_domain = get_domain_from_id(event.sender)
|
||||||
|
if sender_domain != origin:
|
||||||
|
raise SynapseError(400, "The invite event was not from the server sending it")
|
||||||
|
|
||||||
|
if event.state_key is None:
|
||||||
|
raise SynapseError(400, "The invite event did not have a state key")
|
||||||
|
|
||||||
|
if not self.is_mine_id(event.state_key):
|
||||||
|
raise SynapseError(400, "The invite event must be for this server")
|
||||||
|
|
||||||
event.internal_metadata.outlier = True
|
event.internal_metadata.outlier = True
|
||||||
event.internal_metadata.invite_from_remote = True
|
event.internal_metadata.invite_from_remote = True
|
||||||
|
|
||||||
@ -1100,6 +1121,9 @@ class FederationHandler(BaseHandler):
|
|||||||
user_id,
|
user_id,
|
||||||
"leave"
|
"leave"
|
||||||
)
|
)
|
||||||
|
# Mark as outlier as we don't have any state for this event; we're not
|
||||||
|
# even in the room.
|
||||||
|
event.internal_metadata.outlier = True
|
||||||
event = self._sign_event(event)
|
event = self._sign_event(event)
|
||||||
|
|
||||||
# Try the host that we succesfully called /make_leave/ on first for
|
# Try the host that we succesfully called /make_leave/ on first for
|
||||||
@ -1271,7 +1295,7 @@ class FederationHandler(BaseHandler):
|
|||||||
for event in res:
|
for event in res:
|
||||||
# We sign these again because there was a bug where we
|
# We sign these again because there was a bug where we
|
||||||
# incorrectly signed things the first time round
|
# incorrectly signed things the first time round
|
||||||
if self.hs.is_mine_id(event.event_id):
|
if self.is_mine_id(event.event_id):
|
||||||
event.signatures.update(
|
event.signatures.update(
|
||||||
compute_event_signature(
|
compute_event_signature(
|
||||||
event,
|
event,
|
||||||
@ -1344,7 +1368,7 @@ class FederationHandler(BaseHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if event:
|
if event:
|
||||||
if self.hs.is_mine_id(event.event_id):
|
if self.is_mine_id(event.event_id):
|
||||||
# FIXME: This is a temporary work around where we occasionally
|
# FIXME: This is a temporary work around where we occasionally
|
||||||
# return events slightly differently than when they were
|
# return events slightly differently than when they were
|
||||||
# originally signed
|
# originally signed
|
||||||
@ -1389,8 +1413,7 @@ class FederationHandler(BaseHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not event.internal_metadata.is_outlier():
|
if not event.internal_metadata.is_outlier():
|
||||||
action_generator = ActionGenerator(self.hs)
|
yield self.action_generator.handle_push_actions_for_event(
|
||||||
yield action_generator.handle_push_actions_for_event(
|
|
||||||
event, context
|
event, context
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1599,7 +1622,11 @@ class FederationHandler(BaseHandler):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# Now get the current auth_chain for the event.
|
# Now get the current auth_chain for the event.
|
||||||
local_auth_chain = yield self.store.get_auth_chain([event_id])
|
event = yield self.store.get_event(event_id)
|
||||||
|
local_auth_chain = yield self.store.get_auth_chain(
|
||||||
|
[auth_id for auth_id, _ in event.auth_events],
|
||||||
|
include_given=True
|
||||||
|
)
|
||||||
|
|
||||||
# TODO: Check if we would now reject event_id. If so we need to tell
|
# TODO: Check if we would now reject event_id. If so we need to tell
|
||||||
# everyone.
|
# everyone.
|
||||||
@ -1792,7 +1819,9 @@ class FederationHandler(BaseHandler):
|
|||||||
auth_ids = yield self.auth.compute_auth_events(
|
auth_ids = yield self.auth.compute_auth_events(
|
||||||
event, context.prev_state_ids
|
event, context.prev_state_ids
|
||||||
)
|
)
|
||||||
local_auth_chain = yield self.store.get_auth_chain(auth_ids)
|
local_auth_chain = yield self.store.get_auth_chain(
|
||||||
|
auth_ids, include_given=True
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 2. Get remote difference.
|
# 2. Get remote difference.
|
||||||
|
@ -20,7 +20,6 @@ from synapse.api.errors import AuthError, Codes, SynapseError
|
|||||||
from synapse.crypto.event_signing import add_hashes_and_signatures
|
from synapse.crypto.event_signing import add_hashes_and_signatures
|
||||||
from synapse.events.utils import serialize_event
|
from synapse.events.utils import serialize_event
|
||||||
from synapse.events.validator import EventValidator
|
from synapse.events.validator import EventValidator
|
||||||
from synapse.push.action_generator import ActionGenerator
|
|
||||||
from synapse.types import (
|
from synapse.types import (
|
||||||
UserID, RoomAlias, RoomStreamToken,
|
UserID, RoomAlias, RoomStreamToken,
|
||||||
)
|
)
|
||||||
@ -35,6 +34,7 @@ from canonicaljson import encode_canonical_json
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
import ujson
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -54,6 +54,8 @@ class MessageHandler(BaseHandler):
|
|||||||
# This is to stop us from diverging history *too* much.
|
# This is to stop us from diverging history *too* much.
|
||||||
self.limiter = Limiter(max_count=5)
|
self.limiter = Limiter(max_count=5)
|
||||||
|
|
||||||
|
self.action_generator = hs.get_action_generator()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def purge_history(self, room_id, event_id):
|
def purge_history(self, room_id, event_id):
|
||||||
event = yield self.store.get_event(event_id)
|
event = yield self.store.get_event(event_id)
|
||||||
@ -497,6 +499,14 @@ class MessageHandler(BaseHandler):
|
|||||||
logger.warn("Denying new event %r because %s", event, err)
|
logger.warn("Denying new event %r because %s", event, err)
|
||||||
raise err
|
raise err
|
||||||
|
|
||||||
|
# Ensure that we can round trip before trying to persist in db
|
||||||
|
try:
|
||||||
|
dump = ujson.dumps(event.content)
|
||||||
|
ujson.loads(dump)
|
||||||
|
except:
|
||||||
|
logger.exception("Failed to encode content: %r", event.content)
|
||||||
|
raise
|
||||||
|
|
||||||
yield self.maybe_kick_guest_users(event, context)
|
yield self.maybe_kick_guest_users(event, context)
|
||||||
|
|
||||||
if event.type == EventTypes.CanonicalAlias:
|
if event.type == EventTypes.CanonicalAlias:
|
||||||
@ -590,8 +600,7 @@ class MessageHandler(BaseHandler):
|
|||||||
"Changing the room create event is forbidden",
|
"Changing the room create event is forbidden",
|
||||||
)
|
)
|
||||||
|
|
||||||
action_generator = ActionGenerator(self.hs)
|
yield self.action_generator.handle_push_actions_for_event(
|
||||||
yield action_generator.handle_push_actions_for_event(
|
|
||||||
event, context
|
event, context
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -61,7 +61,7 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
}
|
}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def create_room(self, requester, config):
|
def create_room(self, requester, config, ratelimit=True):
|
||||||
""" Creates a new room.
|
""" Creates a new room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -75,6 +75,7 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
"""
|
"""
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
if ratelimit:
|
||||||
yield self.ratelimit(requester)
|
yield self.ratelimit(requester)
|
||||||
|
|
||||||
if "room_alias_name" in config:
|
if "room_alias_name" in config:
|
||||||
@ -167,6 +168,7 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
initial_state=initial_state,
|
initial_state=initial_state,
|
||||||
creation_content=creation_content,
|
creation_content=creation_content,
|
||||||
room_alias=room_alias,
|
room_alias=room_alias,
|
||||||
|
power_level_content_override=config.get("power_level_content_override", {})
|
||||||
)
|
)
|
||||||
|
|
||||||
if "name" in config:
|
if "name" in config:
|
||||||
@ -245,7 +247,8 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
invite_list,
|
invite_list,
|
||||||
initial_state,
|
initial_state,
|
||||||
creation_content,
|
creation_content,
|
||||||
room_alias
|
room_alias,
|
||||||
|
power_level_content_override,
|
||||||
):
|
):
|
||||||
def create(etype, content, **kwargs):
|
def create(etype, content, **kwargs):
|
||||||
e = {
|
e = {
|
||||||
@ -291,7 +294,15 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
ratelimit=False,
|
ratelimit=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (EventTypes.PowerLevels, '') not in initial_state:
|
# We treat the power levels override specially as this needs to be one
|
||||||
|
# of the first events that get sent into a room.
|
||||||
|
pl_content = initial_state.pop((EventTypes.PowerLevels, ''), None)
|
||||||
|
if pl_content is not None:
|
||||||
|
yield send(
|
||||||
|
etype=EventTypes.PowerLevels,
|
||||||
|
content=pl_content,
|
||||||
|
)
|
||||||
|
else:
|
||||||
power_level_content = {
|
power_level_content = {
|
||||||
"users": {
|
"users": {
|
||||||
creator_id: 100,
|
creator_id: 100,
|
||||||
@ -316,6 +327,8 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
for invitee in invite_list:
|
for invitee in invite_list:
|
||||||
power_level_content["users"][invitee] = 100
|
power_level_content["users"][invitee] = 100
|
||||||
|
|
||||||
|
power_level_content.update(power_level_content_override)
|
||||||
|
|
||||||
yield send(
|
yield send(
|
||||||
etype=EventTypes.PowerLevels,
|
etype=EventTypes.PowerLevels,
|
||||||
content=power_level_content,
|
content=power_level_content,
|
||||||
|
@ -203,6 +203,11 @@ class RoomMemberHandler(BaseHandler):
|
|||||||
if not remote_room_hosts:
|
if not remote_room_hosts:
|
||||||
remote_room_hosts = []
|
remote_room_hosts = []
|
||||||
|
|
||||||
|
if effective_membership_state not in ("leave", "ban",):
|
||||||
|
is_blocked = yield self.store.is_room_blocked(room_id)
|
||||||
|
if is_blocked:
|
||||||
|
raise SynapseError(403, "This room has been blocked on this server")
|
||||||
|
|
||||||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||||
current_state_ids = yield self.state_handler.get_current_state_ids(
|
current_state_ids = yield self.state_handler.get_current_state_ids(
|
||||||
room_id, latest_event_ids=latest_event_ids,
|
room_id, latest_event_ids=latest_event_ids,
|
||||||
@ -369,6 +374,11 @@ class RoomMemberHandler(BaseHandler):
|
|||||||
# so don't really fit into the general auth process.
|
# so don't really fit into the general auth process.
|
||||||
raise AuthError(403, "Guest access not allowed")
|
raise AuthError(403, "Guest access not allowed")
|
||||||
|
|
||||||
|
if event.membership not in (Membership.LEAVE, Membership.BAN):
|
||||||
|
is_blocked = yield self.store.is_room_blocked(room_id)
|
||||||
|
if is_blocked:
|
||||||
|
raise SynapseError(403, "This room has been blocked on this server")
|
||||||
|
|
||||||
yield message_handler.handle_new_client_event(
|
yield message_handler.handle_new_client_event(
|
||||||
requester,
|
requester,
|
||||||
event,
|
event,
|
||||||
|
@ -117,6 +117,8 @@ class SyncResult(collections.namedtuple("SyncResult", [
|
|||||||
"archived", # ArchivedSyncResult for each archived room.
|
"archived", # ArchivedSyncResult for each archived room.
|
||||||
"to_device", # List of direct messages for the device.
|
"to_device", # List of direct messages for the device.
|
||||||
"device_lists", # List of user_ids whose devices have chanegd
|
"device_lists", # List of user_ids whose devices have chanegd
|
||||||
|
"device_one_time_keys_count", # Dict of algorithm to count for one time keys
|
||||||
|
# for this device
|
||||||
])):
|
])):
|
||||||
__slots__ = []
|
__slots__ = []
|
||||||
|
|
||||||
@ -550,6 +552,14 @@ class SyncHandler(object):
|
|||||||
sync_result_builder
|
sync_result_builder
|
||||||
)
|
)
|
||||||
|
|
||||||
|
device_id = sync_config.device_id
|
||||||
|
one_time_key_counts = {}
|
||||||
|
if device_id:
|
||||||
|
user_id = sync_config.user.to_string()
|
||||||
|
one_time_key_counts = yield self.store.count_e2e_one_time_keys(
|
||||||
|
user_id, device_id
|
||||||
|
)
|
||||||
|
|
||||||
defer.returnValue(SyncResult(
|
defer.returnValue(SyncResult(
|
||||||
presence=sync_result_builder.presence,
|
presence=sync_result_builder.presence,
|
||||||
account_data=sync_result_builder.account_data,
|
account_data=sync_result_builder.account_data,
|
||||||
@ -558,6 +568,7 @@ class SyncHandler(object):
|
|||||||
archived=sync_result_builder.archived,
|
archived=sync_result_builder.archived,
|
||||||
to_device=sync_result_builder.to_device,
|
to_device=sync_result_builder.to_device,
|
||||||
device_lists=device_lists,
|
device_lists=device_lists,
|
||||||
|
device_one_time_keys_count=one_time_key_counts,
|
||||||
next_batch=sync_result_builder.now_token,
|
next_batch=sync_result_builder.now_token,
|
||||||
))
|
))
|
||||||
|
|
||||||
|
@ -89,7 +89,7 @@ class TypingHandler(object):
|
|||||||
until = self._member_typing_until.get(member, None)
|
until = self._member_typing_until.get(member, None)
|
||||||
if not until or until <= now:
|
if not until or until <= now:
|
||||||
logger.info("Timing out typing for: %s", member.user_id)
|
logger.info("Timing out typing for: %s", member.user_id)
|
||||||
preserve_fn(self._stopped_typing)(member)
|
self._stopped_typing(member)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Check if we need to resend a keep alive over federation for this
|
# Check if we need to resend a keep alive over federation for this
|
||||||
@ -147,7 +147,7 @@ class TypingHandler(object):
|
|||||||
# No point sending another notification
|
# No point sending another notification
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
|
|
||||||
yield self._push_update(
|
self._push_update(
|
||||||
member=member,
|
member=member,
|
||||||
typing=True,
|
typing=True,
|
||||||
)
|
)
|
||||||
@ -171,7 +171,7 @@ class TypingHandler(object):
|
|||||||
|
|
||||||
member = RoomMember(room_id=room_id, user_id=target_user_id)
|
member = RoomMember(room_id=room_id, user_id=target_user_id)
|
||||||
|
|
||||||
yield self._stopped_typing(member)
|
self._stopped_typing(member)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def user_left_room(self, user, room_id):
|
def user_left_room(self, user, room_id):
|
||||||
@ -180,7 +180,6 @@ class TypingHandler(object):
|
|||||||
member = RoomMember(room_id=room_id, user_id=user_id)
|
member = RoomMember(room_id=room_id, user_id=user_id)
|
||||||
yield self._stopped_typing(member)
|
yield self._stopped_typing(member)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _stopped_typing(self, member):
|
def _stopped_typing(self, member):
|
||||||
if member.user_id not in self._room_typing.get(member.room_id, set()):
|
if member.user_id not in self._room_typing.get(member.room_id, set()):
|
||||||
# No point
|
# No point
|
||||||
@ -189,16 +188,15 @@ class TypingHandler(object):
|
|||||||
self._member_typing_until.pop(member, None)
|
self._member_typing_until.pop(member, None)
|
||||||
self._member_last_federation_poke.pop(member, None)
|
self._member_last_federation_poke.pop(member, None)
|
||||||
|
|
||||||
yield self._push_update(
|
self._push_update(
|
||||||
member=member,
|
member=member,
|
||||||
typing=False,
|
typing=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _push_update(self, member, typing):
|
def _push_update(self, member, typing):
|
||||||
if self.hs.is_mine_id(member.user_id):
|
if self.hs.is_mine_id(member.user_id):
|
||||||
# Only send updates for changes to our own users.
|
# Only send updates for changes to our own users.
|
||||||
yield self._push_remote(member, typing)
|
preserve_fn(self._push_remote)(member, typing)
|
||||||
|
|
||||||
self._push_update_local(
|
self._push_update_local(
|
||||||
member=member,
|
member=member,
|
||||||
|
641
synapse/handlers/user_directory.py
Normal file
641
synapse/handlers/user_directory.py
Normal file
@ -0,0 +1,641 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2017 Vector Creations Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api.constants import EventTypes, JoinRules, Membership
|
||||||
|
from synapse.storage.roommember import ProfileInfo
|
||||||
|
from synapse.util.metrics import Measure
|
||||||
|
from synapse.util.async import sleep
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class UserDirectoyHandler(object):
|
||||||
|
"""Handles querying of and keeping updated the user_directory.
|
||||||
|
|
||||||
|
N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY
|
||||||
|
|
||||||
|
The user directory is filled with users who this server can see are joined to a
|
||||||
|
world_readable or publically joinable room. We keep a database table up to date
|
||||||
|
by streaming changes of the current state and recalculating whether users should
|
||||||
|
be in the directory or not when necessary.
|
||||||
|
|
||||||
|
For each user in the directory we also store a room_id which is public and that the
|
||||||
|
user is joined to. This allows us to ignore history_visibility and join_rules changes
|
||||||
|
for that user in all other public rooms, as we know they'll still be in at least
|
||||||
|
one public room.
|
||||||
|
"""
|
||||||
|
|
||||||
|
INITIAL_SLEEP_MS = 50
|
||||||
|
INITIAL_SLEEP_COUNT = 100
|
||||||
|
INITIAL_BATCH_SIZE = 100
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.state = hs.get_state_handler()
|
||||||
|
self.server_name = hs.hostname
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.notifier = hs.get_notifier()
|
||||||
|
self.is_mine_id = hs.is_mine_id
|
||||||
|
self.update_user_directory = hs.config.update_user_directory
|
||||||
|
|
||||||
|
# When start up for the first time we need to populate the user_directory.
|
||||||
|
# This is a set of user_id's we've inserted already
|
||||||
|
self.initially_handled_users = set()
|
||||||
|
self.initially_handled_users_in_public = set()
|
||||||
|
|
||||||
|
self.initially_handled_users_share = set()
|
||||||
|
self.initially_handled_users_share_private_room = set()
|
||||||
|
|
||||||
|
# The current position in the current_state_delta stream
|
||||||
|
self.pos = None
|
||||||
|
|
||||||
|
# Guard to ensure we only process deltas one at a time
|
||||||
|
self._is_processing = False
|
||||||
|
|
||||||
|
if self.update_user_directory:
|
||||||
|
self.notifier.add_replication_callback(self.notify_new_event)
|
||||||
|
|
||||||
|
# We kick this off so that we don't have to wait for a change before
|
||||||
|
# we start populating the user directory
|
||||||
|
self.clock.call_later(0, self.notify_new_event)
|
||||||
|
|
||||||
|
def search_users(self, user_id, search_term, limit):
|
||||||
|
"""Searches for users in directory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict of the form::
|
||||||
|
|
||||||
|
{
|
||||||
|
"limited": <bool>, # whether there were more results or not
|
||||||
|
"results": [ # Ordered by best match first
|
||||||
|
{
|
||||||
|
"user_id": <user_id>,
|
||||||
|
"display_name": <display_name>,
|
||||||
|
"avatar_url": <avatar_url>
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
return self.store.search_user_dir(user_id, search_term, limit)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def notify_new_event(self):
|
||||||
|
"""Called when there may be more deltas to process
|
||||||
|
"""
|
||||||
|
if not self.update_user_directory:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._is_processing:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._is_processing = True
|
||||||
|
try:
|
||||||
|
yield self._unsafe_process()
|
||||||
|
finally:
|
||||||
|
self._is_processing = False
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _unsafe_process(self):
|
||||||
|
# If self.pos is None then means we haven't fetched it from DB
|
||||||
|
if self.pos is None:
|
||||||
|
self.pos = yield self.store.get_user_directory_stream_pos()
|
||||||
|
|
||||||
|
# If still None then we need to do the initial fill of directory
|
||||||
|
if self.pos is None:
|
||||||
|
yield self._do_initial_spam()
|
||||||
|
self.pos = yield self.store.get_user_directory_stream_pos()
|
||||||
|
|
||||||
|
# Loop round handling deltas until we're up to date
|
||||||
|
while True:
|
||||||
|
with Measure(self.clock, "user_dir_delta"):
|
||||||
|
deltas = yield self.store.get_current_state_deltas(self.pos)
|
||||||
|
if not deltas:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Handling %d state deltas", len(deltas))
|
||||||
|
yield self._handle_deltas(deltas)
|
||||||
|
|
||||||
|
self.pos = deltas[-1]["stream_id"]
|
||||||
|
yield self.store.update_user_directory_stream_pos(self.pos)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _do_initial_spam(self):
|
||||||
|
"""Populates the user_directory from the current state of the DB, used
|
||||||
|
when synapse first starts with user_directory support
|
||||||
|
"""
|
||||||
|
new_pos = yield self.store.get_max_stream_id_in_current_state_deltas()
|
||||||
|
|
||||||
|
# Delete any existing entries just in case there are any
|
||||||
|
yield self.store.delete_all_from_user_dir()
|
||||||
|
|
||||||
|
# We process by going through each existing room at a time.
|
||||||
|
room_ids = yield self.store.get_all_rooms()
|
||||||
|
|
||||||
|
logger.info("Doing initial update of user directory. %d rooms", len(room_ids))
|
||||||
|
num_processed_rooms = 1
|
||||||
|
|
||||||
|
for room_id in room_ids:
|
||||||
|
logger.info("Handling room %d/%d", num_processed_rooms, len(room_ids))
|
||||||
|
yield self._handle_intial_room(room_id)
|
||||||
|
num_processed_rooms += 1
|
||||||
|
yield sleep(self.INITIAL_SLEEP_MS / 1000.)
|
||||||
|
|
||||||
|
logger.info("Processed all rooms.")
|
||||||
|
|
||||||
|
self.initially_handled_users = None
|
||||||
|
self.initially_handled_users_in_public = None
|
||||||
|
self.initially_handled_users_share = None
|
||||||
|
self.initially_handled_users_share_private_room = None
|
||||||
|
|
||||||
|
yield self.store.update_user_directory_stream_pos(new_pos)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _handle_intial_room(self, room_id):
|
||||||
|
"""Called when we initially fill out user_directory one room at a time
|
||||||
|
"""
|
||||||
|
is_in_room = yield self.store.is_host_joined(room_id, self.server_name)
|
||||||
|
if not is_in_room:
|
||||||
|
return
|
||||||
|
|
||||||
|
is_public = yield self.store.is_room_world_readable_or_publicly_joinable(room_id)
|
||||||
|
|
||||||
|
users_with_profile = yield self.state.get_current_user_in_room(room_id)
|
||||||
|
user_ids = set(users_with_profile)
|
||||||
|
unhandled_users = user_ids - self.initially_handled_users
|
||||||
|
|
||||||
|
yield self.store.add_profiles_to_user_dir(
|
||||||
|
room_id, {
|
||||||
|
user_id: users_with_profile[user_id] for user_id in unhandled_users
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.initially_handled_users |= unhandled_users
|
||||||
|
|
||||||
|
if is_public:
|
||||||
|
yield self.store.add_users_to_public_room(
|
||||||
|
room_id,
|
||||||
|
user_ids=user_ids - self.initially_handled_users_in_public
|
||||||
|
)
|
||||||
|
self.initially_handled_users_in_public |= user_ids
|
||||||
|
|
||||||
|
# We now go and figure out the new users who share rooms with user entries
|
||||||
|
# We sleep aggressively here as otherwise it can starve resources.
|
||||||
|
# We also batch up inserts/updates, but try to avoid too many at once.
|
||||||
|
to_insert = set()
|
||||||
|
to_update = set()
|
||||||
|
count = 0
|
||||||
|
for user_id in user_ids:
|
||||||
|
if count % self.INITIAL_SLEEP_COUNT == 0:
|
||||||
|
yield sleep(self.INITIAL_SLEEP_MS / 1000.)
|
||||||
|
|
||||||
|
if not self.is_mine_id(user_id):
|
||||||
|
count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if self.store.get_if_app_services_interested_in_user(user_id):
|
||||||
|
count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
for other_user_id in user_ids:
|
||||||
|
if user_id == other_user_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if count % self.INITIAL_SLEEP_COUNT == 0:
|
||||||
|
yield sleep(self.INITIAL_SLEEP_MS / 1000.)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
user_set = (user_id, other_user_id)
|
||||||
|
|
||||||
|
if user_set in self.initially_handled_users_share_private_room:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if user_set in self.initially_handled_users_share:
|
||||||
|
if is_public:
|
||||||
|
continue
|
||||||
|
to_update.add(user_set)
|
||||||
|
else:
|
||||||
|
to_insert.add(user_set)
|
||||||
|
|
||||||
|
if is_public:
|
||||||
|
self.initially_handled_users_share.add(user_set)
|
||||||
|
else:
|
||||||
|
self.initially_handled_users_share_private_room.add(user_set)
|
||||||
|
|
||||||
|
if len(to_insert) > self.INITIAL_BATCH_SIZE:
|
||||||
|
yield self.store.add_users_who_share_room(
|
||||||
|
room_id, not is_public, to_insert,
|
||||||
|
)
|
||||||
|
to_insert.clear()
|
||||||
|
|
||||||
|
if len(to_update) > self.INITIAL_BATCH_SIZE:
|
||||||
|
yield self.store.update_users_who_share_room(
|
||||||
|
room_id, not is_public, to_update,
|
||||||
|
)
|
||||||
|
to_update.clear()
|
||||||
|
|
||||||
|
if to_insert:
|
||||||
|
yield self.store.add_users_who_share_room(
|
||||||
|
room_id, not is_public, to_insert,
|
||||||
|
)
|
||||||
|
to_insert.clear()
|
||||||
|
|
||||||
|
if to_update:
|
||||||
|
yield self.store.update_users_who_share_room(
|
||||||
|
room_id, not is_public, to_update,
|
||||||
|
)
|
||||||
|
to_update.clear()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _handle_deltas(self, deltas):
|
||||||
|
"""Called with the state deltas to process
|
||||||
|
"""
|
||||||
|
for delta in deltas:
|
||||||
|
typ = delta["type"]
|
||||||
|
state_key = delta["state_key"]
|
||||||
|
room_id = delta["room_id"]
|
||||||
|
event_id = delta["event_id"]
|
||||||
|
prev_event_id = delta["prev_event_id"]
|
||||||
|
|
||||||
|
logger.debug("Handling: %r %r, %s", typ, state_key, event_id)
|
||||||
|
|
||||||
|
# For join rule and visibility changes we need to check if the room
|
||||||
|
# may have become public or not and add/remove the users in said room
|
||||||
|
if typ in (EventTypes.RoomHistoryVisibility, EventTypes.JoinRules):
|
||||||
|
yield self._handle_room_publicity_change(
|
||||||
|
room_id, prev_event_id, event_id, typ,
|
||||||
|
)
|
||||||
|
elif typ == EventTypes.Member:
|
||||||
|
change = yield self._get_key_change(
|
||||||
|
prev_event_id, event_id,
|
||||||
|
key_name="membership",
|
||||||
|
public_value=Membership.JOIN,
|
||||||
|
)
|
||||||
|
|
||||||
|
if change is None:
|
||||||
|
# Handle any profile changes
|
||||||
|
yield self._handle_profile_change(
|
||||||
|
state_key, room_id, prev_event_id, event_id,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not change:
|
||||||
|
# Need to check if the server left the room entirely, if so
|
||||||
|
# we might need to remove all the users in that room
|
||||||
|
is_in_room = yield self.store.is_host_joined(
|
||||||
|
room_id, self.server_name,
|
||||||
|
)
|
||||||
|
if not is_in_room:
|
||||||
|
logger.info("Server left room: %r", room_id)
|
||||||
|
# Fetch all the users that we marked as being in user
|
||||||
|
# directory due to being in the room and then check if
|
||||||
|
# need to remove those users or not
|
||||||
|
user_ids = yield self.store.get_users_in_dir_due_to_room(room_id)
|
||||||
|
for user_id in user_ids:
|
||||||
|
yield self._handle_remove_user(room_id, user_id)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
logger.debug("Server is still in room: %r", room_id)
|
||||||
|
|
||||||
|
if change: # The user joined
|
||||||
|
event = yield self.store.get_event(event_id, allow_none=True)
|
||||||
|
profile = ProfileInfo(
|
||||||
|
avatar_url=event.content.get("avatar_url"),
|
||||||
|
display_name=event.content.get("displayname"),
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self._handle_new_user(room_id, state_key, profile)
|
||||||
|
else: # The user left
|
||||||
|
yield self._handle_remove_user(room_id, state_key)
|
||||||
|
else:
|
||||||
|
logger.debug("Ignoring irrelevant type: %r", typ)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _handle_room_publicity_change(self, room_id, prev_event_id, event_id, typ):
|
||||||
|
"""Handle a room having potentially changed from/to world_readable/publically
|
||||||
|
joinable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id (str)
|
||||||
|
prev_event_id (str|None): The previous event before the state change
|
||||||
|
event_id (str|None): The new event after the state change
|
||||||
|
typ (str): Type of the event
|
||||||
|
"""
|
||||||
|
logger.debug("Handling change for %s: %s", typ, room_id)
|
||||||
|
|
||||||
|
if typ == EventTypes.RoomHistoryVisibility:
|
||||||
|
change = yield self._get_key_change(
|
||||||
|
prev_event_id, event_id,
|
||||||
|
key_name="history_visibility",
|
||||||
|
public_value="world_readable",
|
||||||
|
)
|
||||||
|
elif typ == EventTypes.JoinRules:
|
||||||
|
change = yield self._get_key_change(
|
||||||
|
prev_event_id, event_id,
|
||||||
|
key_name="join_rule",
|
||||||
|
public_value=JoinRules.PUBLIC,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception("Invalid event type")
|
||||||
|
# If change is None, no change. True => become world_readable/public,
|
||||||
|
# False => was world_readable/public
|
||||||
|
if change is None:
|
||||||
|
logger.debug("No change")
|
||||||
|
return
|
||||||
|
|
||||||
|
# There's been a change to or from being world readable.
|
||||||
|
|
||||||
|
is_public = yield self.store.is_room_world_readable_or_publicly_joinable(
|
||||||
|
room_id
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Change: %r, is_public: %r", change, is_public)
|
||||||
|
|
||||||
|
if change and not is_public:
|
||||||
|
# If we became world readable but room isn't currently public then
|
||||||
|
# we ignore the change
|
||||||
|
return
|
||||||
|
elif not change and is_public:
|
||||||
|
# If we stopped being world readable but are still public,
|
||||||
|
# ignore the change
|
||||||
|
return
|
||||||
|
|
||||||
|
if change:
|
||||||
|
users_with_profile = yield self.state.get_current_user_in_room(room_id)
|
||||||
|
for user_id, profile in users_with_profile.iteritems():
|
||||||
|
yield self._handle_new_user(room_id, user_id, profile)
|
||||||
|
else:
|
||||||
|
users = yield self.store.get_users_in_public_due_to_room(room_id)
|
||||||
|
for user_id in users:
|
||||||
|
yield self._handle_remove_user(room_id, user_id)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _handle_new_user(self, room_id, user_id, profile):
|
||||||
|
"""Called when we might need to add user to directory
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id (str): room_id that user joined or started being public that
|
||||||
|
user_id (str)
|
||||||
|
"""
|
||||||
|
logger.debug("Adding user to dir, %r", user_id)
|
||||||
|
|
||||||
|
row = yield self.store.get_user_in_directory(user_id)
|
||||||
|
if not row:
|
||||||
|
yield self.store.add_profiles_to_user_dir(room_id, {user_id: profile})
|
||||||
|
|
||||||
|
is_public = yield self.store.is_room_world_readable_or_publicly_joinable(
|
||||||
|
room_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_public:
|
||||||
|
row = yield self.store.get_user_in_public_room(user_id)
|
||||||
|
if not row:
|
||||||
|
yield self.store.add_users_to_public_room(room_id, [user_id])
|
||||||
|
else:
|
||||||
|
logger.debug("Not adding user to public dir, %r", user_id)
|
||||||
|
|
||||||
|
# Now we update users who share rooms with users. We do this by getting
|
||||||
|
# all the current users in the room and seeing which aren't already
|
||||||
|
# marked in the database as sharing with `user_id`
|
||||||
|
|
||||||
|
users_with_profile = yield self.state.get_current_user_in_room(room_id)
|
||||||
|
|
||||||
|
to_insert = set()
|
||||||
|
to_update = set()
|
||||||
|
|
||||||
|
is_appservice = self.store.get_if_app_services_interested_in_user(user_id)
|
||||||
|
|
||||||
|
# First, if they're our user then we need to update for every user
|
||||||
|
if self.is_mine_id(user_id) and not is_appservice:
|
||||||
|
# Returns a map of other_user_id -> shared_private. We only need
|
||||||
|
# to update mappings if for users that either don't share a room
|
||||||
|
# already (aren't in the map) or, if the room is private, those that
|
||||||
|
# only share a public room.
|
||||||
|
user_ids_shared = yield self.store.get_users_who_share_room_from_dir(
|
||||||
|
user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
for other_user_id in users_with_profile:
|
||||||
|
if user_id == other_user_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
shared_is_private = user_ids_shared.get(other_user_id)
|
||||||
|
if shared_is_private is True:
|
||||||
|
# We've already marked in the database they share a private room
|
||||||
|
continue
|
||||||
|
elif shared_is_private is False:
|
||||||
|
# They already share a public room, so only update if this is
|
||||||
|
# a private room
|
||||||
|
if not is_public:
|
||||||
|
to_update.add((user_id, other_user_id))
|
||||||
|
elif shared_is_private is None:
|
||||||
|
# This is the first time they both share a room
|
||||||
|
to_insert.add((user_id, other_user_id))
|
||||||
|
|
||||||
|
# Next we need to update for every local user in the room
|
||||||
|
for other_user_id in users_with_profile:
|
||||||
|
if user_id == other_user_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_appservice = self.store.get_if_app_services_interested_in_user(
|
||||||
|
other_user_id
|
||||||
|
)
|
||||||
|
if self.is_mine_id(other_user_id) and not is_appservice:
|
||||||
|
shared_is_private = yield self.store.get_if_users_share_a_room(
|
||||||
|
other_user_id, user_id,
|
||||||
|
)
|
||||||
|
if shared_is_private is True:
|
||||||
|
# We've already marked in the database they share a private room
|
||||||
|
continue
|
||||||
|
elif shared_is_private is False:
|
||||||
|
# They already share a public room, so only update if this is
|
||||||
|
# a private room
|
||||||
|
if not is_public:
|
||||||
|
to_update.add((other_user_id, user_id))
|
||||||
|
elif shared_is_private is None:
|
||||||
|
# This is the first time they both share a room
|
||||||
|
to_insert.add((other_user_id, user_id))
|
||||||
|
|
||||||
|
if to_insert:
|
||||||
|
yield self.store.add_users_who_share_room(
|
||||||
|
room_id, not is_public, to_insert,
|
||||||
|
)
|
||||||
|
|
||||||
|
if to_update:
|
||||||
|
yield self.store.update_users_who_share_room(
|
||||||
|
room_id, not is_public, to_update,
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _handle_remove_user(self, room_id, user_id):
|
||||||
|
"""Called when we might need to remove user to directory
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id (str): room_id that user left or stopped being public that
|
||||||
|
user_id (str)
|
||||||
|
"""
|
||||||
|
logger.debug("Maybe removing user %r", user_id)
|
||||||
|
|
||||||
|
row = yield self.store.get_user_in_directory(user_id)
|
||||||
|
update_user_dir = row and row["room_id"] == room_id
|
||||||
|
|
||||||
|
row = yield self.store.get_user_in_public_room(user_id)
|
||||||
|
update_user_in_public = row and row["room_id"] == room_id
|
||||||
|
|
||||||
|
if (update_user_in_public or update_user_dir):
|
||||||
|
# XXX: Make this faster?
|
||||||
|
rooms = yield self.store.get_rooms_for_user(user_id)
|
||||||
|
for j_room_id in rooms:
|
||||||
|
if (not update_user_in_public and not update_user_dir):
|
||||||
|
break
|
||||||
|
|
||||||
|
is_in_room = yield self.store.is_host_joined(
|
||||||
|
j_room_id, self.server_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_in_room:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if update_user_dir:
|
||||||
|
update_user_dir = False
|
||||||
|
yield self.store.update_user_in_user_dir(user_id, j_room_id)
|
||||||
|
|
||||||
|
is_public = yield self.store.is_room_world_readable_or_publicly_joinable(
|
||||||
|
j_room_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if update_user_in_public and is_public:
|
||||||
|
yield self.store.update_user_in_public_user_list(user_id, j_room_id)
|
||||||
|
update_user_in_public = False
|
||||||
|
|
||||||
|
if update_user_dir:
|
||||||
|
yield self.store.remove_from_user_dir(user_id)
|
||||||
|
elif update_user_in_public:
|
||||||
|
yield self.store.remove_from_user_in_public_room(user_id)
|
||||||
|
|
||||||
|
# Now handle users_who_share_rooms.
|
||||||
|
|
||||||
|
# Get a list of user tuples that were in the DB due to this room and
|
||||||
|
# users (this includes tuples where the other user matches `user_id`)
|
||||||
|
user_tuples = yield self.store.get_users_in_share_dir_with_room_id(
|
||||||
|
user_id, room_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
for user_id, other_user_id in user_tuples:
|
||||||
|
# For each user tuple get a list of rooms that they still share,
|
||||||
|
# trying to find a private room, and update the entry in the DB
|
||||||
|
rooms = yield self.store.get_rooms_in_common_for_users(user_id, other_user_id)
|
||||||
|
|
||||||
|
# If they dont share a room anymore, remove the mapping
|
||||||
|
if not rooms:
|
||||||
|
yield self.store.remove_user_who_share_room(
|
||||||
|
user_id, other_user_id,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
found_public_share = None
|
||||||
|
for j_room_id in rooms:
|
||||||
|
is_public = yield self.store.is_room_world_readable_or_publicly_joinable(
|
||||||
|
j_room_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_public:
|
||||||
|
found_public_share = j_room_id
|
||||||
|
else:
|
||||||
|
found_public_share = None
|
||||||
|
yield self.store.update_users_who_share_room(
|
||||||
|
room_id, not is_public, [(user_id, other_user_id)],
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
if found_public_share:
|
||||||
|
yield self.store.update_users_who_share_room(
|
||||||
|
room_id, not is_public, [(user_id, other_user_id)],
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _handle_profile_change(self, user_id, room_id, prev_event_id, event_id):
|
||||||
|
"""Check member event changes for any profile changes and update the
|
||||||
|
database if there are.
|
||||||
|
"""
|
||||||
|
if not prev_event_id or not event_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
|
||||||
|
event = yield self.store.get_event(event_id, allow_none=True)
|
||||||
|
|
||||||
|
if not prev_event or not event:
|
||||||
|
return
|
||||||
|
|
||||||
|
if event.membership != Membership.JOIN:
|
||||||
|
return
|
||||||
|
|
||||||
|
prev_name = prev_event.content.get("displayname")
|
||||||
|
new_name = event.content.get("displayname")
|
||||||
|
|
||||||
|
prev_avatar = prev_event.content.get("avatar_url")
|
||||||
|
new_avatar = event.content.get("avatar_url")
|
||||||
|
|
||||||
|
if prev_name != new_name or prev_avatar != new_avatar:
|
||||||
|
yield self.store.update_profile_in_user_dir(
|
||||||
|
user_id, new_name, new_avatar, room_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _get_key_change(self, prev_event_id, event_id, key_name, public_value):
|
||||||
|
"""Given two events check if the `key_name` field in content changed
|
||||||
|
from not matching `public_value` to doing so.
|
||||||
|
|
||||||
|
For example, check if `history_visibility` (`key_name`) changed from
|
||||||
|
`shared` to `world_readable` (`public_value`).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None if the field in the events either both match `public_value`
|
||||||
|
or if neither do, i.e. there has been no change.
|
||||||
|
True if it didnt match `public_value` but now does
|
||||||
|
False if it did match `public_value` but now doesn't
|
||||||
|
"""
|
||||||
|
prev_event = None
|
||||||
|
event = None
|
||||||
|
if prev_event_id:
|
||||||
|
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
|
||||||
|
|
||||||
|
if event_id:
|
||||||
|
event = yield self.store.get_event(event_id, allow_none=True)
|
||||||
|
|
||||||
|
if not event and not prev_event:
|
||||||
|
logger.debug("Neither event exists: %r %r", prev_event_id, event_id)
|
||||||
|
defer.returnValue(None)
|
||||||
|
|
||||||
|
prev_value = None
|
||||||
|
value = None
|
||||||
|
|
||||||
|
if prev_event:
|
||||||
|
prev_value = prev_event.content.get(key_name)
|
||||||
|
|
||||||
|
if event:
|
||||||
|
value = event.content.get(key_name)
|
||||||
|
|
||||||
|
logger.debug("prev_value: %r -> value: %r", prev_value, value)
|
||||||
|
|
||||||
|
if value == public_value and prev_value != public_value:
|
||||||
|
defer.returnValue(True)
|
||||||
|
elif value != public_value and prev_value == public_value:
|
||||||
|
defer.returnValue(False)
|
||||||
|
else:
|
||||||
|
defer.returnValue(None)
|
@ -412,7 +412,7 @@ def set_cors_headers(request):
|
|||||||
)
|
)
|
||||||
request.setHeader(
|
request.setHeader(
|
||||||
"Access-Control-Allow-Headers",
|
"Access-Control-Allow-Headers",
|
||||||
"Origin, X-Requested-With, Content-Type, Accept"
|
"Origin, X-Requested-With, Content-Type, Accept, Authorization"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -251,7 +251,8 @@ class Notifier(object):
|
|||||||
"""Notify any user streams that are interested in this room event"""
|
"""Notify any user streams that are interested in this room event"""
|
||||||
# poke any interested application service.
|
# poke any interested application service.
|
||||||
preserve_fn(self.appservice_handler.notify_interested_services)(
|
preserve_fn(self.appservice_handler.notify_interested_services)(
|
||||||
room_stream_id)
|
room_stream_id
|
||||||
|
)
|
||||||
|
|
||||||
if self.federation_sender:
|
if self.federation_sender:
|
||||||
preserve_fn(self.federation_sender.notify_new_events)(
|
preserve_fn(self.federation_sender.notify_new_events)(
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from .bulk_push_rule_evaluator import evaluator_for_event
|
from .bulk_push_rule_evaluator import BulkPushRuleEvaluator
|
||||||
|
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
@ -24,11 +24,12 @@ import logging
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ActionGenerator:
|
class ActionGenerator(object):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
self.bulk_evaluator = BulkPushRuleEvaluator(hs)
|
||||||
# really we want to get all user ids and all profile tags too,
|
# really we want to get all user ids and all profile tags too,
|
||||||
# since we want the actions for each profile tag for every user and
|
# since we want the actions for each profile tag for every user and
|
||||||
# also actions for a client with no profile tag for each user.
|
# also actions for a client with no profile tag for each user.
|
||||||
@ -38,16 +39,11 @@ class ActionGenerator:
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def handle_push_actions_for_event(self, event, context):
|
def handle_push_actions_for_event(self, event, context):
|
||||||
with Measure(self.clock, "evaluator_for_event"):
|
|
||||||
bulk_evaluator = yield evaluator_for_event(
|
|
||||||
event, self.hs, self.store, context
|
|
||||||
)
|
|
||||||
|
|
||||||
with Measure(self.clock, "action_for_event_by_user"):
|
with Measure(self.clock, "action_for_event_by_user"):
|
||||||
actions_by_user = yield bulk_evaluator.action_for_event_by_user(
|
actions_by_user = yield self.bulk_evaluator.action_for_event_by_user(
|
||||||
event, context
|
event, context
|
||||||
)
|
)
|
||||||
|
|
||||||
context.push_actions = [
|
context.push_actions = [
|
||||||
(uid, actions) for uid, actions in actions_by_user.items()
|
(uid, actions) for uid, actions in actions_by_user.iteritems()
|
||||||
]
|
]
|
||||||
|
@ -19,60 +19,83 @@ from twisted.internet import defer
|
|||||||
|
|
||||||
from .push_rule_evaluator import PushRuleEvaluatorForEvent
|
from .push_rule_evaluator import PushRuleEvaluatorForEvent
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
|
||||||
from synapse.visibility import filter_events_for_clients_context
|
from synapse.visibility import filter_events_for_clients_context
|
||||||
|
from synapse.api.constants import EventTypes, Membership
|
||||||
|
from synapse.util.caches.descriptors import cached
|
||||||
|
from synapse.util.async import Linearizer
|
||||||
|
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
rules_by_room = {}
|
||||||
def evaluator_for_event(event, hs, store, context):
|
|
||||||
rules_by_user = yield store.bulk_get_push_rules_for_room(
|
|
||||||
event, context
|
class BulkPushRuleEvaluator(object):
|
||||||
)
|
"""Calculates the outcome of push rules for an event for all users in the
|
||||||
|
room at once.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.hs = hs
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _get_rules_for_event(self, event, context):
|
||||||
|
"""This gets the rules for all users in the room at the time of the event,
|
||||||
|
as well as the push rules for the invitee if the event is an invite.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict of user_id -> push_rules
|
||||||
|
"""
|
||||||
|
room_id = event.room_id
|
||||||
|
rules_for_room = self._get_rules_for_room(room_id)
|
||||||
|
|
||||||
|
rules_by_user = yield rules_for_room.get_rules(event, context)
|
||||||
|
|
||||||
# if this event is an invite event, we may need to run rules for the user
|
# if this event is an invite event, we may need to run rules for the user
|
||||||
# who's been invited, otherwise they won't get told they've been invited
|
# who's been invited, otherwise they won't get told they've been invited
|
||||||
if event.type == 'm.room.member' and event.content['membership'] == 'invite':
|
if event.type == 'm.room.member' and event.content['membership'] == 'invite':
|
||||||
invited_user = event.state_key
|
invited = event.state_key
|
||||||
if invited_user and hs.is_mine_id(invited_user):
|
if invited and self.hs.is_mine_id(invited):
|
||||||
has_pusher = yield store.user_has_pusher(invited_user)
|
has_pusher = yield self.store.user_has_pusher(invited)
|
||||||
if has_pusher:
|
if has_pusher:
|
||||||
rules_by_user = dict(rules_by_user)
|
rules_by_user = dict(rules_by_user)
|
||||||
rules_by_user[invited_user] = yield store.get_push_rules_for_user(
|
rules_by_user[invited] = yield self.store.get_push_rules_for_user(
|
||||||
invited_user
|
invited
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(BulkPushRuleEvaluator(
|
defer.returnValue(rules_by_user)
|
||||||
event.room_id, rules_by_user, store
|
|
||||||
))
|
|
||||||
|
|
||||||
|
@cached()
|
||||||
|
def _get_rules_for_room(self, room_id):
|
||||||
|
"""Get the current RulesForRoom object for the given room id
|
||||||
|
|
||||||
class BulkPushRuleEvaluator:
|
Returns:
|
||||||
|
RulesForRoom
|
||||||
"""
|
"""
|
||||||
Runs push rules for all users in a room.
|
# It's important that RulesForRoom gets added to self._get_rules_for_room.cache
|
||||||
This is faster than running PushRuleEvaluator for each user because it
|
# before any lookup methods get called on it as otherwise there may be
|
||||||
fetches all the rules for all the users in one (batched) db query
|
# a race if invalidate_all gets called (which assumes its in the cache)
|
||||||
rather than doing multiple queries per-user. It currently uses
|
return RulesForRoom(self.hs, room_id, self._get_rules_for_room.cache)
|
||||||
the same logic to run the actual rules, but could be optimised further
|
|
||||||
(see https://matrix.org/jira/browse/SYN-562)
|
|
||||||
"""
|
|
||||||
def __init__(self, room_id, rules_by_user, store):
|
|
||||||
self.room_id = room_id
|
|
||||||
self.rules_by_user = rules_by_user
|
|
||||||
self.store = store
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def action_for_event_by_user(self, event, context):
|
def action_for_event_by_user(self, event, context):
|
||||||
|
"""Given an event and context, evaluate the push rules and return
|
||||||
|
the results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict of user_id -> action
|
||||||
|
"""
|
||||||
|
rules_by_user = yield self._get_rules_for_event(event, context)
|
||||||
actions_by_user = {}
|
actions_by_user = {}
|
||||||
|
|
||||||
# None of these users can be peeking since this list of users comes
|
# None of these users can be peeking since this list of users comes
|
||||||
# from the set of users in the room, so we know for sure they're all
|
# from the set of users in the room, so we know for sure they're all
|
||||||
# actually in the room.
|
# actually in the room.
|
||||||
user_tuples = [
|
user_tuples = [(u, False) for u in rules_by_user]
|
||||||
(u, False) for u in self.rules_by_user.keys()
|
|
||||||
]
|
|
||||||
|
|
||||||
filtered_by_user = yield filter_events_for_clients_context(
|
filtered_by_user = yield filter_events_for_clients_context(
|
||||||
self.store, user_tuples, [event], {event.event_id: context}
|
self.store, user_tuples, [event], {event.event_id: context}
|
||||||
@ -86,7 +109,7 @@ class BulkPushRuleEvaluator:
|
|||||||
|
|
||||||
condition_cache = {}
|
condition_cache = {}
|
||||||
|
|
||||||
for uid, rules in self.rules_by_user.items():
|
for uid, rules in rules_by_user.iteritems():
|
||||||
display_name = None
|
display_name = None
|
||||||
profile_info = room_members.get(uid)
|
profile_info = room_members.get(uid)
|
||||||
if profile_info:
|
if profile_info:
|
||||||
@ -138,3 +161,240 @@ def _condition_checker(evaluator, conditions, uid, display_name, cache):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class RulesForRoom(object):
|
||||||
|
"""Caches push rules for users in a room.
|
||||||
|
|
||||||
|
This efficiently handles users joining/leaving the room by not invalidating
|
||||||
|
the entire cache for the room.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hs, room_id, rules_for_room_cache):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hs (HomeServer)
|
||||||
|
room_id (str)
|
||||||
|
rules_for_room_cache(Cache): The cache object that caches these
|
||||||
|
RoomsForUser objects.
|
||||||
|
"""
|
||||||
|
self.room_id = room_id
|
||||||
|
self.is_mine_id = hs.is_mine_id
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
|
self.linearizer = Linearizer(name="rules_for_room")
|
||||||
|
|
||||||
|
self.member_map = {} # event_id -> (user_id, state)
|
||||||
|
self.rules_by_user = {} # user_id -> rules
|
||||||
|
|
||||||
|
# The last state group we updated the caches for. If the state_group of
|
||||||
|
# a new event comes along, we know that we can just return the cached
|
||||||
|
# result.
|
||||||
|
# On invalidation of the rules themselves (if the user changes them),
|
||||||
|
# we invalidate everything and set state_group to `object()`
|
||||||
|
self.state_group = object()
|
||||||
|
|
||||||
|
# A sequence number to keep track of when we're allowed to update the
|
||||||
|
# cache. We bump the sequence number when we invalidate the cache. If
|
||||||
|
# the sequence number changes while we're calculating stuff we should
|
||||||
|
# not update the cache with it.
|
||||||
|
self.sequence = 0
|
||||||
|
|
||||||
|
# A cache of user_ids that we *know* aren't interesting, e.g. user_ids
|
||||||
|
# owned by AS's, or remote users, etc. (I.e. users we will never need to
|
||||||
|
# calculate push for)
|
||||||
|
# These never need to be invalidated as we will never set up push for
|
||||||
|
# them.
|
||||||
|
self.uninteresting_user_set = set()
|
||||||
|
|
||||||
|
# We need to be clever on the invalidating caches callbacks, as
|
||||||
|
# otherwise the invalidation callback holds a reference to the object,
|
||||||
|
# potentially causing it to leak.
|
||||||
|
# To get around this we pass a function that on invalidations looks ups
|
||||||
|
# the RoomsForUser entry in the cache, rather than keeping a reference
|
||||||
|
# to self around in the callback.
|
||||||
|
self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_rules(self, event, context):
|
||||||
|
"""Given an event context return the rules for all users who are
|
||||||
|
currently in the room.
|
||||||
|
"""
|
||||||
|
state_group = context.state_group
|
||||||
|
|
||||||
|
with (yield self.linearizer.queue(())):
|
||||||
|
if state_group and self.state_group == state_group:
|
||||||
|
logger.debug("Using cached rules for %r", self.room_id)
|
||||||
|
defer.returnValue(self.rules_by_user)
|
||||||
|
|
||||||
|
ret_rules_by_user = {}
|
||||||
|
missing_member_event_ids = {}
|
||||||
|
if state_group and self.state_group == context.prev_group:
|
||||||
|
# If we have a simple delta then we can reuse most of the previous
|
||||||
|
# results.
|
||||||
|
ret_rules_by_user = self.rules_by_user
|
||||||
|
current_state_ids = context.delta_ids
|
||||||
|
else:
|
||||||
|
current_state_ids = context.current_state_ids
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Looking for member changes in %r %r", state_group, current_state_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
# Loop through to see which member events we've seen and have rules
|
||||||
|
# for and which we need to fetch
|
||||||
|
for key in current_state_ids:
|
||||||
|
typ, user_id = key
|
||||||
|
if typ != EventTypes.Member:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if user_id in self.uninteresting_user_set:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not self.is_mine_id(user_id):
|
||||||
|
self.uninteresting_user_set.add(user_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if self.store.get_if_app_services_interested_in_user(user_id):
|
||||||
|
self.uninteresting_user_set.add(user_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
event_id = current_state_ids[key]
|
||||||
|
|
||||||
|
res = self.member_map.get(event_id, None)
|
||||||
|
if res:
|
||||||
|
user_id, state = res
|
||||||
|
if state == Membership.JOIN:
|
||||||
|
rules = self.rules_by_user.get(user_id, None)
|
||||||
|
if rules:
|
||||||
|
ret_rules_by_user[user_id] = rules
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If a user has left a room we remove their push rule. If they
|
||||||
|
# joined then we readd it later in _update_rules_with_member_event_ids
|
||||||
|
ret_rules_by_user.pop(user_id, None)
|
||||||
|
missing_member_event_ids[user_id] = event_id
|
||||||
|
|
||||||
|
if missing_member_event_ids:
|
||||||
|
# If we have some memebr events we haven't seen, look them up
|
||||||
|
# and fetch push rules for them if appropriate.
|
||||||
|
logger.debug("Found new member events %r", missing_member_event_ids)
|
||||||
|
yield self._update_rules_with_member_event_ids(
|
||||||
|
ret_rules_by_user, missing_member_event_ids, state_group, event
|
||||||
|
)
|
||||||
|
|
||||||
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
|
logger.debug(
|
||||||
|
"Returning push rules for %r %r",
|
||||||
|
self.room_id, ret_rules_by_user.keys(),
|
||||||
|
)
|
||||||
|
defer.returnValue(ret_rules_by_user)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _update_rules_with_member_event_ids(self, ret_rules_by_user, member_event_ids,
|
||||||
|
state_group, event):
|
||||||
|
"""Update the partially filled rules_by_user dict by fetching rules for
|
||||||
|
any newly joined users in the `member_event_ids` list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets
|
||||||
|
updated with any new rules.
|
||||||
|
member_event_ids (list): List of event ids for membership events that
|
||||||
|
have happened since the last time we filled rules_by_user
|
||||||
|
state_group: The state group we are currently computing push rules
|
||||||
|
for. Used when updating the cache.
|
||||||
|
"""
|
||||||
|
sequence = self.sequence
|
||||||
|
|
||||||
|
rows = yield self.store._simple_select_many_batch(
|
||||||
|
table="room_memberships",
|
||||||
|
column="event_id",
|
||||||
|
iterable=member_event_ids.values(),
|
||||||
|
retcols=('user_id', 'membership', 'event_id'),
|
||||||
|
keyvalues={},
|
||||||
|
batch_size=500,
|
||||||
|
desc="_get_rules_for_member_event_ids",
|
||||||
|
)
|
||||||
|
|
||||||
|
members = {
|
||||||
|
row["event_id"]: (row["user_id"], row["membership"])
|
||||||
|
for row in rows
|
||||||
|
}
|
||||||
|
|
||||||
|
# If the event is a join event then it will be in current state evnts
|
||||||
|
# map but not in the DB, so we have to explicitly insert it.
|
||||||
|
if event.type == EventTypes.Member:
|
||||||
|
for event_id in member_event_ids.itervalues():
|
||||||
|
if event_id == event.event_id:
|
||||||
|
members[event_id] = (event.state_key, event.membership)
|
||||||
|
|
||||||
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
|
logger.debug("Found members %r: %r", self.room_id, members.values())
|
||||||
|
|
||||||
|
interested_in_user_ids = set(
|
||||||
|
user_id for user_id, membership in members.itervalues()
|
||||||
|
if membership == Membership.JOIN
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Joined: %r", interested_in_user_ids)
|
||||||
|
|
||||||
|
if_users_with_pushers = yield self.store.get_if_users_have_pushers(
|
||||||
|
interested_in_user_ids,
|
||||||
|
on_invalidate=self.invalidate_all_cb,
|
||||||
|
)
|
||||||
|
|
||||||
|
user_ids = set(
|
||||||
|
uid for uid, have_pusher in if_users_with_pushers.iteritems() if have_pusher
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("With pushers: %r", user_ids)
|
||||||
|
|
||||||
|
users_with_receipts = yield self.store.get_users_with_read_receipts_in_room(
|
||||||
|
self.room_id, on_invalidate=self.invalidate_all_cb,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("With receipts: %r", users_with_receipts)
|
||||||
|
|
||||||
|
# any users with pushers must be ours: they have pushers
|
||||||
|
for uid in users_with_receipts:
|
||||||
|
if uid in interested_in_user_ids:
|
||||||
|
user_ids.add(uid)
|
||||||
|
|
||||||
|
rules_by_user = yield self.store.bulk_get_push_rules(
|
||||||
|
user_ids, on_invalidate=self.invalidate_all_cb,
|
||||||
|
)
|
||||||
|
|
||||||
|
ret_rules_by_user.update(
|
||||||
|
item for item in rules_by_user.iteritems() if item[0] is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
self.update_cache(sequence, members, ret_rules_by_user, state_group)
|
||||||
|
|
||||||
|
def invalidate_all(self):
|
||||||
|
# Note: Don't hand this function directly to an invalidation callback
|
||||||
|
# as it keeps a reference to self and will stop this instance from being
|
||||||
|
# GC'd if it gets dropped from the rules_to_user cache. Instead use
|
||||||
|
# `self.invalidate_all_cb`
|
||||||
|
logger.debug("Invalidating RulesForRoom for %r", self.room_id)
|
||||||
|
self.sequence += 1
|
||||||
|
self.state_group = object()
|
||||||
|
self.member_map = {}
|
||||||
|
self.rules_by_user = {}
|
||||||
|
|
||||||
|
def update_cache(self, sequence, members, rules_by_user, state_group):
|
||||||
|
if sequence == self.sequence:
|
||||||
|
self.member_map.update(members)
|
||||||
|
self.rules_by_user = rules_by_user
|
||||||
|
self.state_group = state_group
|
||||||
|
|
||||||
|
|
||||||
|
class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))):
|
||||||
|
# We rely on _CacheContext implementing __eq__ and __hash__ sensibly,
|
||||||
|
# which namedtuple does for us (i.e. two _CacheContext are the same if
|
||||||
|
# their caches and keys match). This is important in particular to
|
||||||
|
# dedupe when we add callbacks to lru cache nodes, otherwise the number
|
||||||
|
# of callbacks would grow.
|
||||||
|
def __call__(self):
|
||||||
|
rules = self.cache.get(self.room_id, None, update_metrics=False)
|
||||||
|
if rules:
|
||||||
|
rules.invalidate_all()
|
||||||
|
@ -21,7 +21,6 @@ import logging
|
|||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext
|
||||||
|
|
||||||
from mailer import Mailer
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -56,8 +55,10 @@ class EmailPusher(object):
|
|||||||
This shares quite a bit of code with httpusher: it would be good to
|
This shares quite a bit of code with httpusher: it would be good to
|
||||||
factor out the common parts
|
factor out the common parts
|
||||||
"""
|
"""
|
||||||
def __init__(self, hs, pusherdict):
|
def __init__(self, hs, pusherdict, mailer):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
self.mailer = mailer
|
||||||
|
|
||||||
self.store = self.hs.get_datastore()
|
self.store = self.hs.get_datastore()
|
||||||
self.clock = self.hs.get_clock()
|
self.clock = self.hs.get_clock()
|
||||||
self.pusher_id = pusherdict['id']
|
self.pusher_id = pusherdict['id']
|
||||||
@ -73,16 +74,6 @@ class EmailPusher(object):
|
|||||||
|
|
||||||
self.processing = False
|
self.processing = False
|
||||||
|
|
||||||
if self.hs.config.email_enable_notifs:
|
|
||||||
if 'data' in pusherdict and 'brand' in pusherdict['data']:
|
|
||||||
app_name = pusherdict['data']['brand']
|
|
||||||
else:
|
|
||||||
app_name = self.hs.config.email_app_name
|
|
||||||
|
|
||||||
self.mailer = Mailer(self.hs, app_name)
|
|
||||||
else:
|
|
||||||
self.mailer = None
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_started(self):
|
def on_started(self):
|
||||||
if self.mailer is not None:
|
if self.mailer is not None:
|
||||||
|
@ -275,7 +275,7 @@ class HttpPusher(object):
|
|||||||
if event.type == 'm.room.member':
|
if event.type == 'm.room.member':
|
||||||
d['notification']['membership'] = event.content['membership']
|
d['notification']['membership'] = event.content['membership']
|
||||||
d['notification']['user_is_target'] = event.state_key == self.user_id
|
d['notification']['user_is_target'] = event.state_key == self.user_id
|
||||||
if 'content' in event:
|
if not self.hs.config.push_redact_content and 'content' in event:
|
||||||
d['notification']['content'] = event.content
|
d['notification']['content'] = event.content
|
||||||
|
|
||||||
# We no longer send aliases separately, instead, we send the human
|
# We no longer send aliases separately, instead, we send the human
|
||||||
|
@ -78,23 +78,17 @@ ALLOWED_ATTRS = {
|
|||||||
|
|
||||||
|
|
||||||
class Mailer(object):
|
class Mailer(object):
|
||||||
def __init__(self, hs, app_name):
|
def __init__(self, hs, app_name, notif_template_html, notif_template_text):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
self.notif_template_html = notif_template_html
|
||||||
|
self.notif_template_text = notif_template_text
|
||||||
|
|
||||||
self.store = self.hs.get_datastore()
|
self.store = self.hs.get_datastore()
|
||||||
self.macaroon_gen = self.hs.get_macaroon_generator()
|
self.macaroon_gen = self.hs.get_macaroon_generator()
|
||||||
self.state_handler = self.hs.get_state_handler()
|
self.state_handler = self.hs.get_state_handler()
|
||||||
loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir)
|
|
||||||
self.app_name = app_name
|
self.app_name = app_name
|
||||||
|
|
||||||
logger.info("Created Mailer for app_name %s" % app_name)
|
logger.info("Created Mailer for app_name %s" % app_name)
|
||||||
env = jinja2.Environment(loader=loader)
|
|
||||||
env.filters["format_ts"] = format_ts_filter
|
|
||||||
env.filters["mxc_to_http"] = self.mxc_to_http_filter
|
|
||||||
self.notif_template_html = env.get_template(
|
|
||||||
self.hs.config.email_notif_template_html
|
|
||||||
)
|
|
||||||
self.notif_template_text = env.get_template(
|
|
||||||
self.hs.config.email_notif_template_text
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def send_notification_mail(self, app_id, user_id, email_address,
|
def send_notification_mail(self, app_id, user_id, email_address,
|
||||||
@ -481,28 +475,6 @@ class Mailer(object):
|
|||||||
urllib.urlencode(params),
|
urllib.urlencode(params),
|
||||||
)
|
)
|
||||||
|
|
||||||
def mxc_to_http_filter(self, value, width, height, resize_method="crop"):
|
|
||||||
if value[0:6] != "mxc://":
|
|
||||||
return ""
|
|
||||||
|
|
||||||
serverAndMediaId = value[6:]
|
|
||||||
fragment = None
|
|
||||||
if '#' in serverAndMediaId:
|
|
||||||
(serverAndMediaId, fragment) = serverAndMediaId.split('#', 1)
|
|
||||||
fragment = "#" + fragment
|
|
||||||
|
|
||||||
params = {
|
|
||||||
"width": width,
|
|
||||||
"height": height,
|
|
||||||
"method": resize_method,
|
|
||||||
}
|
|
||||||
return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
|
|
||||||
self.hs.config.public_baseurl,
|
|
||||||
serverAndMediaId,
|
|
||||||
urllib.urlencode(params),
|
|
||||||
fragment or "",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def safe_markup(raw_html):
|
def safe_markup(raw_html):
|
||||||
return jinja2.Markup(bleach.linkify(bleach.clean(
|
return jinja2.Markup(bleach.linkify(bleach.clean(
|
||||||
@ -543,3 +515,52 @@ def string_ordinal_total(s):
|
|||||||
|
|
||||||
def format_ts_filter(value, format):
|
def format_ts_filter(value, format):
|
||||||
return time.strftime(format, time.localtime(value / 1000))
|
return time.strftime(format, time.localtime(value / 1000))
|
||||||
|
|
||||||
|
|
||||||
|
def load_jinja2_templates(config):
|
||||||
|
"""Load the jinja2 email templates from disk
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(notif_template_html, notif_template_text)
|
||||||
|
"""
|
||||||
|
logger.info("loading jinja2")
|
||||||
|
|
||||||
|
loader = jinja2.FileSystemLoader(config.email_template_dir)
|
||||||
|
env = jinja2.Environment(loader=loader)
|
||||||
|
env.filters["format_ts"] = format_ts_filter
|
||||||
|
env.filters["mxc_to_http"] = _create_mxc_to_http_filter(config)
|
||||||
|
|
||||||
|
notif_template_html = env.get_template(
|
||||||
|
config.email_notif_template_html
|
||||||
|
)
|
||||||
|
notif_template_text = env.get_template(
|
||||||
|
config.email_notif_template_text
|
||||||
|
)
|
||||||
|
|
||||||
|
return notif_template_html, notif_template_text
|
||||||
|
|
||||||
|
|
||||||
|
def _create_mxc_to_http_filter(config):
|
||||||
|
def mxc_to_http_filter(value, width, height, resize_method="crop"):
|
||||||
|
if value[0:6] != "mxc://":
|
||||||
|
return ""
|
||||||
|
|
||||||
|
serverAndMediaId = value[6:]
|
||||||
|
fragment = None
|
||||||
|
if '#' in serverAndMediaId:
|
||||||
|
(serverAndMediaId, fragment) = serverAndMediaId.split('#', 1)
|
||||||
|
fragment = "#" + fragment
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"width": width,
|
||||||
|
"height": height,
|
||||||
|
"method": resize_method,
|
||||||
|
}
|
||||||
|
return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
|
||||||
|
config.public_baseurl,
|
||||||
|
serverAndMediaId,
|
||||||
|
urllib.urlencode(params),
|
||||||
|
fragment or "",
|
||||||
|
)
|
||||||
|
|
||||||
|
return mxc_to_http_filter
|
||||||
|
@ -26,22 +26,54 @@ logger = logging.getLogger(__name__)
|
|||||||
# process works fine)
|
# process works fine)
|
||||||
try:
|
try:
|
||||||
from synapse.push.emailpusher import EmailPusher
|
from synapse.push.emailpusher import EmailPusher
|
||||||
|
from synapse.push.mailer import Mailer, load_jinja2_templates
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def create_pusher(hs, pusherdict):
|
class PusherFactory(object):
|
||||||
logger.info("trying to create_pusher for %r", pusherdict)
|
def __init__(self, hs):
|
||||||
|
self.hs = hs
|
||||||
|
|
||||||
PUSHER_TYPES = {
|
self.pusher_types = {
|
||||||
"http": HttpPusher,
|
"http": HttpPusher,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
|
logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
|
||||||
if hs.config.email_enable_notifs:
|
if hs.config.email_enable_notifs:
|
||||||
PUSHER_TYPES["email"] = EmailPusher
|
self.mailers = {} # app_name -> Mailer
|
||||||
|
|
||||||
|
templates = load_jinja2_templates(hs.config)
|
||||||
|
self.notif_template_html, self.notif_template_text = templates
|
||||||
|
|
||||||
|
self.pusher_types["email"] = self._create_email_pusher
|
||||||
|
|
||||||
logger.info("defined email pusher type")
|
logger.info("defined email pusher type")
|
||||||
|
|
||||||
if pusherdict['kind'] in PUSHER_TYPES:
|
def create_pusher(self, pusherdict):
|
||||||
|
logger.info("trying to create_pusher for %r", pusherdict)
|
||||||
|
|
||||||
|
if pusherdict['kind'] in self.pusher_types:
|
||||||
logger.info("found pusher")
|
logger.info("found pusher")
|
||||||
return PUSHER_TYPES[pusherdict['kind']](hs, pusherdict)
|
return self.pusher_types[pusherdict['kind']](self.hs, pusherdict)
|
||||||
|
|
||||||
|
def _create_email_pusher(self, _hs, pusherdict):
|
||||||
|
app_name = self._app_name_from_pusherdict(pusherdict)
|
||||||
|
mailer = self.mailers.get(app_name)
|
||||||
|
if not mailer:
|
||||||
|
mailer = Mailer(
|
||||||
|
hs=self.hs,
|
||||||
|
app_name=app_name,
|
||||||
|
notif_template_html=self.notif_template_html,
|
||||||
|
notif_template_text=self.notif_template_text,
|
||||||
|
)
|
||||||
|
self.mailers[app_name] = mailer
|
||||||
|
return EmailPusher(self.hs, pusherdict, mailer)
|
||||||
|
|
||||||
|
def _app_name_from_pusherdict(self, pusherdict):
|
||||||
|
if 'data' in pusherdict and 'brand' in pusherdict['data']:
|
||||||
|
app_name = pusherdict['data']['brand']
|
||||||
|
else:
|
||||||
|
app_name = self.hs.config.email_app_name
|
||||||
|
|
||||||
|
return app_name
|
||||||
|
@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
import pusher
|
from .pusher import PusherFactory
|
||||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
|
|
||||||
@ -28,6 +28,7 @@ logger = logging.getLogger(__name__)
|
|||||||
class PusherPool:
|
class PusherPool:
|
||||||
def __init__(self, _hs):
|
def __init__(self, _hs):
|
||||||
self.hs = _hs
|
self.hs = _hs
|
||||||
|
self.pusher_factory = PusherFactory(_hs)
|
||||||
self.start_pushers = _hs.config.start_pushers
|
self.start_pushers = _hs.config.start_pushers
|
||||||
self.store = self.hs.get_datastore()
|
self.store = self.hs.get_datastore()
|
||||||
self.clock = self.hs.get_clock()
|
self.clock = self.hs.get_clock()
|
||||||
@ -48,7 +49,7 @@ class PusherPool:
|
|||||||
# will then get pulled out of the database,
|
# will then get pulled out of the database,
|
||||||
# recreated, added and started: this means we have only one
|
# recreated, added and started: this means we have only one
|
||||||
# code path adding pushers.
|
# code path adding pushers.
|
||||||
pusher.create_pusher(self.hs, {
|
self.pusher_factory.create_pusher({
|
||||||
"id": None,
|
"id": None,
|
||||||
"user_name": user_id,
|
"user_name": user_id,
|
||||||
"kind": kind,
|
"kind": kind,
|
||||||
@ -186,7 +187,7 @@ class PusherPool:
|
|||||||
logger.info("Starting %d pushers", len(pushers))
|
logger.info("Starting %d pushers", len(pushers))
|
||||||
for pusherdict in pushers:
|
for pusherdict in pushers:
|
||||||
try:
|
try:
|
||||||
p = pusher.create_pusher(self.hs, pusherdict)
|
p = self.pusher_factory.create_pusher(pusherdict)
|
||||||
except:
|
except:
|
||||||
logger.exception("Couldn't start a pusher: caught Exception")
|
logger.exception("Couldn't start a pusher: caught Exception")
|
||||||
continue
|
continue
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
from ._base import BaseSlavedStore
|
from ._base import BaseSlavedStore
|
||||||
from synapse.storage import DataStore
|
from synapse.storage import DataStore
|
||||||
from synapse.config.appservice import load_appservices
|
from synapse.config.appservice import load_appservices
|
||||||
|
from synapse.storage.appservice import _make_exclusive_regex
|
||||||
|
|
||||||
|
|
||||||
class SlavedApplicationServiceStore(BaseSlavedStore):
|
class SlavedApplicationServiceStore(BaseSlavedStore):
|
||||||
@ -25,6 +26,7 @@ class SlavedApplicationServiceStore(BaseSlavedStore):
|
|||||||
hs.config.server_name,
|
hs.config.server_name,
|
||||||
hs.config.app_service_config_files
|
hs.config.app_service_config_files
|
||||||
)
|
)
|
||||||
|
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
|
||||||
|
|
||||||
get_app_service_by_token = DataStore.get_app_service_by_token.__func__
|
get_app_service_by_token = DataStore.get_app_service_by_token.__func__
|
||||||
get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__
|
get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__
|
||||||
@ -38,3 +40,6 @@ class SlavedApplicationServiceStore(BaseSlavedStore):
|
|||||||
get_appservice_state = DataStore.get_appservice_state.__func__
|
get_appservice_state = DataStore.get_appservice_state.__func__
|
||||||
set_appservice_last_pos = DataStore.set_appservice_last_pos.__func__
|
set_appservice_last_pos = DataStore.set_appservice_last_pos.__func__
|
||||||
set_appservice_state = DataStore.set_appservice_state.__func__
|
set_appservice_state = DataStore.set_appservice_state.__func__
|
||||||
|
get_if_app_services_interested_in_user = (
|
||||||
|
DataStore.get_if_app_services_interested_in_user.__func__
|
||||||
|
)
|
||||||
|
48
synapse/replication/slave/storage/client_ips.py
Normal file
48
synapse/replication/slave/storage/client_ips.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2017 Vector Creations Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ._base import BaseSlavedStore
|
||||||
|
from synapse.storage.client_ips import LAST_SEEN_GRANULARITY
|
||||||
|
from synapse.util.caches import CACHE_SIZE_FACTOR
|
||||||
|
from synapse.util.caches.descriptors import Cache
|
||||||
|
|
||||||
|
|
||||||
|
class SlavedClientIpStore(BaseSlavedStore):
|
||||||
|
def __init__(self, db_conn, hs):
|
||||||
|
super(SlavedClientIpStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
|
self.client_ip_last_seen = Cache(
|
||||||
|
name="client_ip_last_seen",
|
||||||
|
keylen=4,
|
||||||
|
max_entries=50000 * CACHE_SIZE_FACTOR,
|
||||||
|
)
|
||||||
|
|
||||||
|
def insert_client_ip(self, user, access_token, ip, user_agent, device_id):
|
||||||
|
now = int(self._clock.time_msec())
|
||||||
|
user_id = user.to_string()
|
||||||
|
key = (user_id, access_token, ip)
|
||||||
|
|
||||||
|
try:
|
||||||
|
last_seen = self.client_ip_last_seen.get(key)
|
||||||
|
except KeyError:
|
||||||
|
last_seen = None
|
||||||
|
|
||||||
|
# Rate-limited inserts
|
||||||
|
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.hs.get_tcp_replication().send_user_ip(
|
||||||
|
user_id, access_token, ip, user_agent, device_id, now
|
||||||
|
)
|
@ -16,6 +16,7 @@
|
|||||||
from ._base import BaseSlavedStore
|
from ._base import BaseSlavedStore
|
||||||
from ._slaved_id_tracker import SlavedIdTracker
|
from ._slaved_id_tracker import SlavedIdTracker
|
||||||
from synapse.storage import DataStore
|
from synapse.storage import DataStore
|
||||||
|
from synapse.storage.end_to_end_keys import EndToEndKeyStore
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
|
|
||||||
|
|
||||||
@ -45,6 +46,7 @@ class SlavedDeviceStore(BaseSlavedStore):
|
|||||||
_mark_as_sent_devices_by_remote_txn = (
|
_mark_as_sent_devices_by_remote_txn = (
|
||||||
DataStore._mark_as_sent_devices_by_remote_txn.__func__
|
DataStore._mark_as_sent_devices_by_remote_txn.__func__
|
||||||
)
|
)
|
||||||
|
count_e2e_one_time_keys = EndToEndKeyStore.__dict__["count_e2e_one_time_keys"]
|
||||||
|
|
||||||
def stream_positions(self):
|
def stream_positions(self):
|
||||||
result = super(SlavedDeviceStore, self).stream_positions()
|
result = super(SlavedDeviceStore, self).stream_positions()
|
||||||
|
@ -108,6 +108,8 @@ class SlavedEventStore(BaseSlavedStore):
|
|||||||
get_current_state_ids = (
|
get_current_state_ids = (
|
||||||
StateStore.__dict__["get_current_state_ids"]
|
StateStore.__dict__["get_current_state_ids"]
|
||||||
)
|
)
|
||||||
|
get_state_group_delta = StateStore.__dict__["get_state_group_delta"]
|
||||||
|
_get_joined_hosts_cache = RoomMemberStore.__dict__["_get_joined_hosts_cache"]
|
||||||
has_room_changed_since = DataStore.has_room_changed_since.__func__
|
has_room_changed_since = DataStore.has_room_changed_since.__func__
|
||||||
|
|
||||||
get_unread_push_actions_for_user_in_range_for_http = (
|
get_unread_push_actions_for_user_in_range_for_http = (
|
||||||
@ -151,8 +153,7 @@ class SlavedEventStore(BaseSlavedStore):
|
|||||||
get_room_events_stream_for_rooms = (
|
get_room_events_stream_for_rooms = (
|
||||||
DataStore.get_room_events_stream_for_rooms.__func__
|
DataStore.get_room_events_stream_for_rooms.__func__
|
||||||
)
|
)
|
||||||
is_host_joined = DataStore.is_host_joined.__func__
|
is_host_joined = RoomMemberStore.__dict__["is_host_joined"]
|
||||||
_is_host_joined = RoomMemberStore.__dict__["_is_host_joined"]
|
|
||||||
get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__
|
get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__
|
||||||
|
|
||||||
_set_before_and_after = staticmethod(DataStore._set_before_and_after)
|
_set_before_and_after = staticmethod(DataStore._set_before_and_after)
|
||||||
|
@ -20,6 +20,7 @@ from twisted.internet.protocol import ReconnectingClientFactory
|
|||||||
|
|
||||||
from .commands import (
|
from .commands import (
|
||||||
FederationAckCommand, UserSyncCommand, RemovePusherCommand, InvalidateCacheCommand,
|
FederationAckCommand, UserSyncCommand, RemovePusherCommand, InvalidateCacheCommand,
|
||||||
|
UserIpCommand,
|
||||||
)
|
)
|
||||||
from .protocol import ClientReplicationStreamProtocol
|
from .protocol import ClientReplicationStreamProtocol
|
||||||
|
|
||||||
@ -178,6 +179,12 @@ class ReplicationClientHandler(object):
|
|||||||
cmd = InvalidateCacheCommand(cache_func.__name__, keys)
|
cmd = InvalidateCacheCommand(cache_func.__name__, keys)
|
||||||
self.send_command(cmd)
|
self.send_command(cmd)
|
||||||
|
|
||||||
|
def send_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
|
||||||
|
"""Tell the master that the user made a request.
|
||||||
|
"""
|
||||||
|
cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
|
||||||
|
self.send_command(cmd)
|
||||||
|
|
||||||
def await_sync(self, data):
|
def await_sync(self, data):
|
||||||
"""Returns a deferred that is resolved when we receive a SYNC command
|
"""Returns a deferred that is resolved when we receive a SYNC command
|
||||||
with given data.
|
with given data.
|
||||||
|
@ -304,6 +304,36 @@ class InvalidateCacheCommand(Command):
|
|||||||
return " ".join((self.cache_func, json.dumps(self.keys)))
|
return " ".join((self.cache_func, json.dumps(self.keys)))
|
||||||
|
|
||||||
|
|
||||||
|
class UserIpCommand(Command):
|
||||||
|
"""Sent periodically when a worker sees activity from a client.
|
||||||
|
|
||||||
|
Format::
|
||||||
|
|
||||||
|
USER_IP <user_id>, <access_token>, <ip>, <device_id>, <last_seen>, <user_agent>
|
||||||
|
"""
|
||||||
|
NAME = "USER_IP"
|
||||||
|
|
||||||
|
def __init__(self, user_id, access_token, ip, user_agent, device_id, last_seen):
|
||||||
|
self.user_id = user_id
|
||||||
|
self.access_token = access_token
|
||||||
|
self.ip = ip
|
||||||
|
self.user_agent = user_agent
|
||||||
|
self.device_id = device_id
|
||||||
|
self.last_seen = last_seen
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_line(cls, line):
|
||||||
|
user_id, access_token, ip, device_id, last_seen, user_agent = line.split(" ", 5)
|
||||||
|
|
||||||
|
return cls(user_id, access_token, ip, user_agent, device_id, int(last_seen))
|
||||||
|
|
||||||
|
def to_line(self):
|
||||||
|
return " ".join((
|
||||||
|
self.user_id, self.access_token, self.ip, self.device_id,
|
||||||
|
str(self.last_seen), self.user_agent,
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
# Map of command name to command type.
|
# Map of command name to command type.
|
||||||
COMMAND_MAP = {
|
COMMAND_MAP = {
|
||||||
cmd.NAME: cmd
|
cmd.NAME: cmd
|
||||||
@ -320,6 +350,7 @@ COMMAND_MAP = {
|
|||||||
SyncCommand,
|
SyncCommand,
|
||||||
RemovePusherCommand,
|
RemovePusherCommand,
|
||||||
InvalidateCacheCommand,
|
InvalidateCacheCommand,
|
||||||
|
UserIpCommand,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -342,5 +373,6 @@ VALID_CLIENT_COMMANDS = (
|
|||||||
FederationAckCommand.NAME,
|
FederationAckCommand.NAME,
|
||||||
RemovePusherCommand.NAME,
|
RemovePusherCommand.NAME,
|
||||||
InvalidateCacheCommand.NAME,
|
InvalidateCacheCommand.NAME,
|
||||||
|
UserIpCommand.NAME,
|
||||||
ErrorCommand.NAME,
|
ErrorCommand.NAME,
|
||||||
)
|
)
|
||||||
|
@ -406,6 +406,12 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||||||
def on_INVALIDATE_CACHE(self, cmd):
|
def on_INVALIDATE_CACHE(self, cmd):
|
||||||
self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
|
self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
|
||||||
|
|
||||||
|
def on_USER_IP(self, cmd):
|
||||||
|
self.streamer.on_user_ip(
|
||||||
|
cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id,
|
||||||
|
cmd.last_seen,
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def subscribe_to_stream(self, stream_name, token):
|
def subscribe_to_stream(self, stream_name, token):
|
||||||
"""Subscribe the remote to a streams.
|
"""Subscribe the remote to a streams.
|
||||||
|
@ -35,6 +35,7 @@ user_sync_counter = metrics.register_counter("user_sync")
|
|||||||
federation_ack_counter = metrics.register_counter("federation_ack")
|
federation_ack_counter = metrics.register_counter("federation_ack")
|
||||||
remove_pusher_counter = metrics.register_counter("remove_pusher")
|
remove_pusher_counter = metrics.register_counter("remove_pusher")
|
||||||
invalidate_cache_counter = metrics.register_counter("invalidate_cache")
|
invalidate_cache_counter = metrics.register_counter("invalidate_cache")
|
||||||
|
user_ip_cache_counter = metrics.register_counter("user_ip_cache")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -67,6 +68,7 @@ class ReplicationStreamer(object):
|
|||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.presence_handler = hs.get_presence_handler()
|
self.presence_handler = hs.get_presence_handler()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
self.notifier = hs.get_notifier()
|
||||||
|
|
||||||
# Current connections.
|
# Current connections.
|
||||||
self.connections = []
|
self.connections = []
|
||||||
@ -99,7 +101,7 @@ class ReplicationStreamer(object):
|
|||||||
if not hs.config.send_federation:
|
if not hs.config.send_federation:
|
||||||
self.federation_sender = hs.get_federation_sender()
|
self.federation_sender = hs.get_federation_sender()
|
||||||
|
|
||||||
hs.get_notifier().add_replication_callback(self.on_notifier_poke)
|
self.notifier.add_replication_callback(self.on_notifier_poke)
|
||||||
|
|
||||||
# Keeps track of whether we are currently checking for updates
|
# Keeps track of whether we are currently checking for updates
|
||||||
self.is_looping = False
|
self.is_looping = False
|
||||||
@ -237,6 +239,15 @@ class ReplicationStreamer(object):
|
|||||||
invalidate_cache_counter.inc()
|
invalidate_cache_counter.inc()
|
||||||
getattr(self.store, cache_func).invalidate(tuple(keys))
|
getattr(self.store, cache_func).invalidate(tuple(keys))
|
||||||
|
|
||||||
|
@measure_func("repl.on_user_ip")
|
||||||
|
def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
|
||||||
|
"""The client saw a user request
|
||||||
|
"""
|
||||||
|
user_ip_cache_counter.inc()
|
||||||
|
self.store.insert_client_ip(
|
||||||
|
user_id, access_token, ip, user_agent, device_id, last_seen,
|
||||||
|
)
|
||||||
|
|
||||||
def send_sync_to_all_connections(self, data):
|
def send_sync_to_all_connections(self, data):
|
||||||
"""Sends a SYNC command to all clients.
|
"""Sends a SYNC command to all clients.
|
||||||
|
|
||||||
|
@ -112,6 +112,12 @@ AccountDataStreamRow = namedtuple("AccountDataStream", (
|
|||||||
"data_type", # str
|
"data_type", # str
|
||||||
"data", # dict
|
"data", # dict
|
||||||
))
|
))
|
||||||
|
CurrentStateDeltaStreamRow = namedtuple("CurrentStateDeltaStream", (
|
||||||
|
"room_id", # str
|
||||||
|
"type", # str
|
||||||
|
"state_key", # str
|
||||||
|
"event_id", # str, optional
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
class Stream(object):
|
class Stream(object):
|
||||||
@ -443,6 +449,21 @@ class AccountDataStream(Stream):
|
|||||||
defer.returnValue(results)
|
defer.returnValue(results)
|
||||||
|
|
||||||
|
|
||||||
|
class CurrentStateDeltaStream(Stream):
|
||||||
|
"""Current state for a room was changed
|
||||||
|
"""
|
||||||
|
NAME = "current_state_deltas"
|
||||||
|
ROW_TYPE = CurrentStateDeltaStreamRow
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
store = hs.get_datastore()
|
||||||
|
|
||||||
|
self.current_token = store.get_max_current_state_delta_stream_id
|
||||||
|
self.update_function = store.get_all_updated_current_state_deltas
|
||||||
|
|
||||||
|
super(CurrentStateDeltaStream, self).__init__(hs)
|
||||||
|
|
||||||
|
|
||||||
STREAMS_MAP = {
|
STREAMS_MAP = {
|
||||||
stream.NAME: stream
|
stream.NAME: stream
|
||||||
for stream in (
|
for stream in (
|
||||||
@ -460,5 +481,6 @@ STREAMS_MAP = {
|
|||||||
FederationStream,
|
FederationStream,
|
||||||
TagAccountDataStream,
|
TagAccountDataStream,
|
||||||
AccountDataStream,
|
AccountDataStream,
|
||||||
|
CurrentStateDeltaStream,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -51,6 +51,7 @@ from synapse.rest.client.v2_alpha import (
|
|||||||
devices,
|
devices,
|
||||||
thirdparty,
|
thirdparty,
|
||||||
sendtodevice,
|
sendtodevice,
|
||||||
|
user_directory,
|
||||||
)
|
)
|
||||||
|
|
||||||
from synapse.http.server import JsonResource
|
from synapse.http.server import JsonResource
|
||||||
@ -100,3 +101,4 @@ class ClientRestResource(JsonResource):
|
|||||||
devices.register_servlets(hs, client_resource)
|
devices.register_servlets(hs, client_resource)
|
||||||
thirdparty.register_servlets(hs, client_resource)
|
thirdparty.register_servlets(hs, client_resource)
|
||||||
sendtodevice.register_servlets(hs, client_resource)
|
sendtodevice.register_servlets(hs, client_resource)
|
||||||
|
user_directory.register_servlets(hs, client_resource)
|
||||||
|
@ -15,8 +15,9 @@
|
|||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api.constants import Membership
|
||||||
from synapse.api.errors import AuthError, SynapseError
|
from synapse.api.errors import AuthError, SynapseError
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID, create_requester
|
||||||
from synapse.http.servlet import parse_json_object_from_request
|
from synapse.http.servlet import parse_json_object_from_request
|
||||||
|
|
||||||
from .base import ClientV1RestServlet, client_path_patterns
|
from .base import ClientV1RestServlet, client_path_patterns
|
||||||
@ -157,6 +158,142 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
|
|||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
|
||||||
|
class ShutdownRoomRestServlet(ClientV1RestServlet):
|
||||||
|
"""Shuts down a room by removing all local users from the room and blocking
|
||||||
|
all future invites and joins to the room. Any local aliases will be repointed
|
||||||
|
to a new room created by `new_room_user_id` and kicked users will be auto
|
||||||
|
joined to the new room.
|
||||||
|
"""
|
||||||
|
PATTERNS = client_path_patterns("/admin/shutdown_room/(?P<room_id>[^/]+)")
|
||||||
|
|
||||||
|
DEFAULT_MESSAGE = (
|
||||||
|
"Sharing illegal content on this server is not permitted and rooms in"
|
||||||
|
" violatation will be blocked."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(ShutdownRoomRestServlet, self).__init__(hs)
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.handlers = hs.get_handlers()
|
||||||
|
self.state = hs.get_state_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request, room_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
is_admin = yield self.auth.is_server_admin(requester.user)
|
||||||
|
if not is_admin:
|
||||||
|
raise AuthError(403, "You are not a server admin")
|
||||||
|
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
new_room_user_id = content.get("new_room_user_id")
|
||||||
|
if not new_room_user_id:
|
||||||
|
raise SynapseError(400, "Please provide field `new_room_user_id`")
|
||||||
|
|
||||||
|
room_creator_requester = create_requester(new_room_user_id)
|
||||||
|
|
||||||
|
message = content.get("message", self.DEFAULT_MESSAGE)
|
||||||
|
room_name = content.get("room_name", "Content Violation Notification")
|
||||||
|
|
||||||
|
info = yield self.handlers.room_creation_handler.create_room(
|
||||||
|
room_creator_requester,
|
||||||
|
config={
|
||||||
|
"preset": "public_chat",
|
||||||
|
"name": room_name,
|
||||||
|
"power_level_content_override": {
|
||||||
|
"users_default": -10,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ratelimit=False,
|
||||||
|
)
|
||||||
|
new_room_id = info["room_id"]
|
||||||
|
|
||||||
|
msg_handler = self.handlers.message_handler
|
||||||
|
yield msg_handler.create_and_send_nonmember_event(
|
||||||
|
room_creator_requester,
|
||||||
|
{
|
||||||
|
"type": "m.room.message",
|
||||||
|
"content": {"body": message, "msgtype": "m.text"},
|
||||||
|
"room_id": new_room_id,
|
||||||
|
"sender": new_room_user_id,
|
||||||
|
},
|
||||||
|
ratelimit=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
logger.info("Shutting down room %r", room_id)
|
||||||
|
|
||||||
|
yield self.store.block_room(room_id, requester_user_id)
|
||||||
|
|
||||||
|
users = yield self.state.get_current_user_in_room(room_id)
|
||||||
|
kicked_users = []
|
||||||
|
for user_id in users:
|
||||||
|
if not self.hs.is_mine_id(user_id):
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info("Kicking %r from %r...", user_id, room_id)
|
||||||
|
|
||||||
|
target_requester = create_requester(user_id)
|
||||||
|
yield self.handlers.room_member_handler.update_membership(
|
||||||
|
requester=target_requester,
|
||||||
|
target=target_requester.user,
|
||||||
|
room_id=room_id,
|
||||||
|
action=Membership.LEAVE,
|
||||||
|
content={},
|
||||||
|
ratelimit=False
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self.handlers.room_member_handler.forget(target_requester.user, room_id)
|
||||||
|
|
||||||
|
yield self.handlers.room_member_handler.update_membership(
|
||||||
|
requester=target_requester,
|
||||||
|
target=target_requester.user,
|
||||||
|
room_id=new_room_id,
|
||||||
|
action=Membership.JOIN,
|
||||||
|
content={},
|
||||||
|
ratelimit=False
|
||||||
|
)
|
||||||
|
|
||||||
|
kicked_users.append(user_id)
|
||||||
|
|
||||||
|
aliases_for_room = yield self.store.get_aliases_for_room(room_id)
|
||||||
|
|
||||||
|
yield self.store.update_aliases_for_room(
|
||||||
|
room_id, new_room_id, requester_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, {
|
||||||
|
"kicked_users": kicked_users,
|
||||||
|
"local_aliases": aliases_for_room,
|
||||||
|
"new_room_id": new_room_id,
|
||||||
|
}))
|
||||||
|
|
||||||
|
|
||||||
|
class QuarantineMediaInRoom(ClientV1RestServlet):
|
||||||
|
"""Quarantines all media in a room so that no one can download it via
|
||||||
|
this server.
|
||||||
|
"""
|
||||||
|
PATTERNS = client_path_patterns("/admin/quarantine_media/(?P<room_id>[^/]+)")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(QuarantineMediaInRoom, self).__init__(hs)
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request, room_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
is_admin = yield self.auth.is_server_admin(requester.user)
|
||||||
|
if not is_admin:
|
||||||
|
raise AuthError(403, "You are not a server admin")
|
||||||
|
|
||||||
|
num_quarantined = yield self.store.quarantine_media_ids_in_room(
|
||||||
|
room_id, requester.user.to_string(),
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, {"num_quarantined": num_quarantined}))
|
||||||
|
|
||||||
|
|
||||||
class ResetPasswordRestServlet(ClientV1RestServlet):
|
class ResetPasswordRestServlet(ClientV1RestServlet):
|
||||||
"""Post request to allow an administrator reset password for a user.
|
"""Post request to allow an administrator reset password for a user.
|
||||||
This need a user have a administrator access in Synapse.
|
This need a user have a administrator access in Synapse.
|
||||||
@ -353,3 +490,5 @@ def register_servlets(hs, http_server):
|
|||||||
ResetPasswordRestServlet(hs).register(http_server)
|
ResetPasswordRestServlet(hs).register(http_server)
|
||||||
GetUsersPaginatedRestServlet(hs).register(http_server)
|
GetUsersPaginatedRestServlet(hs).register(http_server)
|
||||||
SearchUsersRestServlet(hs).register(http_server)
|
SearchUsersRestServlet(hs).register(http_server)
|
||||||
|
ShutdownRoomRestServlet(hs).register(http_server)
|
||||||
|
QuarantineMediaInRoom(hs).register(http_server)
|
||||||
|
@ -192,6 +192,7 @@ class SyncRestServlet(RestServlet):
|
|||||||
"invite": invited,
|
"invite": invited,
|
||||||
"leave": archived,
|
"leave": archived,
|
||||||
},
|
},
|
||||||
|
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
|
||||||
"next_batch": sync_result.next_batch.to_string(),
|
"next_batch": sync_result.next_batch.to_string(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
79
synapse/rest/client/v2_alpha/user_directory.py
Normal file
79
synapse/rest/client/v2_alpha/user_directory.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2017 Vector Creations Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api.errors import SynapseError
|
||||||
|
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||||
|
from ._base import client_v2_patterns
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class UserDirectorySearchRestServlet(RestServlet):
|
||||||
|
PATTERNS = client_v2_patterns("/user_directory/search$")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hs (synapse.server.HomeServer): server
|
||||||
|
"""
|
||||||
|
super(UserDirectorySearchRestServlet, self).__init__()
|
||||||
|
self.hs = hs
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.user_directory_handler = hs.get_user_directory_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request):
|
||||||
|
"""Searches for users in directory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict of the form::
|
||||||
|
|
||||||
|
{
|
||||||
|
"limited": <bool>, # whether there were more results or not
|
||||||
|
"results": [ # Ordered by best match first
|
||||||
|
{
|
||||||
|
"user_id": <user_id>,
|
||||||
|
"display_name": <display_name>,
|
||||||
|
"avatar_url": <avatar_url>
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
requester = yield self.auth.get_user_by_req(request, allow_guest=False)
|
||||||
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
limit = body.get("limit", 10)
|
||||||
|
limit = min(limit, 50)
|
||||||
|
|
||||||
|
try:
|
||||||
|
search_term = body["search_term"]
|
||||||
|
except:
|
||||||
|
raise SynapseError(400, "`search_term` is required field")
|
||||||
|
|
||||||
|
results = yield self.user_directory_handler.search_users(
|
||||||
|
user_id, search_term, limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, results))
|
||||||
|
|
||||||
|
|
||||||
|
def register_servlets(hs, http_server):
|
||||||
|
UserDirectorySearchRestServlet(hs).register(http_server)
|
@ -66,13 +66,18 @@ class DownloadResource(Resource):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _respond_local_file(self, request, media_id, name):
|
def _respond_local_file(self, request, media_id, name):
|
||||||
media_info = yield self.store.get_local_media(media_id)
|
media_info = yield self.store.get_local_media(media_id)
|
||||||
if not media_info:
|
if not media_info or media_info["quarantined_by"]:
|
||||||
respond_404(request)
|
respond_404(request)
|
||||||
return
|
return
|
||||||
|
|
||||||
media_type = media_info["media_type"]
|
media_type = media_info["media_type"]
|
||||||
media_length = media_info["media_length"]
|
media_length = media_info["media_length"]
|
||||||
upload_name = name if name else media_info["upload_name"]
|
upload_name = name if name else media_info["upload_name"]
|
||||||
|
if media_info["url_cache"]:
|
||||||
|
# TODO: Check the file still exists, if it doesn't we can redownload
|
||||||
|
# it from the url `media_info["url_cache"]`
|
||||||
|
file_path = self.filepaths.url_cache_filepath(media_id)
|
||||||
|
else:
|
||||||
file_path = self.filepaths.local_media_filepath(media_id)
|
file_path = self.filepaths.local_media_filepath(media_id)
|
||||||
|
|
||||||
yield respond_with_file(
|
yield respond_with_file(
|
||||||
|
@ -71,3 +71,21 @@ class MediaFilePaths(object):
|
|||||||
self.base_path, "remote_thumbnail", server_name,
|
self.base_path, "remote_thumbnail", server_name,
|
||||||
file_id[0:2], file_id[2:4], file_id[4:],
|
file_id[0:2], file_id[2:4], file_id[4:],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def url_cache_filepath(self, media_id):
|
||||||
|
return os.path.join(
|
||||||
|
self.base_path, "url_cache",
|
||||||
|
media_id[0:2], media_id[2:4], media_id[4:]
|
||||||
|
)
|
||||||
|
|
||||||
|
def url_cache_thumbnail(self, media_id, width, height, content_type,
|
||||||
|
method):
|
||||||
|
top_level_type, sub_type = content_type.split("/")
|
||||||
|
file_name = "%i-%i-%s-%s-%s" % (
|
||||||
|
width, height, top_level_type, sub_type, method
|
||||||
|
)
|
||||||
|
return os.path.join(
|
||||||
|
self.base_path, "url_cache_thumbnails",
|
||||||
|
media_id[0:2], media_id[2:4], media_id[4:],
|
||||||
|
file_name
|
||||||
|
)
|
||||||
|
@ -135,6 +135,8 @@ class MediaRepository(object):
|
|||||||
media_info = yield self._download_remote_file(
|
media_info = yield self._download_remote_file(
|
||||||
server_name, media_id
|
server_name, media_id
|
||||||
)
|
)
|
||||||
|
elif media_info["quarantined_by"]:
|
||||||
|
raise NotFoundError()
|
||||||
else:
|
else:
|
||||||
self.recently_accessed_remotes.add((server_name, media_id))
|
self.recently_accessed_remotes.add((server_name, media_id))
|
||||||
yield self.store.update_cached_last_access_time(
|
yield self.store.update_cached_last_access_time(
|
||||||
@ -184,6 +186,7 @@ class MediaRepository(object):
|
|||||||
raise
|
raise
|
||||||
except NotRetryingDestination:
|
except NotRetryingDestination:
|
||||||
logger.warn("Not retrying destination %r", server_name)
|
logger.warn("Not retrying destination %r", server_name)
|
||||||
|
raise SynapseError(502, "Failed to fetch remote media")
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to fetch remote media %s/%s",
|
logger.exception("Failed to fetch remote media %s/%s",
|
||||||
server_name, media_id)
|
server_name, media_id)
|
||||||
@ -323,13 +326,17 @@ class MediaRepository(object):
|
|||||||
defer.returnValue(t_path)
|
defer.returnValue(t_path)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _generate_local_thumbnails(self, media_id, media_info):
|
def _generate_local_thumbnails(self, media_id, media_info, url_cache=False):
|
||||||
media_type = media_info["media_type"]
|
media_type = media_info["media_type"]
|
||||||
requirements = self._get_thumbnail_requirements(media_type)
|
requirements = self._get_thumbnail_requirements(media_type)
|
||||||
if not requirements:
|
if not requirements:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if url_cache:
|
||||||
|
input_path = self.filepaths.url_cache_filepath(media_id)
|
||||||
|
else:
|
||||||
input_path = self.filepaths.local_media_filepath(media_id)
|
input_path = self.filepaths.local_media_filepath(media_id)
|
||||||
|
|
||||||
thumbnailer = Thumbnailer(input_path)
|
thumbnailer = Thumbnailer(input_path)
|
||||||
m_width = thumbnailer.width
|
m_width = thumbnailer.width
|
||||||
m_height = thumbnailer.height
|
m_height = thumbnailer.height
|
||||||
@ -357,6 +364,11 @@ class MediaRepository(object):
|
|||||||
|
|
||||||
for t_width, t_height, t_type in scales:
|
for t_width, t_height, t_type in scales:
|
||||||
t_method = "scale"
|
t_method = "scale"
|
||||||
|
if url_cache:
|
||||||
|
t_path = self.filepaths.url_cache_thumbnail(
|
||||||
|
media_id, t_width, t_height, t_type, t_method
|
||||||
|
)
|
||||||
|
else:
|
||||||
t_path = self.filepaths.local_media_thumbnail(
|
t_path = self.filepaths.local_media_thumbnail(
|
||||||
media_id, t_width, t_height, t_type, t_method
|
media_id, t_width, t_height, t_type, t_method
|
||||||
)
|
)
|
||||||
@ -374,6 +386,11 @@ class MediaRepository(object):
|
|||||||
# thumbnail.
|
# thumbnail.
|
||||||
continue
|
continue
|
||||||
t_method = "crop"
|
t_method = "crop"
|
||||||
|
if url_cache:
|
||||||
|
t_path = self.filepaths.url_cache_thumbnail(
|
||||||
|
media_id, t_width, t_height, t_type, t_method
|
||||||
|
)
|
||||||
|
else:
|
||||||
t_path = self.filepaths.local_media_thumbnail(
|
t_path = self.filepaths.local_media_thumbnail(
|
||||||
media_id, t_width, t_height, t_type, t_method
|
media_id, t_width, t_height, t_type, t_method
|
||||||
)
|
)
|
||||||
|
@ -164,7 +164,7 @@ class PreviewUrlResource(Resource):
|
|||||||
|
|
||||||
if _is_media(media_info['media_type']):
|
if _is_media(media_info['media_type']):
|
||||||
dims = yield self.media_repo._generate_local_thumbnails(
|
dims = yield self.media_repo._generate_local_thumbnails(
|
||||||
media_info['filesystem_id'], media_info
|
media_info['filesystem_id'], media_info, url_cache=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
og = {
|
og = {
|
||||||
@ -210,7 +210,7 @@ class PreviewUrlResource(Resource):
|
|||||||
if _is_media(image_info['media_type']):
|
if _is_media(image_info['media_type']):
|
||||||
# TODO: make sure we don't choke on white-on-transparent images
|
# TODO: make sure we don't choke on white-on-transparent images
|
||||||
dims = yield self.media_repo._generate_local_thumbnails(
|
dims = yield self.media_repo._generate_local_thumbnails(
|
||||||
image_info['filesystem_id'], image_info
|
image_info['filesystem_id'], image_info, url_cache=True,
|
||||||
)
|
)
|
||||||
if dims:
|
if dims:
|
||||||
og["og:image:width"] = dims['width']
|
og["og:image:width"] = dims['width']
|
||||||
@ -256,7 +256,7 @@ class PreviewUrlResource(Resource):
|
|||||||
# XXX: horrible duplication with base_resource's _download_remote_file()
|
# XXX: horrible duplication with base_resource's _download_remote_file()
|
||||||
file_id = random_string(24)
|
file_id = random_string(24)
|
||||||
|
|
||||||
fname = self.filepaths.local_media_filepath(file_id)
|
fname = self.filepaths.url_cache_filepath(file_id)
|
||||||
self.media_repo._makedirs(fname)
|
self.media_repo._makedirs(fname)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -303,6 +303,7 @@ class PreviewUrlResource(Resource):
|
|||||||
upload_name=download_name,
|
upload_name=download_name,
|
||||||
media_length=length,
|
media_length=length,
|
||||||
user_id=user,
|
user_id=user,
|
||||||
|
url_cache=url,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -434,6 +435,8 @@ def _calc_og(tree, media_uri):
|
|||||||
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
|
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
|
||||||
)
|
)
|
||||||
og['og:description'] = summarize_paragraphs(text_nodes)
|
og['og:description'] = summarize_paragraphs(text_nodes)
|
||||||
|
else:
|
||||||
|
og['og:description'] = summarize_paragraphs([og['og:description']])
|
||||||
|
|
||||||
# TODO: delete the url downloads to stop diskfilling,
|
# TODO: delete the url downloads to stop diskfilling,
|
||||||
# as we only ever cared about its OG
|
# as we only ever cared about its OG
|
||||||
|
@ -81,7 +81,7 @@ class ThumbnailResource(Resource):
|
|||||||
method, m_type):
|
method, m_type):
|
||||||
media_info = yield self.store.get_local_media(media_id)
|
media_info = yield self.store.get_local_media(media_id)
|
||||||
|
|
||||||
if not media_info:
|
if not media_info or media_info["quarantined_by"]:
|
||||||
respond_404(request)
|
respond_404(request)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -101,6 +101,13 @@ class ThumbnailResource(Resource):
|
|||||||
t_type = thumbnail_info["thumbnail_type"]
|
t_type = thumbnail_info["thumbnail_type"]
|
||||||
t_method = thumbnail_info["thumbnail_method"]
|
t_method = thumbnail_info["thumbnail_method"]
|
||||||
|
|
||||||
|
if media_info["url_cache"]:
|
||||||
|
# TODO: Check the file still exists, if it doesn't we can redownload
|
||||||
|
# it from the url `media_info["url_cache"]`
|
||||||
|
file_path = self.filepaths.url_cache_thumbnail(
|
||||||
|
media_id, t_width, t_height, t_type, t_method,
|
||||||
|
)
|
||||||
|
else:
|
||||||
file_path = self.filepaths.local_media_thumbnail(
|
file_path = self.filepaths.local_media_thumbnail(
|
||||||
media_id, t_width, t_height, t_type, t_method,
|
media_id, t_width, t_height, t_type, t_method,
|
||||||
)
|
)
|
||||||
@ -117,7 +124,7 @@ class ThumbnailResource(Resource):
|
|||||||
desired_type):
|
desired_type):
|
||||||
media_info = yield self.store.get_local_media(media_id)
|
media_info = yield self.store.get_local_media(media_id)
|
||||||
|
|
||||||
if not media_info:
|
if not media_info or media_info["quarantined_by"]:
|
||||||
respond_404(request)
|
respond_404(request)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -134,8 +141,17 @@ class ThumbnailResource(Resource):
|
|||||||
t_type = info["thumbnail_type"] == desired_type
|
t_type = info["thumbnail_type"] == desired_type
|
||||||
|
|
||||||
if t_w and t_h and t_method and t_type:
|
if t_w and t_h and t_method and t_type:
|
||||||
|
if media_info["url_cache"]:
|
||||||
|
# TODO: Check the file still exists, if it doesn't we can redownload
|
||||||
|
# it from the url `media_info["url_cache"]`
|
||||||
|
file_path = self.filepaths.url_cache_thumbnail(
|
||||||
|
media_id, desired_width, desired_height, desired_type,
|
||||||
|
desired_method,
|
||||||
|
)
|
||||||
|
else:
|
||||||
file_path = self.filepaths.local_media_thumbnail(
|
file_path = self.filepaths.local_media_thumbnail(
|
||||||
media_id, desired_width, desired_height, desired_type, desired_method,
|
media_id, desired_width, desired_height, desired_type,
|
||||||
|
desired_method,
|
||||||
)
|
)
|
||||||
yield respond_with_file(request, desired_type, file_path)
|
yield respond_with_file(request, desired_type, file_path)
|
||||||
return
|
return
|
||||||
|
@ -49,9 +49,11 @@ from synapse.handlers.events import EventHandler, EventStreamHandler
|
|||||||
from synapse.handlers.initial_sync import InitialSyncHandler
|
from synapse.handlers.initial_sync import InitialSyncHandler
|
||||||
from synapse.handlers.receipts import ReceiptsHandler
|
from synapse.handlers.receipts import ReceiptsHandler
|
||||||
from synapse.handlers.read_marker import ReadMarkerHandler
|
from synapse.handlers.read_marker import ReadMarkerHandler
|
||||||
|
from synapse.handlers.user_directory import UserDirectoyHandler
|
||||||
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
|
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
|
||||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||||
from synapse.notifier import Notifier
|
from synapse.notifier import Notifier
|
||||||
|
from synapse.push.action_generator import ActionGenerator
|
||||||
from synapse.push.pusherpool import PusherPool
|
from synapse.push.pusherpool import PusherPool
|
||||||
from synapse.rest.media.v1.media_repository import MediaRepository
|
from synapse.rest.media.v1.media_repository import MediaRepository
|
||||||
from synapse.state import StateHandler
|
from synapse.state import StateHandler
|
||||||
@ -135,6 +137,8 @@ class HomeServer(object):
|
|||||||
'macaroon_generator',
|
'macaroon_generator',
|
||||||
'tcp_replication',
|
'tcp_replication',
|
||||||
'read_marker_handler',
|
'read_marker_handler',
|
||||||
|
'action_generator',
|
||||||
|
'user_directory_handler',
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, hostname, **kwargs):
|
def __init__(self, hostname, **kwargs):
|
||||||
@ -299,6 +303,12 @@ class HomeServer(object):
|
|||||||
def build_tcp_replication(self):
|
def build_tcp_replication(self):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def build_action_generator(self):
|
||||||
|
return ActionGenerator(self)
|
||||||
|
|
||||||
|
def build_user_directory_handler(self):
|
||||||
|
return UserDirectoyHandler(self)
|
||||||
|
|
||||||
def remove_pusher(self, app_id, push_key, user_id):
|
def remove_pusher(self, app_id, push_key, user_id):
|
||||||
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
|
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
|
||||||
|
|
||||||
|
@ -24,13 +24,13 @@ from synapse.api.constants import EventTypes
|
|||||||
from synapse.api.errors import AuthError
|
from synapse.api.errors import AuthError
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async import Linearizer
|
||||||
|
from synapse.util.caches import CACHE_SIZE_FACTOR
|
||||||
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from frozendict import frozendict
|
from frozendict import frozendict
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -38,9 +38,6 @@ logger = logging.getLogger(__name__)
|
|||||||
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
|
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
|
||||||
|
|
||||||
|
|
||||||
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
|
|
||||||
|
|
||||||
|
|
||||||
SIZE_OF_CACHE = int(100000 * CACHE_SIZE_FACTOR)
|
SIZE_OF_CACHE = int(100000 * CACHE_SIZE_FACTOR)
|
||||||
EVICTION_TIMEOUT_SECONDS = 60 * 60
|
EVICTION_TIMEOUT_SECONDS = 60 * 60
|
||||||
|
|
||||||
@ -170,9 +167,7 @@ class StateHandler(object):
|
|||||||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||||
logger.debug("calling resolve_state_groups from get_current_user_in_room")
|
logger.debug("calling resolve_state_groups from get_current_user_in_room")
|
||||||
entry = yield self.resolve_state_groups(room_id, latest_event_ids)
|
entry = yield self.resolve_state_groups(room_id, latest_event_ids)
|
||||||
joined_users = yield self.store.get_joined_users_from_state(
|
joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
|
||||||
room_id, entry.state_id, entry.state
|
|
||||||
)
|
|
||||||
defer.returnValue(joined_users)
|
defer.returnValue(joined_users)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -181,9 +176,7 @@ class StateHandler(object):
|
|||||||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||||
logger.debug("calling resolve_state_groups from get_current_hosts_in_room")
|
logger.debug("calling resolve_state_groups from get_current_hosts_in_room")
|
||||||
entry = yield self.resolve_state_groups(room_id, latest_event_ids)
|
entry = yield self.resolve_state_groups(room_id, latest_event_ids)
|
||||||
joined_hosts = yield self.store.get_joined_hosts(
|
joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
|
||||||
room_id, entry.state_id, entry.state
|
|
||||||
)
|
|
||||||
defer.returnValue(joined_hosts)
|
defer.returnValue(joined_hosts)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -195,12 +188,12 @@ class StateHandler(object):
|
|||||||
Returns:
|
Returns:
|
||||||
synapse.events.snapshot.EventContext:
|
synapse.events.snapshot.EventContext:
|
||||||
"""
|
"""
|
||||||
context = EventContext()
|
|
||||||
|
|
||||||
if event.internal_metadata.is_outlier():
|
if event.internal_metadata.is_outlier():
|
||||||
# If this is an outlier, then we know it shouldn't have any current
|
# If this is an outlier, then we know it shouldn't have any current
|
||||||
# state. Certainly store.get_current_state won't return any, and
|
# state. Certainly store.get_current_state won't return any, and
|
||||||
# persisting the event won't store the state group.
|
# persisting the event won't store the state group.
|
||||||
|
context = EventContext()
|
||||||
if old_state:
|
if old_state:
|
||||||
context.prev_state_ids = {
|
context.prev_state_ids = {
|
||||||
(s.type, s.state_key): s.event_id for s in old_state
|
(s.type, s.state_key): s.event_id for s in old_state
|
||||||
@ -219,6 +212,7 @@ class StateHandler(object):
|
|||||||
defer.returnValue(context)
|
defer.returnValue(context)
|
||||||
|
|
||||||
if old_state:
|
if old_state:
|
||||||
|
context = EventContext()
|
||||||
context.prev_state_ids = {
|
context.prev_state_ids = {
|
||||||
(s.type, s.state_key): s.event_id for s in old_state
|
(s.type, s.state_key): s.event_id for s in old_state
|
||||||
}
|
}
|
||||||
@ -239,19 +233,13 @@ class StateHandler(object):
|
|||||||
defer.returnValue(context)
|
defer.returnValue(context)
|
||||||
|
|
||||||
logger.debug("calling resolve_state_groups from compute_event_context")
|
logger.debug("calling resolve_state_groups from compute_event_context")
|
||||||
if event.is_state():
|
|
||||||
entry = yield self.resolve_state_groups(
|
|
||||||
event.room_id, [e for e, _ in event.prev_events],
|
|
||||||
event_type=event.type,
|
|
||||||
state_key=event.state_key,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
entry = yield self.resolve_state_groups(
|
entry = yield self.resolve_state_groups(
|
||||||
event.room_id, [e for e, _ in event.prev_events],
|
event.room_id, [e for e, _ in event.prev_events],
|
||||||
)
|
)
|
||||||
|
|
||||||
curr_state = entry.state
|
curr_state = entry.state
|
||||||
|
|
||||||
|
context = EventContext()
|
||||||
context.prev_state_ids = curr_state
|
context.prev_state_ids = curr_state
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
context.state_group = self.store.get_next_state_group()
|
context.state_group = self.store.get_next_state_group()
|
||||||
@ -264,10 +252,14 @@ class StateHandler(object):
|
|||||||
context.current_state_ids = dict(context.prev_state_ids)
|
context.current_state_ids = dict(context.prev_state_ids)
|
||||||
context.current_state_ids[key] = event.event_id
|
context.current_state_ids[key] = event.event_id
|
||||||
|
|
||||||
|
if entry.state_group:
|
||||||
|
context.prev_group = entry.state_group
|
||||||
|
context.delta_ids = {
|
||||||
|
key: event.event_id
|
||||||
|
}
|
||||||
|
elif entry.prev_group:
|
||||||
context.prev_group = entry.prev_group
|
context.prev_group = entry.prev_group
|
||||||
context.delta_ids = entry.delta_ids
|
context.delta_ids = dict(entry.delta_ids)
|
||||||
if context.delta_ids is not None:
|
|
||||||
context.delta_ids = dict(context.delta_ids)
|
|
||||||
context.delta_ids[key] = event.event_id
|
context.delta_ids[key] = event.event_id
|
||||||
else:
|
else:
|
||||||
if entry.state_group is None:
|
if entry.state_group is None:
|
||||||
@ -284,7 +276,7 @@ class StateHandler(object):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def resolve_state_groups(self, room_id, event_ids, event_type=None, state_key=""):
|
def resolve_state_groups(self, room_id, event_ids):
|
||||||
""" Given a list of event_ids this method fetches the state at each
|
""" Given a list of event_ids this method fetches the state at each
|
||||||
event, resolves conflicts between them and returns them.
|
event, resolves conflicts between them and returns them.
|
||||||
|
|
||||||
@ -309,11 +301,13 @@ class StateHandler(object):
|
|||||||
if len(group_names) == 1:
|
if len(group_names) == 1:
|
||||||
name, state_list = state_groups_ids.items().pop()
|
name, state_list = state_groups_ids.items().pop()
|
||||||
|
|
||||||
|
prev_group, delta_ids = yield self.store.get_state_group_delta(name)
|
||||||
|
|
||||||
defer.returnValue(_StateCacheEntry(
|
defer.returnValue(_StateCacheEntry(
|
||||||
state=state_list,
|
state=state_list,
|
||||||
state_group=name,
|
state_group=name,
|
||||||
prev_group=name,
|
prev_group=prev_group,
|
||||||
delta_ids={},
|
delta_ids=delta_ids,
|
||||||
))
|
))
|
||||||
|
|
||||||
with (yield self.resolve_linearizer.queue(group_names)):
|
with (yield self.resolve_linearizer.queue(group_names)):
|
||||||
@ -357,20 +351,18 @@ class StateHandler(object):
|
|||||||
if new_state_event_ids == frozenset(e_id for e_id in events):
|
if new_state_event_ids == frozenset(e_id for e_id in events):
|
||||||
state_group = sg
|
state_group = sg
|
||||||
break
|
break
|
||||||
if state_group is None:
|
|
||||||
# Worker instances don't have access to this method, but we want
|
# TODO: We want to create a state group for this set of events, to
|
||||||
# to set the state_group on the main instance to increase cache
|
# increase cache hits, but we need to make sure that it doesn't
|
||||||
# hits.
|
# end up as a prev_group without being added to the database
|
||||||
if hasattr(self.store, "get_next_state_group"):
|
|
||||||
state_group = self.store.get_next_state_group()
|
|
||||||
|
|
||||||
prev_group = None
|
prev_group = None
|
||||||
delta_ids = None
|
delta_ids = None
|
||||||
for old_group, old_ids in state_groups_ids.items():
|
for old_group, old_ids in state_groups_ids.iteritems():
|
||||||
if not set(new_state.iterkeys()) - set(old_ids.iterkeys()):
|
if not set(new_state) - set(old_ids):
|
||||||
n_delta_ids = {
|
n_delta_ids = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in new_state.items()
|
for k, v in new_state.iteritems()
|
||||||
if old_ids.get(k) != v
|
if old_ids.get(k) != v
|
||||||
}
|
}
|
||||||
if not delta_ids or len(n_delta_ids) < len(delta_ids):
|
if not delta_ids or len(n_delta_ids) < len(delta_ids):
|
||||||
|
@ -49,6 +49,7 @@ from .tags import TagsStore
|
|||||||
from .account_data import AccountDataStore
|
from .account_data import AccountDataStore
|
||||||
from .openid import OpenIdStore
|
from .openid import OpenIdStore
|
||||||
from .client_ips import ClientIpStore
|
from .client_ips import ClientIpStore
|
||||||
|
from .user_directory import UserDirectoryStore
|
||||||
|
|
||||||
from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
|
from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
|
||||||
from .engines import PostgresEngine
|
from .engines import PostgresEngine
|
||||||
@ -86,6 +87,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||||||
ClientIpStore,
|
ClientIpStore,
|
||||||
DeviceStore,
|
DeviceStore,
|
||||||
DeviceInboxStore,
|
DeviceInboxStore,
|
||||||
|
UserDirectoryStore,
|
||||||
):
|
):
|
||||||
|
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
@ -221,11 +223,24 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||||||
"DeviceListFederationStreamChangeCache", device_list_max,
|
"DeviceListFederationStreamChangeCache", device_list_max,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
|
||||||
|
db_conn, "current_state_delta_stream",
|
||||||
|
entity_column="room_id",
|
||||||
|
stream_column="stream_id",
|
||||||
|
max_value=events_max, # As we share the stream id with events token
|
||||||
|
limit=1000,
|
||||||
|
)
|
||||||
|
self._curr_state_delta_stream_cache = StreamChangeCache(
|
||||||
|
"_curr_state_delta_stream_cache", min_curr_state_delta_id,
|
||||||
|
prefilled_cache=curr_state_delta_prefill,
|
||||||
|
)
|
||||||
|
|
||||||
cur = LoggingTransaction(
|
cur = LoggingTransaction(
|
||||||
db_conn.cursor(),
|
db_conn.cursor(),
|
||||||
name="_find_stream_orderings_for_times_txn",
|
name="_find_stream_orderings_for_times_txn",
|
||||||
database_engine=self.database_engine,
|
database_engine=self.database_engine,
|
||||||
after_callbacks=[]
|
after_callbacks=[],
|
||||||
|
final_callbacks=[],
|
||||||
)
|
)
|
||||||
self._find_stream_orderings_for_times_txn(cur)
|
self._find_stream_orderings_for_times_txn(cur)
|
||||||
cur.close()
|
cur.close()
|
||||||
@ -289,16 +304,6 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||||||
ret = yield self.runInteraction("count_users", _count_users)
|
ret = yield self.runInteraction("count_users", _count_users)
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
|
|
||||||
def get_user_ip_and_agents(self, user):
|
|
||||||
return self._simple_select_list(
|
|
||||||
table="user_ips",
|
|
||||||
keyvalues={"user_id": user.to_string()},
|
|
||||||
retcols=[
|
|
||||||
"access_token", "ip", "user_agent", "last_seen"
|
|
||||||
],
|
|
||||||
desc="get_user_ip_and_agents",
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_users(self):
|
def get_users(self):
|
||||||
"""Function to reterive a list of users in users table.
|
"""Function to reterive a list of users in users table.
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@ import logging
|
|||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
||||||
|
from synapse.util.caches import CACHE_SIZE_FACTOR
|
||||||
from synapse.util.caches.dictionary_cache import DictionaryCache
|
from synapse.util.caches.dictionary_cache import DictionaryCache
|
||||||
from synapse.util.caches.descriptors import Cache
|
from synapse.util.caches.descriptors import Cache
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
@ -27,10 +28,6 @@ from twisted.internet import defer
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import threading
|
import threading
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -52,13 +49,17 @@ class LoggingTransaction(object):
|
|||||||
"""An object that almost-transparently proxies for the 'txn' object
|
"""An object that almost-transparently proxies for the 'txn' object
|
||||||
passed to the constructor. Adds logging and metrics to the .execute()
|
passed to the constructor. Adds logging and metrics to the .execute()
|
||||||
method."""
|
method."""
|
||||||
__slots__ = ["txn", "name", "database_engine", "after_callbacks"]
|
__slots__ = [
|
||||||
|
"txn", "name", "database_engine", "after_callbacks", "final_callbacks",
|
||||||
|
]
|
||||||
|
|
||||||
def __init__(self, txn, name, database_engine, after_callbacks):
|
def __init__(self, txn, name, database_engine, after_callbacks,
|
||||||
|
final_callbacks):
|
||||||
object.__setattr__(self, "txn", txn)
|
object.__setattr__(self, "txn", txn)
|
||||||
object.__setattr__(self, "name", name)
|
object.__setattr__(self, "name", name)
|
||||||
object.__setattr__(self, "database_engine", database_engine)
|
object.__setattr__(self, "database_engine", database_engine)
|
||||||
object.__setattr__(self, "after_callbacks", after_callbacks)
|
object.__setattr__(self, "after_callbacks", after_callbacks)
|
||||||
|
object.__setattr__(self, "final_callbacks", final_callbacks)
|
||||||
|
|
||||||
def call_after(self, callback, *args, **kwargs):
|
def call_after(self, callback, *args, **kwargs):
|
||||||
"""Call the given callback on the main twisted thread after the
|
"""Call the given callback on the main twisted thread after the
|
||||||
@ -67,6 +68,9 @@ class LoggingTransaction(object):
|
|||||||
"""
|
"""
|
||||||
self.after_callbacks.append((callback, args, kwargs))
|
self.after_callbacks.append((callback, args, kwargs))
|
||||||
|
|
||||||
|
def call_finally(self, callback, *args, **kwargs):
|
||||||
|
self.final_callbacks.append((callback, args, kwargs))
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
return getattr(self.txn, name)
|
return getattr(self.txn, name)
|
||||||
|
|
||||||
@ -217,8 +221,8 @@ class SQLBaseStore(object):
|
|||||||
|
|
||||||
self._clock.looping_call(loop, 10000)
|
self._clock.looping_call(loop, 10000)
|
||||||
|
|
||||||
def _new_transaction(self, conn, desc, after_callbacks, logging_context,
|
def _new_transaction(self, conn, desc, after_callbacks, final_callbacks,
|
||||||
func, *args, **kwargs):
|
logging_context, func, *args, **kwargs):
|
||||||
start = time.time() * 1000
|
start = time.time() * 1000
|
||||||
txn_id = self._TXN_ID
|
txn_id = self._TXN_ID
|
||||||
|
|
||||||
@ -237,7 +241,8 @@ class SQLBaseStore(object):
|
|||||||
try:
|
try:
|
||||||
txn = conn.cursor()
|
txn = conn.cursor()
|
||||||
txn = LoggingTransaction(
|
txn = LoggingTransaction(
|
||||||
txn, name, self.database_engine, after_callbacks
|
txn, name, self.database_engine, after_callbacks,
|
||||||
|
final_callbacks,
|
||||||
)
|
)
|
||||||
r = func(txn, *args, **kwargs)
|
r = func(txn, *args, **kwargs)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
@ -298,6 +303,7 @@ class SQLBaseStore(object):
|
|||||||
start_time = time.time() * 1000
|
start_time = time.time() * 1000
|
||||||
|
|
||||||
after_callbacks = []
|
after_callbacks = []
|
||||||
|
final_callbacks = []
|
||||||
|
|
||||||
def inner_func(conn, *args, **kwargs):
|
def inner_func(conn, *args, **kwargs):
|
||||||
with LoggingContext("runInteraction") as context:
|
with LoggingContext("runInteraction") as context:
|
||||||
@ -309,7 +315,7 @@ class SQLBaseStore(object):
|
|||||||
|
|
||||||
current_context.copy_to(context)
|
current_context.copy_to(context)
|
||||||
return self._new_transaction(
|
return self._new_transaction(
|
||||||
conn, desc, after_callbacks, current_context,
|
conn, desc, after_callbacks, final_callbacks, current_context,
|
||||||
func, *args, **kwargs
|
func, *args, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -318,9 +324,13 @@ class SQLBaseStore(object):
|
|||||||
result = yield self._db_pool.runWithConnection(
|
result = yield self._db_pool.runWithConnection(
|
||||||
inner_func, *args, **kwargs
|
inner_func, *args, **kwargs
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
for after_callback, after_args, after_kwargs in after_callbacks:
|
for after_callback, after_args, after_kwargs in after_callbacks:
|
||||||
after_callback(*after_args, **after_kwargs)
|
after_callback(*after_args, **after_kwargs)
|
||||||
|
finally:
|
||||||
|
for after_callback, after_args, after_kwargs in final_callbacks:
|
||||||
|
after_callback(*after_args, **after_kwargs)
|
||||||
|
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -425,6 +435,11 @@ class SQLBaseStore(object):
|
|||||||
|
|
||||||
txn.execute(sql, vals)
|
txn.execute(sql, vals)
|
||||||
|
|
||||||
|
def _simple_insert_many(self, table, values, desc):
|
||||||
|
return self.runInteraction(
|
||||||
|
desc, self._simple_insert_many_txn, table, values
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _simple_insert_many_txn(txn, table, values):
|
def _simple_insert_many_txn(txn, table, values):
|
||||||
if not values:
|
if not values:
|
||||||
@ -936,7 +951,7 @@ class SQLBaseStore(object):
|
|||||||
# __exit__ called after the transaction finishes.
|
# __exit__ called after the transaction finishes.
|
||||||
ctx = self._cache_id_gen.get_next()
|
ctx = self._cache_id_gen.get_next()
|
||||||
stream_id = ctx.__enter__()
|
stream_id = ctx.__enter__()
|
||||||
txn.call_after(ctx.__exit__, None, None, None)
|
txn.call_finally(ctx.__exit__, None, None, None)
|
||||||
txn.call_after(self.hs.get_notifier().on_new_replication_data)
|
txn.call_after(self.hs.get_notifier().on_new_replication_data)
|
||||||
|
|
||||||
self._simple_insert_txn(
|
self._simple_insert_txn(
|
||||||
|
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
@ -26,6 +27,25 @@ from ._base import SQLBaseStore
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_exclusive_regex(services_cache):
|
||||||
|
# We precompie a regex constructed from all the regexes that the AS's
|
||||||
|
# have registered for exclusive users.
|
||||||
|
exclusive_user_regexes = [
|
||||||
|
regex.pattern
|
||||||
|
for service in services_cache
|
||||||
|
for regex in service.get_exlusive_user_regexes()
|
||||||
|
]
|
||||||
|
if exclusive_user_regexes:
|
||||||
|
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
|
||||||
|
exclusive_user_regex = re.compile(exclusive_user_regex)
|
||||||
|
else:
|
||||||
|
# We handle this case specially otherwise the constructed regex
|
||||||
|
# will always match
|
||||||
|
exclusive_user_regex = None
|
||||||
|
|
||||||
|
return exclusive_user_regex
|
||||||
|
|
||||||
|
|
||||||
class ApplicationServiceStore(SQLBaseStore):
|
class ApplicationServiceStore(SQLBaseStore):
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
@ -35,16 +55,17 @@ class ApplicationServiceStore(SQLBaseStore):
|
|||||||
hs.hostname,
|
hs.hostname,
|
||||||
hs.config.app_service_config_files
|
hs.config.app_service_config_files
|
||||||
)
|
)
|
||||||
|
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
|
||||||
|
|
||||||
def get_app_services(self):
|
def get_app_services(self):
|
||||||
return self.services_cache
|
return self.services_cache
|
||||||
|
|
||||||
def get_if_app_services_interested_in_user(self, user_id):
|
def get_if_app_services_interested_in_user(self, user_id):
|
||||||
"""Check if the user is one associated with an app service
|
"""Check if the user is one associated with an app service (exclusively)
|
||||||
"""
|
"""
|
||||||
for service in self.services_cache:
|
if self.exclusive_user_regex:
|
||||||
if service.is_interested_in_user(user_id):
|
return bool(self.exclusive_user_regex.match(user_id))
|
||||||
return True
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_app_service_by_user_id(self, user_id):
|
def get_app_service_by_user_id(self, user_id):
|
||||||
|
@ -15,11 +15,14 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer, reactor
|
||||||
|
|
||||||
from ._base import Cache
|
from ._base import Cache
|
||||||
from . import background_updates
|
from . import background_updates
|
||||||
|
|
||||||
|
from synapse.util.caches import CACHE_SIZE_FACTOR
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
|
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
|
||||||
@ -33,7 +36,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
|||||||
self.client_ip_last_seen = Cache(
|
self.client_ip_last_seen = Cache(
|
||||||
name="client_ip_last_seen",
|
name="client_ip_last_seen",
|
||||||
keylen=4,
|
keylen=4,
|
||||||
max_entries=5000,
|
max_entries=50000 * CACHE_SIZE_FACTOR,
|
||||||
)
|
)
|
||||||
|
|
||||||
super(ClientIpStore, self).__init__(hs)
|
super(ClientIpStore, self).__init__(hs)
|
||||||
@ -45,7 +48,14 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
|||||||
columns=["user_id", "device_id", "last_seen"],
|
columns=["user_id", "device_id", "last_seen"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
# (user_id, access_token, ip) -> (user_agent, device_id, last_seen)
|
||||||
|
self._batch_row_update = {}
|
||||||
|
|
||||||
|
self._client_ip_looper = self._clock.looping_call(
|
||||||
|
self._update_client_ips_batch, 5 * 1000
|
||||||
|
)
|
||||||
|
reactor.addSystemEventTrigger("before", "shutdown", self._update_client_ips_batch)
|
||||||
|
|
||||||
def insert_client_ip(self, user, access_token, ip, user_agent, device_id):
|
def insert_client_ip(self, user, access_token, ip, user_agent, device_id):
|
||||||
now = int(self._clock.time_msec())
|
now = int(self._clock.time_msec())
|
||||||
key = (user.to_string(), access_token, ip)
|
key = (user.to_string(), access_token, ip)
|
||||||
@ -57,34 +67,48 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
|||||||
|
|
||||||
# Rate-limited inserts
|
# Rate-limited inserts
|
||||||
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
|
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
|
||||||
defer.returnValue(None)
|
return
|
||||||
|
|
||||||
self.client_ip_last_seen.prefill(key, now)
|
self.client_ip_last_seen.prefill(key, now)
|
||||||
|
|
||||||
# It's safe not to lock here: a) no unique constraint,
|
self._batch_row_update[key] = (user_agent, device_id, now)
|
||||||
# b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
|
|
||||||
yield self._simple_upsert(
|
def _update_client_ips_batch(self):
|
||||||
"user_ips",
|
to_update = self._batch_row_update
|
||||||
|
self._batch_row_update = {}
|
||||||
|
return self.runInteraction(
|
||||||
|
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
|
||||||
|
)
|
||||||
|
|
||||||
|
def _update_client_ips_batch_txn(self, txn, to_update):
|
||||||
|
self.database_engine.lock_table(txn, "user_ips")
|
||||||
|
|
||||||
|
for entry in to_update.iteritems():
|
||||||
|
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
|
||||||
|
|
||||||
|
self._simple_upsert_txn(
|
||||||
|
txn,
|
||||||
|
table="user_ips",
|
||||||
keyvalues={
|
keyvalues={
|
||||||
"user_id": user.to_string(),
|
"user_id": user_id,
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
"ip": ip,
|
"ip": ip,
|
||||||
"user_agent": user_agent,
|
"user_agent": user_agent,
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
},
|
},
|
||||||
values={
|
values={
|
||||||
"last_seen": now,
|
"last_seen": last_seen,
|
||||||
},
|
},
|
||||||
desc="insert_client_ip",
|
|
||||||
lock=False,
|
lock=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_last_client_ip_by_device(self, devices):
|
def get_last_client_ip_by_device(self, user_id, device_id):
|
||||||
"""For each device_id listed, give the user_ip it was last seen on
|
"""For each device_id listed, give the user_ip it was last seen on
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
devices (iterable[(str, str)]): list of (user_id, device_id) pairs
|
user_id (str)
|
||||||
|
device_id (str): If None fetches all devices for the user
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred: resolves to a dict, where the keys
|
defer.Deferred: resolves to a dict, where the keys
|
||||||
@ -95,6 +119,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
|||||||
res = yield self.runInteraction(
|
res = yield self.runInteraction(
|
||||||
"get_last_client_ip_by_device",
|
"get_last_client_ip_by_device",
|
||||||
self._get_last_client_ip_by_device_txn,
|
self._get_last_client_ip_by_device_txn,
|
||||||
|
user_id, device_id,
|
||||||
retcols=(
|
retcols=(
|
||||||
"user_id",
|
"user_id",
|
||||||
"access_token",
|
"access_token",
|
||||||
@ -103,19 +128,30 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
|||||||
"device_id",
|
"device_id",
|
||||||
"last_seen",
|
"last_seen",
|
||||||
),
|
),
|
||||||
devices=devices
|
|
||||||
)
|
)
|
||||||
|
|
||||||
ret = {(d["user_id"], d["device_id"]): d for d in res}
|
ret = {(d["user_id"], d["device_id"]): d for d in res}
|
||||||
|
for key in self._batch_row_update:
|
||||||
|
uid, access_token, ip = key
|
||||||
|
if uid == user_id:
|
||||||
|
user_agent, did, last_seen = self._batch_row_update[key]
|
||||||
|
if not device_id or did == device_id:
|
||||||
|
ret[(user_id, device_id)] = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"access_token": access_token,
|
||||||
|
"ip": ip,
|
||||||
|
"user_agent": user_agent,
|
||||||
|
"device_id": did,
|
||||||
|
"last_seen": last_seen,
|
||||||
|
}
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_last_client_ip_by_device_txn(cls, txn, devices, retcols):
|
def _get_last_client_ip_by_device_txn(cls, txn, user_id, device_id, retcols):
|
||||||
where_clauses = []
|
where_clauses = []
|
||||||
bindings = []
|
bindings = []
|
||||||
for (user_id, device_id) in devices:
|
|
||||||
if device_id is None:
|
if device_id is None:
|
||||||
where_clauses.append("(user_id = ? AND device_id IS NULL)")
|
where_clauses.append("user_id = ?")
|
||||||
bindings.extend((user_id, ))
|
bindings.extend((user_id, ))
|
||||||
else:
|
else:
|
||||||
where_clauses.append("(user_id = ? AND device_id = ?)")
|
where_clauses.append("(user_id = ? AND device_id = ?)")
|
||||||
@ -147,3 +183,37 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
|||||||
|
|
||||||
txn.execute(sql, bindings)
|
txn.execute(sql, bindings)
|
||||||
return cls.cursor_to_dict(txn)
|
return cls.cursor_to_dict(txn)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_user_ip_and_agents(self, user):
|
||||||
|
user_id = user.to_string()
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
for key in self._batch_row_update:
|
||||||
|
uid, access_token, ip = key
|
||||||
|
if uid == user_id:
|
||||||
|
user_agent, _, last_seen = self._batch_row_update[key]
|
||||||
|
results[(access_token, ip)] = (user_agent, last_seen)
|
||||||
|
|
||||||
|
rows = yield self._simple_select_list(
|
||||||
|
table="user_ips",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
retcols=[
|
||||||
|
"access_token", "ip", "user_agent", "last_seen"
|
||||||
|
],
|
||||||
|
desc="get_user_ip_and_agents",
|
||||||
|
)
|
||||||
|
|
||||||
|
results.update(
|
||||||
|
((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"]))
|
||||||
|
for row in rows
|
||||||
|
)
|
||||||
|
defer.returnValue(list(
|
||||||
|
{
|
||||||
|
"access_token": access_token,
|
||||||
|
"ip": ip,
|
||||||
|
"user_agent": user_agent,
|
||||||
|
"last_seen": last_seen,
|
||||||
|
}
|
||||||
|
for (access_token, ip), (user_agent, last_seen) in results.iteritems()
|
||||||
|
))
|
||||||
|
@ -368,7 +368,7 @@ class DeviceStore(SQLBaseStore):
|
|||||||
|
|
||||||
prev_sent_id_sql = """
|
prev_sent_id_sql = """
|
||||||
SELECT coalesce(max(stream_id), 0) as stream_id
|
SELECT coalesce(max(stream_id), 0) as stream_id
|
||||||
FROM device_lists_outbound_pokes
|
FROM device_lists_outbound_last_success
|
||||||
WHERE destination = ? AND user_id = ? AND stream_id <= ?
|
WHERE destination = ? AND user_id = ? AND stream_id <= ?
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -510,32 +510,43 @@ class DeviceStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
|
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
|
||||||
# First we DELETE all rows such that only the latest row for each
|
# We update the device_lists_outbound_last_success with the successfully
|
||||||
# (destination, user_id is left. We do this by selecting first and
|
# poked users. We do the join to see which users need to be inserted and
|
||||||
# deleting.
|
# which updated.
|
||||||
sql = """
|
sql = """
|
||||||
SELECT user_id, coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes
|
SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
|
||||||
WHERE destination = ? AND stream_id <= ?
|
FROM device_lists_outbound_pokes as o
|
||||||
|
LEFT JOIN device_lists_outbound_last_success as s
|
||||||
|
USING (destination, user_id)
|
||||||
|
WHERE destination = ? AND o.stream_id <= ?
|
||||||
GROUP BY user_id
|
GROUP BY user_id
|
||||||
HAVING count(*) > 1
|
|
||||||
"""
|
"""
|
||||||
txn.execute(sql, (destination, stream_id,))
|
txn.execute(sql, (destination, stream_id,))
|
||||||
rows = txn.fetchall()
|
rows = txn.fetchall()
|
||||||
|
|
||||||
sql = """
|
sql = """
|
||||||
DELETE FROM device_lists_outbound_pokes
|
UPDATE device_lists_outbound_last_success
|
||||||
WHERE destination = ? AND user_id = ? AND stream_id < ?
|
SET stream_id = ?
|
||||||
|
WHERE destination = ? AND user_id = ?
|
||||||
"""
|
"""
|
||||||
txn.executemany(
|
txn.executemany(
|
||||||
sql, ((destination, row[0], row[1],) for row in rows)
|
sql, ((row[1], destination, row[0],) for row in rows if row[2])
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mark everything that is left as sent
|
|
||||||
sql = """
|
sql = """
|
||||||
UPDATE device_lists_outbound_pokes SET sent = ?
|
INSERT INTO device_lists_outbound_last_success
|
||||||
|
(destination, user_id, stream_id) VALUES (?, ?, ?)
|
||||||
|
"""
|
||||||
|
txn.executemany(
|
||||||
|
sql, ((destination, row[0], row[1],) for row in rows if not row[2])
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete all sent outbound pokes
|
||||||
|
sql = """
|
||||||
|
DELETE FROM device_lists_outbound_pokes
|
||||||
WHERE destination = ? AND stream_id <= ?
|
WHERE destination = ? AND stream_id <= ?
|
||||||
"""
|
"""
|
||||||
txn.execute(sql, (True, destination, stream_id,))
|
txn.execute(sql, (destination, stream_id,))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_user_whose_devices_changed(self, from_key):
|
def get_user_whose_devices_changed(self, from_key):
|
||||||
@ -670,6 +681,14 @@ class DeviceStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Since we've deleted unsent deltas, we need to remove the entry
|
||||||
|
# of last successful sent so that the prev_ids are correctly set.
|
||||||
|
sql = """
|
||||||
|
DELETE FROM device_lists_outbound_last_success
|
||||||
|
WHERE destination = ? AND user_id = ?
|
||||||
|
"""
|
||||||
|
txn.executemany(sql, ((row[0], row[1]) for row in rows))
|
||||||
|
|
||||||
logger.info("Pruned %d device list outbound pokes", txn.rowcount)
|
logger.info("Pruned %d device list outbound pokes", txn.rowcount)
|
||||||
|
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
|
@ -170,3 +170,17 @@ class DirectoryStore(SQLBaseStore):
|
|||||||
"room_alias",
|
"room_alias",
|
||||||
desc="get_aliases_for_room",
|
desc="get_aliases_for_room",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def update_aliases_for_room(self, old_room_id, new_room_id, creator):
|
||||||
|
def _update_aliases_for_room_txn(txn):
|
||||||
|
sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?"
|
||||||
|
txn.execute(sql, (new_room_id, creator, old_room_id,))
|
||||||
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.get_aliases_for_room, (old_room_id,)
|
||||||
|
)
|
||||||
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.get_aliases_for_room, (new_room_id,)
|
||||||
|
)
|
||||||
|
return self.runInteraction(
|
||||||
|
"_update_aliases_for_room_txn", _update_aliases_for_room_txn
|
||||||
|
)
|
||||||
|
@ -185,8 +185,8 @@ class EndToEndKeyStore(SQLBaseStore):
|
|||||||
for algorithm, key_id, json_bytes in new_keys
|
for algorithm, key_id, json_bytes in new_keys
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
txn.call_after(
|
self._invalidate_cache_and_stream(
|
||||||
self.count_e2e_one_time_keys.invalidate, (user_id, device_id,)
|
txn, self.count_e2e_one_time_keys, (user_id, device_id,)
|
||||||
)
|
)
|
||||||
yield self.runInteraction(
|
yield self.runInteraction(
|
||||||
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
|
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
|
||||||
@ -237,24 +237,29 @@ class EndToEndKeyStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
for user_id, device_id, algorithm, key_id in delete:
|
for user_id, device_id, algorithm, key_id in delete:
|
||||||
txn.execute(sql, (user_id, device_id, algorithm, key_id))
|
txn.execute(sql, (user_id, device_id, algorithm, key_id))
|
||||||
txn.call_after(
|
self._invalidate_cache_and_stream(
|
||||||
self.count_e2e_one_time_keys.invalidate, (user_id, device_id,)
|
txn, self.count_e2e_one_time_keys, (user_id, device_id,)
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
|
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def delete_e2e_keys_by_device(self, user_id, device_id):
|
def delete_e2e_keys_by_device(self, user_id, device_id):
|
||||||
yield self._simple_delete(
|
def delete_e2e_keys_by_device_txn(txn):
|
||||||
|
self._simple_delete_txn(
|
||||||
|
txn,
|
||||||
table="e2e_device_keys_json",
|
table="e2e_device_keys_json",
|
||||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||||
desc="delete_e2e_device_keys_by_device"
|
|
||||||
)
|
)
|
||||||
yield self._simple_delete(
|
self._simple_delete_txn(
|
||||||
|
txn,
|
||||||
table="e2e_one_time_keys_json",
|
table="e2e_one_time_keys_json",
|
||||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||||
desc="delete_e2e_one_time_keys_by_device"
|
|
||||||
)
|
)
|
||||||
self.count_e2e_one_time_keys.invalidate((user_id, device_id,))
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.count_e2e_one_time_keys, (user_id, device_id,)
|
||||||
|
)
|
||||||
|
return self.runInteraction(
|
||||||
|
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
|
||||||
|
)
|
||||||
|
@ -37,24 +37,54 @@ class EventFederationStore(SQLBaseStore):
|
|||||||
and backfilling from another server respectively.
|
and backfilling from another server respectively.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(EventFederationStore, self).__init__(hs)
|
super(EventFederationStore, self).__init__(hs)
|
||||||
|
|
||||||
|
self.register_background_update_handler(
|
||||||
|
self.EVENT_AUTH_STATE_ONLY,
|
||||||
|
self._background_delete_non_state_event_auth,
|
||||||
|
)
|
||||||
|
|
||||||
hs.get_clock().looping_call(
|
hs.get_clock().looping_call(
|
||||||
self._delete_old_forward_extrem_cache, 60 * 60 * 1000
|
self._delete_old_forward_extrem_cache, 60 * 60 * 1000
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_auth_chain(self, event_ids):
|
def get_auth_chain(self, event_ids, include_given=False):
|
||||||
return self.get_auth_chain_ids(event_ids).addCallback(self._get_events)
|
"""Get auth events for given event_ids. The events *must* be state events.
|
||||||
|
|
||||||
def get_auth_chain_ids(self, event_ids):
|
Args:
|
||||||
|
event_ids (list): state events
|
||||||
|
include_given (bool): include the given events in result
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of events
|
||||||
|
"""
|
||||||
|
return self.get_auth_chain_ids(
|
||||||
|
event_ids, include_given=include_given,
|
||||||
|
).addCallback(self._get_events)
|
||||||
|
|
||||||
|
def get_auth_chain_ids(self, event_ids, include_given=False):
|
||||||
|
"""Get auth events for given event_ids. The events *must* be state events.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_ids (list): state events
|
||||||
|
include_given (bool): include the given events in result
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of event_ids
|
||||||
|
"""
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"get_auth_chain_ids",
|
"get_auth_chain_ids",
|
||||||
self._get_auth_chain_ids_txn,
|
self._get_auth_chain_ids_txn,
|
||||||
event_ids
|
event_ids, include_given
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_auth_chain_ids_txn(self, txn, event_ids):
|
def _get_auth_chain_ids_txn(self, txn, event_ids, include_given):
|
||||||
|
if include_given:
|
||||||
|
results = set(event_ids)
|
||||||
|
else:
|
||||||
results = set()
|
results = set()
|
||||||
|
|
||||||
base_sql = (
|
base_sql = (
|
||||||
@ -504,3 +534,52 @@ class EventFederationStore(SQLBaseStore):
|
|||||||
|
|
||||||
txn.execute(query, (room_id,))
|
txn.execute(query, (room_id,))
|
||||||
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
|
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _background_delete_non_state_event_auth(self, progress, batch_size):
|
||||||
|
def delete_event_auth(txn):
|
||||||
|
target_min_stream_id = progress.get("target_min_stream_id_inclusive")
|
||||||
|
max_stream_id = progress.get("max_stream_id_exclusive")
|
||||||
|
|
||||||
|
if not target_min_stream_id or not max_stream_id:
|
||||||
|
txn.execute("SELECT COALESCE(MIN(stream_ordering), 0) FROM events")
|
||||||
|
rows = txn.fetchall()
|
||||||
|
target_min_stream_id = rows[0][0]
|
||||||
|
|
||||||
|
txn.execute("SELECT COALESCE(MAX(stream_ordering), 0) FROM events")
|
||||||
|
rows = txn.fetchall()
|
||||||
|
max_stream_id = rows[0][0]
|
||||||
|
|
||||||
|
min_stream_id = max_stream_id - batch_size
|
||||||
|
|
||||||
|
sql = """
|
||||||
|
DELETE FROM event_auth
|
||||||
|
WHERE event_id IN (
|
||||||
|
SELECT event_id FROM events
|
||||||
|
LEFT JOIN state_events USING (room_id, event_id)
|
||||||
|
WHERE ? <= stream_ordering AND stream_ordering < ?
|
||||||
|
AND state_key IS null
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
txn.execute(sql, (min_stream_id, max_stream_id,))
|
||||||
|
|
||||||
|
new_progress = {
|
||||||
|
"target_min_stream_id_inclusive": target_min_stream_id,
|
||||||
|
"max_stream_id_exclusive": min_stream_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
self._background_update_progress_txn(
|
||||||
|
txn, self.EVENT_AUTH_STATE_ONLY, new_progress
|
||||||
|
)
|
||||||
|
|
||||||
|
return min_stream_id >= target_min_stream_id
|
||||||
|
|
||||||
|
result = yield self.runInteraction(
|
||||||
|
self.EVENT_AUTH_STATE_ONLY, delete_event_auth
|
||||||
|
)
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
yield self._end_background_update(self.EVENT_AUTH_STATE_ONLY)
|
||||||
|
|
||||||
|
defer.returnValue(batch_size)
|
||||||
|
@ -403,6 +403,11 @@ class EventsStore(SQLBaseStore):
|
|||||||
(room_id, ), new_state
|
(room_id, ), new_state
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for room_id, latest_event_ids in new_forward_extremeties.iteritems():
|
||||||
|
self.get_latest_event_ids_in_room.prefill(
|
||||||
|
(room_id,), list(latest_event_ids)
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids):
|
def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids):
|
||||||
"""Calculates the new forward extremeties for a room given events to
|
"""Calculates the new forward extremeties for a room given events to
|
||||||
@ -647,9 +652,10 @@ class EventsStore(SQLBaseStore):
|
|||||||
list of the event ids which are the forward extremities.
|
list of the event ids which are the forward extremities.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self._update_current_state_txn(txn, current_state_for_room)
|
|
||||||
|
|
||||||
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
|
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
|
||||||
|
|
||||||
|
self._update_current_state_txn(txn, current_state_for_room, max_stream_order)
|
||||||
|
|
||||||
self._update_forward_extremities_txn(
|
self._update_forward_extremities_txn(
|
||||||
txn,
|
txn,
|
||||||
new_forward_extremities=new_forward_extremeties,
|
new_forward_extremities=new_forward_extremeties,
|
||||||
@ -712,7 +718,7 @@ class EventsStore(SQLBaseStore):
|
|||||||
backfilled=backfilled,
|
backfilled=backfilled,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _update_current_state_txn(self, txn, state_delta_by_room):
|
def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
|
||||||
for room_id, current_state_tuple in state_delta_by_room.iteritems():
|
for room_id, current_state_tuple in state_delta_by_room.iteritems():
|
||||||
to_delete, to_insert, _ = current_state_tuple
|
to_delete, to_insert, _ = current_state_tuple
|
||||||
txn.executemany(
|
txn.executemany(
|
||||||
@ -734,6 +740,29 @@ class EventsStore(SQLBaseStore):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
state_deltas = {key: None for key in to_delete}
|
||||||
|
state_deltas.update(to_insert)
|
||||||
|
|
||||||
|
self._simple_insert_many_txn(
|
||||||
|
txn,
|
||||||
|
table="current_state_delta_stream",
|
||||||
|
values=[
|
||||||
|
{
|
||||||
|
"stream_id": max_stream_order,
|
||||||
|
"room_id": room_id,
|
||||||
|
"type": key[0],
|
||||||
|
"state_key": key[1],
|
||||||
|
"event_id": ev_id,
|
||||||
|
"prev_event_id": to_delete.get(key, None),
|
||||||
|
}
|
||||||
|
for key, ev_id in state_deltas.iteritems()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self._curr_state_delta_stream_cache.entity_has_changed(
|
||||||
|
room_id, max_stream_order,
|
||||||
|
)
|
||||||
|
|
||||||
# Invalidate the various caches
|
# Invalidate the various caches
|
||||||
|
|
||||||
# Figure out the changes of membership to invalidate the
|
# Figure out the changes of membership to invalidate the
|
||||||
@ -742,11 +771,7 @@ class EventsStore(SQLBaseStore):
|
|||||||
# and which we have added, then we invlidate the caches for all
|
# and which we have added, then we invlidate the caches for all
|
||||||
# those users.
|
# those users.
|
||||||
members_changed = set(
|
members_changed = set(
|
||||||
state_key for ev_type, state_key in to_delete.iterkeys()
|
state_key for ev_type, state_key in state_deltas
|
||||||
if ev_type == EventTypes.Member
|
|
||||||
)
|
|
||||||
members_changed.update(
|
|
||||||
state_key for ev_type, state_key in to_insert.iterkeys()
|
|
||||||
if ev_type == EventTypes.Member
|
if ev_type == EventTypes.Member
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -755,6 +780,11 @@ class EventsStore(SQLBaseStore):
|
|||||||
txn, self.get_rooms_for_user, (member,)
|
txn, self.get_rooms_for_user, (member,)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for host in set(get_domain_from_id(u) for u in members_changed):
|
||||||
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.is_host_joined, (room_id, host)
|
||||||
|
)
|
||||||
|
|
||||||
self._invalidate_cache_and_stream(
|
self._invalidate_cache_and_stream(
|
||||||
txn, self.get_users_in_room, (room_id,)
|
txn, self.get_users_in_room, (room_id,)
|
||||||
)
|
)
|
||||||
@ -1119,6 +1149,7 @@ class EventsStore(SQLBaseStore):
|
|||||||
}
|
}
|
||||||
for event, _ in events_and_contexts
|
for event, _ in events_and_contexts
|
||||||
for auth_id, _ in event.auth_events
|
for auth_id, _ in event.auth_events
|
||||||
|
if event.is_state()
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1418,7 +1449,7 @@ class EventsStore(SQLBaseStore):
|
|||||||
]
|
]
|
||||||
|
|
||||||
rows = self._new_transaction(
|
rows = self._new_transaction(
|
||||||
conn, "do_fetch", [], None, self._fetch_event_rows, event_ids
|
conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
row_dict = {
|
row_dict = {
|
||||||
@ -2243,6 +2274,24 @@ class EventsStore(SQLBaseStore):
|
|||||||
|
|
||||||
defer.returnValue((int(res["topological_ordering"]), int(res["stream_ordering"])))
|
defer.returnValue((int(res["topological_ordering"]), int(res["stream_ordering"])))
|
||||||
|
|
||||||
|
def get_max_current_state_delta_stream_id(self):
|
||||||
|
return self._stream_id_gen.get_current_token()
|
||||||
|
|
||||||
|
def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
|
||||||
|
def get_all_updated_current_state_deltas_txn(txn):
|
||||||
|
sql = """
|
||||||
|
SELECT stream_id, room_id, type, state_key, event_id
|
||||||
|
FROM current_state_delta_stream
|
||||||
|
WHERE ? < stream_id AND stream_id <= ?
|
||||||
|
ORDER BY stream_id ASC LIMIT ?
|
||||||
|
"""
|
||||||
|
txn.execute(sql, (from_token, to_token, limit))
|
||||||
|
return txn.fetchall()
|
||||||
|
return self.runInteraction(
|
||||||
|
"get_all_updated_current_state_deltas",
|
||||||
|
get_all_updated_current_state_deltas_txn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
AllNewEventsResult = namedtuple("AllNewEventsResult", [
|
AllNewEventsResult = namedtuple("AllNewEventsResult", [
|
||||||
"new_forward_events", "new_backfill_events",
|
"new_forward_events", "new_backfill_events",
|
||||||
|
@ -19,6 +19,7 @@ from ._base import SQLBaseStore
|
|||||||
from synapse.api.errors import SynapseError, Codes
|
from synapse.api.errors import SynapseError, Codes
|
||||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||||
|
|
||||||
|
from canonicaljson import encode_canonical_json
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
|
|
||||||
|
|
||||||
@ -46,11 +47,20 @@ class FilteringStore(SQLBaseStore):
|
|||||||
defer.returnValue(json.loads(str(def_json).decode("utf-8")))
|
defer.returnValue(json.loads(str(def_json).decode("utf-8")))
|
||||||
|
|
||||||
def add_user_filter(self, user_localpart, user_filter):
|
def add_user_filter(self, user_localpart, user_filter):
|
||||||
def_json = json.dumps(user_filter).encode("utf-8")
|
def_json = encode_canonical_json(user_filter)
|
||||||
|
|
||||||
# Need an atomic transaction to SELECT the maximal ID so far then
|
# Need an atomic transaction to SELECT the maximal ID so far then
|
||||||
# INSERT a new one
|
# INSERT a new one
|
||||||
def _do_txn(txn):
|
def _do_txn(txn):
|
||||||
|
sql = (
|
||||||
|
"SELECT filter_id FROM user_filters "
|
||||||
|
"WHERE user_id = ? AND filter_json = ?"
|
||||||
|
)
|
||||||
|
txn.execute(sql, (user_localpart, def_json))
|
||||||
|
filter_id_response = txn.fetchone()
|
||||||
|
if filter_id_response is not None:
|
||||||
|
return filter_id_response[0]
|
||||||
|
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT MAX(filter_id) FROM user_filters "
|
"SELECT MAX(filter_id) FROM user_filters "
|
||||||
"WHERE user_id = ?"
|
"WHERE user_id = ?"
|
||||||
|
@ -30,13 +30,16 @@ class MediaRepositoryStore(SQLBaseStore):
|
|||||||
return self._simple_select_one(
|
return self._simple_select_one(
|
||||||
"local_media_repository",
|
"local_media_repository",
|
||||||
{"media_id": media_id},
|
{"media_id": media_id},
|
||||||
("media_type", "media_length", "upload_name", "created_ts"),
|
(
|
||||||
|
"media_type", "media_length", "upload_name", "created_ts",
|
||||||
|
"quarantined_by", "url_cache",
|
||||||
|
),
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
desc="get_local_media",
|
desc="get_local_media",
|
||||||
)
|
)
|
||||||
|
|
||||||
def store_local_media(self, media_id, media_type, time_now_ms, upload_name,
|
def store_local_media(self, media_id, media_type, time_now_ms, upload_name,
|
||||||
media_length, user_id):
|
media_length, user_id, url_cache=None):
|
||||||
return self._simple_insert(
|
return self._simple_insert(
|
||||||
"local_media_repository",
|
"local_media_repository",
|
||||||
{
|
{
|
||||||
@ -46,6 +49,7 @@ class MediaRepositoryStore(SQLBaseStore):
|
|||||||
"upload_name": upload_name,
|
"upload_name": upload_name,
|
||||||
"media_length": media_length,
|
"media_length": media_length,
|
||||||
"user_id": user_id.to_string(),
|
"user_id": user_id.to_string(),
|
||||||
|
"url_cache": url_cache,
|
||||||
},
|
},
|
||||||
desc="store_local_media",
|
desc="store_local_media",
|
||||||
)
|
)
|
||||||
@ -138,7 +142,7 @@ class MediaRepositoryStore(SQLBaseStore):
|
|||||||
{"media_origin": origin, "media_id": media_id},
|
{"media_origin": origin, "media_id": media_id},
|
||||||
(
|
(
|
||||||
"media_type", "media_length", "upload_name", "created_ts",
|
"media_type", "media_length", "upload_name", "created_ts",
|
||||||
"filesystem_id",
|
"filesystem_id", "quarantined_by",
|
||||||
),
|
),
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
desc="get_cached_remote_media",
|
desc="get_cached_remote_media",
|
||||||
|
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# Remember to update this number every time a change is made to database
|
# Remember to update this number every time a change is made to database
|
||||||
# schema files, so the users will be informed on server restarts.
|
# schema files, so the users will be informed on server restarts.
|
||||||
SCHEMA_VERSION = 41
|
SCHEMA_VERSION = 43
|
||||||
|
|
||||||
dir_path = os.path.abspath(os.path.dirname(__file__))
|
dir_path = os.path.abspath(os.path.dirname(__file__))
|
||||||
|
|
||||||
|
@ -49,7 +49,7 @@ def _load_rules(rawrules, enabled_map):
|
|||||||
|
|
||||||
|
|
||||||
class PushRuleStore(SQLBaseStore):
|
class PushRuleStore(SQLBaseStore):
|
||||||
@cachedInlineCallbacks()
|
@cachedInlineCallbacks(max_entries=5000)
|
||||||
def get_push_rules_for_user(self, user_id):
|
def get_push_rules_for_user(self, user_id):
|
||||||
rows = yield self._simple_select_list(
|
rows = yield self._simple_select_list(
|
||||||
table="push_rules",
|
table="push_rules",
|
||||||
@ -73,7 +73,7 @@ class PushRuleStore(SQLBaseStore):
|
|||||||
|
|
||||||
defer.returnValue(rules)
|
defer.returnValue(rules)
|
||||||
|
|
||||||
@cachedInlineCallbacks()
|
@cachedInlineCallbacks(max_entries=5000)
|
||||||
def get_push_rules_enabled_for_user(self, user_id):
|
def get_push_rules_enabled_for_user(self, user_id):
|
||||||
results = yield self._simple_select_list(
|
results = yield self._simple_select_list(
|
||||||
table="push_rules_enable",
|
table="push_rules_enable",
|
||||||
|
@ -45,7 +45,9 @@ class ReceiptsStore(SQLBaseStore):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Returns an ObservableDeferred
|
# Returns an ObservableDeferred
|
||||||
res = self.get_users_with_read_receipts_in_room.cache.get((room_id,), None)
|
res = self.get_users_with_read_receipts_in_room.cache.get(
|
||||||
|
room_id, None, update_metrics=False,
|
||||||
|
)
|
||||||
|
|
||||||
if res:
|
if res:
|
||||||
if isinstance(res, defer.Deferred) and res.called:
|
if isinstance(res, defer.Deferred) and res.called:
|
||||||
|
@ -24,6 +24,7 @@ from .engines import PostgresEngine, Sqlite3Engine
|
|||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
import ujson as json
|
import ujson as json
|
||||||
|
import re
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -507,3 +508,98 @@ class RoomStore(SQLBaseStore):
|
|||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
|
|
||||||
|
@cached(max_entries=10000)
|
||||||
|
def is_room_blocked(self, room_id):
|
||||||
|
return self._simple_select_one_onecol(
|
||||||
|
table="blocked_rooms",
|
||||||
|
keyvalues={
|
||||||
|
"room_id": room_id,
|
||||||
|
},
|
||||||
|
retcol="1",
|
||||||
|
allow_none=True,
|
||||||
|
desc="is_room_blocked",
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def block_room(self, room_id, user_id):
|
||||||
|
yield self._simple_insert(
|
||||||
|
table="blocked_rooms",
|
||||||
|
values={
|
||||||
|
"room_id": room_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
},
|
||||||
|
desc="block_room",
|
||||||
|
)
|
||||||
|
self.is_room_blocked.invalidate((room_id,))
|
||||||
|
|
||||||
|
def quarantine_media_ids_in_room(self, room_id, quarantined_by):
|
||||||
|
"""For a room loops through all events with media and quarantines
|
||||||
|
the associated media
|
||||||
|
"""
|
||||||
|
def _get_media_ids_in_room(txn):
|
||||||
|
mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
|
||||||
|
|
||||||
|
next_token = self.get_current_events_token() + 1
|
||||||
|
|
||||||
|
total_media_quarantined = 0
|
||||||
|
|
||||||
|
while next_token:
|
||||||
|
sql = """
|
||||||
|
SELECT stream_ordering, content FROM events
|
||||||
|
WHERE room_id = ?
|
||||||
|
AND stream_ordering < ?
|
||||||
|
AND contains_url = ? AND outlier = ?
|
||||||
|
ORDER BY stream_ordering DESC
|
||||||
|
LIMIT ?
|
||||||
|
"""
|
||||||
|
txn.execute(sql, (room_id, next_token, True, False, 100))
|
||||||
|
|
||||||
|
next_token = None
|
||||||
|
local_media_mxcs = []
|
||||||
|
remote_media_mxcs = []
|
||||||
|
for stream_ordering, content_json in txn:
|
||||||
|
next_token = stream_ordering
|
||||||
|
content = json.loads(content_json)
|
||||||
|
|
||||||
|
content_url = content.get("url")
|
||||||
|
thumbnail_url = content.get("info", {}).get("thumbnail_url")
|
||||||
|
|
||||||
|
for url in (content_url, thumbnail_url):
|
||||||
|
if not url:
|
||||||
|
continue
|
||||||
|
matches = mxc_re.match(url)
|
||||||
|
if matches:
|
||||||
|
hostname = matches.group(1)
|
||||||
|
media_id = matches.group(2)
|
||||||
|
if hostname == self.hostname:
|
||||||
|
local_media_mxcs.append(media_id)
|
||||||
|
else:
|
||||||
|
remote_media_mxcs.append((hostname, media_id))
|
||||||
|
|
||||||
|
# Now update all the tables to set the quarantined_by flag
|
||||||
|
|
||||||
|
txn.executemany("""
|
||||||
|
UPDATE local_media_repository
|
||||||
|
SET quarantined_by = ?
|
||||||
|
WHERE media_id = ?
|
||||||
|
""", ((quarantined_by, media_id) for media_id in local_media_mxcs))
|
||||||
|
|
||||||
|
txn.executemany(
|
||||||
|
"""
|
||||||
|
UPDATE remote_media_cache
|
||||||
|
SET quarantined_by = ?
|
||||||
|
WHERE media_origin AND media_id = ?
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
(quarantined_by, origin, media_id)
|
||||||
|
for origin, media_id in remote_media_mxcs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
total_media_quarantined += len(local_media_mxcs)
|
||||||
|
total_media_quarantined += len(remote_media_mxcs)
|
||||||
|
|
||||||
|
return total_media_quarantined
|
||||||
|
|
||||||
|
return self.runInteraction("get_media_ids_in_room", _get_media_ids_in_room)
|
||||||
|
@ -18,6 +18,7 @@ from twisted.internet import defer
|
|||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
|
from synapse.util.async import Linearizer
|
||||||
from synapse.util.caches import intern_string
|
from synapse.util.caches import intern_string
|
||||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
||||||
from synapse.util.stringutils import to_ascii
|
from synapse.util.stringutils import to_ascii
|
||||||
@ -392,7 +393,8 @@ class RoomMemberStore(SQLBaseStore):
|
|||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_joined_users_from_state(self, room_id, state_group, state_ids):
|
def get_joined_users_from_state(self, room_id, state_entry):
|
||||||
|
state_group = state_entry.state_group
|
||||||
if not state_group:
|
if not state_group:
|
||||||
# If state_group is None it means it has yet to be assigned a
|
# If state_group is None it means it has yet to be assigned a
|
||||||
# state group, i.e. we need to make sure that calls with a state_group
|
# state group, i.e. we need to make sure that calls with a state_group
|
||||||
@ -401,7 +403,7 @@ class RoomMemberStore(SQLBaseStore):
|
|||||||
state_group = object()
|
state_group = object()
|
||||||
|
|
||||||
return self._get_joined_users_from_context(
|
return self._get_joined_users_from_context(
|
||||||
room_id, state_group, state_ids,
|
room_id, state_group, state_entry.state, context=state_entry,
|
||||||
)
|
)
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True,
|
@cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True,
|
||||||
@ -499,42 +501,40 @@ class RoomMemberStore(SQLBaseStore):
|
|||||||
|
|
||||||
defer.returnValue(users_in_room)
|
defer.returnValue(users_in_room)
|
||||||
|
|
||||||
def is_host_joined(self, room_id, host, state_group, state_ids):
|
@cachedInlineCallbacks(max_entries=10000)
|
||||||
if not state_group:
|
def is_host_joined(self, room_id, host):
|
||||||
# If state_group is None it means it has yet to be assigned a
|
if '%' in host or '_' in host:
|
||||||
# state group, i.e. we need to make sure that calls with a state_group
|
raise Exception("Invalid host name")
|
||||||
# of None don't hit previous cached calls with a None state_group.
|
|
||||||
# To do this we set the state_group to a new object as object() != object()
|
|
||||||
state_group = object()
|
|
||||||
|
|
||||||
return self._is_host_joined(
|
sql = """
|
||||||
room_id, host, state_group, state_ids
|
SELECT state_key FROM current_state_events AS c
|
||||||
)
|
INNER JOIN room_memberships USING (event_id)
|
||||||
|
WHERE membership = 'join'
|
||||||
|
AND type = 'm.room.member'
|
||||||
|
AND c.room_id = ?
|
||||||
|
AND state_key LIKE ?
|
||||||
|
LIMIT 1
|
||||||
|
"""
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=3)
|
# We do need to be careful to ensure that host doesn't have any wild cards
|
||||||
def _is_host_joined(self, room_id, host, state_group, current_state_ids):
|
# in it, but we checked above for known ones and we'll check below that
|
||||||
# We don't use `state_group`, its there so that we can cache based
|
# the returned user actually has the correct domain.
|
||||||
# on it. However, its important that its never None, since two current_state's
|
like_clause = "%:" + host
|
||||||
# with a state_group of None are likely to be different.
|
|
||||||
# See bulk_get_push_rules_for_room for how we work around this.
|
|
||||||
assert state_group is not None
|
|
||||||
|
|
||||||
for (etype, state_key), event_id in current_state_ids.items():
|
rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause)
|
||||||
if etype == EventTypes.Member:
|
|
||||||
try:
|
|
||||||
if get_domain_from_id(state_key) != host:
|
|
||||||
continue
|
|
||||||
except:
|
|
||||||
logger.warn("state_key not user_id: %s", state_key)
|
|
||||||
continue
|
|
||||||
|
|
||||||
event = yield self.get_event(event_id, allow_none=True)
|
|
||||||
if event and event.content["membership"] == Membership.JOIN:
|
|
||||||
defer.returnValue(True)
|
|
||||||
|
|
||||||
|
if not rows:
|
||||||
defer.returnValue(False)
|
defer.returnValue(False)
|
||||||
|
|
||||||
def get_joined_hosts(self, room_id, state_group, state_ids):
|
user_id = rows[0][0]
|
||||||
|
if get_domain_from_id(user_id) != host:
|
||||||
|
# This can only happen if the host name has something funky in it
|
||||||
|
raise Exception("Invalid host name")
|
||||||
|
|
||||||
|
defer.returnValue(True)
|
||||||
|
|
||||||
|
def get_joined_hosts(self, room_id, state_entry):
|
||||||
|
state_group = state_entry.state_group
|
||||||
if not state_group:
|
if not state_group:
|
||||||
# If state_group is None it means it has yet to be assigned a
|
# If state_group is None it means it has yet to be assigned a
|
||||||
# state group, i.e. we need to make sure that calls with a state_group
|
# state group, i.e. we need to make sure that calls with a state_group
|
||||||
@ -543,33 +543,20 @@ class RoomMemberStore(SQLBaseStore):
|
|||||||
state_group = object()
|
state_group = object()
|
||||||
|
|
||||||
return self._get_joined_hosts(
|
return self._get_joined_hosts(
|
||||||
room_id, state_group, state_ids
|
room_id, state_group, state_entry.state, state_entry=state_entry,
|
||||||
)
|
)
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
|
@cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
|
||||||
def _get_joined_hosts(self, room_id, state_group, current_state_ids):
|
# @defer.inlineCallbacks
|
||||||
|
def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry):
|
||||||
# We don't use `state_group`, its there so that we can cache based
|
# We don't use `state_group`, its there so that we can cache based
|
||||||
# on it. However, its important that its never None, since two current_state's
|
# on it. However, its important that its never None, since two current_state's
|
||||||
# with a state_group of None are likely to be different.
|
# with a state_group of None are likely to be different.
|
||||||
# See bulk_get_push_rules_for_room for how we work around this.
|
# See bulk_get_push_rules_for_room for how we work around this.
|
||||||
assert state_group is not None
|
assert state_group is not None
|
||||||
|
|
||||||
joined_hosts = set()
|
cache = self._get_joined_hosts_cache(room_id)
|
||||||
for etype, state_key in current_state_ids:
|
joined_hosts = yield cache.get_destinations(state_entry)
|
||||||
if etype == EventTypes.Member:
|
|
||||||
try:
|
|
||||||
host = get_domain_from_id(state_key)
|
|
||||||
except:
|
|
||||||
logger.warn("state_key not user_id: %s", state_key)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if host in joined_hosts:
|
|
||||||
continue
|
|
||||||
|
|
||||||
event_id = current_state_ids[(etype, state_key)]
|
|
||||||
event = yield self.get_event(event_id, allow_none=True)
|
|
||||||
if event and event.content["membership"] == Membership.JOIN:
|
|
||||||
joined_hosts.add(intern_string(host))
|
|
||||||
|
|
||||||
defer.returnValue(joined_hosts)
|
defer.returnValue(joined_hosts)
|
||||||
|
|
||||||
@ -647,3 +634,75 @@ class RoomMemberStore(SQLBaseStore):
|
|||||||
yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME)
|
yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME)
|
||||||
|
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
|
@cached(max_entries=10000, iterable=True)
|
||||||
|
def _get_joined_hosts_cache(self, room_id):
|
||||||
|
return _JoinedHostsCache(self, room_id)
|
||||||
|
|
||||||
|
|
||||||
|
class _JoinedHostsCache(object):
|
||||||
|
"""Cache for joined hosts in a room that is optimised to handle updates
|
||||||
|
via state deltas.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, store, room_id):
|
||||||
|
self.store = store
|
||||||
|
self.room_id = room_id
|
||||||
|
|
||||||
|
self.hosts_to_joined_users = {}
|
||||||
|
|
||||||
|
self.state_group = object()
|
||||||
|
|
||||||
|
self.linearizer = Linearizer("_JoinedHostsCache")
|
||||||
|
|
||||||
|
self._len = 0
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_destinations(self, state_entry):
|
||||||
|
"""Get set of destinations for a state entry
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_entry(synapse.state._StateCacheEntry)
|
||||||
|
"""
|
||||||
|
if state_entry.state_group == self.state_group:
|
||||||
|
defer.returnValue(frozenset(self.hosts_to_joined_users))
|
||||||
|
|
||||||
|
with (yield self.linearizer.queue(())):
|
||||||
|
if state_entry.state_group == self.state_group:
|
||||||
|
pass
|
||||||
|
elif state_entry.prev_group == self.state_group:
|
||||||
|
for (typ, state_key), event_id in state_entry.delta_ids.iteritems():
|
||||||
|
if typ != EventTypes.Member:
|
||||||
|
continue
|
||||||
|
|
||||||
|
host = intern_string(get_domain_from_id(state_key))
|
||||||
|
user_id = state_key
|
||||||
|
known_joins = self.hosts_to_joined_users.setdefault(host, set())
|
||||||
|
|
||||||
|
event = yield self.store.get_event(event_id)
|
||||||
|
if event.membership == Membership.JOIN:
|
||||||
|
known_joins.add(user_id)
|
||||||
|
else:
|
||||||
|
known_joins.discard(user_id)
|
||||||
|
|
||||||
|
if not known_joins:
|
||||||
|
self.hosts_to_joined_users.pop(host, None)
|
||||||
|
else:
|
||||||
|
joined_users = yield self.store.get_joined_users_from_state(
|
||||||
|
self.room_id, state_entry,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.hosts_to_joined_users = {}
|
||||||
|
for user_id in joined_users:
|
||||||
|
host = intern_string(get_domain_from_id(user_id))
|
||||||
|
self.hosts_to_joined_users.setdefault(host, set()).add(user_id)
|
||||||
|
|
||||||
|
if state_entry.state_group:
|
||||||
|
self.state_group = state_entry.state_group
|
||||||
|
else:
|
||||||
|
self.state_group = object()
|
||||||
|
self._len = sum(len(v) for v in self.hosts_to_joined_users.itervalues())
|
||||||
|
defer.returnValue(frozenset(self.hosts_to_joined_users))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self._len
|
||||||
|
26
synapse/storage/schema/delta/42/current_state_delta.sql
Normal file
26
synapse/storage/schema/delta/42/current_state_delta.sql
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
/* Copyright 2017 Vector Creations Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
CREATE TABLE current_state_delta_stream (
|
||||||
|
stream_id BIGINT NOT NULL,
|
||||||
|
room_id TEXT NOT NULL,
|
||||||
|
type TEXT NOT NULL,
|
||||||
|
state_key TEXT NOT NULL,
|
||||||
|
event_id TEXT, -- Is null if the key was removed
|
||||||
|
prev_event_id TEXT -- Is null if the key was added
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX current_state_delta_stream_idx ON current_state_delta_stream(stream_id);
|
33
synapse/storage/schema/delta/42/device_list_last_id.sql
Normal file
33
synapse/storage/schema/delta/42/device_list_last_id.sql
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
/* Copyright 2017 Vector Creations Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
-- Table of last stream_id that we sent to destination for user_id. This is
|
||||||
|
-- used to fill out the `prev_id` fields of outbound device list updates.
|
||||||
|
CREATE TABLE device_lists_outbound_last_success (
|
||||||
|
destination TEXT NOT NULL,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
stream_id BIGINT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
INSERT INTO device_lists_outbound_last_success
|
||||||
|
SELECT destination, user_id, coalesce(max(stream_id), 0) as stream_id
|
||||||
|
FROM device_lists_outbound_pokes
|
||||||
|
WHERE sent = (1 = 1) -- sqlite doesn't have inbuilt boolean values
|
||||||
|
GROUP BY destination, user_id;
|
||||||
|
|
||||||
|
CREATE INDEX device_lists_outbound_last_success_idx ON device_lists_outbound_last_success(
|
||||||
|
destination, user_id, stream_id
|
||||||
|
);
|
17
synapse/storage/schema/delta/42/event_auth_state_only.sql
Normal file
17
synapse/storage/schema/delta/42/event_auth_state_only.sql
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
/* Copyright 2017 Vector Creations Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
INSERT INTO background_updates (update_name, progress_json) VALUES
|
||||||
|
('event_auth_state_only', '{}');
|
84
synapse/storage/schema/delta/42/user_dir.py
Normal file
84
synapse/storage/schema/delta/42/user_dir.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
# Copyright 2017 Vector Creations Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from synapse.storage.prepare_database import get_statements
|
||||||
|
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
BOTH_TABLES = """
|
||||||
|
CREATE TABLE user_directory_stream_pos (
|
||||||
|
Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
|
||||||
|
stream_id BIGINT,
|
||||||
|
CHECK (Lock='X')
|
||||||
|
);
|
||||||
|
|
||||||
|
INSERT INTO user_directory_stream_pos (stream_id) VALUES (null);
|
||||||
|
|
||||||
|
CREATE TABLE user_directory (
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
room_id TEXT NOT NULL, -- A room_id that we know the user is joined to
|
||||||
|
display_name TEXT,
|
||||||
|
avatar_url TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX user_directory_room_idx ON user_directory(room_id);
|
||||||
|
CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id);
|
||||||
|
|
||||||
|
CREATE TABLE users_in_pubic_room (
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
room_id TEXT NOT NULL -- A room_id that we know is public
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX users_in_pubic_room_room_idx ON users_in_pubic_room(room_id);
|
||||||
|
CREATE UNIQUE INDEX users_in_pubic_room_user_idx ON users_in_pubic_room(user_id);
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
POSTGRES_TABLE = """
|
||||||
|
CREATE TABLE user_directory_search (
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
vector tsvector
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX user_directory_search_fts_idx ON user_directory_search USING gin(vector);
|
||||||
|
CREATE UNIQUE INDEX user_directory_search_user_idx ON user_directory_search(user_id);
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
SQLITE_TABLE = """
|
||||||
|
CREATE VIRTUAL TABLE user_directory_search
|
||||||
|
USING fts4 ( user_id, value );
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def run_create(cur, database_engine, *args, **kwargs):
|
||||||
|
for statement in get_statements(BOTH_TABLES.splitlines()):
|
||||||
|
cur.execute(statement)
|
||||||
|
|
||||||
|
if isinstance(database_engine, PostgresEngine):
|
||||||
|
for statement in get_statements(POSTGRES_TABLE.splitlines()):
|
||||||
|
cur.execute(statement)
|
||||||
|
elif isinstance(database_engine, Sqlite3Engine):
|
||||||
|
for statement in get_statements(SQLITE_TABLE.splitlines()):
|
||||||
|
cur.execute(statement)
|
||||||
|
else:
|
||||||
|
raise Exception("Unrecognized database engine")
|
||||||
|
|
||||||
|
|
||||||
|
def run_upgrade(*args, **kwargs):
|
||||||
|
pass
|
21
synapse/storage/schema/delta/43/blocked_rooms.sql
Normal file
21
synapse/storage/schema/delta/43/blocked_rooms.sql
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
/* Copyright 2017 Vector Creations Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
CREATE TABLE blocked_rooms (
|
||||||
|
room_id TEXT NOT NULL,
|
||||||
|
user_id TEXT NOT NULL -- Admin who blocked the room
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX blocked_rooms_idx ON blocked_rooms(room_id);
|
17
synapse/storage/schema/delta/43/quarantine_media.sql
Normal file
17
synapse/storage/schema/delta/43/quarantine_media.sql
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
/* Copyright 2017 Vector Creations Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
ALTER TABLE local_media_repository ADD COLUMN quarantined_by TEXT;
|
||||||
|
ALTER TABLE remote_media_cache ADD COLUMN quarantined_by TEXT;
|
16
synapse/storage/schema/delta/43/url_cache.sql
Normal file
16
synapse/storage/schema/delta/43/url_cache.sql
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
/* Copyright 2017 Vector Creations Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
ALTER TABLE local_media_repository ADD COLUMN url_cache TEXT;
|
33
synapse/storage/schema/delta/43/user_share.sql
Normal file
33
synapse/storage/schema/delta/43/user_share.sql
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
/* Copyright 2017 Vector Creations Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
-- Table keeping track of who shares a room with who. We only keep track
|
||||||
|
-- of this for local users, so `user_id` is local users only (but we do keep track
|
||||||
|
-- of which remote users share a room)
|
||||||
|
CREATE TABLE users_who_share_rooms (
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
other_user_id TEXT NOT NULL,
|
||||||
|
room_id TEXT NOT NULL,
|
||||||
|
share_private BOOLEAN NOT NULL -- is the shared room private? i.e. they share a private room
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX users_who_share_rooms_u_idx ON users_who_share_rooms(user_id, other_user_id);
|
||||||
|
CREATE INDEX users_who_share_rooms_r_idx ON users_who_share_rooms(room_id);
|
||||||
|
CREATE INDEX users_who_share_rooms_o_idx ON users_who_share_rooms(other_user_id);
|
||||||
|
|
||||||
|
|
||||||
|
-- Make sure that we popualte the table initially
|
||||||
|
UPDATE user_directory_stream_pos SET stream_id = NULL;
|
@ -20,6 +20,7 @@ from synapse.util.stringutils import to_ascii
|
|||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@ -29,6 +30,16 @@ logger = logging.getLogger(__name__)
|
|||||||
MAX_STATE_DELTA_HOPS = 100
|
MAX_STATE_DELTA_HOPS = 100
|
||||||
|
|
||||||
|
|
||||||
|
class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))):
|
||||||
|
"""Return type of get_state_group_delta that implements __len__, which lets
|
||||||
|
us use the itrable flag when caching
|
||||||
|
"""
|
||||||
|
__slots__ = []
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.delta_ids) if self.delta_ids else 0
|
||||||
|
|
||||||
|
|
||||||
class StateStore(SQLBaseStore):
|
class StateStore(SQLBaseStore):
|
||||||
""" Keeps track of the state at a given event.
|
""" Keeps track of the state at a given event.
|
||||||
|
|
||||||
@ -98,6 +109,46 @@ class StateStore(SQLBaseStore):
|
|||||||
_get_current_state_ids_txn,
|
_get_current_state_ids_txn,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cached(max_entries=10000, iterable=True)
|
||||||
|
def get_state_group_delta(self, state_group):
|
||||||
|
"""Given a state group try to return a previous group and a delta between
|
||||||
|
the old and the new.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(prev_group, delta_ids), where both may be None.
|
||||||
|
"""
|
||||||
|
def _get_state_group_delta_txn(txn):
|
||||||
|
prev_group = self._simple_select_one_onecol_txn(
|
||||||
|
txn,
|
||||||
|
table="state_group_edges",
|
||||||
|
keyvalues={
|
||||||
|
"state_group": state_group,
|
||||||
|
},
|
||||||
|
retcol="prev_state_group",
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not prev_group:
|
||||||
|
return _GetStateGroupDelta(None, None)
|
||||||
|
|
||||||
|
delta_ids = self._simple_select_list_txn(
|
||||||
|
txn,
|
||||||
|
table="state_groups_state",
|
||||||
|
keyvalues={
|
||||||
|
"state_group": state_group,
|
||||||
|
},
|
||||||
|
retcols=("type", "state_key", "event_id",)
|
||||||
|
)
|
||||||
|
|
||||||
|
return _GetStateGroupDelta(prev_group, {
|
||||||
|
(row["type"], row["state_key"]): row["event_id"]
|
||||||
|
for row in delta_ids
|
||||||
|
})
|
||||||
|
return self.runInteraction(
|
||||||
|
"get_state_group_delta",
|
||||||
|
_get_state_group_delta_txn,
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_state_groups_ids(self, room_id, event_ids):
|
def get_state_groups_ids(self, room_id, event_ids):
|
||||||
if not event_ids:
|
if not event_ids:
|
||||||
@ -184,6 +235,19 @@ class StateStore(SQLBaseStore):
|
|||||||
# We persist as a delta if we can, while also ensuring the chain
|
# We persist as a delta if we can, while also ensuring the chain
|
||||||
# of deltas isn't tooo long, as otherwise read performance degrades.
|
# of deltas isn't tooo long, as otherwise read performance degrades.
|
||||||
if context.prev_group:
|
if context.prev_group:
|
||||||
|
is_in_db = self._simple_select_one_onecol_txn(
|
||||||
|
txn,
|
||||||
|
table="state_groups",
|
||||||
|
keyvalues={"id": context.prev_group},
|
||||||
|
retcol="id",
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
if not is_in_db:
|
||||||
|
raise Exception(
|
||||||
|
"Trying to persist state with unpersisted prev_group: %r"
|
||||||
|
% (context.prev_group,)
|
||||||
|
)
|
||||||
|
|
||||||
potential_hops = self._count_state_group_hops_txn(
|
potential_hops = self._count_state_group_hops_txn(
|
||||||
txn, context.prev_group
|
txn, context.prev_group
|
||||||
)
|
)
|
||||||
@ -251,6 +315,12 @@ class StateStore(SQLBaseStore):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for event_id, state_group_id in state_groups.iteritems():
|
||||||
|
txn.call_after(
|
||||||
|
self._get_state_group_for_event.prefill,
|
||||||
|
(event_id,), state_group_id
|
||||||
|
)
|
||||||
|
|
||||||
def _count_state_group_hops_txn(self, txn, state_group):
|
def _count_state_group_hops_txn(self, txn, state_group):
|
||||||
"""Given a state group, count how many hops there are in the tree.
|
"""Given a state group, count how many hops there are in the tree.
|
||||||
|
|
||||||
@ -520,8 +590,8 @@ class StateStore(SQLBaseStore):
|
|||||||
state_map = yield self.get_state_ids_for_events([event_id], types)
|
state_map = yield self.get_state_ids_for_events([event_id], types)
|
||||||
defer.returnValue(state_map[event_id])
|
defer.returnValue(state_map[event_id])
|
||||||
|
|
||||||
@cached(num_args=2, max_entries=50000)
|
@cached(max_entries=50000)
|
||||||
def _get_state_group_for_event(self, room_id, event_id):
|
def _get_state_group_for_event(self, event_id):
|
||||||
return self._simple_select_one_onecol(
|
return self._simple_select_one_onecol(
|
||||||
table="event_to_state_groups",
|
table="event_to_state_groups",
|
||||||
keyvalues={
|
keyvalues={
|
||||||
@ -563,20 +633,22 @@ class StateStore(SQLBaseStore):
|
|||||||
where a `state_key` of `None` matches all state_keys for the
|
where a `state_key` of `None` matches all state_keys for the
|
||||||
`type`.
|
`type`.
|
||||||
"""
|
"""
|
||||||
is_all, state_dict_ids = self._state_group_cache.get(group)
|
is_all, known_absent, state_dict_ids = self._state_group_cache.get(group)
|
||||||
|
|
||||||
type_to_key = {}
|
type_to_key = {}
|
||||||
missing_types = set()
|
missing_types = set()
|
||||||
|
|
||||||
for typ, state_key in types:
|
for typ, state_key in types:
|
||||||
|
key = (typ, state_key)
|
||||||
if state_key is None:
|
if state_key is None:
|
||||||
type_to_key[typ] = None
|
type_to_key[typ] = None
|
||||||
missing_types.add((typ, state_key))
|
missing_types.add(key)
|
||||||
else:
|
else:
|
||||||
if type_to_key.get(typ, object()) is not None:
|
if type_to_key.get(typ, object()) is not None:
|
||||||
type_to_key.setdefault(typ, set()).add(state_key)
|
type_to_key.setdefault(typ, set()).add(state_key)
|
||||||
|
|
||||||
if (typ, state_key) not in state_dict_ids:
|
if key not in state_dict_ids and key not in known_absent:
|
||||||
missing_types.add((typ, state_key))
|
missing_types.add(key)
|
||||||
|
|
||||||
sentinel = object()
|
sentinel = object()
|
||||||
|
|
||||||
@ -590,7 +662,7 @@ class StateStore(SQLBaseStore):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
got_all = not (missing_types or types is None)
|
got_all = is_all or not missing_types
|
||||||
|
|
||||||
return {
|
return {
|
||||||
k: v for k, v in state_dict_ids.iteritems()
|
k: v for k, v in state_dict_ids.iteritems()
|
||||||
@ -607,7 +679,7 @@ class StateStore(SQLBaseStore):
|
|||||||
Args:
|
Args:
|
||||||
group: The state group to lookup
|
group: The state group to lookup
|
||||||
"""
|
"""
|
||||||
is_all, state_dict_ids = self._state_group_cache.get(group)
|
is_all, _, state_dict_ids = self._state_group_cache.get(group)
|
||||||
|
|
||||||
return state_dict_ids, is_all
|
return state_dict_ids, is_all
|
||||||
|
|
||||||
@ -624,7 +696,7 @@ class StateStore(SQLBaseStore):
|
|||||||
missing_groups = []
|
missing_groups = []
|
||||||
if types is not None:
|
if types is not None:
|
||||||
for group in set(groups):
|
for group in set(groups):
|
||||||
state_dict_ids, missing_types, got_all = self._get_some_state_from_cache(
|
state_dict_ids, _, got_all = self._get_some_state_from_cache(
|
||||||
group, types
|
group, types
|
||||||
)
|
)
|
||||||
results[group] = state_dict_ids
|
results[group] = state_dict_ids
|
||||||
@ -653,18 +725,6 @@ class StateStore(SQLBaseStore):
|
|||||||
# Now we want to update the cache with all the things we fetched
|
# Now we want to update the cache with all the things we fetched
|
||||||
# from the database.
|
# from the database.
|
||||||
for group, group_state_dict in group_to_state_dict.iteritems():
|
for group, group_state_dict in group_to_state_dict.iteritems():
|
||||||
if types:
|
|
||||||
# We delibrately put key -> None mappings into the cache to
|
|
||||||
# cache absence of the key, on the assumption that if we've
|
|
||||||
# explicitly asked for some types then we will probably ask
|
|
||||||
# for them again.
|
|
||||||
state_dict = {
|
|
||||||
(intern_string(etype), intern_string(state_key)): None
|
|
||||||
for (etype, state_key) in types
|
|
||||||
}
|
|
||||||
state_dict.update(results[group])
|
|
||||||
results[group] = state_dict
|
|
||||||
else:
|
|
||||||
state_dict = results[group]
|
state_dict = results[group]
|
||||||
|
|
||||||
state_dict.update(
|
state_dict.update(
|
||||||
@ -677,17 +737,9 @@ class StateStore(SQLBaseStore):
|
|||||||
key=group,
|
key=group,
|
||||||
value=state_dict,
|
value=state_dict,
|
||||||
full=(types is None),
|
full=(types is None),
|
||||||
|
known_absent=types,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Remove all the entries with None values. The None values were just
|
|
||||||
# used for bookkeeping in the cache.
|
|
||||||
for group, state_dict in results.iteritems():
|
|
||||||
results[group] = {
|
|
||||||
key: event_id
|
|
||||||
for key, event_id in state_dict.iteritems()
|
|
||||||
if event_id
|
|
||||||
}
|
|
||||||
|
|
||||||
defer.returnValue(results)
|
defer.returnValue(results)
|
||||||
|
|
||||||
def get_next_state_group(self):
|
def get_next_state_group(self):
|
||||||
|
743
synapse/storage/user_directory.py
Normal file
743
synapse/storage/user_directory.py
Normal file
@ -0,0 +1,743 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2017 Vector Creations Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from ._base import SQLBaseStore
|
||||||
|
|
||||||
|
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
||||||
|
from synapse.api.constants import EventTypes, JoinRules
|
||||||
|
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||||
|
from synapse.types import get_domain_from_id, get_localpart_from_id
|
||||||
|
|
||||||
|
import re
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class UserDirectoryStore(SQLBaseStore):
|
||||||
|
@cachedInlineCallbacks(cache_context=True)
|
||||||
|
def is_room_world_readable_or_publicly_joinable(self, room_id, cache_context):
|
||||||
|
"""Check if the room is either world_readable or publically joinable
|
||||||
|
"""
|
||||||
|
current_state_ids = yield self.get_current_state_ids(
|
||||||
|
room_id, on_invalidate=cache_context.invalidate
|
||||||
|
)
|
||||||
|
|
||||||
|
join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
|
||||||
|
if join_rules_id:
|
||||||
|
join_rule_ev = yield self.get_event(join_rules_id, allow_none=True)
|
||||||
|
if join_rule_ev:
|
||||||
|
if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
|
||||||
|
defer.returnValue(True)
|
||||||
|
|
||||||
|
hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
|
||||||
|
if hist_vis_id:
|
||||||
|
hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True)
|
||||||
|
if hist_vis_ev:
|
||||||
|
if hist_vis_ev.content.get("history_visibility") == "world_readable":
|
||||||
|
defer.returnValue(True)
|
||||||
|
|
||||||
|
defer.returnValue(False)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def add_users_to_public_room(self, room_id, user_ids):
|
||||||
|
"""Add user to the list of users in public rooms
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id (str): A room_id that all users are in that is world_readable
|
||||||
|
or publically joinable
|
||||||
|
user_ids (list(str)): Users to add
|
||||||
|
"""
|
||||||
|
yield self._simple_insert_many(
|
||||||
|
table="users_in_pubic_room",
|
||||||
|
values=[
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"room_id": room_id,
|
||||||
|
}
|
||||||
|
for user_id in user_ids
|
||||||
|
],
|
||||||
|
desc="add_users_to_public_room"
|
||||||
|
)
|
||||||
|
for user_id in user_ids:
|
||||||
|
self.get_user_in_public_room.invalidate((user_id,))
|
||||||
|
|
||||||
|
def add_profiles_to_user_dir(self, room_id, users_with_profile):
|
||||||
|
"""Add profiles to the user directory
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id (str): A room_id that all users are joined to
|
||||||
|
users_with_profile (dict): Users to add to directory in the form of
|
||||||
|
mapping of user_id -> ProfileInfo
|
||||||
|
"""
|
||||||
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
|
# We weight the loclpart most highly, then display name and finally
|
||||||
|
# server name
|
||||||
|
sql = """
|
||||||
|
INSERT INTO user_directory_search(user_id, vector)
|
||||||
|
VALUES (?,
|
||||||
|
setweight(to_tsvector('english', ?), 'A')
|
||||||
|
|| setweight(to_tsvector('english', ?), 'D')
|
||||||
|
|| setweight(to_tsvector('english', COALESCE(?, '')), 'B')
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
args = (
|
||||||
|
(
|
||||||
|
user_id, get_localpart_from_id(user_id), get_domain_from_id(user_id),
|
||||||
|
profile.display_name,
|
||||||
|
)
|
||||||
|
for user_id, profile in users_with_profile.iteritems()
|
||||||
|
)
|
||||||
|
elif isinstance(self.database_engine, Sqlite3Engine):
|
||||||
|
sql = """
|
||||||
|
INSERT INTO user_directory_search(user_id, value)
|
||||||
|
VALUES (?,?)
|
||||||
|
"""
|
||||||
|
args = (
|
||||||
|
(
|
||||||
|
user_id,
|
||||||
|
"%s %s" % (user_id, p.display_name,) if p.display_name else user_id
|
||||||
|
)
|
||||||
|
for user_id, p in users_with_profile.iteritems()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# This should be unreachable.
|
||||||
|
raise Exception("Unrecognized database engine")
|
||||||
|
|
||||||
|
def _add_profiles_to_user_dir_txn(txn):
|
||||||
|
txn.executemany(sql, args)
|
||||||
|
self._simple_insert_many_txn(
|
||||||
|
txn,
|
||||||
|
table="user_directory",
|
||||||
|
values=[
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"room_id": room_id,
|
||||||
|
"display_name": profile.display_name,
|
||||||
|
"avatar_url": profile.avatar_url,
|
||||||
|
}
|
||||||
|
for user_id, profile in users_with_profile.iteritems()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
for user_id in users_with_profile:
|
||||||
|
txn.call_after(
|
||||||
|
self.get_user_in_directory.invalidate, (user_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.runInteraction(
|
||||||
|
"add_profiles_to_user_dir", _add_profiles_to_user_dir_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def update_user_in_user_dir(self, user_id, room_id):
|
||||||
|
yield self._simple_update_one(
|
||||||
|
table="user_directory",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
updatevalues={"room_id": room_id},
|
||||||
|
desc="update_user_in_user_dir",
|
||||||
|
)
|
||||||
|
self.get_user_in_directory.invalidate((user_id,))
|
||||||
|
|
||||||
|
def update_profile_in_user_dir(self, user_id, display_name, avatar_url, room_id):
|
||||||
|
def _update_profile_in_user_dir_txn(txn):
|
||||||
|
new_entry = self._simple_upsert_txn(
|
||||||
|
txn,
|
||||||
|
table="user_directory",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
insertion_values={"room_id": room_id},
|
||||||
|
values={"display_name": display_name, "avatar_url": avatar_url},
|
||||||
|
lock=False, # We're only inserter
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
|
# We weight the loclpart most highly, then display name and finally
|
||||||
|
# server name
|
||||||
|
if new_entry:
|
||||||
|
sql = """
|
||||||
|
INSERT INTO user_directory_search(user_id, vector)
|
||||||
|
VALUES (?,
|
||||||
|
setweight(to_tsvector('english', ?), 'A')
|
||||||
|
|| setweight(to_tsvector('english', ?), 'D')
|
||||||
|
|| setweight(to_tsvector('english', COALESCE(?, '')), 'B')
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
txn.execute(
|
||||||
|
sql,
|
||||||
|
(
|
||||||
|
user_id, get_localpart_from_id(user_id),
|
||||||
|
get_domain_from_id(user_id), display_name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sql = """
|
||||||
|
UPDATE user_directory_search
|
||||||
|
SET vector = setweight(to_tsvector('english', ?), 'A')
|
||||||
|
|| setweight(to_tsvector('english', ?), 'D')
|
||||||
|
|| setweight(to_tsvector('english', COALESCE(?, '')), 'B')
|
||||||
|
WHERE user_id = ?
|
||||||
|
"""
|
||||||
|
txn.execute(
|
||||||
|
sql,
|
||||||
|
(
|
||||||
|
get_localpart_from_id(user_id), get_domain_from_id(user_id),
|
||||||
|
display_name, user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif isinstance(self.database_engine, Sqlite3Engine):
|
||||||
|
value = "%s %s" % (user_id, display_name,) if display_name else user_id
|
||||||
|
self._simple_upsert_txn(
|
||||||
|
txn,
|
||||||
|
table="user_directory_search",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
values={"value": value},
|
||||||
|
lock=False, # We're only inserter
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# This should be unreachable.
|
||||||
|
raise Exception("Unrecognized database engine")
|
||||||
|
|
||||||
|
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
|
||||||
|
|
||||||
|
return self.runInteraction(
|
||||||
|
"update_profile_in_user_dir", _update_profile_in_user_dir_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def update_user_in_public_user_list(self, user_id, room_id):
|
||||||
|
yield self._simple_update_one(
|
||||||
|
table="users_in_pubic_room",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
updatevalues={"room_id": room_id},
|
||||||
|
desc="update_user_in_public_user_list",
|
||||||
|
)
|
||||||
|
self.get_user_in_public_room.invalidate((user_id,))
|
||||||
|
|
||||||
|
def remove_from_user_dir(self, user_id):
|
||||||
|
def _remove_from_user_dir_txn(txn):
|
||||||
|
self._simple_delete_txn(
|
||||||
|
txn,
|
||||||
|
table="user_directory",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
)
|
||||||
|
self._simple_delete_txn(
|
||||||
|
txn,
|
||||||
|
table="user_directory_search",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
)
|
||||||
|
self._simple_delete_txn(
|
||||||
|
txn,
|
||||||
|
table="users_in_pubic_room",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
)
|
||||||
|
txn.call_after(
|
||||||
|
self.get_user_in_directory.invalidate, (user_id,)
|
||||||
|
)
|
||||||
|
txn.call_after(
|
||||||
|
self.get_user_in_public_room.invalidate, (user_id,)
|
||||||
|
)
|
||||||
|
return self.runInteraction(
|
||||||
|
"remove_from_user_dir", _remove_from_user_dir_txn,
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def remove_from_user_in_public_room(self, user_id):
|
||||||
|
yield self._simple_delete(
|
||||||
|
table="users_in_pubic_room",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
desc="remove_from_user_in_public_room",
|
||||||
|
)
|
||||||
|
self.get_user_in_public_room.invalidate((user_id,))
|
||||||
|
|
||||||
|
def get_users_in_public_due_to_room(self, room_id):
|
||||||
|
"""Get all user_ids that are in the room directory becuase they're
|
||||||
|
in the given room_id
|
||||||
|
"""
|
||||||
|
return self._simple_select_onecol(
|
||||||
|
table="users_in_pubic_room",
|
||||||
|
keyvalues={"room_id": room_id},
|
||||||
|
retcol="user_id",
|
||||||
|
desc="get_users_in_public_due_to_room",
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_users_in_dir_due_to_room(self, room_id):
|
||||||
|
"""Get all user_ids that are in the room directory becuase they're
|
||||||
|
in the given room_id
|
||||||
|
"""
|
||||||
|
user_ids_dir = yield self._simple_select_onecol(
|
||||||
|
table="user_directory",
|
||||||
|
keyvalues={"room_id": room_id},
|
||||||
|
retcol="user_id",
|
||||||
|
desc="get_users_in_dir_due_to_room",
|
||||||
|
)
|
||||||
|
|
||||||
|
user_ids_pub = yield self._simple_select_onecol(
|
||||||
|
table="users_in_pubic_room",
|
||||||
|
keyvalues={"room_id": room_id},
|
||||||
|
retcol="user_id",
|
||||||
|
desc="get_users_in_dir_due_to_room",
|
||||||
|
)
|
||||||
|
|
||||||
|
user_ids_share = yield self._simple_select_onecol(
|
||||||
|
table="users_who_share_rooms",
|
||||||
|
keyvalues={"room_id": room_id},
|
||||||
|
retcol="user_id",
|
||||||
|
desc="get_users_in_dir_due_to_room",
|
||||||
|
)
|
||||||
|
|
||||||
|
user_ids = set(user_ids_dir)
|
||||||
|
user_ids.update(user_ids_pub)
|
||||||
|
user_ids.update(user_ids_share)
|
||||||
|
|
||||||
|
defer.returnValue(user_ids)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_all_rooms(self):
|
||||||
|
"""Get all room_ids we've ever known about, in ascending order of "size"
|
||||||
|
"""
|
||||||
|
sql = """
|
||||||
|
SELECT room_id FROM current_state_events
|
||||||
|
GROUP BY room_id
|
||||||
|
ORDER BY count(*) ASC
|
||||||
|
"""
|
||||||
|
rows = yield self._execute("get_all_rooms", None, sql)
|
||||||
|
defer.returnValue([room_id for room_id, in rows])
|
||||||
|
|
||||||
|
def add_users_who_share_room(self, room_id, share_private, user_id_tuples):
|
||||||
|
"""Insert entries into the users_who_share_rooms table. The first
|
||||||
|
user should be a local user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id (str)
|
||||||
|
share_private (bool): Is the room private
|
||||||
|
user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
|
||||||
|
"""
|
||||||
|
def _add_users_who_share_room_txn(txn):
|
||||||
|
self._simple_insert_many_txn(
|
||||||
|
txn,
|
||||||
|
table="users_who_share_rooms",
|
||||||
|
values=[
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"other_user_id": other_user_id,
|
||||||
|
"room_id": room_id,
|
||||||
|
"share_private": share_private,
|
||||||
|
}
|
||||||
|
for user_id, other_user_id in user_id_tuples
|
||||||
|
],
|
||||||
|
)
|
||||||
|
for user_id, other_user_id in user_id_tuples:
|
||||||
|
txn.call_after(
|
||||||
|
self.get_users_who_share_room_from_dir.invalidate,
|
||||||
|
(user_id,),
|
||||||
|
)
|
||||||
|
txn.call_after(
|
||||||
|
self.get_if_users_share_a_room.invalidate,
|
||||||
|
(user_id, other_user_id),
|
||||||
|
)
|
||||||
|
return self.runInteraction(
|
||||||
|
"add_users_who_share_room", _add_users_who_share_room_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_users_who_share_room(self, room_id, share_private, user_id_sets):
|
||||||
|
"""Updates entries in the users_who_share_rooms table. The first
|
||||||
|
user should be a local user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id (str)
|
||||||
|
share_private (bool): Is the room private
|
||||||
|
user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
|
||||||
|
"""
|
||||||
|
def _update_users_who_share_room_txn(txn):
|
||||||
|
sql = """
|
||||||
|
UPDATE users_who_share_rooms
|
||||||
|
SET room_id = ?, share_private = ?
|
||||||
|
WHERE user_id = ? AND other_user_id = ?
|
||||||
|
"""
|
||||||
|
txn.executemany(
|
||||||
|
sql,
|
||||||
|
(
|
||||||
|
(room_id, share_private, uid, oid)
|
||||||
|
for uid, oid in user_id_sets
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for user_id, other_user_id in user_id_sets:
|
||||||
|
txn.call_after(
|
||||||
|
self.get_users_who_share_room_from_dir.invalidate,
|
||||||
|
(user_id,),
|
||||||
|
)
|
||||||
|
txn.call_after(
|
||||||
|
self.get_if_users_share_a_room.invalidate,
|
||||||
|
(user_id, other_user_id),
|
||||||
|
)
|
||||||
|
return self.runInteraction(
|
||||||
|
"update_users_who_share_room", _update_users_who_share_room_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
def remove_user_who_share_room(self, user_id, other_user_id):
|
||||||
|
"""Deletes entries in the users_who_share_rooms table. The first
|
||||||
|
user should be a local user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id (str)
|
||||||
|
share_private (bool): Is the room private
|
||||||
|
user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
|
||||||
|
"""
|
||||||
|
def _remove_user_who_share_room_txn(txn):
|
||||||
|
self._simple_delete_txn(
|
||||||
|
txn,
|
||||||
|
table="users_who_share_rooms",
|
||||||
|
keyvalues={
|
||||||
|
"user_id": user_id,
|
||||||
|
"other_user_id": other_user_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
txn.call_after(
|
||||||
|
self.get_users_who_share_room_from_dir.invalidate,
|
||||||
|
(user_id,),
|
||||||
|
)
|
||||||
|
txn.call_after(
|
||||||
|
self.get_if_users_share_a_room.invalidate,
|
||||||
|
(user_id, other_user_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.runInteraction(
|
||||||
|
"remove_user_who_share_room", _remove_user_who_share_room_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
@cached(max_entries=500000)
|
||||||
|
def get_if_users_share_a_room(self, user_id, other_user_id):
|
||||||
|
"""Gets if users share a room.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): Must be a local user_id
|
||||||
|
other_user_id (str)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool|None: None if they don't share a room, otherwise whether they
|
||||||
|
share a private room or not.
|
||||||
|
"""
|
||||||
|
return self._simple_select_one_onecol(
|
||||||
|
table="users_who_share_rooms",
|
||||||
|
keyvalues={
|
||||||
|
"user_id": user_id,
|
||||||
|
"other_user_id": other_user_id,
|
||||||
|
},
|
||||||
|
retcol="share_private",
|
||||||
|
allow_none=True,
|
||||||
|
desc="get_if_users_share_a_room",
|
||||||
|
)
|
||||||
|
|
||||||
|
@cachedInlineCallbacks(max_entries=500000, iterable=True)
|
||||||
|
def get_users_who_share_room_from_dir(self, user_id):
|
||||||
|
"""Returns the set of users who share a room with `user_id`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id(str): Must be a local user
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: user_id -> share_private mapping
|
||||||
|
"""
|
||||||
|
rows = yield self._simple_select_list(
|
||||||
|
table="users_who_share_rooms",
|
||||||
|
keyvalues={
|
||||||
|
"user_id": user_id,
|
||||||
|
},
|
||||||
|
retcols=("other_user_id", "share_private",),
|
||||||
|
desc="get_users_who_share_room_with_user",
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue({
|
||||||
|
row["other_user_id"]: row["share_private"]
|
||||||
|
for row in rows
|
||||||
|
})
|
||||||
|
|
||||||
|
def get_users_in_share_dir_with_room_id(self, user_id, room_id):
|
||||||
|
"""Get all user tuples that are in the users_who_share_rooms due to the
|
||||||
|
given room_id.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[(user_id, other_user_id)]: where one of the two will match the given
|
||||||
|
user_id.
|
||||||
|
"""
|
||||||
|
sql = """
|
||||||
|
SELECT user_id, other_user_id FROM users_who_share_rooms
|
||||||
|
WHERE room_id = ? AND (user_id = ? OR other_user_id = ?)
|
||||||
|
"""
|
||||||
|
return self._execute(
|
||||||
|
"get_users_in_share_dir_with_room_id", None, sql, room_id, user_id, user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_rooms_in_common_for_users(self, user_id, other_user_id):
|
||||||
|
"""Given two user_ids find out the list of rooms they share.
|
||||||
|
"""
|
||||||
|
sql = """
|
||||||
|
SELECT room_id FROM (
|
||||||
|
SELECT c.room_id FROM current_state_events AS c
|
||||||
|
INNER JOIN room_memberships USING (event_id)
|
||||||
|
WHERE type = 'm.room.member'
|
||||||
|
AND membership = 'join'
|
||||||
|
AND state_key = ?
|
||||||
|
) AS f1 INNER JOIN (
|
||||||
|
SELECT c.room_id FROM current_state_events AS c
|
||||||
|
INNER JOIN room_memberships USING (event_id)
|
||||||
|
WHERE type = 'm.room.member'
|
||||||
|
AND membership = 'join'
|
||||||
|
AND state_key = ?
|
||||||
|
) f2 USING (room_id)
|
||||||
|
"""
|
||||||
|
|
||||||
|
rows = yield self._execute(
|
||||||
|
"get_rooms_in_common_for_users", None, sql, user_id, other_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue([room_id for room_id, in rows])
|
||||||
|
|
||||||
|
def delete_all_from_user_dir(self):
|
||||||
|
"""Delete the entire user directory
|
||||||
|
"""
|
||||||
|
def _delete_all_from_user_dir_txn(txn):
|
||||||
|
txn.execute("DELETE FROM user_directory")
|
||||||
|
txn.execute("DELETE FROM user_directory_search")
|
||||||
|
txn.execute("DELETE FROM users_in_pubic_room")
|
||||||
|
txn.execute("DELETE FROM users_who_share_rooms")
|
||||||
|
txn.call_after(self.get_user_in_directory.invalidate_all)
|
||||||
|
txn.call_after(self.get_user_in_public_room.invalidate_all)
|
||||||
|
txn.call_after(self.get_users_who_share_room_from_dir.invalidate_all)
|
||||||
|
txn.call_after(self.get_if_users_share_a_room.invalidate_all)
|
||||||
|
return self.runInteraction(
|
||||||
|
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
@cached()
|
||||||
|
def get_user_in_directory(self, user_id):
|
||||||
|
return self._simple_select_one(
|
||||||
|
table="user_directory",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
retcols=("room_id", "display_name", "avatar_url",),
|
||||||
|
allow_none=True,
|
||||||
|
desc="get_user_in_directory",
|
||||||
|
)
|
||||||
|
|
||||||
|
@cached()
|
||||||
|
def get_user_in_public_room(self, user_id):
|
||||||
|
return self._simple_select_one(
|
||||||
|
table="users_in_pubic_room",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
retcols=("room_id",),
|
||||||
|
allow_none=True,
|
||||||
|
desc="get_user_in_public_room",
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_user_directory_stream_pos(self):
|
||||||
|
return self._simple_select_one_onecol(
|
||||||
|
table="user_directory_stream_pos",
|
||||||
|
keyvalues={},
|
||||||
|
retcol="stream_id",
|
||||||
|
desc="get_user_directory_stream_pos",
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_user_directory_stream_pos(self, stream_id):
|
||||||
|
return self._simple_update_one(
|
||||||
|
table="user_directory_stream_pos",
|
||||||
|
keyvalues={},
|
||||||
|
updatevalues={"stream_id": stream_id},
|
||||||
|
desc="update_user_directory_stream_pos",
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_current_state_deltas(self, prev_stream_id):
|
||||||
|
prev_stream_id = int(prev_stream_id)
|
||||||
|
if not self._curr_state_delta_stream_cache.has_any_entity_changed(prev_stream_id):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_current_state_deltas_txn(txn):
|
||||||
|
# First we calculate the max stream id that will give us less than
|
||||||
|
# N results.
|
||||||
|
# We arbitarily limit to 100 stream_id entries to ensure we don't
|
||||||
|
# select toooo many.
|
||||||
|
sql = """
|
||||||
|
SELECT stream_id, count(*)
|
||||||
|
FROM current_state_delta_stream
|
||||||
|
WHERE stream_id > ?
|
||||||
|
GROUP BY stream_id
|
||||||
|
ORDER BY stream_id ASC
|
||||||
|
LIMIT 100
|
||||||
|
"""
|
||||||
|
txn.execute(sql, (prev_stream_id,))
|
||||||
|
|
||||||
|
total = 0
|
||||||
|
max_stream_id = prev_stream_id
|
||||||
|
for max_stream_id, count in txn:
|
||||||
|
total += count
|
||||||
|
if total > 100:
|
||||||
|
# We arbitarily limit to 100 entries to ensure we don't
|
||||||
|
# select toooo many.
|
||||||
|
break
|
||||||
|
|
||||||
|
# Now actually get the deltas
|
||||||
|
sql = """
|
||||||
|
SELECT stream_id, room_id, type, state_key, event_id, prev_event_id
|
||||||
|
FROM current_state_delta_stream
|
||||||
|
WHERE ? < stream_id AND stream_id <= ?
|
||||||
|
ORDER BY stream_id ASC
|
||||||
|
"""
|
||||||
|
txn.execute(sql, (prev_stream_id, max_stream_id,))
|
||||||
|
return self.cursor_to_dict(txn)
|
||||||
|
|
||||||
|
return self.runInteraction(
|
||||||
|
"get_current_state_deltas", get_current_state_deltas_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_max_stream_id_in_current_state_deltas(self):
|
||||||
|
return self._simple_select_one_onecol(
|
||||||
|
table="current_state_delta_stream",
|
||||||
|
keyvalues={},
|
||||||
|
retcol="COALESCE(MAX(stream_id), -1)",
|
||||||
|
desc="get_max_stream_id_in_current_state_deltas",
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def search_user_dir(self, user_id, search_term, limit):
|
||||||
|
"""Searches for users in directory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict of the form::
|
||||||
|
|
||||||
|
{
|
||||||
|
"limited": <bool>, # whether there were more results or not
|
||||||
|
"results": [ # Ordered by best match first
|
||||||
|
{
|
||||||
|
"user_id": <user_id>,
|
||||||
|
"display_name": <display_name>,
|
||||||
|
"avatar_url": <avatar_url>
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
|
full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
|
||||||
|
|
||||||
|
# We order by rank and then if they have profile info
|
||||||
|
# The ranking algorithm is hand tweaked for "best" results. Broadly
|
||||||
|
# the idea is we give a higher weight to exact matches.
|
||||||
|
# The array of numbers are the weights for the various part of the
|
||||||
|
# search: (domain, _, display name, localpart)
|
||||||
|
sql = """
|
||||||
|
SELECT d.user_id, display_name, avatar_url
|
||||||
|
FROM user_directory_search
|
||||||
|
INNER JOIN user_directory AS d USING (user_id)
|
||||||
|
LEFT JOIN users_in_pubic_room AS p USING (user_id)
|
||||||
|
LEFT JOIN (
|
||||||
|
SELECT other_user_id AS user_id FROM users_who_share_rooms
|
||||||
|
WHERE user_id = ? AND share_private
|
||||||
|
) AS s USING (user_id)
|
||||||
|
WHERE
|
||||||
|
(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)
|
||||||
|
AND vector @@ to_tsquery('english', ?)
|
||||||
|
ORDER BY
|
||||||
|
(CASE WHEN s.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
|
||||||
|
* (CASE WHEN display_name IS NOT NULL THEN 1.2 ELSE 1.0 END)
|
||||||
|
* (CASE WHEN avatar_url IS NOT NULL THEN 1.2 ELSE 1.0 END)
|
||||||
|
* (
|
||||||
|
3 * ts_rank_cd(
|
||||||
|
'{0.1, 0.1, 0.9, 1.0}',
|
||||||
|
vector,
|
||||||
|
to_tsquery('english', ?),
|
||||||
|
8
|
||||||
|
)
|
||||||
|
+ ts_rank_cd(
|
||||||
|
'{0.1, 0.1, 0.9, 1.0}',
|
||||||
|
vector,
|
||||||
|
to_tsquery('english', ?),
|
||||||
|
8
|
||||||
|
)
|
||||||
|
)
|
||||||
|
DESC,
|
||||||
|
display_name IS NULL,
|
||||||
|
avatar_url IS NULL
|
||||||
|
LIMIT ?
|
||||||
|
"""
|
||||||
|
args = (user_id, full_query, exact_query, prefix_query, limit + 1,)
|
||||||
|
elif isinstance(self.database_engine, Sqlite3Engine):
|
||||||
|
search_query = _parse_query_sqlite(search_term)
|
||||||
|
|
||||||
|
sql = """
|
||||||
|
SELECT d.user_id, display_name, avatar_url
|
||||||
|
FROM user_directory_search
|
||||||
|
INNER JOIN user_directory AS d USING (user_id)
|
||||||
|
LEFT JOIN users_in_pubic_room AS p USING (user_id)
|
||||||
|
LEFT JOIN (
|
||||||
|
SELECT other_user_id AS user_id FROM users_who_share_rooms
|
||||||
|
WHERE user_id = ? AND share_private
|
||||||
|
) AS s USING (user_id)
|
||||||
|
WHERE
|
||||||
|
(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)
|
||||||
|
AND value MATCH ?
|
||||||
|
ORDER BY
|
||||||
|
rank(matchinfo(user_directory_search)) DESC,
|
||||||
|
display_name IS NULL,
|
||||||
|
avatar_url IS NULL
|
||||||
|
LIMIT ?
|
||||||
|
"""
|
||||||
|
args = (user_id, search_query, limit + 1)
|
||||||
|
else:
|
||||||
|
# This should be unreachable.
|
||||||
|
raise Exception("Unrecognized database engine")
|
||||||
|
|
||||||
|
results = yield self._execute(
|
||||||
|
"search_user_dir", self.cursor_to_dict, sql, *args
|
||||||
|
)
|
||||||
|
|
||||||
|
limited = len(results) > limit
|
||||||
|
|
||||||
|
defer.returnValue({
|
||||||
|
"limited": limited,
|
||||||
|
"results": results,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_query_sqlite(search_term):
|
||||||
|
"""Takes a plain unicode string from the user and converts it into a form
|
||||||
|
that can be passed to database.
|
||||||
|
We use this so that we can add prefix matching, which isn't something
|
||||||
|
that is supported by default.
|
||||||
|
|
||||||
|
We specifically add both a prefix and non prefix matching term so that
|
||||||
|
exact matches get ranked higher.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Pull out the individual words, discarding any non-word characters.
|
||||||
|
results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
|
||||||
|
return " & ".join("(%s* | %s)" % (result, result,) for result in results)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_query_postgres(search_term):
|
||||||
|
"""Takes a plain unicode string from the user and converts it into a form
|
||||||
|
that can be passed to database.
|
||||||
|
We use this so that we can add prefix matching, which isn't something
|
||||||
|
that is supported by default.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Pull out the individual words, discarding any non-word characters.
|
||||||
|
results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
|
||||||
|
|
||||||
|
both = " & ".join("(%s:* | %s)" % (result, result,) for result in results)
|
||||||
|
exact = " & ".join("%s" % (result,) for result in results)
|
||||||
|
prefix = " & ".join("%s:*" % (result,) for result in results)
|
||||||
|
|
||||||
|
return both, exact, prefix
|
@ -62,6 +62,13 @@ def get_domain_from_id(string):
|
|||||||
return string[idx + 1:]
|
return string[idx + 1:]
|
||||||
|
|
||||||
|
|
||||||
|
def get_localpart_from_id(string):
|
||||||
|
idx = string.find(":")
|
||||||
|
if idx == -1:
|
||||||
|
raise SynapseError(400, "Invalid ID: %r" % (string,))
|
||||||
|
return string[1:idx]
|
||||||
|
|
||||||
|
|
||||||
class DomainSpecificString(
|
class DomainSpecificString(
|
||||||
namedtuple("DomainSpecificString", ("localpart", "domain"))
|
namedtuple("DomainSpecificString", ("localpart", "domain"))
|
||||||
):
|
):
|
||||||
|
@ -16,7 +16,7 @@
|
|||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
import os
|
import os
|
||||||
|
|
||||||
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
|
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5))
|
||||||
|
|
||||||
metrics = synapse.metrics.get_metrics_for("synapse.util.caches")
|
metrics = synapse.metrics.get_metrics_for("synapse.util.caches")
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@ import logging
|
|||||||
|
|
||||||
from synapse.util.async import ObservableDeferred
|
from synapse.util.async import ObservableDeferred
|
||||||
from synapse.util import unwrapFirstError, logcontext
|
from synapse.util import unwrapFirstError, logcontext
|
||||||
|
from synapse.util.caches import CACHE_SIZE_FACTOR
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
|
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
|
||||||
from synapse.util.stringutils import to_ascii
|
from synapse.util.stringutils import to_ascii
|
||||||
@ -25,7 +26,6 @@ from . import register_cache
|
|||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
import os
|
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import threading
|
import threading
|
||||||
@ -37,9 +37,6 @@ logger = logging.getLogger(__name__)
|
|||||||
_CacheSentinel = object()
|
_CacheSentinel = object()
|
||||||
|
|
||||||
|
|
||||||
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
|
|
||||||
|
|
||||||
|
|
||||||
class CacheEntry(object):
|
class CacheEntry(object):
|
||||||
__slots__ = [
|
__slots__ = [
|
||||||
"deferred", "sequence", "callbacks", "invalidated"
|
"deferred", "sequence", "callbacks", "invalidated"
|
||||||
@ -404,6 +401,7 @@ class CacheDescriptor(_CacheDescriptorBase):
|
|||||||
|
|
||||||
wrapped.invalidate_all = cache.invalidate_all
|
wrapped.invalidate_all = cache.invalidate_all
|
||||||
wrapped.cache = cache
|
wrapped.cache = cache
|
||||||
|
wrapped.num_args = self.num_args
|
||||||
|
|
||||||
obj.__dict__[self.orig.__name__] = wrapped
|
obj.__dict__[self.orig.__name__] = wrapped
|
||||||
|
|
||||||
@ -451,8 +449,9 @@ class CacheListDescriptor(_CacheDescriptorBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __get__(self, obj, objtype=None):
|
def __get__(self, obj, objtype=None):
|
||||||
|
cached_method = getattr(obj, self.cached_method_name)
|
||||||
cache = getattr(obj, self.cached_method_name).cache
|
cache = cached_method.cache
|
||||||
|
num_args = cached_method.num_args
|
||||||
|
|
||||||
@functools.wraps(self.orig)
|
@functools.wraps(self.orig)
|
||||||
def wrapped(*args, **kwargs):
|
def wrapped(*args, **kwargs):
|
||||||
@ -469,12 +468,23 @@ class CacheListDescriptor(_CacheDescriptorBase):
|
|||||||
results = {}
|
results = {}
|
||||||
cached_defers = {}
|
cached_defers = {}
|
||||||
missing = []
|
missing = []
|
||||||
for arg in list_args:
|
|
||||||
key = list(keyargs)
|
|
||||||
key[self.list_pos] = arg
|
|
||||||
|
|
||||||
|
# If the cache takes a single arg then that is used as the key,
|
||||||
|
# otherwise a tuple is used.
|
||||||
|
if num_args == 1:
|
||||||
|
def cache_get(arg):
|
||||||
|
return cache.get(arg, callback=invalidate_callback)
|
||||||
|
else:
|
||||||
|
key = list(keyargs)
|
||||||
|
|
||||||
|
def cache_get(arg):
|
||||||
|
key[self.list_pos] = arg
|
||||||
|
return cache.get(tuple(key), callback=invalidate_callback)
|
||||||
|
|
||||||
|
for arg in list_args:
|
||||||
try:
|
try:
|
||||||
res = cache.get(tuple(key), callback=invalidate_callback)
|
res = cache_get(arg)
|
||||||
|
|
||||||
if not isinstance(res, ObservableDeferred):
|
if not isinstance(res, ObservableDeferred):
|
||||||
results[arg] = res
|
results[arg] = res
|
||||||
elif not res.has_succeeded():
|
elif not res.has_succeeded():
|
||||||
@ -505,6 +515,17 @@ class CacheListDescriptor(_CacheDescriptorBase):
|
|||||||
|
|
||||||
observer = ObservableDeferred(observer)
|
observer = ObservableDeferred(observer)
|
||||||
|
|
||||||
|
if num_args == 1:
|
||||||
|
cache.set(
|
||||||
|
arg, observer,
|
||||||
|
callback=invalidate_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
def invalidate(f, key):
|
||||||
|
cache.invalidate(key)
|
||||||
|
return f
|
||||||
|
observer.addErrback(invalidate, arg)
|
||||||
|
else:
|
||||||
key = list(keyargs)
|
key = list(keyargs)
|
||||||
key[self.list_pos] = arg
|
key[self.list_pos] = arg
|
||||||
cache.set(
|
cache.set(
|
||||||
|
@ -23,7 +23,17 @@ import logging
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "value"))):
|
class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "value"))):
|
||||||
|
"""Returned when getting an entry from the cache
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
full (bool): Whether the cache has the full or dict or just some keys.
|
||||||
|
If not full then not all requested keys will necessarily be present
|
||||||
|
in `value`
|
||||||
|
known_absent (set): Keys that were looked up in the dict and were not
|
||||||
|
there.
|
||||||
|
value (dict): The full or partial dict value
|
||||||
|
"""
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.value)
|
return len(self.value)
|
||||||
|
|
||||||
@ -58,21 +68,31 @@ class DictionaryCache(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get(self, key, dict_keys=None):
|
def get(self, key, dict_keys=None):
|
||||||
|
"""Fetch an entry out of the cache
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key
|
||||||
|
dict_key(list): If given a set of keys then return only those keys
|
||||||
|
that exist in the cache.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DictionaryEntry
|
||||||
|
"""
|
||||||
entry = self.cache.get(key, self.sentinel)
|
entry = self.cache.get(key, self.sentinel)
|
||||||
if entry is not self.sentinel:
|
if entry is not self.sentinel:
|
||||||
self.metrics.inc_hits()
|
self.metrics.inc_hits()
|
||||||
|
|
||||||
if dict_keys is None:
|
if dict_keys is None:
|
||||||
return DictionaryEntry(entry.full, dict(entry.value))
|
return DictionaryEntry(entry.full, entry.known_absent, dict(entry.value))
|
||||||
else:
|
else:
|
||||||
return DictionaryEntry(entry.full, {
|
return DictionaryEntry(entry.full, entry.known_absent, {
|
||||||
k: entry.value[k]
|
k: entry.value[k]
|
||||||
for k in dict_keys
|
for k in dict_keys
|
||||||
if k in entry.value
|
if k in entry.value
|
||||||
})
|
})
|
||||||
|
|
||||||
self.metrics.inc_misses()
|
self.metrics.inc_misses()
|
||||||
return DictionaryEntry(False, {})
|
return DictionaryEntry(False, set(), {})
|
||||||
|
|
||||||
def invalidate(self, key):
|
def invalidate(self, key):
|
||||||
self.check_thread()
|
self.check_thread()
|
||||||
@ -87,19 +107,34 @@ class DictionaryCache(object):
|
|||||||
self.sequence += 1
|
self.sequence += 1
|
||||||
self.cache.clear()
|
self.cache.clear()
|
||||||
|
|
||||||
def update(self, sequence, key, value, full=False):
|
def update(self, sequence, key, value, full=False, known_absent=None):
|
||||||
|
"""Updates the entry in the cache
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sequence
|
||||||
|
key
|
||||||
|
value (dict): The value to update the cache with.
|
||||||
|
full (bool): Whether the given value is the full dict, or just a
|
||||||
|
partial subset there of. If not full then any existing entries
|
||||||
|
for the key will be updated.
|
||||||
|
known_absent (set): Set of keys that we know don't exist in the full
|
||||||
|
dict.
|
||||||
|
"""
|
||||||
self.check_thread()
|
self.check_thread()
|
||||||
if self.sequence == sequence:
|
if self.sequence == sequence:
|
||||||
# Only update the cache if the caches sequence number matches the
|
# Only update the cache if the caches sequence number matches the
|
||||||
# number that the cache had before the SELECT was started (SYN-369)
|
# number that the cache had before the SELECT was started (SYN-369)
|
||||||
|
if known_absent is None:
|
||||||
|
known_absent = set()
|
||||||
if full:
|
if full:
|
||||||
self._insert(key, value)
|
self._insert(key, value, known_absent)
|
||||||
else:
|
else:
|
||||||
self._update_or_insert(key, value)
|
self._update_or_insert(key, value, known_absent)
|
||||||
|
|
||||||
def _update_or_insert(self, key, value):
|
def _update_or_insert(self, key, value, known_absent):
|
||||||
entry = self.cache.setdefault(key, DictionaryEntry(False, {}))
|
entry = self.cache.setdefault(key, DictionaryEntry(False, set(), {}))
|
||||||
entry.value.update(value)
|
entry.value.update(value)
|
||||||
|
entry.known_absent.update(known_absent)
|
||||||
|
|
||||||
def _insert(self, key, value):
|
def _insert(self, key, value, known_absent):
|
||||||
self.cache[key] = DictionaryEntry(True, value)
|
self.cache[key] = DictionaryEntry(True, known_absent, value)
|
||||||
|
@ -94,6 +94,9 @@ class ExpiringCache(object):
|
|||||||
|
|
||||||
return entry.value
|
return entry.value
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
return key in self._cache
|
||||||
|
|
||||||
def get(self, key, default=None):
|
def get(self, key, default=None):
|
||||||
try:
|
try:
|
||||||
return self[key]
|
return self[key]
|
||||||
|
@ -13,20 +13,16 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from synapse.util.caches import register_cache
|
from synapse.util.caches import register_cache, CACHE_SIZE_FACTOR
|
||||||
|
|
||||||
|
|
||||||
from blist import sorteddict
|
from blist import sorteddict
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
|
|
||||||
|
|
||||||
|
|
||||||
class StreamChangeCache(object):
|
class StreamChangeCache(object):
|
||||||
"""Keeps track of the stream positions of the latest change in a set of entities.
|
"""Keeps track of the stream positions of the latest change in a set of entities.
|
||||||
|
|
||||||
@ -89,6 +85,21 @@ class StreamChangeCache(object):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def has_any_entity_changed(self, stream_pos):
|
||||||
|
"""Returns if any entity has changed
|
||||||
|
"""
|
||||||
|
assert type(stream_pos) is int
|
||||||
|
|
||||||
|
if stream_pos >= self._earliest_known_stream_pos:
|
||||||
|
self.metrics.inc_hits()
|
||||||
|
keys = self._cache.keys()
|
||||||
|
i = keys.bisect_right(stream_pos)
|
||||||
|
|
||||||
|
return i < len(keys)
|
||||||
|
else:
|
||||||
|
self.metrics.inc_misses()
|
||||||
|
return True
|
||||||
|
|
||||||
def get_all_entities_changed(self, stream_pos):
|
def get_all_entities_changed(self, stream_pos):
|
||||||
"""Returns all entites that have had new things since the given
|
"""Returns all entites that have had new things since the given
|
||||||
position. If the position is too old it will return None.
|
position. If the position is too old it will return None.
|
||||||
|
@ -241,7 +241,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
|
|||||||
callcount2 = [0]
|
callcount2 = [0]
|
||||||
|
|
||||||
class A(object):
|
class A(object):
|
||||||
@cached(max_entries=20) # HACK: This makes it 2 due to cache factor
|
@cached(max_entries=4) # HACK: This makes it 2 due to cache factor
|
||||||
def func(self, key):
|
def func(self, key):
|
||||||
callcount[0] += 1
|
callcount[0] += 1
|
||||||
return key
|
return key
|
||||||
|
@ -43,10 +43,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
|
|||||||
"access_token", "ip", "user_agent", "device_id",
|
"access_token", "ip", "user_agent", "device_id",
|
||||||
)
|
)
|
||||||
|
|
||||||
# deliberately use an iterable here to make sure that the lookup
|
result = yield self.store.get_last_client_ip_by_device(user_id, "device_id")
|
||||||
# method doesn't iterate it twice
|
|
||||||
device_list = iter(((user_id, "device_id"),))
|
|
||||||
result = yield self.store.get_last_client_ip_by_device(device_list)
|
|
||||||
|
|
||||||
r = result[(user_id, "device_id")]
|
r = result[(user_id, "device_id")]
|
||||||
self.assertDictContainsSubset(
|
self.assertDictContainsSubset(
|
||||||
|
@ -143,6 +143,7 @@ class StateTestCase(unittest.TestCase):
|
|||||||
"add_event_hashes",
|
"add_event_hashes",
|
||||||
"get_events",
|
"get_events",
|
||||||
"get_next_state_group",
|
"get_next_state_group",
|
||||||
|
"get_state_group_delta",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
hs = Mock(spec_set=[
|
hs = Mock(spec_set=[
|
||||||
@ -154,6 +155,7 @@ class StateTestCase(unittest.TestCase):
|
|||||||
hs.get_auth.return_value = Auth(hs)
|
hs.get_auth.return_value = Auth(hs)
|
||||||
|
|
||||||
self.store.get_next_state_group.side_effect = Mock
|
self.store.get_next_state_group.side_effect = Mock
|
||||||
|
self.store.get_state_group_delta.return_value = (None, None)
|
||||||
|
|
||||||
self.state = StateHandler(hs)
|
self.state = StateHandler(hs)
|
||||||
self.event_id = 0
|
self.event_id = 0
|
||||||
|
@ -28,7 +28,7 @@ class DictCacheTestCase(unittest.TestCase):
|
|||||||
key = "test_simple_cache_hit_full"
|
key = "test_simple_cache_hit_full"
|
||||||
|
|
||||||
v = self.cache.get(key)
|
v = self.cache.get(key)
|
||||||
self.assertEqual((False, {}), v)
|
self.assertEqual((False, set(), {}), v)
|
||||||
|
|
||||||
seq = self.cache.sequence
|
seq = self.cache.sequence
|
||||||
test_value = {"test": "test_simple_cache_hit_full"}
|
test_value = {"test": "test_simple_cache_hit_full"}
|
||||||
|
@ -55,6 +55,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
|
|||||||
config.password_providers = []
|
config.password_providers = []
|
||||||
config.worker_replication_url = ""
|
config.worker_replication_url = ""
|
||||||
config.worker_app = None
|
config.worker_app = None
|
||||||
|
config.email_enable_notifs = False
|
||||||
|
|
||||||
config.use_frozen_dicts = True
|
config.use_frozen_dicts = True
|
||||||
config.database_config = {"name": "sqlite3"}
|
config.database_config = {"name": "sqlite3"}
|
||||||
|
Loading…
Reference in New Issue
Block a user