Merge branch 'release-v0.24.0' of github.com:matrix-org/synapse

This commit is contained in:
Erik Johnston 2017-10-23 13:13:53 +01:00
commit ffba978077
83 changed files with 6543 additions and 576 deletions

View File

@ -1,3 +1,45 @@
Changes in synapse v0.24.0 (2017-10-23)
=======================================
No changes since v0.24.0-rc1
Changes in synapse v0.24.0-rc1 (2017-10-19)
===========================================
Features:
* Add Group Server (PR #2352, #2363, #2374, #2377, #2378, #2382, #2410, #2426,
#2430, #2454, #2471, #2472, #2544)
* Add support for channel notifications (PR #2501)
* Add basic implementation of backup media store (PR #2538)
* Add config option to auto-join new users to rooms (PR #2545)
Changes:
* Make the spam checker a module (PR #2474)
* Delete expired url cache data (PR #2478)
* Ignore incoming events for rooms that we have left (PR #2490)
* Allow spam checker to reject invites too (PR #2492)
* Add room creation checks to spam checker (PR #2495)
* Spam checking: add the invitee to user_may_invite (PR #2502)
* Process events from federation for different rooms in parallel (PR #2520)
* Allow error strings from spam checker (PR #2531)
* Improve error handling for missing files in config (PR #2551)
Bug fixes:
* Fix handling SERVFAILs when doing AAAA lookups for federation (PR #2477)
* Fix incompatibility with newer versions of ujson (PR #2483) Thanks to
@jeremycline!
* Fix notification keywords that start/end with non-word chars (PR #2500)
* Fix stack overflow and logcontexts from linearizer (PR #2532)
* Fix 500 error when fields missing from power_levels event (PR #2552)
* Fix 500 error when we get an error handling a PDU (PR #2553)
Changes in synapse v0.23.1 (2017-10-02) Changes in synapse v0.23.1 (2017-10-02)
======================================= =======================================

View File

@ -50,7 +50,7 @@ master_doc = 'index'
# General information about the project. # General information about the project.
project = u'Synapse' project = u'Synapse'
copyright = u'2014, TNG' copyright = u'Copyright 2014-2017 OpenMarket Ltd, 2017 Vector Creations Ltd, 2017 New Vector Ltd'
# The version info for the project you're documenting, acts as replacement for # The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the # |version| and |release|, also used in various other places throughout the

View File

@ -376,10 +376,13 @@ class Porter(object):
" VALUES (?,?,?,?,to_tsvector('english', ?),?,?)" " VALUES (?,?,?,?,to_tsvector('english', ?),?,?)"
) )
rows_dict = [ rows_dict = []
dict(zip(headers, row)) for row in rows:
for row in rows d = dict(zip(headers, row))
] if "\0" in d['value']:
logger.warn('dropping search row %s', d)
else:
rows_dict.append(d)
txn.executemany(sql, [ txn.executemany(sql, [
( (

View File

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.23.1" __version__ = "0.24.0"

View File

@ -40,6 +40,7 @@ from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.groups import SlavedGroupServerStore
from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.client.v1 import events from synapse.rest.client.v1 import events
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
@ -69,6 +70,7 @@ class SynchrotronSlavedStore(
SlavedRegistrationStore, SlavedRegistrationStore,
SlavedFilteringStore, SlavedFilteringStore,
SlavedPresenceStore, SlavedPresenceStore,
SlavedGroupServerStore,
SlavedDeviceInboxStore, SlavedDeviceInboxStore,
SlavedDeviceStore, SlavedDeviceStore,
SlavedClientIpStore, SlavedClientIpStore,
@ -403,6 +405,10 @@ class SyncReplicationHandler(ReplicationClientHandler):
) )
elif stream_name == "presence": elif stream_name == "presence":
yield self.presence_handler.process_replication_rows(token, rows) yield self.presence_handler.process_replication_rows(token, rows)
elif stream_name == "receipts":
self.notifier.on_new_event(
"groups_key", token, users=[row.user_id for row in rows],
)
def start(config_options): def start(config_options):

View File

@ -81,22 +81,38 @@ class Config(object):
def abspath(file_path): def abspath(file_path):
return os.path.abspath(file_path) if file_path else file_path return os.path.abspath(file_path) if file_path else file_path
@classmethod
def path_exists(cls, file_path):
"""Check if a file exists
Unlike os.path.exists, this throws an exception if there is an error
checking if the file exists (for example, if there is a perms error on
the parent dir).
Returns:
bool: True if the file exists; False if not.
"""
try:
os.stat(file_path)
return True
except OSError as e:
if e.errno != errno.ENOENT:
raise e
return False
@classmethod @classmethod
def check_file(cls, file_path, config_name): def check_file(cls, file_path, config_name):
if file_path is None: if file_path is None:
raise ConfigError( raise ConfigError(
"Missing config for %s." "Missing config for %s."
" You must specify a path for the config file. You can "
"do this with the -c or --config-path option. "
"Adding --generate-config along with --server-name "
"<server name> will generate a config file at the given path."
% (config_name,) % (config_name,)
) )
if not os.path.exists(file_path): try:
os.stat(file_path)
except OSError as e:
raise ConfigError( raise ConfigError(
"File %s config for %s doesn't exist." "Error accessing file '%s' (config for %s): %s"
" Try running again with --generate-config" % (file_path, config_name, e.strerror)
% (file_path, config_name,)
) )
return cls.abspath(file_path) return cls.abspath(file_path)
@ -248,7 +264,7 @@ class Config(object):
" -c CONFIG-FILE\"" " -c CONFIG-FILE\""
) )
(config_path,) = config_files (config_path,) = config_files
if not os.path.exists(config_path): if not cls.path_exists(config_path):
if config_args.keys_directory: if config_args.keys_directory:
config_dir_path = config_args.keys_directory config_dir_path = config_args.keys_directory
else: else:
@ -261,7 +277,7 @@ class Config(object):
"Must specify a server_name to a generate config for." "Must specify a server_name to a generate config for."
" Pass -H server.name." " Pass -H server.name."
) )
if not os.path.exists(config_dir_path): if not cls.path_exists(config_dir_path):
os.makedirs(config_dir_path) os.makedirs(config_dir_path)
with open(config_path, "wb") as config_file: with open(config_path, "wb") as config_file:
config_bytes, config = obj.generate_config( config_bytes, config = obj.generate_config(

32
synapse/config/groups.py Normal file
View File

@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class GroupsConfig(Config):
def read_config(self, config):
self.enable_group_creation = config.get("enable_group_creation", False)
self.group_creation_prefix = config.get("group_creation_prefix", "")
def default_config(self, **kwargs):
return """\
# Whether to allow non server admins to create groups on this server
enable_group_creation: false
# If enabled, non server admins can only create groups with local parts
# starting with this prefix
# group_creation_prefix: "unofficial/"
"""

View File

@ -34,6 +34,8 @@ from .password_auth_providers import PasswordAuthProviderConfig
from .emailconfig import EmailConfig from .emailconfig import EmailConfig
from .workers import WorkerConfig from .workers import WorkerConfig
from .push import PushConfig from .push import PushConfig
from .spam_checker import SpamCheckerConfig
from .groups import GroupsConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
@ -41,7 +43,8 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig, VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
JWTConfig, PasswordConfig, EmailConfig, JWTConfig, PasswordConfig, EmailConfig,
WorkerConfig, PasswordAuthProviderConfig, PushConfig,): WorkerConfig, PasswordAuthProviderConfig, PushConfig,
SpamCheckerConfig, GroupsConfig,):
pass pass

View File

@ -118,10 +118,9 @@ class KeyConfig(Config):
signing_keys = self.read_file(signing_key_path, "signing_key") signing_keys = self.read_file(signing_key_path, "signing_key")
try: try:
return read_signing_keys(signing_keys.splitlines(True)) return read_signing_keys(signing_keys.splitlines(True))
except Exception: except Exception as e:
raise ConfigError( raise ConfigError(
"Error reading signing_key." "Error reading signing_key: %s" % (str(e))
" Try running again with --generate-config"
) )
def read_old_signing_keys(self, old_signing_keys): def read_old_signing_keys(self, old_signing_keys):
@ -141,7 +140,8 @@ class KeyConfig(Config):
def generate_files(self, config): def generate_files(self, config):
signing_key_path = config["signing_key_path"] signing_key_path = config["signing_key_path"]
if not os.path.exists(signing_key_path):
if not self.path_exists(signing_key_path):
with open(signing_key_path, "w") as signing_key_file: with open(signing_key_path, "w") as signing_key_file:
key_id = "a_" + random_string(4) key_id = "a_" + random_string(4)
write_signing_keys( write_signing_keys(

View File

@ -15,13 +15,15 @@
from ._base import Config, ConfigError from ._base import Config, ConfigError
import importlib from synapse.util.module_loader import load_module
class PasswordAuthProviderConfig(Config): class PasswordAuthProviderConfig(Config):
def read_config(self, config): def read_config(self, config):
self.password_providers = [] self.password_providers = []
provider_config = None
# We want to be backwards compatible with the old `ldap_config` # We want to be backwards compatible with the old `ldap_config`
# param. # param.
ldap_config = config.get("ldap_config", {}) ldap_config = config.get("ldap_config", {})
@ -38,19 +40,15 @@ class PasswordAuthProviderConfig(Config):
if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider": if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider":
from ldap_auth_provider import LdapAuthProvider from ldap_auth_provider import LdapAuthProvider
provider_class = LdapAuthProvider provider_class = LdapAuthProvider
try:
provider_config = provider_class.parse_config(provider["config"])
except Exception as e:
raise ConfigError(
"Failed to parse config for %r: %r" % (provider['module'], e)
)
else: else:
# We need to import the module, and then pick the class out of (provider_class, provider_config) = load_module(provider)
# that, so we split based on the last dot.
module, clz = provider['module'].rsplit(".", 1)
module = importlib.import_module(module)
provider_class = getattr(module, clz)
try:
provider_config = provider_class.parse_config(provider["config"])
except Exception as e:
raise ConfigError(
"Failed to parse config for %r: %r" % (provider['module'], e)
)
self.password_providers.append((provider_class, provider_config)) self.password_providers.append((provider_class, provider_config))
def default_config(self, **kwargs): def default_config(self, **kwargs):

View File

@ -41,6 +41,8 @@ class RegistrationConfig(Config):
self.allow_guest_access and config.get("invite_3pid_guest", False) self.allow_guest_access and config.get("invite_3pid_guest", False)
) )
self.auto_join_rooms = config.get("auto_join_rooms", [])
def default_config(self, **kwargs): def default_config(self, **kwargs):
registration_shared_secret = random_string_with_symbols(50) registration_shared_secret = random_string_with_symbols(50)
@ -70,6 +72,11 @@ class RegistrationConfig(Config):
- matrix.org - matrix.org
- vector.im - vector.im
- riot.im - riot.im
# Users who register on this homeserver will automatically be joined
# to these rooms
#auto_join_rooms:
# - "#example:example.com"
""" % locals() """ % locals()
def add_arguments(self, parser): def add_arguments(self, parser):

View File

@ -70,7 +70,19 @@ class ContentRepositoryConfig(Config):
self.max_upload_size = self.parse_size(config["max_upload_size"]) self.max_upload_size = self.parse_size(config["max_upload_size"])
self.max_image_pixels = self.parse_size(config["max_image_pixels"]) self.max_image_pixels = self.parse_size(config["max_image_pixels"])
self.max_spider_size = self.parse_size(config["max_spider_size"]) self.max_spider_size = self.parse_size(config["max_spider_size"])
self.media_store_path = self.ensure_directory(config["media_store_path"]) self.media_store_path = self.ensure_directory(config["media_store_path"])
self.backup_media_store_path = config.get("backup_media_store_path")
if self.backup_media_store_path:
self.backup_media_store_path = self.ensure_directory(
self.backup_media_store_path
)
self.synchronous_backup_media_store = config.get(
"synchronous_backup_media_store", False
)
self.uploads_path = self.ensure_directory(config["uploads_path"]) self.uploads_path = self.ensure_directory(config["uploads_path"])
self.dynamic_thumbnails = config["dynamic_thumbnails"] self.dynamic_thumbnails = config["dynamic_thumbnails"]
self.thumbnail_requirements = parse_thumbnail_requirements( self.thumbnail_requirements = parse_thumbnail_requirements(
@ -115,6 +127,14 @@ class ContentRepositoryConfig(Config):
# Directory where uploaded images and attachments are stored. # Directory where uploaded images and attachments are stored.
media_store_path: "%(media_store)s" media_store_path: "%(media_store)s"
# A secondary directory where uploaded images and attachments are
# stored as a backup.
# backup_media_store_path: "%(media_store)s"
# Whether to wait for successful write to backup media store before
# returning successfully.
# synchronous_backup_media_store: false
# Directory where in-progress uploads are stored. # Directory where in-progress uploads are stored.
uploads_path: "%(uploads_path)s" uploads_path: "%(uploads_path)s"

View File

@ -0,0 +1,35 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util.module_loader import load_module
from ._base import Config
class SpamCheckerConfig(Config):
def read_config(self, config):
self.spam_checker = None
provider = config.get("spam_checker", None)
if provider is not None:
self.spam_checker = load_module(provider)
def default_config(self, **kwargs):
return """\
# spam_checker:
# module: "my_custom_project.SuperSpamChecker"
# config:
# example_option: 'things'
"""

View File

@ -126,7 +126,7 @@ class TlsConfig(Config):
tls_private_key_path = config["tls_private_key_path"] tls_private_key_path = config["tls_private_key_path"]
tls_dh_params_path = config["tls_dh_params_path"] tls_dh_params_path = config["tls_dh_params_path"]
if not os.path.exists(tls_private_key_path): if not self.path_exists(tls_private_key_path):
with open(tls_private_key_path, "w") as private_key_file: with open(tls_private_key_path, "w") as private_key_file:
tls_private_key = crypto.PKey() tls_private_key = crypto.PKey()
tls_private_key.generate_key(crypto.TYPE_RSA, 2048) tls_private_key.generate_key(crypto.TYPE_RSA, 2048)
@ -141,7 +141,7 @@ class TlsConfig(Config):
crypto.FILETYPE_PEM, private_key_pem crypto.FILETYPE_PEM, private_key_pem
) )
if not os.path.exists(tls_certificate_path): if not self.path_exists(tls_certificate_path):
with open(tls_certificate_path, "w") as certificate_file: with open(tls_certificate_path, "w") as certificate_file:
cert = crypto.X509() cert = crypto.X509()
subject = cert.get_subject() subject = cert.get_subject()
@ -159,7 +159,7 @@ class TlsConfig(Config):
certificate_file.write(cert_pem) certificate_file.write(cert_pem)
if not os.path.exists(tls_dh_params_path): if not self.path_exists(tls_dh_params_path):
if GENERATE_DH_PARAMS: if GENERATE_DH_PARAMS:
subprocess.check_call([ subprocess.check_call([
"openssl", "dhparam", "openssl", "dhparam",

View File

@ -470,14 +470,14 @@ def _check_power_levels(event, auth_events):
("invite", None), ("invite", None),
] ]
old_list = current_state.content.get("users") old_list = current_state.content.get("users", {})
for user in set(old_list.keys() + user_list.keys()): for user in set(old_list.keys() + user_list.keys()):
levels_to_check.append( levels_to_check.append(
(user, "users") (user, "users")
) )
old_list = current_state.content.get("events") old_list = current_state.content.get("events", {})
new_list = event.content.get("events") new_list = event.content.get("events", {})
for ev_id in set(old_list.keys() + new_list.keys()): for ev_id in set(old_list.keys() + new_list.keys()):
levels_to_check.append( levels_to_check.append(
(ev_id, "events") (ev_id, "events")

View File

@ -14,25 +14,100 @@
# limitations under the License. # limitations under the License.
def check_event_for_spam(event): class SpamChecker(object):
"""Checks if a given event is considered "spammy" by this server. def __init__(self, hs):
self.spam_checker = None
If the server considers an event spammy, then it will be rejected if module = None
sent by a local user. If it is sent by a user on another server, then config = None
users receive a blank event. try:
module, config = hs.config.spam_checker
except:
pass
Args: if module is not None:
event (synapse.events.EventBase): the event to be checked self.spam_checker = module(config=config)
Returns: def check_event_for_spam(self, event):
bool: True if the event is spammy. """Checks if a given event is considered "spammy" by this server.
"""
if not hasattr(event, "content") or "body" not in event.content:
return False
# for example: If the server considers an event spammy, then it will be rejected if
# sent by a local user. If it is sent by a user on another server, then
# if "the third flower is green" in event.content["body"]: users receive a blank event.
# return True
return False Args:
event (synapse.events.EventBase): the event to be checked
Returns:
bool: True if the event is spammy.
"""
if self.spam_checker is None:
return False
return self.spam_checker.check_event_for_spam(event)
def user_may_invite(self, inviter_userid, invitee_userid, room_id):
"""Checks if a given user may send an invite
If this method returns false, the invite will be rejected.
Args:
userid (string): The sender's user ID
Returns:
bool: True if the user may send an invite, otherwise False
"""
if self.spam_checker is None:
return True
return self.spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
def user_may_create_room(self, userid):
"""Checks if a given user may create a room
If this method returns false, the creation request will be rejected.
Args:
userid (string): The sender's user ID
Returns:
bool: True if the user may create a room, otherwise False
"""
if self.spam_checker is None:
return True
return self.spam_checker.user_may_create_room(userid)
def user_may_create_room_alias(self, userid, room_alias):
"""Checks if a given user may create a room alias
If this method returns false, the association request will be rejected.
Args:
userid (string): The sender's user ID
room_alias (string): The alias to be created
Returns:
bool: True if the user may create a room alias, otherwise False
"""
if self.spam_checker is None:
return True
return self.spam_checker.user_may_create_room_alias(userid, room_alias)
def user_may_publish_room(self, userid, room_id):
"""Checks if a given user may publish a room to the directory
If this method returns false, the publish request will be rejected.
Args:
userid (string): The sender's user ID
room_id (string): The ID of the room that would be published
Returns:
bool: True if the user may publish the room, otherwise False
"""
if self.spam_checker is None:
return True
return self.spam_checker.user_may_publish_room(userid, room_id)

View File

@ -16,7 +16,6 @@ import logging
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.crypto.event_signing import check_event_content_hash from synapse.crypto.event_signing import check_event_content_hash
from synapse.events import spamcheck
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.util import unwrapFirstError, logcontext from synapse.util import unwrapFirstError, logcontext
from twisted.internet import defer from twisted.internet import defer
@ -26,7 +25,7 @@ logger = logging.getLogger(__name__)
class FederationBase(object): class FederationBase(object):
def __init__(self, hs): def __init__(self, hs):
pass self.spam_checker = hs.get_spam_checker()
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False, def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
@ -144,7 +143,7 @@ class FederationBase(object):
) )
return redacted return redacted
if spamcheck.check_event_for_spam(pdu): if self.spam_checker.check_event_for_spam(pdu):
logger.warn( logger.warn(
"Event contains spam, redacting %s: %s", "Event contains spam, redacting %s: %s",
pdu.event_id, pdu.get_pdu_json() pdu.event_id, pdu.get_pdu_json()

View File

@ -12,14 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from .federation_base import FederationBase from .federation_base import FederationBase
from .units import Transaction, Edu from .units import Transaction, Edu
from synapse.util.async import Linearizer from synapse.util import async
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
@ -33,6 +31,9 @@ from synapse.crypto.event_signing import compute_event_signature
import simplejson as json import simplejson as json
import logging import logging
# when processing incoming transactions, we try to handle multiple rooms in
# parallel, up to this limit.
TRANSACTION_CONCURRENCY_LIMIT = 10
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -52,7 +53,8 @@ class FederationServer(FederationBase):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._server_linearizer = Linearizer("fed_server") self._server_linearizer = async.Linearizer("fed_server")
self._transaction_linearizer = async.Linearizer("fed_txn_handler")
# We cache responses to state queries, as they take a while and often # We cache responses to state queries, as they take a while and often
# come in waves. # come in waves.
@ -109,25 +111,41 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_incoming_transaction(self, transaction_data): def on_incoming_transaction(self, transaction_data):
# keep this as early as possible to make the calculated origin ts as
# accurate as possible.
request_time = self._clock.time_msec()
transaction = Transaction(**transaction_data) transaction = Transaction(**transaction_data)
received_pdus_counter.inc_by(len(transaction.pdus)) if not transaction.transaction_id:
raise Exception("Transaction missing transaction_id")
for p in transaction.pdus: if not transaction.origin:
if "unsigned" in p: raise Exception("Transaction missing origin")
unsigned = p["unsigned"]
if "age" in unsigned:
p["age"] = unsigned["age"]
if "age" in p:
p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
del p["age"]
pdu_list = [
self.event_from_pdu_json(p) for p in transaction.pdus
]
logger.debug("[%s] Got transaction", transaction.transaction_id) logger.debug("[%s] Got transaction", transaction.transaction_id)
# use a linearizer to ensure that we don't process the same transaction
# multiple times in parallel.
with (yield self._transaction_linearizer.queue(
(transaction.origin, transaction.transaction_id),
)):
result = yield self._handle_incoming_transaction(
transaction, request_time,
)
defer.returnValue(result)
@defer.inlineCallbacks
def _handle_incoming_transaction(self, transaction, request_time):
""" Process an incoming transaction and return the HTTP response
Args:
transaction (Transaction): incoming transaction
request_time (int): timestamp that the HTTP request arrived at
Returns:
Deferred[(int, object)]: http response code and body
"""
response = yield self.transaction_actions.have_responded(transaction) response = yield self.transaction_actions.have_responded(transaction)
if response: if response:
@ -140,42 +158,49 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id) logger.debug("[%s] Transaction is new", transaction.transaction_id)
results = [] received_pdus_counter.inc_by(len(transaction.pdus))
for pdu in pdu_list: pdus_by_room = {}
# check that it's actually being sent from a valid destination to
# workaround bug #1753 in 0.18.5 and 0.18.6
if transaction.origin != get_domain_from_id(pdu.event_id):
# We continue to accept join events from any server; this is
# necessary for the federation join dance to work correctly.
# (When we join over federation, the "helper" server is
# responsible for sending out the join event, rather than the
# origin. See bug #1893).
if not (
pdu.type == 'm.room.member' and
pdu.content and
pdu.content.get("membership", None) == 'join'
):
logger.info(
"Discarding PDU %s from invalid origin %s",
pdu.event_id, transaction.origin
)
continue
else:
logger.info(
"Accepting join PDU %s from %s",
pdu.event_id, transaction.origin
)
try: for p in transaction.pdus:
yield self._handle_received_pdu(transaction.origin, pdu) if "unsigned" in p:
results.append({}) unsigned = p["unsigned"]
except FederationError as e: if "age" in unsigned:
self.send_failure(e, transaction.origin) p["age"] = unsigned["age"]
results.append({"error": str(e)}) if "age" in p:
except Exception as e: p["age_ts"] = request_time - int(p["age"])
results.append({"error": str(e)}) del p["age"]
logger.exception("Failed to handle PDU")
event = self.event_from_pdu_json(p)
room_id = event.room_id
pdus_by_room.setdefault(room_id, []).append(event)
pdu_results = {}
# we can process different rooms in parallel (which is useful if they
# require callouts to other servers to fetch missing events), but
# impose a limit to avoid going too crazy with ram/cpu.
@defer.inlineCallbacks
def process_pdus_for_room(room_id):
logger.debug("Processing PDUs for %s", room_id)
for pdu in pdus_by_room[room_id]:
event_id = pdu.event_id
try:
yield self._handle_received_pdu(
transaction.origin, pdu
)
pdu_results[event_id] = {}
except FederationError as e:
logger.warn("Error handling PDU %s: %s", event_id, e)
pdu_results[event_id] = {"error": str(e)}
except Exception as e:
pdu_results[event_id] = {"error": str(e)}
logger.exception("Failed to handle PDU %s", event_id)
yield async.concurrently_execute(
process_pdus_for_room, pdus_by_room.keys(),
TRANSACTION_CONCURRENCY_LIMIT,
)
if hasattr(transaction, "edus"): if hasattr(transaction, "edus"):
for edu in (Edu(**x) for x in transaction.edus): for edu in (Edu(**x) for x in transaction.edus):
@ -185,17 +210,16 @@ class FederationServer(FederationBase):
edu.content edu.content
) )
for failure in getattr(transaction, "pdu_failures", []): pdu_failures = getattr(transaction, "pdu_failures", [])
logger.info("Got failure %r", failure) for failure in pdu_failures:
logger.info("Got failure %r", failure)
logger.debug("Returning: %s", str(results))
response = { response = {
"pdus": dict(zip( "pdus": pdu_results,
(p.event_id for p in pdu_list), results
)),
} }
logger.debug("Returning: %s", str(response))
yield self.transaction_actions.set_response( yield self.transaction_actions.set_response(
transaction, transaction,
200, response 200, response
@ -520,6 +544,30 @@ class FederationServer(FederationBase):
Returns (Deferred): completes with None Returns (Deferred): completes with None
Raises: FederationError if the signatures / hash do not match Raises: FederationError if the signatures / hash do not match
""" """
# check that it's actually being sent from a valid destination to
# workaround bug #1753 in 0.18.5 and 0.18.6
if origin != get_domain_from_id(pdu.event_id):
# We continue to accept join events from any server; this is
# necessary for the federation join dance to work correctly.
# (When we join over federation, the "helper" server is
# responsible for sending out the join event, rather than the
# origin. See bug #1893).
if not (
pdu.type == 'm.room.member' and
pdu.content and
pdu.content.get("membership", None) == 'join'
):
logger.info(
"Discarding PDU %s from invalid origin %s",
pdu.event_id, origin
)
return
else:
logger.info(
"Accepting join PDU %s from %s",
pdu.event_id, origin
)
# Check signature. # Check signature.
try: try:
pdu = yield self._check_sigs_and_hash(pdu) pdu = yield self._check_sigs_and_hash(pdu)

View File

@ -20,8 +20,8 @@ from .persistence import TransactionActions
from .units import Transaction, Edu from .units import Transaction, Edu
from synapse.api.errors import HttpResponseException from synapse.api.errors import HttpResponseException
from synapse.util import logcontext
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.logcontext import preserve_context_over_fn, preserve_fn
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.handlers.presence import format_user_presence_state, get_interested_remotes from synapse.handlers.presence import format_user_presence_state, get_interested_remotes
@ -231,11 +231,9 @@ class TransactionQueue(object):
(pdu, order) (pdu, order)
) )
preserve_context_over_fn( self._attempt_new_transaction(destination)
self._attempt_new_transaction, destination
)
@preserve_fn # the caller should not yield on this @logcontext.preserve_fn # the caller should not yield on this
@defer.inlineCallbacks @defer.inlineCallbacks
def send_presence(self, states): def send_presence(self, states):
"""Send the new presence states to the appropriate destinations. """Send the new presence states to the appropriate destinations.
@ -299,7 +297,7 @@ class TransactionQueue(object):
state.user_id: state for state in states state.user_id: state for state in states
}) })
preserve_fn(self._attempt_new_transaction)(destination) self._attempt_new_transaction(destination)
def send_edu(self, destination, edu_type, content, key=None): def send_edu(self, destination, edu_type, content, key=None):
edu = Edu( edu = Edu(
@ -321,9 +319,7 @@ class TransactionQueue(object):
else: else:
self.pending_edus_by_dest.setdefault(destination, []).append(edu) self.pending_edus_by_dest.setdefault(destination, []).append(edu)
preserve_context_over_fn( self._attempt_new_transaction(destination)
self._attempt_new_transaction, destination
)
def send_failure(self, failure, destination): def send_failure(self, failure, destination):
if destination == self.server_name or destination == "localhost": if destination == self.server_name or destination == "localhost":
@ -336,9 +332,7 @@ class TransactionQueue(object):
destination, [] destination, []
).append(failure) ).append(failure)
preserve_context_over_fn( self._attempt_new_transaction(destination)
self._attempt_new_transaction, destination
)
def send_device_messages(self, destination): def send_device_messages(self, destination):
if destination == self.server_name or destination == "localhost": if destination == self.server_name or destination == "localhost":
@ -347,15 +341,24 @@ class TransactionQueue(object):
if not self.can_send_to(destination): if not self.can_send_to(destination):
return return
preserve_context_over_fn( self._attempt_new_transaction(destination)
self._attempt_new_transaction, destination
)
def get_current_token(self): def get_current_token(self):
return 0 return 0
@defer.inlineCallbacks
def _attempt_new_transaction(self, destination): def _attempt_new_transaction(self, destination):
"""Try to start a new transaction to this destination
If there is already a transaction in progress to this destination,
returns immediately. Otherwise kicks off the process of sending a
transaction in the background.
Args:
destination (str):
Returns:
None
"""
# list of (pending_pdu, deferred, order) # list of (pending_pdu, deferred, order)
if destination in self.pending_transactions: if destination in self.pending_transactions:
# XXX: pending_transactions can get stuck on by a never-ending # XXX: pending_transactions can get stuck on by a never-ending
@ -368,6 +371,19 @@ class TransactionQueue(object):
) )
return return
logger.debug("TX [%s] Starting transaction loop", destination)
# Drop the logcontext before starting the transaction. It doesn't
# really make sense to log all the outbound transactions against
# whatever path led us to this point: that's pretty arbitrary really.
#
# (this also means we can fire off _perform_transaction without
# yielding)
with logcontext.PreserveLoggingContext():
self._transaction_transmission_loop(destination)
@defer.inlineCallbacks
def _transaction_transmission_loop(self, destination):
pending_pdus = [] pending_pdus = []
try: try:
self.pending_transactions[destination] = 1 self.pending_transactions[destination] = 1

View File

@ -471,3 +471,384 @@ class TransportLayerClient(object):
) )
defer.returnValue(content) defer.returnValue(content)
@log_function
def get_group_profile(self, destination, group_id, requester_user_id):
"""Get a group profile
"""
path = PREFIX + "/groups/%s/profile" % (group_id,)
return self.client.get_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
ignore_backoff=True,
)
@log_function
def get_group_summary(self, destination, group_id, requester_user_id):
"""Get a group summary
"""
path = PREFIX + "/groups/%s/summary" % (group_id,)
return self.client.get_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
ignore_backoff=True,
)
@log_function
def get_rooms_in_group(self, destination, group_id, requester_user_id):
"""Get all rooms in a group
"""
path = PREFIX + "/groups/%s/rooms" % (group_id,)
return self.client.get_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
ignore_backoff=True,
)
def add_room_to_group(self, destination, group_id, requester_user_id, room_id,
content):
"""Add a room to a group
"""
path = PREFIX + "/groups/%s/room/%s" % (group_id, room_id,)
return self.client.post_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
data=content,
ignore_backoff=True,
)
def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
"""Remove a room from a group
"""
path = PREFIX + "/groups/%s/room/%s" % (group_id, room_id,)
return self.client.delete_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
ignore_backoff=True,
)
@log_function
def get_users_in_group(self, destination, group_id, requester_user_id):
"""Get users in a group
"""
path = PREFIX + "/groups/%s/users" % (group_id,)
return self.client.get_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
ignore_backoff=True,
)
@log_function
def get_invited_users_in_group(self, destination, group_id, requester_user_id):
"""Get users that have been invited to a group
"""
path = PREFIX + "/groups/%s/invited_users" % (group_id,)
return self.client.get_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
ignore_backoff=True,
)
@log_function
def accept_group_invite(self, destination, group_id, user_id, content):
"""Accept a group invite
"""
path = PREFIX + "/groups/%s/users/%s/accept_invite" % (group_id, user_id)
return self.client.post_json(
destination=destination,
path=path,
data=content,
ignore_backoff=True,
)
@log_function
def invite_to_group(self, destination, group_id, user_id, requester_user_id, content):
"""Invite a user to a group
"""
path = PREFIX + "/groups/%s/users/%s/invite" % (group_id, user_id)
return self.client.post_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
data=content,
ignore_backoff=True,
)
@log_function
def invite_to_group_notification(self, destination, group_id, user_id, content):
"""Sent by group server to inform a user's server that they have been
invited.
"""
path = PREFIX + "/groups/local/%s/users/%s/invite" % (group_id, user_id)
return self.client.post_json(
destination=destination,
path=path,
data=content,
ignore_backoff=True,
)
@log_function
def remove_user_from_group(self, destination, group_id, requester_user_id,
user_id, content):
"""Remove a user fron a group
"""
path = PREFIX + "/groups/%s/users/%s/remove" % (group_id, user_id)
return self.client.post_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
data=content,
ignore_backoff=True,
)
@log_function
def remove_user_from_group_notification(self, destination, group_id, user_id,
content):
"""Sent by group server to inform a user's server that they have been
kicked from the group.
"""
path = PREFIX + "/groups/local/%s/users/%s/remove" % (group_id, user_id)
return self.client.post_json(
destination=destination,
path=path,
data=content,
ignore_backoff=True,
)
@log_function
def renew_group_attestation(self, destination, group_id, user_id, content):
"""Sent by either a group server or a user's server to periodically update
the attestations
"""
path = PREFIX + "/groups/%s/renew_attestation/%s" % (group_id, user_id)
return self.client.post_json(
destination=destination,
path=path,
data=content,
ignore_backoff=True,
)
@log_function
def update_group_summary_room(self, destination, group_id, user_id, room_id,
category_id, content):
"""Update a room entry in a group summary
"""
if category_id:
path = PREFIX + "/groups/%s/summary/categories/%s/rooms/%s" % (
group_id, category_id, room_id,
)
else:
path = PREFIX + "/groups/%s/summary/rooms/%s" % (group_id, room_id,)
return self.client.post_json(
destination=destination,
path=path,
args={"requester_user_id": user_id},
data=content,
ignore_backoff=True,
)
@log_function
def delete_group_summary_room(self, destination, group_id, user_id, room_id,
category_id):
"""Delete a room entry in a group summary
"""
if category_id:
path = PREFIX + "/groups/%s/summary/categories/%s/rooms/%s" % (
group_id, category_id, room_id,
)
else:
path = PREFIX + "/groups/%s/summary/rooms/%s" % (group_id, room_id,)
return self.client.delete_json(
destination=destination,
path=path,
args={"requester_user_id": user_id},
ignore_backoff=True,
)
@log_function
def get_group_categories(self, destination, group_id, requester_user_id):
"""Get all categories in a group
"""
path = PREFIX + "/groups/%s/categories" % (group_id,)
return self.client.get_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
ignore_backoff=True,
)
@log_function
def get_group_category(self, destination, group_id, requester_user_id, category_id):
"""Get category info in a group
"""
path = PREFIX + "/groups/%s/categories/%s" % (group_id, category_id,)
return self.client.get_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
ignore_backoff=True,
)
@log_function
def update_group_category(self, destination, group_id, requester_user_id, category_id,
content):
"""Update a category in a group
"""
path = PREFIX + "/groups/%s/categories/%s" % (group_id, category_id,)
return self.client.post_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
data=content,
ignore_backoff=True,
)
@log_function
def delete_group_category(self, destination, group_id, requester_user_id,
category_id):
"""Delete a category in a group
"""
path = PREFIX + "/groups/%s/categories/%s" % (group_id, category_id,)
return self.client.delete_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
ignore_backoff=True,
)
@log_function
def get_group_roles(self, destination, group_id, requester_user_id):
"""Get all roles in a group
"""
path = PREFIX + "/groups/%s/roles" % (group_id,)
return self.client.get_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
ignore_backoff=True,
)
@log_function
def get_group_role(self, destination, group_id, requester_user_id, role_id):
"""Get a roles info
"""
path = PREFIX + "/groups/%s/roles/%s" % (group_id, role_id,)
return self.client.get_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
ignore_backoff=True,
)
@log_function
def update_group_role(self, destination, group_id, requester_user_id, role_id,
content):
"""Update a role in a group
"""
path = PREFIX + "/groups/%s/roles/%s" % (group_id, role_id,)
return self.client.post_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
data=content,
ignore_backoff=True,
)
@log_function
def delete_group_role(self, destination, group_id, requester_user_id, role_id):
"""Delete a role in a group
"""
path = PREFIX + "/groups/%s/roles/%s" % (group_id, role_id,)
return self.client.delete_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
ignore_backoff=True,
)
@log_function
def update_group_summary_user(self, destination, group_id, requester_user_id,
user_id, role_id, content):
"""Update a users entry in a group
"""
if role_id:
path = PREFIX + "/groups/%s/summary/roles/%s/users/%s" % (
group_id, role_id, user_id,
)
else:
path = PREFIX + "/groups/%s/summary/users/%s" % (group_id, user_id,)
return self.client.post_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
data=content,
ignore_backoff=True,
)
@log_function
def delete_group_summary_user(self, destination, group_id, requester_user_id,
user_id, role_id):
"""Delete a users entry in a group
"""
if role_id:
path = PREFIX + "/groups/%s/summary/roles/%s/users/%s" % (
group_id, role_id, user_id,
)
else:
path = PREFIX + "/groups/%s/summary/users/%s" % (group_id, user_id,)
return self.client.delete_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
ignore_backoff=True,
)
def bulk_get_publicised_groups(self, destination, user_ids):
"""Get the groups a list of users are publicising
"""
path = PREFIX + "/get_groups_publicised"
content = {"user_ids": user_ids}
return self.client.post_json(
destination=destination,
path=path,
data=content,
ignore_backoff=True,
)

View File

@ -25,7 +25,7 @@ from synapse.http.servlet import (
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import preserve_fn
from synapse.types import ThirdPartyInstanceID from synapse.types import ThirdPartyInstanceID, get_domain_from_id
import functools import functools
import logging import logging
@ -609,6 +609,493 @@ class FederationVersionServlet(BaseFederationServlet):
})) }))
class FederationGroupsProfileServlet(BaseFederationServlet):
"""Get the basic profile of a group on behalf of a user
"""
PATH = "/groups/(?P<group_id>[^/]*)/profile$"
@defer.inlineCallbacks
def on_GET(self, origin, content, query, group_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
new_content = yield self.handler.get_group_profile(
group_id, requester_user_id
)
defer.returnValue((200, new_content))
class FederationGroupsSummaryServlet(BaseFederationServlet):
PATH = "/groups/(?P<group_id>[^/]*)/summary$"
@defer.inlineCallbacks
def on_GET(self, origin, content, query, group_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
new_content = yield self.handler.get_group_summary(
group_id, requester_user_id
)
defer.returnValue((200, new_content))
@defer.inlineCallbacks
def on_POST(self, origin, content, query, group_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
new_content = yield self.handler.update_group_profile(
group_id, requester_user_id, content
)
defer.returnValue((200, new_content))
class FederationGroupsRoomsServlet(BaseFederationServlet):
"""Get the rooms in a group on behalf of a user
"""
PATH = "/groups/(?P<group_id>[^/]*)/rooms$"
@defer.inlineCallbacks
def on_GET(self, origin, content, query, group_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
new_content = yield self.handler.get_rooms_in_group(
group_id, requester_user_id
)
defer.returnValue((200, new_content))
class FederationGroupsAddRoomsServlet(BaseFederationServlet):
"""Add/remove room from group
"""
PATH = "/groups/(?P<group_id>[^/]*)/room/(?<room_id>)$"
@defer.inlineCallbacks
def on_POST(self, origin, content, query, group_id, room_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
new_content = yield self.handler.add_room_to_group(
group_id, requester_user_id, room_id, content
)
defer.returnValue((200, new_content))
@defer.inlineCallbacks
def on_DELETE(self, origin, content, query, group_id, room_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
new_content = yield self.handler.remove_room_from_group(
group_id, requester_user_id, room_id,
)
defer.returnValue((200, new_content))
class FederationGroupsUsersServlet(BaseFederationServlet):
"""Get the users in a group on behalf of a user
"""
PATH = "/groups/(?P<group_id>[^/]*)/users$"
@defer.inlineCallbacks
def on_GET(self, origin, content, query, group_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
new_content = yield self.handler.get_users_in_group(
group_id, requester_user_id
)
defer.returnValue((200, new_content))
class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
"""Get the users that have been invited to a group
"""
PATH = "/groups/(?P<group_id>[^/]*)/invited_users$"
@defer.inlineCallbacks
def on_GET(self, origin, content, query, group_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
new_content = yield self.handler.get_invited_users_in_group(
group_id, requester_user_id
)
defer.returnValue((200, new_content))
class FederationGroupsInviteServlet(BaseFederationServlet):
"""Ask a group server to invite someone to the group
"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite$"
@defer.inlineCallbacks
def on_POST(self, origin, content, query, group_id, user_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
new_content = yield self.handler.invite_to_group(
group_id, user_id, requester_user_id, content,
)
defer.returnValue((200, new_content))
class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
"""Accept an invitation from the group server
"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite$"
@defer.inlineCallbacks
def on_POST(self, origin, content, query, group_id, user_id):
if get_domain_from_id(user_id) != origin:
raise SynapseError(403, "user_id doesn't match origin")
new_content = yield self.handler.accept_invite(
group_id, user_id, content,
)
defer.returnValue((200, new_content))
class FederationGroupsRemoveUserServlet(BaseFederationServlet):
"""Leave or kick a user from the group
"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove$"
@defer.inlineCallbacks
def on_POST(self, origin, content, query, group_id, user_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
new_content = yield self.handler.remove_user_from_group(
group_id, user_id, requester_user_id, content,
)
defer.returnValue((200, new_content))
class FederationGroupsLocalInviteServlet(BaseFederationServlet):
"""A group server has invited a local user
"""
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite$"
@defer.inlineCallbacks
def on_POST(self, origin, content, query, group_id, user_id):
if get_domain_from_id(group_id) != origin:
raise SynapseError(403, "group_id doesn't match origin")
new_content = yield self.handler.on_invite(
group_id, user_id, content,
)
defer.returnValue((200, new_content))
class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
"""A group server has removed a local user
"""
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove$"
@defer.inlineCallbacks
def on_POST(self, origin, content, query, group_id, user_id):
if get_domain_from_id(group_id) != origin:
raise SynapseError(403, "user_id doesn't match origin")
new_content = yield self.handler.user_removed_from_group(
group_id, user_id, content,
)
defer.returnValue((200, new_content))
class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
"""A group or user's server renews their attestation
"""
PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)$"
@defer.inlineCallbacks
def on_POST(self, origin, content, query, group_id, user_id):
# We don't need to check auth here as we check the attestation signatures
new_content = yield self.handler.on_renew_attestation(
group_id, user_id, content
)
defer.returnValue((200, new_content))
class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
"""Add/remove a room from the group summary, with optional category.
Matches both:
- /groups/:group/summary/rooms/:room_id
- /groups/:group/summary/categories/:category/rooms/:room_id
"""
PATH = (
"/groups/(?P<group_id>[^/]*)/summary"
"(/categories/(?P<category_id>[^/]+))?"
"/rooms/(?P<room_id>[^/]*)$"
)
@defer.inlineCallbacks
def on_POST(self, origin, content, query, group_id, category_id, room_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
if category_id == "":
raise SynapseError(400, "category_id cannot be empty string")
resp = yield self.handler.update_group_summary_room(
group_id, requester_user_id,
room_id=room_id,
category_id=category_id,
content=content,
)
defer.returnValue((200, resp))
@defer.inlineCallbacks
def on_DELETE(self, origin, content, query, group_id, category_id, room_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
if category_id == "":
raise SynapseError(400, "category_id cannot be empty string")
resp = yield self.handler.delete_group_summary_room(
group_id, requester_user_id,
room_id=room_id,
category_id=category_id,
)
defer.returnValue((200, resp))
class FederationGroupsCategoriesServlet(BaseFederationServlet):
"""Get all categories for a group
"""
PATH = (
"/groups/(?P<group_id>[^/]*)/categories/$"
)
@defer.inlineCallbacks
def on_GET(self, origin, content, query, group_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
resp = yield self.handler.get_group_categories(
group_id, requester_user_id,
)
defer.returnValue((200, resp))
class FederationGroupsCategoryServlet(BaseFederationServlet):
"""Add/remove/get a category in a group
"""
PATH = (
"/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
)
@defer.inlineCallbacks
def on_GET(self, origin, content, query, group_id, category_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
resp = yield self.handler.get_group_category(
group_id, requester_user_id, category_id
)
defer.returnValue((200, resp))
@defer.inlineCallbacks
def on_POST(self, origin, content, query, group_id, category_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
if category_id == "":
raise SynapseError(400, "category_id cannot be empty string")
resp = yield self.handler.upsert_group_category(
group_id, requester_user_id, category_id, content,
)
defer.returnValue((200, resp))
@defer.inlineCallbacks
def on_DELETE(self, origin, content, query, group_id, category_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
if category_id == "":
raise SynapseError(400, "category_id cannot be empty string")
resp = yield self.handler.delete_group_category(
group_id, requester_user_id, category_id,
)
defer.returnValue((200, resp))
class FederationGroupsRolesServlet(BaseFederationServlet):
"""Get roles in a group
"""
PATH = (
"/groups/(?P<group_id>[^/]*)/roles/$"
)
@defer.inlineCallbacks
def on_GET(self, origin, content, query, group_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
resp = yield self.handler.get_group_roles(
group_id, requester_user_id,
)
defer.returnValue((200, resp))
class FederationGroupsRoleServlet(BaseFederationServlet):
"""Add/remove/get a role in a group
"""
PATH = (
"/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$"
)
@defer.inlineCallbacks
def on_GET(self, origin, content, query, group_id, role_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
resp = yield self.handler.get_group_role(
group_id, requester_user_id, role_id
)
defer.returnValue((200, resp))
@defer.inlineCallbacks
def on_POST(self, origin, content, query, group_id, role_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
if role_id == "":
raise SynapseError(400, "role_id cannot be empty string")
resp = yield self.handler.update_group_role(
group_id, requester_user_id, role_id, content,
)
defer.returnValue((200, resp))
@defer.inlineCallbacks
def on_DELETE(self, origin, content, query, group_id, role_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
if role_id == "":
raise SynapseError(400, "role_id cannot be empty string")
resp = yield self.handler.delete_group_role(
group_id, requester_user_id, role_id,
)
defer.returnValue((200, resp))
class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
"""Add/remove a user from the group summary, with optional role.
Matches both:
- /groups/:group/summary/users/:user_id
- /groups/:group/summary/roles/:role/users/:user_id
"""
PATH = (
"/groups/(?P<group_id>[^/]*)/summary"
"(/roles/(?P<role_id>[^/]+))?"
"/users/(?P<user_id>[^/]*)$"
)
@defer.inlineCallbacks
def on_POST(self, origin, content, query, group_id, role_id, user_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
if role_id == "":
raise SynapseError(400, "role_id cannot be empty string")
resp = yield self.handler.update_group_summary_user(
group_id, requester_user_id,
user_id=user_id,
role_id=role_id,
content=content,
)
defer.returnValue((200, resp))
@defer.inlineCallbacks
def on_DELETE(self, origin, content, query, group_id, role_id, user_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
if role_id == "":
raise SynapseError(400, "role_id cannot be empty string")
resp = yield self.handler.delete_group_summary_user(
group_id, requester_user_id,
user_id=user_id,
role_id=role_id,
)
defer.returnValue((200, resp))
class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
"""Get roles in a group
"""
PATH = (
"/get_groups_publicised$"
)
@defer.inlineCallbacks
def on_POST(self, origin, content, query):
resp = yield self.handler.bulk_get_publicised_groups(
content["user_ids"], proxy=False,
)
defer.returnValue((200, resp))
FEDERATION_SERVLET_CLASSES = ( FEDERATION_SERVLET_CLASSES = (
FederationSendServlet, FederationSendServlet,
FederationPullServlet, FederationPullServlet,
@ -635,10 +1122,40 @@ FEDERATION_SERVLET_CLASSES = (
FederationVersionServlet, FederationVersionServlet,
) )
ROOM_LIST_CLASSES = ( ROOM_LIST_CLASSES = (
PublicRoomList, PublicRoomList,
) )
GROUP_SERVER_SERVLET_CLASSES = (
FederationGroupsProfileServlet,
FederationGroupsSummaryServlet,
FederationGroupsRoomsServlet,
FederationGroupsUsersServlet,
FederationGroupsInvitedUsersServlet,
FederationGroupsInviteServlet,
FederationGroupsAcceptInviteServlet,
FederationGroupsRemoveUserServlet,
FederationGroupsSummaryRoomsServlet,
FederationGroupsCategoriesServlet,
FederationGroupsCategoryServlet,
FederationGroupsRolesServlet,
FederationGroupsRoleServlet,
FederationGroupsSummaryUsersServlet,
)
GROUP_LOCAL_SERVLET_CLASSES = (
FederationGroupsLocalInviteServlet,
FederationGroupsRemoveLocalUserServlet,
FederationGroupsBulkPublicisedServlet,
)
GROUP_ATTESTATION_SERVLET_CLASSES = (
FederationGroupsRenewAttestaionServlet,
)
def register_servlets(hs, resource, authenticator, ratelimiter): def register_servlets(hs, resource, authenticator, ratelimiter):
for servletclass in FEDERATION_SERVLET_CLASSES: for servletclass in FEDERATION_SERVLET_CLASSES:
@ -656,3 +1173,27 @@ def register_servlets(hs, resource, authenticator, ratelimiter):
ratelimiter=ratelimiter, ratelimiter=ratelimiter,
server_name=hs.hostname, server_name=hs.hostname,
).register(resource) ).register(resource)
for servletclass in GROUP_SERVER_SERVLET_CLASSES:
servletclass(
handler=hs.get_groups_server_handler(),
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
).register(resource)
for servletclass in GROUP_LOCAL_SERVLET_CLASSES:
servletclass(
handler=hs.get_groups_local_handler(),
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
).register(resource)
for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES:
servletclass(
handler=hs.get_groups_attestation_renewer(),
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
).register(resource)

View File

View File

@ -0,0 +1,151 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.types import get_domain_from_id
from synapse.util.logcontext import preserve_fn
from signedjson.sign import sign_json
# Default validity duration for new attestations we create
DEFAULT_ATTESTATION_LENGTH_MS = 3 * 24 * 60 * 60 * 1000
# Start trying to update our attestations when they come this close to expiring
UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000
class GroupAttestationSigning(object):
"""Creates and verifies group attestations.
"""
def __init__(self, hs):
self.keyring = hs.get_keyring()
self.clock = hs.get_clock()
self.server_name = hs.hostname
self.signing_key = hs.config.signing_key[0]
@defer.inlineCallbacks
def verify_attestation(self, attestation, group_id, user_id, server_name=None):
"""Verifies that the given attestation matches the given parameters.
An optional server_name can be supplied to explicitly set which server's
signature is expected. Otherwise assumes that either the group_id or user_id
is local and uses the other's server as the one to check.
"""
if not server_name:
if get_domain_from_id(group_id) == self.server_name:
server_name = get_domain_from_id(user_id)
elif get_domain_from_id(user_id) == self.server_name:
server_name = get_domain_from_id(group_id)
else:
raise Exception("Expected either group_id or user_id to be local")
if user_id != attestation["user_id"]:
raise SynapseError(400, "Attestation has incorrect user_id")
if group_id != attestation["group_id"]:
raise SynapseError(400, "Attestation has incorrect group_id")
valid_until_ms = attestation["valid_until_ms"]
# TODO: We also want to check that *new* attestations that people give
# us to store are valid for at least a little while.
if valid_until_ms < self.clock.time_msec():
raise SynapseError(400, "Attestation expired")
yield self.keyring.verify_json_for_server(server_name, attestation)
def create_attestation(self, group_id, user_id):
"""Create an attestation for the group_id and user_id with default
validity length.
"""
return sign_json({
"group_id": group_id,
"user_id": user_id,
"valid_until_ms": self.clock.time_msec() + DEFAULT_ATTESTATION_LENGTH_MS,
}, self.server_name, self.signing_key)
class GroupAttestionRenewer(object):
"""Responsible for sending and receiving attestation updates.
"""
def __init__(self, hs):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.assestations = hs.get_groups_attestation_signing()
self.transport_client = hs.get_federation_transport_client()
self.is_mine_id = hs.is_mine_id
self.attestations = hs.get_groups_attestation_signing()
self._renew_attestations_loop = self.clock.looping_call(
self._renew_attestations, 30 * 60 * 1000,
)
@defer.inlineCallbacks
def on_renew_attestation(self, group_id, user_id, content):
"""When a remote updates an attestation
"""
attestation = content["attestation"]
if not self.is_mine_id(group_id) and not self.is_mine_id(user_id):
raise SynapseError(400, "Neither user not group are on this server")
yield self.attestations.verify_attestation(
attestation,
user_id=user_id,
group_id=group_id,
)
yield self.store.update_remote_attestion(group_id, user_id, attestation)
defer.returnValue({})
@defer.inlineCallbacks
def _renew_attestations(self):
"""Called periodically to check if we need to update any of our attestations
"""
now = self.clock.time_msec()
rows = yield self.store.get_attestations_need_renewals(
now + UPDATE_ATTESTATION_TIME_MS
)
@defer.inlineCallbacks
def _renew_attestation(group_id, user_id):
attestation = self.attestations.create_attestation(group_id, user_id)
if self.is_mine_id(group_id):
destination = get_domain_from_id(user_id)
else:
destination = get_domain_from_id(group_id)
yield self.transport_client.renew_group_attestation(
destination, group_id, user_id,
content={"attestation": attestation},
)
yield self.store.update_attestation_renewal(
group_id, user_id, attestation
)
for row in rows:
group_id = row["group_id"]
user_id = row["user_id"]
preserve_fn(_renew_attestation)(group_id, user_id)

View File

@ -0,0 +1,803 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.types import UserID, get_domain_from_id, RoomID, GroupID
import logging
import urllib
logger = logging.getLogger(__name__)
# TODO: Allow users to "knock" or simpkly join depending on rules
# TODO: Federation admin APIs
# TODO: is_priveged flag to users and is_public to users and rooms
# TODO: Audit log for admins (profile updates, membership changes, users who tried
# to join but were rejected, etc)
# TODO: Flairs
class GroupsServerHandler(object):
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
self.room_list_handler = hs.get_room_list_handler()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.keyring = hs.get_keyring()
self.is_mine_id = hs.is_mine_id
self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname
self.attestations = hs.get_groups_attestation_signing()
self.transport_client = hs.get_federation_transport_client()
self.profile_handler = hs.get_profile_handler()
# Ensure attestations get renewed
hs.get_groups_attestation_renewer()
@defer.inlineCallbacks
def check_group_is_ours(self, group_id, and_exists=False, and_is_admin=None):
"""Check that the group is ours, and optionally if it exists.
If group does exist then return group.
Args:
group_id (str)
and_exists (bool): whether to also check if group exists
and_is_admin (str): whether to also check if given str is a user_id
that is an admin
"""
if not self.is_mine_id(group_id):
raise SynapseError(400, "Group not on this server")
group = yield self.store.get_group(group_id)
if and_exists and not group:
raise SynapseError(404, "Unknown group")
if and_is_admin:
is_admin = yield self.store.is_user_admin_in_group(group_id, and_is_admin)
if not is_admin:
raise SynapseError(403, "User is not admin in group")
defer.returnValue(group)
@defer.inlineCallbacks
def get_group_summary(self, group_id, requester_user_id):
"""Get the summary for a group as seen by requester_user_id.
The group summary consists of the profile of the room, and a curated
list of users and rooms. These list *may* be organised by role/category.
The roles/categories are ordered, and so are the users/rooms within them.
A user/room may appear in multiple roles/categories.
"""
yield self.check_group_is_ours(group_id, and_exists=True)
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
profile = yield self.get_group_profile(group_id, requester_user_id)
users, roles = yield self.store.get_users_for_summary_by_role(
group_id, include_private=is_user_in_group,
)
# TODO: Add profiles to users
rooms, categories = yield self.store.get_rooms_for_summary_by_category(
group_id, include_private=is_user_in_group,
)
for room_entry in rooms:
room_id = room_entry["room_id"]
joined_users = yield self.store.get_users_in_room(room_id)
entry = yield self.room_list_handler.generate_room_entry(
room_id, len(joined_users),
with_alias=False, allow_private=True,
)
entry = dict(entry) # so we don't change whats cached
entry.pop("room_id", None)
room_entry["profile"] = entry
rooms.sort(key=lambda e: e.get("order", 0))
for entry in users:
user_id = entry["user_id"]
if not self.is_mine_id(requester_user_id):
attestation = yield self.store.get_remote_attestation(group_id, user_id)
if not attestation:
continue
entry["attestation"] = attestation
else:
entry["attestation"] = self.attestations.create_attestation(
group_id, user_id,
)
user_profile = yield self.profile_handler.get_profile_from_cache(user_id)
entry.update(user_profile)
users.sort(key=lambda e: e.get("order", 0))
membership_info = yield self.store.get_users_membership_info_in_group(
group_id, requester_user_id,
)
defer.returnValue({
"profile": profile,
"users_section": {
"users": users,
"roles": roles,
"total_user_count_estimate": 0, # TODO
},
"rooms_section": {
"rooms": rooms,
"categories": categories,
"total_room_count_estimate": 0, # TODO
},
"user": membership_info,
})
@defer.inlineCallbacks
def update_group_summary_room(self, group_id, user_id, room_id, category_id, content):
"""Add/update a room to the group summary
"""
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
RoomID.from_string(room_id) # Ensure valid room id
order = content.get("order", None)
is_public = _parse_visibility_from_contents(content)
yield self.store.add_room_to_summary(
group_id=group_id,
room_id=room_id,
category_id=category_id,
order=order,
is_public=is_public,
)
defer.returnValue({})
@defer.inlineCallbacks
def delete_group_summary_room(self, group_id, user_id, room_id, category_id):
"""Remove a room from the summary
"""
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
yield self.store.remove_room_from_summary(
group_id=group_id,
room_id=room_id,
category_id=category_id,
)
defer.returnValue({})
@defer.inlineCallbacks
def get_group_categories(self, group_id, user_id):
"""Get all categories in a group (as seen by user)
"""
yield self.check_group_is_ours(group_id, and_exists=True)
categories = yield self.store.get_group_categories(
group_id=group_id,
)
defer.returnValue({"categories": categories})
@defer.inlineCallbacks
def get_group_category(self, group_id, user_id, category_id):
"""Get a specific category in a group (as seen by user)
"""
yield self.check_group_is_ours(group_id, and_exists=True)
res = yield self.store.get_group_category(
group_id=group_id,
category_id=category_id,
)
defer.returnValue(res)
@defer.inlineCallbacks
def update_group_category(self, group_id, user_id, category_id, content):
"""Add/Update a group category
"""
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
is_public = _parse_visibility_from_contents(content)
profile = content.get("profile")
yield self.store.upsert_group_category(
group_id=group_id,
category_id=category_id,
is_public=is_public,
profile=profile,
)
defer.returnValue({})
@defer.inlineCallbacks
def delete_group_category(self, group_id, user_id, category_id):
"""Delete a group category
"""
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
yield self.store.remove_group_category(
group_id=group_id,
category_id=category_id,
)
defer.returnValue({})
@defer.inlineCallbacks
def get_group_roles(self, group_id, user_id):
"""Get all roles in a group (as seen by user)
"""
yield self.check_group_is_ours(group_id, and_exists=True)
roles = yield self.store.get_group_roles(
group_id=group_id,
)
defer.returnValue({"roles": roles})
@defer.inlineCallbacks
def get_group_role(self, group_id, user_id, role_id):
"""Get a specific role in a group (as seen by user)
"""
yield self.check_group_is_ours(group_id, and_exists=True)
res = yield self.store.get_group_role(
group_id=group_id,
role_id=role_id,
)
defer.returnValue(res)
@defer.inlineCallbacks
def update_group_role(self, group_id, user_id, role_id, content):
"""Add/update a role in a group
"""
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
is_public = _parse_visibility_from_contents(content)
profile = content.get("profile")
yield self.store.upsert_group_role(
group_id=group_id,
role_id=role_id,
is_public=is_public,
profile=profile,
)
defer.returnValue({})
@defer.inlineCallbacks
def delete_group_role(self, group_id, user_id, role_id):
"""Remove role from group
"""
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
yield self.store.remove_group_role(
group_id=group_id,
role_id=role_id,
)
defer.returnValue({})
@defer.inlineCallbacks
def update_group_summary_user(self, group_id, requester_user_id, user_id, role_id,
content):
"""Add/update a users entry in the group summary
"""
yield self.check_group_is_ours(
group_id, and_exists=True, and_is_admin=requester_user_id,
)
order = content.get("order", None)
is_public = _parse_visibility_from_contents(content)
yield self.store.add_user_to_summary(
group_id=group_id,
user_id=user_id,
role_id=role_id,
order=order,
is_public=is_public,
)
defer.returnValue({})
@defer.inlineCallbacks
def delete_group_summary_user(self, group_id, requester_user_id, user_id, role_id):
"""Remove a user from the group summary
"""
yield self.check_group_is_ours(
group_id, and_exists=True, and_is_admin=requester_user_id,
)
yield self.store.remove_user_from_summary(
group_id=group_id,
user_id=user_id,
role_id=role_id,
)
defer.returnValue({})
@defer.inlineCallbacks
def get_group_profile(self, group_id, requester_user_id):
"""Get the group profile as seen by requester_user_id
"""
yield self.check_group_is_ours(group_id)
group_description = yield self.store.get_group(group_id)
if group_description:
defer.returnValue(group_description)
else:
raise SynapseError(404, "Unknown group")
@defer.inlineCallbacks
def update_group_profile(self, group_id, requester_user_id, content):
"""Update the group profile
"""
yield self.check_group_is_ours(
group_id, and_exists=True, and_is_admin=requester_user_id,
)
profile = {}
for keyname in ("name", "avatar_url", "short_description",
"long_description"):
if keyname in content:
value = content[keyname]
if not isinstance(value, basestring):
raise SynapseError(400, "%r value is not a string" % (keyname,))
profile[keyname] = value
yield self.store.update_group_profile(group_id, profile)
@defer.inlineCallbacks
def get_users_in_group(self, group_id, requester_user_id):
"""Get the users in group as seen by requester_user_id.
The ordering is arbitrary at the moment
"""
yield self.check_group_is_ours(group_id, and_exists=True)
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
user_results = yield self.store.get_users_in_group(
group_id, include_private=is_user_in_group,
)
chunk = []
for user_result in user_results:
g_user_id = user_result["user_id"]
is_public = user_result["is_public"]
entry = {"user_id": g_user_id}
profile = yield self.profile_handler.get_profile_from_cache(g_user_id)
entry.update(profile)
if not is_public:
entry["is_public"] = False
if not self.is_mine_id(g_user_id):
attestation = yield self.store.get_remote_attestation(group_id, g_user_id)
if not attestation:
continue
entry["attestation"] = attestation
else:
entry["attestation"] = self.attestations.create_attestation(
group_id, g_user_id,
)
chunk.append(entry)
# TODO: If admin add lists of users whose attestations have timed out
defer.returnValue({
"chunk": chunk,
"total_user_count_estimate": len(user_results),
})
@defer.inlineCallbacks
def get_invited_users_in_group(self, group_id, requester_user_id):
"""Get the users that have been invited to a group as seen by requester_user_id.
The ordering is arbitrary at the moment
"""
yield self.check_group_is_ours(group_id, and_exists=True)
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
if not is_user_in_group:
raise SynapseError(403, "User not in group")
invited_users = yield self.store.get_invited_users_in_group(group_id)
user_profiles = []
for user_id in invited_users:
user_profile = {
"user_id": user_id
}
try:
profile = yield self.profile_handler.get_profile_from_cache(user_id)
user_profile.update(profile)
except Exception as e:
logger.warn("Error getting profile for %s: %s", user_id, e)
user_profiles.append(user_profile)
defer.returnValue({
"chunk": user_profiles,
"total_user_count_estimate": len(invited_users),
})
@defer.inlineCallbacks
def get_rooms_in_group(self, group_id, requester_user_id):
"""Get the rooms in group as seen by requester_user_id
This returns rooms in order of decreasing number of joined users
"""
yield self.check_group_is_ours(group_id, and_exists=True)
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
room_results = yield self.store.get_rooms_in_group(
group_id, include_private=is_user_in_group,
)
chunk = []
for room_result in room_results:
room_id = room_result["room_id"]
is_public = room_result["is_public"]
joined_users = yield self.store.get_users_in_room(room_id)
entry = yield self.room_list_handler.generate_room_entry(
room_id, len(joined_users),
with_alias=False, allow_private=True,
)
if not entry:
continue
if not is_public:
entry["is_public"] = False
chunk.append(entry)
chunk.sort(key=lambda e: -e["num_joined_members"])
defer.returnValue({
"chunk": chunk,
"total_room_count_estimate": len(room_results),
})
@defer.inlineCallbacks
def add_room_to_group(self, group_id, requester_user_id, room_id, content):
"""Add room to group
"""
RoomID.from_string(room_id) # Ensure valid room id
yield self.check_group_is_ours(
group_id, and_exists=True, and_is_admin=requester_user_id
)
is_public = _parse_visibility_from_contents(content)
yield self.store.add_room_to_group(group_id, room_id, is_public=is_public)
defer.returnValue({})
@defer.inlineCallbacks
def remove_room_from_group(self, group_id, requester_user_id, room_id):
"""Remove room from group
"""
yield self.check_group_is_ours(
group_id, and_exists=True, and_is_admin=requester_user_id
)
yield self.store.remove_room_from_group(group_id, room_id)
defer.returnValue({})
@defer.inlineCallbacks
def invite_to_group(self, group_id, user_id, requester_user_id, content):
"""Invite user to group
"""
group = yield self.check_group_is_ours(
group_id, and_exists=True, and_is_admin=requester_user_id
)
# TODO: Check if user knocked
# TODO: Check if user is already invited
content = {
"profile": {
"name": group["name"],
"avatar_url": group["avatar_url"],
},
"inviter": requester_user_id,
}
if self.hs.is_mine_id(user_id):
groups_local = self.hs.get_groups_local_handler()
res = yield groups_local.on_invite(group_id, user_id, content)
local_attestation = None
else:
local_attestation = self.attestations.create_attestation(group_id, user_id)
content.update({
"attestation": local_attestation,
})
res = yield self.transport_client.invite_to_group_notification(
get_domain_from_id(user_id), group_id, user_id, content
)
user_profile = res.get("user_profile", {})
yield self.store.add_remote_profile_cache(
user_id,
displayname=user_profile.get("displayname"),
avatar_url=user_profile.get("avatar_url"),
)
if res["state"] == "join":
if not self.hs.is_mine_id(user_id):
remote_attestation = res["attestation"]
yield self.attestations.verify_attestation(
remote_attestation,
user_id=user_id,
group_id=group_id,
)
else:
remote_attestation = None
yield self.store.add_user_to_group(
group_id, user_id,
is_admin=False,
is_public=False, # TODO
local_attestation=local_attestation,
remote_attestation=remote_attestation,
)
elif res["state"] == "invite":
yield self.store.add_group_invite(
group_id, user_id,
)
defer.returnValue({
"state": "invite"
})
elif res["state"] == "reject":
defer.returnValue({
"state": "reject"
})
else:
raise SynapseError(502, "Unknown state returned by HS")
@defer.inlineCallbacks
def accept_invite(self, group_id, user_id, content):
"""User tries to accept an invite to the group.
This is different from them asking to join, and so should error if no
invite exists (and they're not a member of the group)
"""
yield self.check_group_is_ours(group_id, and_exists=True)
if not self.store.is_user_invited_to_local_group(group_id, user_id):
raise SynapseError(403, "User not invited to group")
if not self.hs.is_mine_id(user_id):
remote_attestation = content["attestation"]
yield self.attestations.verify_attestation(
remote_attestation,
user_id=user_id,
group_id=group_id,
)
else:
remote_attestation = None
local_attestation = self.attestations.create_attestation(group_id, user_id)
is_public = _parse_visibility_from_contents(content)
yield self.store.add_user_to_group(
group_id, user_id,
is_admin=False,
is_public=is_public,
local_attestation=local_attestation,
remote_attestation=remote_attestation,
)
defer.returnValue({
"state": "join",
"attestation": local_attestation,
})
@defer.inlineCallbacks
def knock(self, group_id, user_id, content):
"""A user requests becoming a member of the group
"""
yield self.check_group_is_ours(group_id, and_exists=True)
raise NotImplementedError()
@defer.inlineCallbacks
def accept_knock(self, group_id, user_id, content):
"""Accept a users knock to the room.
Errors if the user hasn't knocked, rather than inviting them.
"""
yield self.check_group_is_ours(group_id, and_exists=True)
raise NotImplementedError()
@defer.inlineCallbacks
def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
"""Remove a user from the group; either a user is leaving or and admin
kicked htem.
"""
yield self.check_group_is_ours(group_id, and_exists=True)
is_kick = False
if requester_user_id != user_id:
is_admin = yield self.store.is_user_admin_in_group(
group_id, requester_user_id
)
if not is_admin:
raise SynapseError(403, "User is not admin in group")
is_kick = True
yield self.store.remove_user_from_group(
group_id, user_id,
)
if is_kick:
if self.hs.is_mine_id(user_id):
groups_local = self.hs.get_groups_local_handler()
yield groups_local.user_removed_from_group(group_id, user_id, {})
else:
yield self.transport_client.remove_user_from_group_notification(
get_domain_from_id(user_id), group_id, user_id, {}
)
if not self.hs.is_mine_id(user_id):
yield self.store.maybe_delete_remote_profile_cache(user_id)
defer.returnValue({})
@defer.inlineCallbacks
def create_group(self, group_id, user_id, content):
group = yield self.check_group_is_ours(group_id)
_validate_group_id(group_id)
logger.info("Attempting to create group with ID: %r", group_id)
if group:
raise SynapseError(400, "Group already exists")
is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id))
if not is_admin:
if not self.hs.config.enable_group_creation:
raise SynapseError(
403, "Only server admin can create group on this server",
)
localpart = GroupID.from_string(group_id).localpart
if not localpart.startswith(self.hs.config.group_creation_prefix):
raise SynapseError(
400,
"Can only create groups with prefix %r on this server" % (
self.hs.config.group_creation_prefix,
),
)
profile = content.get("profile", {})
name = profile.get("name")
avatar_url = profile.get("avatar_url")
short_description = profile.get("short_description")
long_description = profile.get("long_description")
user_profile = content.get("user_profile", {})
yield self.store.create_group(
group_id,
user_id,
name=name,
avatar_url=avatar_url,
short_description=short_description,
long_description=long_description,
)
if not self.hs.is_mine_id(user_id):
remote_attestation = content["attestation"]
yield self.attestations.verify_attestation(
remote_attestation,
user_id=user_id,
group_id=group_id,
)
local_attestation = self.attestations.create_attestation(group_id, user_id)
else:
local_attestation = None
remote_attestation = None
yield self.store.add_user_to_group(
group_id, user_id,
is_admin=True,
is_public=True, # TODO
local_attestation=local_attestation,
remote_attestation=remote_attestation,
)
if not self.hs.is_mine_id(user_id):
yield self.store.add_remote_profile_cache(
user_id,
displayname=user_profile.get("displayname"),
avatar_url=user_profile.get("avatar_url"),
)
defer.returnValue({
"group_id": group_id,
})
def _parse_visibility_from_contents(content):
"""Given a content for a request parse out whether the entity should be
public or not
"""
visibility = content.get("visibility")
if visibility:
vis_type = visibility["type"]
if vis_type not in ("public", "private"):
raise SynapseError(
400, "Synapse only supports 'public'/'private' visibility"
)
is_public = vis_type == "public"
else:
is_public = True
return is_public
def _validate_group_id(group_id):
"""Validates the group ID is valid for creation on this home server
"""
localpart = GroupID.from_string(group_id).localpart
if localpart.lower() != localpart:
raise SynapseError(400, "Group ID must be lower case")
if urllib.quote(localpart.encode('utf-8')) != localpart:
raise SynapseError(
400,
"Group ID can only contain characters a-z, 0-9, or '_-./'",
)

View File

@ -20,7 +20,6 @@ from .room import (
from .room_member import RoomMemberHandler from .room_member import RoomMemberHandler
from .message import MessageHandler from .message import MessageHandler
from .federation import FederationHandler from .federation import FederationHandler
from .profile import ProfileHandler
from .directory import DirectoryHandler from .directory import DirectoryHandler
from .admin import AdminHandler from .admin import AdminHandler
from .identity import IdentityHandler from .identity import IdentityHandler
@ -52,7 +51,6 @@ class Handlers(object):
self.room_creation_handler = RoomCreationHandler(hs) self.room_creation_handler = RoomCreationHandler(hs)
self.room_member_handler = RoomMemberHandler(hs) self.room_member_handler = RoomMemberHandler(hs)
self.federation_handler = FederationHandler(hs) self.federation_handler = FederationHandler(hs)
self.profile_handler = ProfileHandler(hs)
self.directory_handler = DirectoryHandler(hs) self.directory_handler = DirectoryHandler(hs)
self.admin_handler = AdminHandler(hs) self.admin_handler = AdminHandler(hs)
self.identity_handler = IdentityHandler(hs) self.identity_handler = IdentityHandler(hs)

View File

@ -40,6 +40,8 @@ class DirectoryHandler(BaseHandler):
"directory", self.on_directory_query "directory", self.on_directory_query
) )
self.spam_checker = hs.get_spam_checker()
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_association(self, room_alias, room_id, servers=None, creator=None): def _create_association(self, room_alias, room_id, servers=None, creator=None):
# general association creation for both human users and app services # general association creation for both human users and app services
@ -73,6 +75,11 @@ class DirectoryHandler(BaseHandler):
# association creation for human users # association creation for human users
# TODO(erikj): Do user auth. # TODO(erikj): Do user auth.
if not self.spam_checker.user_may_create_room_alias(user_id, room_alias):
raise SynapseError(
403, "This user is not permitted to create this alias",
)
can_create = yield self.can_modify_alias( can_create = yield self.can_modify_alias(
room_alias, room_alias,
user_id=user_id user_id=user_id
@ -327,6 +334,14 @@ class DirectoryHandler(BaseHandler):
room_id (str) room_id (str)
visibility (str): "public" or "private" visibility (str): "public" or "private"
""" """
if not self.spam_checker.user_may_publish_room(
requester.user.to_string(), room_id
):
raise AuthError(
403,
"This user is not permitted to publish rooms to the room list"
)
if requester.is_guest: if requester.is_guest:
raise AuthError(403, "Guests cannot edit the published room list") raise AuthError(403, "Guests cannot edit the published room list")

View File

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
"""Contains handlers for federation events.""" """Contains handlers for federation events."""
import synapse.util.logcontext
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
@ -26,10 +25,7 @@ from synapse.api.errors import (
) )
from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError, logcontext
from synapse.util.logcontext import (
preserve_fn, preserve_context_over_deferred
)
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor, Linearizer from synapse.util.async import run_on_reactor, Linearizer
@ -77,6 +73,7 @@ class FederationHandler(BaseHandler):
self.action_generator = hs.get_action_generator() self.action_generator = hs.get_action_generator()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.pusher_pool = hs.get_pusherpool() self.pusher_pool = hs.get_pusherpool()
self.spam_checker = hs.get_spam_checker()
self.replication_layer.set_handler(self) self.replication_layer.set_handler(self)
@ -125,6 +122,28 @@ class FederationHandler(BaseHandler):
self.room_queues[pdu.room_id].append((pdu, origin)) self.room_queues[pdu.room_id].append((pdu, origin))
return return
# If we're no longer in the room just ditch the event entirely. This
# is probably an old server that has come back and thinks we're still
# in the room (or we've been rejoined to the room by a state reset).
#
# If we were never in the room then maybe our database got vaped and
# we should check if we *are* in fact in the room. If we are then we
# can magically rejoin the room.
is_in_room = yield self.auth.check_host_in_room(
pdu.room_id,
self.server_name
)
if not is_in_room:
was_in_room = yield self.store.was_host_joined(
pdu.room_id, self.server_name,
)
if was_in_room:
logger.info(
"Ignoring PDU %s for room %s from %s as we've left the room!",
pdu.event_id, pdu.room_id, origin,
)
return
state = None state = None
auth_chain = [] auth_chain = []
@ -591,9 +610,9 @@ class FederationHandler(BaseHandler):
missing_auth - failed_to_fetch missing_auth - failed_to_fetch
) )
results = yield preserve_context_over_deferred(defer.gatherResults( results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[ [
preserve_fn(self.replication_layer.get_pdu)( logcontext.preserve_fn(self.replication_layer.get_pdu)(
[dest], [dest],
event_id, event_id,
outlier=True, outlier=True,
@ -785,10 +804,14 @@ class FederationHandler(BaseHandler):
event_ids = list(extremities.keys()) event_ids = list(extremities.keys())
logger.debug("calling resolve_state_groups in _maybe_backfill") logger.debug("calling resolve_state_groups in _maybe_backfill")
states = yield preserve_context_over_deferred(defer.gatherResults([ states = yield logcontext.make_deferred_yieldable(defer.gatherResults(
preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e]) [
for e in event_ids logcontext.preserve_fn(self.state_handler.resolve_state_groups)(
])) room_id, [e]
)
for e in event_ids
], consumeErrors=True,
))
states = dict(zip(event_ids, [s.state for s in states])) states = dict(zip(event_ids, [s.state for s in states]))
state_map = yield self.store.get_events( state_map = yield self.store.get_events(
@ -941,9 +964,7 @@ class FederationHandler(BaseHandler):
# lots of requests for missing prev_events which we do actually # lots of requests for missing prev_events which we do actually
# have. Hence we fire off the deferred, but don't wait for it. # have. Hence we fire off the deferred, but don't wait for it.
synapse.util.logcontext.preserve_fn(self._handle_queued_pdus)( logcontext.preserve_fn(self._handle_queued_pdus)(room_queue)
room_queue
)
defer.returnValue(True) defer.returnValue(True)
@ -1070,6 +1091,9 @@ class FederationHandler(BaseHandler):
""" """
event = pdu event = pdu
if event.state_key is None:
raise SynapseError(400, "The invite event did not have a state key")
is_blocked = yield self.store.is_room_blocked(event.room_id) is_blocked = yield self.store.is_room_blocked(event.room_id)
if is_blocked: if is_blocked:
raise SynapseError(403, "This room has been blocked on this server") raise SynapseError(403, "This room has been blocked on this server")
@ -1077,6 +1101,13 @@ class FederationHandler(BaseHandler):
if self.hs.config.block_non_admin_invites: if self.hs.config.block_non_admin_invites:
raise SynapseError(403, "This server does not accept room invites") raise SynapseError(403, "This server does not accept room invites")
if not self.spam_checker.user_may_invite(
event.sender, event.state_key, event.room_id,
):
raise SynapseError(
403, "This user is not permitted to send invites to this server/user"
)
membership = event.content.get("membership") membership = event.content.get("membership")
if event.type != EventTypes.Member or membership != Membership.INVITE: if event.type != EventTypes.Member or membership != Membership.INVITE:
raise SynapseError(400, "The event was not an m.room.member invite event") raise SynapseError(400, "The event was not an m.room.member invite event")
@ -1085,9 +1116,6 @@ class FederationHandler(BaseHandler):
if sender_domain != origin: if sender_domain != origin:
raise SynapseError(400, "The invite event was not from the server sending it") raise SynapseError(400, "The invite event was not from the server sending it")
if event.state_key is None:
raise SynapseError(400, "The invite event did not have a state key")
if not self.is_mine_id(event.state_key): if not self.is_mine_id(event.state_key):
raise SynapseError(400, "The invite event must be for this server") raise SynapseError(400, "The invite event must be for this server")
@ -1430,7 +1458,7 @@ class FederationHandler(BaseHandler):
if not backfilled: if not backfilled:
# this intentionally does not yield: we don't care about the result # this intentionally does not yield: we don't care about the result
# and don't need to wait for it. # and don't need to wait for it.
preserve_fn(self.pusher_pool.on_new_notifications)( logcontext.preserve_fn(self.pusher_pool.on_new_notifications)(
event_stream_id, max_stream_id event_stream_id, max_stream_id
) )
@ -1443,16 +1471,16 @@ class FederationHandler(BaseHandler):
a bunch of outliers, but not a chunk of individual events that depend a bunch of outliers, but not a chunk of individual events that depend
on each other for state calculations. on each other for state calculations.
""" """
contexts = yield preserve_context_over_deferred(defer.gatherResults( contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[ [
preserve_fn(self._prep_event)( logcontext.preserve_fn(self._prep_event)(
origin, origin,
ev_info["event"], ev_info["event"],
state=ev_info.get("state"), state=ev_info.get("state"),
auth_events=ev_info.get("auth_events"), auth_events=ev_info.get("auth_events"),
) )
for ev_info in event_infos for ev_info in event_infos
] ], consumeErrors=True,
)) ))
yield self.store.persist_events( yield self.store.persist_events(
@ -1760,18 +1788,17 @@ class FederationHandler(BaseHandler):
# Do auth conflict res. # Do auth conflict res.
logger.info("Different auth: %s", different_auth) logger.info("Different auth: %s", different_auth)
different_events = yield preserve_context_over_deferred(defer.gatherResults( different_events = yield logcontext.make_deferred_yieldable(
[ defer.gatherResults([
preserve_fn(self.store.get_event)( logcontext.preserve_fn(self.store.get_event)(
d, d,
allow_none=True, allow_none=True,
allow_rejected=False, allow_rejected=False,
) )
for d in different_auth for d in different_auth
if d in have_events and not have_events[d] if d in have_events and not have_events[d]
], ], consumeErrors=True)
consumeErrors=True ).addErrback(unwrapFirstError)
)).addErrback(unwrapFirstError)
if different_events: if different_events:
local_view = dict(auth_events) local_view = dict(auth_events)

View File

@ -0,0 +1,417 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.types import get_domain_from_id
import logging
logger = logging.getLogger(__name__)
def _create_rerouter(func_name):
"""Returns a function that looks at the group id and calls the function
on federation or the local group server if the group is local
"""
def f(self, group_id, *args, **kwargs):
if self.is_mine_id(group_id):
return getattr(self.groups_server_handler, func_name)(
group_id, *args, **kwargs
)
else:
destination = get_domain_from_id(group_id)
return getattr(self.transport_client, func_name)(
destination, group_id, *args, **kwargs
)
return f
class GroupsLocalHandler(object):
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
self.room_list_handler = hs.get_room_list_handler()
self.groups_server_handler = hs.get_groups_server_handler()
self.transport_client = hs.get_federation_transport_client()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.keyring = hs.get_keyring()
self.is_mine_id = hs.is_mine_id
self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname
self.notifier = hs.get_notifier()
self.attestations = hs.get_groups_attestation_signing()
self.profile_handler = hs.get_profile_handler()
# Ensure attestations get renewed
hs.get_groups_attestation_renewer()
# The following functions merely route the query to the local groups server
# or federation depending on if the group is local or remote
get_group_profile = _create_rerouter("get_group_profile")
update_group_profile = _create_rerouter("update_group_profile")
get_rooms_in_group = _create_rerouter("get_rooms_in_group")
get_invited_users_in_group = _create_rerouter("get_invited_users_in_group")
add_room_to_group = _create_rerouter("add_room_to_group")
remove_room_from_group = _create_rerouter("remove_room_from_group")
update_group_summary_room = _create_rerouter("update_group_summary_room")
delete_group_summary_room = _create_rerouter("delete_group_summary_room")
update_group_category = _create_rerouter("update_group_category")
delete_group_category = _create_rerouter("delete_group_category")
get_group_category = _create_rerouter("get_group_category")
get_group_categories = _create_rerouter("get_group_categories")
update_group_summary_user = _create_rerouter("update_group_summary_user")
delete_group_summary_user = _create_rerouter("delete_group_summary_user")
update_group_role = _create_rerouter("update_group_role")
delete_group_role = _create_rerouter("delete_group_role")
get_group_role = _create_rerouter("get_group_role")
get_group_roles = _create_rerouter("get_group_roles")
@defer.inlineCallbacks
def get_group_summary(self, group_id, requester_user_id):
"""Get the group summary for a group.
If the group is remote we check that the users have valid attestations.
"""
if self.is_mine_id(group_id):
res = yield self.groups_server_handler.get_group_summary(
group_id, requester_user_id
)
else:
res = yield self.transport_client.get_group_summary(
get_domain_from_id(group_id), group_id, requester_user_id,
)
group_server_name = get_domain_from_id(group_id)
# Loop through the users and validate the attestations.
chunk = res["users_section"]["users"]
valid_users = []
for entry in chunk:
g_user_id = entry["user_id"]
attestation = entry.pop("attestation", {})
try:
if get_domain_from_id(g_user_id) != group_server_name:
yield self.attestations.verify_attestation(
attestation,
group_id=group_id,
user_id=g_user_id,
server_name=get_domain_from_id(g_user_id),
)
valid_users.append(entry)
except Exception as e:
logger.info("Failed to verify user is in group: %s", e)
res["users_section"]["users"] = valid_users
res["users_section"]["users"].sort(key=lambda e: e.get("order", 0))
res["rooms_section"]["rooms"].sort(key=lambda e: e.get("order", 0))
# Add `is_publicised` flag to indicate whether the user has publicised their
# membership of the group on their profile
result = yield self.store.get_publicised_groups_for_user(requester_user_id)
is_publicised = group_id in result
res.setdefault("user", {})["is_publicised"] = is_publicised
defer.returnValue(res)
@defer.inlineCallbacks
def create_group(self, group_id, user_id, content):
"""Create a group
"""
logger.info("Asking to create group with ID: %r", group_id)
if self.is_mine_id(group_id):
res = yield self.groups_server_handler.create_group(
group_id, user_id, content
)
local_attestation = None
remote_attestation = None
else:
local_attestation = self.attestations.create_attestation(group_id, user_id)
content["attestation"] = local_attestation
content["user_profile"] = yield self.profile_handler.get_profile(user_id)
res = yield self.transport_client.create_group(
get_domain_from_id(group_id), group_id, user_id, content,
)
remote_attestation = res["attestation"]
yield self.attestations.verify_attestation(
remote_attestation,
group_id=group_id,
user_id=user_id,
server_name=get_domain_from_id(group_id),
)
is_publicised = content.get("publicise", False)
token = yield self.store.register_user_group_membership(
group_id, user_id,
membership="join",
is_admin=True,
local_attestation=local_attestation,
remote_attestation=remote_attestation,
is_publicised=is_publicised,
)
self.notifier.on_new_event(
"groups_key", token, users=[user_id],
)
defer.returnValue(res)
@defer.inlineCallbacks
def get_users_in_group(self, group_id, requester_user_id):
"""Get users in a group
"""
if self.is_mine_id(group_id):
res = yield self.groups_server_handler.get_users_in_group(
group_id, requester_user_id
)
defer.returnValue(res)
group_server_name = get_domain_from_id(group_id)
res = yield self.transport_client.get_users_in_group(
get_domain_from_id(group_id), group_id, requester_user_id,
)
chunk = res["chunk"]
valid_entries = []
for entry in chunk:
g_user_id = entry["user_id"]
attestation = entry.pop("attestation", {})
try:
if get_domain_from_id(g_user_id) != group_server_name:
yield self.attestations.verify_attestation(
attestation,
group_id=group_id,
user_id=g_user_id,
server_name=get_domain_from_id(g_user_id),
)
valid_entries.append(entry)
except Exception as e:
logger.info("Failed to verify user is in group: %s", e)
res["chunk"] = valid_entries
defer.returnValue(res)
@defer.inlineCallbacks
def join_group(self, group_id, user_id, content):
"""Request to join a group
"""
raise NotImplementedError() # TODO
@defer.inlineCallbacks
def accept_invite(self, group_id, user_id, content):
"""Accept an invite to a group
"""
if self.is_mine_id(group_id):
yield self.groups_server_handler.accept_invite(
group_id, user_id, content
)
local_attestation = None
remote_attestation = None
else:
local_attestation = self.attestations.create_attestation(group_id, user_id)
content["attestation"] = local_attestation
res = yield self.transport_client.accept_group_invite(
get_domain_from_id(group_id), group_id, user_id, content,
)
remote_attestation = res["attestation"]
yield self.attestations.verify_attestation(
remote_attestation,
group_id=group_id,
user_id=user_id,
server_name=get_domain_from_id(group_id),
)
# TODO: Check that the group is public and we're being added publically
is_publicised = content.get("publicise", False)
token = yield self.store.register_user_group_membership(
group_id, user_id,
membership="join",
is_admin=False,
local_attestation=local_attestation,
remote_attestation=remote_attestation,
is_publicised=is_publicised,
)
self.notifier.on_new_event(
"groups_key", token, users=[user_id],
)
defer.returnValue({})
@defer.inlineCallbacks
def invite(self, group_id, user_id, requester_user_id, config):
"""Invite a user to a group
"""
content = {
"requester_user_id": requester_user_id,
"config": config,
}
if self.is_mine_id(group_id):
res = yield self.groups_server_handler.invite_to_group(
group_id, user_id, requester_user_id, content,
)
else:
res = yield self.transport_client.invite_to_group(
get_domain_from_id(group_id), group_id, user_id, requester_user_id,
content,
)
defer.returnValue(res)
@defer.inlineCallbacks
def on_invite(self, group_id, user_id, content):
"""One of our users were invited to a group
"""
# TODO: Support auto join and rejection
if not self.is_mine_id(user_id):
raise SynapseError(400, "User not on this server")
local_profile = {}
if "profile" in content:
if "name" in content["profile"]:
local_profile["name"] = content["profile"]["name"]
if "avatar_url" in content["profile"]:
local_profile["avatar_url"] = content["profile"]["avatar_url"]
token = yield self.store.register_user_group_membership(
group_id, user_id,
membership="invite",
content={"profile": local_profile, "inviter": content["inviter"]},
)
self.notifier.on_new_event(
"groups_key", token, users=[user_id],
)
try:
user_profile = yield self.profile_handler.get_profile(user_id)
except Exception as e:
logger.warn("No profile for user %s: %s", user_id, e)
user_profile = {}
defer.returnValue({"state": "invite", "user_profile": user_profile})
@defer.inlineCallbacks
def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
"""Remove a user from a group
"""
if user_id == requester_user_id:
token = yield self.store.register_user_group_membership(
group_id, user_id,
membership="leave",
)
self.notifier.on_new_event(
"groups_key", token, users=[user_id],
)
# TODO: Should probably remember that we tried to leave so that we can
# retry if the group server is currently down.
if self.is_mine_id(group_id):
res = yield self.groups_server_handler.remove_user_from_group(
group_id, user_id, requester_user_id, content,
)
else:
content["requester_user_id"] = requester_user_id
res = yield self.transport_client.remove_user_from_group(
get_domain_from_id(group_id), group_id, requester_user_id,
user_id, content,
)
defer.returnValue(res)
@defer.inlineCallbacks
def user_removed_from_group(self, group_id, user_id, content):
"""One of our users was removed/kicked from a group
"""
# TODO: Check if user in group
token = yield self.store.register_user_group_membership(
group_id, user_id,
membership="leave",
)
self.notifier.on_new_event(
"groups_key", token, users=[user_id],
)
@defer.inlineCallbacks
def get_joined_groups(self, user_id):
group_ids = yield self.store.get_joined_groups(user_id)
defer.returnValue({"groups": group_ids})
@defer.inlineCallbacks
def get_publicised_groups_for_user(self, user_id):
if self.hs.is_mine_id(user_id):
result = yield self.store.get_publicised_groups_for_user(user_id)
defer.returnValue({"groups": result})
else:
result = yield self.transport_client.get_publicised_groups_for_user(
get_domain_from_id(user_id), user_id
)
# TODO: Verify attestations
defer.returnValue(result)
@defer.inlineCallbacks
def bulk_get_publicised_groups(self, user_ids, proxy=True):
destinations = {}
local_users = set()
for user_id in user_ids:
if self.hs.is_mine_id(user_id):
local_users.add(user_id)
else:
destinations.setdefault(
get_domain_from_id(user_id), set()
).add(user_id)
if not proxy and destinations:
raise SynapseError(400, "Some user_ids are not local")
results = {}
failed_results = []
for destination, dest_user_ids in destinations.iteritems():
try:
r = yield self.transport_client.bulk_get_publicised_groups(
destination, list(dest_user_ids),
)
results.update(r["users"])
except Exception:
failed_results.extend(dest_user_ids)
for uid in local_users:
results[uid] = yield self.store.get_publicised_groups_for_user(
uid
)
defer.returnValue({"users": results})

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd # Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2017 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.events import spamcheck
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
@ -26,6 +26,7 @@ from synapse.types import (
from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import preserve_fn
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.util.frozenutils import unfreeze
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from ._base import BaseHandler from ._base import BaseHandler
@ -47,6 +48,7 @@ class MessageHandler(BaseHandler):
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.validator = EventValidator() self.validator = EventValidator()
self.profile_handler = hs.get_profile_handler()
self.pagination_lock = ReadWriteLock() self.pagination_lock = ReadWriteLock()
@ -58,6 +60,8 @@ class MessageHandler(BaseHandler):
self.action_generator = hs.get_action_generator() self.action_generator = hs.get_action_generator()
self.spam_checker = hs.get_spam_checker()
@defer.inlineCallbacks @defer.inlineCallbacks
def purge_history(self, room_id, event_id): def purge_history(self, room_id, event_id):
event = yield self.store.get_event(event_id) event = yield self.store.get_event(event_id)
@ -210,7 +214,7 @@ class MessageHandler(BaseHandler):
if membership in {Membership.JOIN, Membership.INVITE}: if membership in {Membership.JOIN, Membership.INVITE}:
# If event doesn't include a display name, add one. # If event doesn't include a display name, add one.
profile = self.hs.get_handlers().profile_handler profile = self.profile_handler
content = builder.content content = builder.content
try: try:
@ -322,9 +326,12 @@ class MessageHandler(BaseHandler):
txn_id=txn_id txn_id=txn_id
) )
if spamcheck.check_event_for_spam(event): spam_error = self.spam_checker.check_event_for_spam(event)
if spam_error:
if not isinstance(spam_error, basestring):
spam_error = "Spam is not permitted here"
raise SynapseError( raise SynapseError(
403, "Spam is not permitted here", Codes.FORBIDDEN 403, spam_error, Codes.FORBIDDEN
) )
yield self.send_nonmember_event( yield self.send_nonmember_event(
@ -418,6 +425,51 @@ class MessageHandler(BaseHandler):
[serialize_event(c, now) for c in room_state.values()] [serialize_event(c, now) for c in room_state.values()]
) )
@defer.inlineCallbacks
def get_joined_members(self, requester, room_id):
"""Get all the joined members in the room and their profile information.
If the user has left the room return the state events from when they left.
Args:
requester(Requester): The user requesting state events.
room_id(str): The room ID to get all state events from.
Returns:
A dict of user_id to profile info
"""
user_id = requester.user.to_string()
if not requester.app_service:
# We check AS auth after fetching the room membership, as it
# requires us to pull out all joined members anyway.
membership, _ = yield self._check_in_room_or_world_readable(
room_id, user_id
)
if membership != Membership.JOIN:
raise NotImplementedError(
"Getting joined members after leaving is not implemented"
)
users_with_profile = yield self.state.get_current_user_in_room(room_id)
# If this is an AS, double check that they are allowed to see the members.
# This can either be because the AS user is in the room or becuase there
# is a user in the room that the AS is "interested in"
if requester.app_service and user_id not in users_with_profile:
for uid in users_with_profile:
if requester.app_service.is_interested_in_user(uid):
break
else:
# Loop fell through, AS has no interested users in room
raise AuthError(403, "Appservice not in room")
defer.returnValue({
user_id: {
"avatar_url": profile.avatar_url,
"display_name": profile.display_name,
}
for user_id, profile in users_with_profile.iteritems()
})
@measure_func("_create_new_client_event") @measure_func("_create_new_client_event")
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_new_client_event(self, builder, requester=None, prev_event_ids=None): def _create_new_client_event(self, builder, requester=None, prev_event_ids=None):
@ -509,7 +561,7 @@ class MessageHandler(BaseHandler):
# Ensure that we can round trip before trying to persist in db # Ensure that we can round trip before trying to persist in db
try: try:
dump = ujson.dumps(event.content) dump = ujson.dumps(unfreeze(event.content))
ujson.loads(dump) ujson.loads(dump)
except: except:
logger.exception("Failed to encode content: %r", event.content) logger.exception("Failed to encode content: %r", event.content)

View File

@ -19,14 +19,15 @@ from twisted.internet import defer
import synapse.types import synapse.types
from synapse.api.errors import SynapseError, AuthError, CodeMessageException from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.types import UserID from synapse.types import UserID, get_domain_from_id
from ._base import BaseHandler from ._base import BaseHandler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ProfileHandler(BaseHandler): class ProfileHandler(BaseHandler):
PROFILE_UPDATE_MS = 60 * 1000
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
def __init__(self, hs): def __init__(self, hs):
super(ProfileHandler, self).__init__(hs) super(ProfileHandler, self).__init__(hs)
@ -36,6 +37,63 @@ class ProfileHandler(BaseHandler):
"profile", self.on_profile_query "profile", self.on_profile_query
) )
self.clock.looping_call(self._update_remote_profile_cache, self.PROFILE_UPDATE_MS)
@defer.inlineCallbacks
def get_profile(self, user_id):
target_user = UserID.from_string(user_id)
if self.hs.is_mine(target_user):
displayname = yield self.store.get_profile_displayname(
target_user.localpart
)
avatar_url = yield self.store.get_profile_avatar_url(
target_user.localpart
)
defer.returnValue({
"displayname": displayname,
"avatar_url": avatar_url,
})
else:
try:
result = yield self.federation.make_query(
destination=target_user.domain,
query_type="profile",
args={
"user_id": user_id,
},
ignore_backoff=True,
)
defer.returnValue(result)
except CodeMessageException as e:
if e.code != 404:
logger.exception("Failed to get displayname")
raise
@defer.inlineCallbacks
def get_profile_from_cache(self, user_id):
"""Get the profile information from our local cache. If the user is
ours then the profile information will always be corect. Otherwise,
it may be out of date/missing.
"""
target_user = UserID.from_string(user_id)
if self.hs.is_mine(target_user):
displayname = yield self.store.get_profile_displayname(
target_user.localpart
)
avatar_url = yield self.store.get_profile_avatar_url(
target_user.localpart
)
defer.returnValue({
"displayname": displayname,
"avatar_url": avatar_url,
})
else:
profile = yield self.store.get_from_remote_profile_cache(user_id)
defer.returnValue(profile or {})
@defer.inlineCallbacks @defer.inlineCallbacks
def get_displayname(self, target_user): def get_displayname(self, target_user):
if self.hs.is_mine(target_user): if self.hs.is_mine(target_user):
@ -182,3 +240,44 @@ class ProfileHandler(BaseHandler):
"Failed to update join event for room %s - %s", "Failed to update join event for room %s - %s",
room_id, str(e.message) room_id, str(e.message)
) )
def _update_remote_profile_cache(self):
"""Called periodically to check profiles of remote users we haven't
checked in a while.
"""
entries = yield self.store.get_remote_profile_cache_entries_that_expire(
last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS
)
for user_id, displayname, avatar_url in entries:
is_subscribed = yield self.store.is_subscribed_remote_profile_for_user(
user_id,
)
if not is_subscribed:
yield self.store.maybe_delete_remote_profile_cache(user_id)
continue
try:
profile = yield self.federation.make_query(
destination=get_domain_from_id(user_id),
query_type="profile",
args={
"user_id": user_id,
},
ignore_backoff=True,
)
except:
logger.exception("Failed to get avatar_url")
yield self.store.update_remote_profile_cache(
user_id, displayname, avatar_url
)
continue
new_name = profile.get("displayname")
new_avatar = profile.get("avatar_url")
# We always hit update to update the last_check timestamp
yield self.store.update_remote_profile_cache(
user_id, new_name, new_avatar
)

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.util import logcontext
from ._base import BaseHandler from ._base import BaseHandler
@ -59,6 +60,8 @@ class ReceiptsHandler(BaseHandler):
is_new = yield self._handle_new_receipts([receipt]) is_new = yield self._handle_new_receipts([receipt])
if is_new: if is_new:
# fire off a process in the background to send the receipt to
# remote servers
self._push_remotes([receipt]) self._push_remotes([receipt])
@defer.inlineCallbacks @defer.inlineCallbacks
@ -126,6 +129,7 @@ class ReceiptsHandler(BaseHandler):
defer.returnValue(True) defer.returnValue(True)
@logcontext.preserve_fn # caller should not yield on this
@defer.inlineCallbacks @defer.inlineCallbacks
def _push_remotes(self, receipts): def _push_remotes(self, receipts):
"""Given a list of receipts, works out which remote servers should be """Given a list of receipts, works out which remote servers should be

View File

@ -36,6 +36,7 @@ class RegistrationHandler(BaseHandler):
super(RegistrationHandler, self).__init__(hs) super(RegistrationHandler, self).__init__(hs)
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.profile_handler = hs.get_profile_handler()
self.captcha_client = CaptchaServerHttpClient(hs) self.captcha_client = CaptchaServerHttpClient(hs)
self._next_generated_user_id = None self._next_generated_user_id = None
@ -423,8 +424,7 @@ class RegistrationHandler(BaseHandler):
if displayname is not None: if displayname is not None:
logger.info("setting user display name: %s -> %s", user_id, displayname) logger.info("setting user display name: %s -> %s", user_id, displayname)
profile_handler = self.hs.get_handlers().profile_handler yield self.profile_handler.set_displayname(
yield profile_handler.set_displayname(
user, requester, displayname, by_admin=True, user, requester, displayname, by_admin=True,
) )

View File

@ -60,6 +60,11 @@ class RoomCreationHandler(BaseHandler):
}, },
} }
def __init__(self, hs):
super(RoomCreationHandler, self).__init__(hs)
self.spam_checker = hs.get_spam_checker()
@defer.inlineCallbacks @defer.inlineCallbacks
def create_room(self, requester, config, ratelimit=True): def create_room(self, requester, config, ratelimit=True):
""" Creates a new room. """ Creates a new room.
@ -75,6 +80,9 @@ class RoomCreationHandler(BaseHandler):
""" """
user_id = requester.user.to_string() user_id = requester.user.to_string()
if not self.spam_checker.user_may_create_room(user_id):
raise SynapseError(403, "You are not permitted to create rooms")
if ratelimit: if ratelimit:
yield self.ratelimit(requester) yield self.ratelimit(requester)

View File

@ -276,13 +276,14 @@ class RoomListHandler(BaseHandler):
# We've already got enough, so lets just drop it. # We've already got enough, so lets just drop it.
return return
result = yield self._generate_room_entry(room_id, num_joined_users) result = yield self.generate_room_entry(room_id, num_joined_users)
if result and _matches_room_entry(result, search_filter): if result and _matches_room_entry(result, search_filter):
chunk.append(result) chunk.append(result)
@cachedInlineCallbacks(num_args=1, cache_context=True) @cachedInlineCallbacks(num_args=1, cache_context=True)
def _generate_room_entry(self, room_id, num_joined_users, cache_context): def generate_room_entry(self, room_id, num_joined_users, cache_context,
with_alias=True, allow_private=False):
"""Returns the entry for a room """Returns the entry for a room
""" """
result = { result = {
@ -316,14 +317,15 @@ class RoomListHandler(BaseHandler):
join_rules_event = current_state.get((EventTypes.JoinRules, "")) join_rules_event = current_state.get((EventTypes.JoinRules, ""))
if join_rules_event: if join_rules_event:
join_rule = join_rules_event.content.get("join_rule", None) join_rule = join_rules_event.content.get("join_rule", None)
if join_rule and join_rule != JoinRules.PUBLIC: if not allow_private and join_rule and join_rule != JoinRules.PUBLIC:
defer.returnValue(None) defer.returnValue(None)
aliases = yield self.store.get_aliases_for_room( if with_alias:
room_id, on_invalidate=cache_context.invalidate aliases = yield self.store.get_aliases_for_room(
) room_id, on_invalidate=cache_context.invalidate
if aliases: )
result["aliases"] = aliases if aliases:
result["aliases"] = aliases
name_event = yield current_state.get((EventTypes.Name, "")) name_event = yield current_state.get((EventTypes.Name, ""))
if name_event: if name_event:

View File

@ -45,9 +45,12 @@ class RoomMemberHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(RoomMemberHandler, self).__init__(hs) super(RoomMemberHandler, self).__init__(hs)
self.profile_handler = hs.get_profile_handler()
self.member_linearizer = Linearizer(name="member") self.member_linearizer = Linearizer(name="member")
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.spam_checker = hs.get_spam_checker()
self.distributor = hs.get_distributor() self.distributor = hs.get_distributor()
self.distributor.declare("user_joined_room") self.distributor.declare("user_joined_room")
@ -210,12 +213,26 @@ class RoomMemberHandler(BaseHandler):
if is_blocked: if is_blocked:
raise SynapseError(403, "This room has been blocked on this server") raise SynapseError(403, "This room has been blocked on this server")
if (effective_membership_state == "invite" and if effective_membership_state == "invite":
self.hs.config.block_non_admin_invites): block_invite = False
is_requester_admin = yield self.auth.is_server_admin( is_requester_admin = yield self.auth.is_server_admin(
requester.user, requester.user,
) )
if not is_requester_admin: if not is_requester_admin:
if self.hs.config.block_non_admin_invites:
logger.info(
"Blocking invite: user is not admin and non-admin "
"invites disabled"
)
block_invite = True
if not self.spam_checker.user_may_invite(
requester.user.to_string(), target.to_string(), room_id,
):
logger.info("Blocking invite due to spam checker")
block_invite = True
if block_invite:
raise SynapseError( raise SynapseError(
403, "Invites have been disabled on this server", 403, "Invites have been disabled on this server",
) )
@ -267,7 +284,7 @@ class RoomMemberHandler(BaseHandler):
content["membership"] = Membership.JOIN content["membership"] = Membership.JOIN
profile = self.hs.get_handlers().profile_handler profile = self.profile_handler
if not content_specified: if not content_specified:
content["displayname"] = yield profile.get_displayname(target) content["displayname"] = yield profile.get_displayname(target)
content["avatar_url"] = yield profile.get_avatar_url(target) content["avatar_url"] = yield profile.get_avatar_url(target)

View File

@ -108,6 +108,17 @@ class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [
return True return True
class GroupsSyncResult(collections.namedtuple("GroupsSyncResult", [
"join",
"invite",
"leave",
])):
__slots__ = []
def __nonzero__(self):
return bool(self.join or self.invite or self.leave)
class DeviceLists(collections.namedtuple("DeviceLists", [ class DeviceLists(collections.namedtuple("DeviceLists", [
"changed", # list of user_ids whose devices may have changed "changed", # list of user_ids whose devices may have changed
"left", # list of user_ids whose devices we no longer track "left", # list of user_ids whose devices we no longer track
@ -129,6 +140,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
"device_lists", # List of user_ids whose devices have chanegd "device_lists", # List of user_ids whose devices have chanegd
"device_one_time_keys_count", # Dict of algorithm to count for one time keys "device_one_time_keys_count", # Dict of algorithm to count for one time keys
# for this device # for this device
"groups",
])): ])):
__slots__ = [] __slots__ = []
@ -144,7 +156,8 @@ class SyncResult(collections.namedtuple("SyncResult", [
self.archived or self.archived or
self.account_data or self.account_data or
self.to_device or self.to_device or
self.device_lists self.device_lists or
self.groups
) )
@ -595,6 +608,8 @@ class SyncHandler(object):
user_id, device_id user_id, device_id
) )
yield self._generate_sync_entry_for_groups(sync_result_builder)
defer.returnValue(SyncResult( defer.returnValue(SyncResult(
presence=sync_result_builder.presence, presence=sync_result_builder.presence,
account_data=sync_result_builder.account_data, account_data=sync_result_builder.account_data,
@ -603,10 +618,57 @@ class SyncHandler(object):
archived=sync_result_builder.archived, archived=sync_result_builder.archived,
to_device=sync_result_builder.to_device, to_device=sync_result_builder.to_device,
device_lists=device_lists, device_lists=device_lists,
groups=sync_result_builder.groups,
device_one_time_keys_count=one_time_key_counts, device_one_time_keys_count=one_time_key_counts,
next_batch=sync_result_builder.now_token, next_batch=sync_result_builder.now_token,
)) ))
@measure_func("_generate_sync_entry_for_groups")
@defer.inlineCallbacks
def _generate_sync_entry_for_groups(self, sync_result_builder):
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
now_token = sync_result_builder.now_token
if since_token and since_token.groups_key:
results = yield self.store.get_groups_changes_for_user(
user_id, since_token.groups_key, now_token.groups_key,
)
else:
results = yield self.store.get_all_groups_for_user(
user_id, now_token.groups_key,
)
invited = {}
joined = {}
left = {}
for result in results:
membership = result["membership"]
group_id = result["group_id"]
gtype = result["type"]
content = result["content"]
if membership == "join":
if gtype == "membership":
# TODO: Add profile
content.pop("membership", None)
joined[group_id] = content["content"]
else:
joined.setdefault(group_id, {})[gtype] = content
elif membership == "invite":
if gtype == "membership":
content.pop("membership", None)
invited[group_id] = content["content"]
else:
if gtype == "membership":
left[group_id] = content["content"]
sync_result_builder.groups = GroupsSyncResult(
join=joined,
invite=invited,
leave=left,
)
@measure_func("_generate_sync_entry_for_device_list") @measure_func("_generate_sync_entry_for_device_list")
@defer.inlineCallbacks @defer.inlineCallbacks
def _generate_sync_entry_for_device_list(self, sync_result_builder, def _generate_sync_entry_for_device_list(self, sync_result_builder,
@ -1368,6 +1430,7 @@ class SyncResultBuilder(object):
self.invited = [] self.invited = []
self.archived = [] self.archived = []
self.device = [] self.device = []
self.groups = None
self.to_device = [] self.to_device = []

View File

@ -354,16 +354,28 @@ def _get_hosts_for_srv_record(dns_client, host):
return res[0] return res[0]
def eb(res): def eb(res, record_type):
res.trap(DNSNameError) if res.check(DNSNameError):
return [] return []
logger.warn("Error looking up %s for %s: %s",
record_type, host, res, res.value)
return res
# no logcontexts here, so we can safely fire these off and gatherResults # no logcontexts here, so we can safely fire these off and gatherResults
d1 = dns_client.lookupAddress(host).addCallbacks(cb, eb) d1 = dns_client.lookupAddress(host).addCallbacks(cb, eb)
d2 = dns_client.lookupIPV6Address(host).addCallbacks(cb, eb) d2 = dns_client.lookupIPV6Address(host).addCallbacks(cb, eb)
results = yield defer.gatherResults([d1, d2], consumeErrors=True) results = yield defer.DeferredList(
[d1, d2], consumeErrors=True)
# if all of the lookups failed, raise an exception rather than blowing out
# the cache with an empty result.
if results and all(s == defer.FAILURE for (s, _) in results):
defer.returnValue(results[0][1])
for (success, result) in results:
if success == defer.FAILURE:
continue
for result in results:
for answer in result: for answer in result:
if not answer.payload: if not answer.payload:
continue continue

View File

@ -204,18 +204,15 @@ class MatrixFederationHttpClient(object):
raise raise
logger.warn( logger.warn(
"{%s} Sending request failed to %s: %s %s: %s - %s", "{%s} Sending request failed to %s: %s %s: %s",
txn_id, txn_id,
destination, destination,
method, method,
url_bytes, url_bytes,
type(e).__name__,
_flatten_response_never_received(e), _flatten_response_never_received(e),
) )
log_result = "%s - %s" % ( log_result = _flatten_response_never_received(e)
type(e).__name__, _flatten_response_never_received(e),
)
if retries_left and not timeout: if retries_left and not timeout:
if long_retries: if long_retries:
@ -347,7 +344,7 @@ class MatrixFederationHttpClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def post_json(self, destination, path, data={}, long_retries=False, def post_json(self, destination, path, data={}, long_retries=False,
timeout=None, ignore_backoff=False): timeout=None, ignore_backoff=False, args={}):
""" Sends the specifed json data using POST """ Sends the specifed json data using POST
Args: Args:
@ -383,6 +380,7 @@ class MatrixFederationHttpClient(object):
destination, destination,
"POST", "POST",
path, path,
query_bytes=encode_query_args(args),
body_callback=body_callback, body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]}, headers_dict={"Content-Type": ["application/json"]},
long_retries=long_retries, long_retries=long_retries,
@ -427,13 +425,6 @@ class MatrixFederationHttpClient(object):
""" """
logger.debug("get_json args: %s", args) logger.debug("get_json args: %s", args)
encoded_args = {}
for k, vs in args.items():
if isinstance(vs, basestring):
vs = [vs]
encoded_args[k] = [v.encode("UTF-8") for v in vs]
query_bytes = urllib.urlencode(encoded_args, True)
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail) logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
def body_callback(method, url_bytes, headers_dict): def body_callback(method, url_bytes, headers_dict):
@ -444,7 +435,7 @@ class MatrixFederationHttpClient(object):
destination, destination,
"GET", "GET",
path, path,
query_bytes=query_bytes, query_bytes=encode_query_args(args),
body_callback=body_callback, body_callback=body_callback,
retry_on_dns_fail=retry_on_dns_fail, retry_on_dns_fail=retry_on_dns_fail,
timeout=timeout, timeout=timeout,
@ -460,6 +451,52 @@ class MatrixFederationHttpClient(object):
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks
def delete_json(self, destination, path, long_retries=False,
timeout=None, ignore_backoff=False, args={}):
"""Send a DELETE request to the remote expecting some json response
Args:
destination (str): The remote server to send the HTTP request
to.
path (str): The HTTP path.
long_retries (bool): A boolean that indicates whether we should
retry for a short or long time.
timeout(int): How long to try (in ms) the destination for before
giving up. None indicates no timeout.
ignore_backoff (bool): true to ignore the historical backoff data and
try the request anyway.
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body.
Fails with ``HTTPRequestException`` if we get an HTTP response
code >= 300.
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
"""
response = yield self._request(
destination,
"DELETE",
path,
query_bytes=encode_query_args(args),
headers_dict={"Content-Type": ["application/json"]},
long_retries=long_retries,
timeout=timeout,
ignore_backoff=ignore_backoff,
)
if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent?
check_content_type_is_json(response.headers)
with logcontext.PreserveLoggingContext():
body = yield readBody(response)
defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_file(self, destination, path, output_stream, args={}, def get_file(self, destination, path, output_stream, args={},
retry_on_dns_fail=True, max_size=None, retry_on_dns_fail=True, max_size=None,
@ -578,12 +615,14 @@ class _JsonProducer(object):
def _flatten_response_never_received(e): def _flatten_response_never_received(e):
if hasattr(e, "reasons"): if hasattr(e, "reasons"):
return ", ".join( reasons = ", ".join(
_flatten_response_never_received(f.value) _flatten_response_never_received(f.value)
for f in e.reasons for f in e.reasons
) )
return "%s:[%s]" % (type(e).__name__, reasons)
else: else:
return "%s: %s" % (type(e).__name__, e.message,) return repr(e)
def check_content_type_is_json(headers): def check_content_type_is_json(headers):
@ -610,3 +649,15 @@ def check_content_type_is_json(headers):
raise RuntimeError( raise RuntimeError(
"Content-Type not application/json: was '%s'" % c_type "Content-Type not application/json: was '%s'" % c_type
) )
def encode_query_args(args):
encoded_args = {}
for k, vs in args.items():
if isinstance(vs, basestring):
vs = [vs]
encoded_args[k] = [v.encode("UTF-8") for v in vs]
query_bytes = urllib.urlencode(encoded_args, True)
return query_bytes

View File

@ -145,7 +145,9 @@ def wrap_request_handler(request_handler, include_metrics=False):
"error": "Internal server error", "error": "Internal server error",
"errcode": Codes.UNKNOWN, "errcode": Codes.UNKNOWN,
}, },
send_cors=True send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
version_string=self.version_string,
) )
finally: finally:
try: try:

View File

@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -238,6 +239,28 @@ BASE_APPEND_OVERRIDE_RULES = [
} }
] ]
}, },
{
'rule_id': 'global/override/.m.rule.roomnotif',
'conditions': [
{
'kind': 'event_match',
'key': 'content.body',
'pattern': '@room',
'_id': '_roomnotif_content',
},
{
'kind': 'sender_notification_permission',
'key': 'room',
'_id': '_roomnotif_pl',
},
],
'actions': [
'notify', {
'set_tweak': 'highlight',
'value': True,
}
]
}
] ]

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd # Copyright 2015 OpenMarket Ltd
# Copyright 2017 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -19,11 +20,13 @@ from twisted.internet import defer
from .push_rule_evaluator import PushRuleEvaluatorForEvent from .push_rule_evaluator import PushRuleEvaluatorForEvent
from synapse.event_auth import get_user_power_level
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.metrics import get_metrics_for from synapse.metrics import get_metrics_for
from synapse.util.caches import metrics as cache_metrics from synapse.util.caches import metrics as cache_metrics
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from synapse.state import POWER_KEY
from collections import namedtuple from collections import namedtuple
@ -59,6 +62,7 @@ class BulkPushRuleEvaluator(object):
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.room_push_rule_cache_metrics = cache_metrics.register_cache( self.room_push_rule_cache_metrics = cache_metrics.register_cache(
"cache", "cache",
@ -108,6 +112,29 @@ class BulkPushRuleEvaluator(object):
self.room_push_rule_cache_metrics, self.room_push_rule_cache_metrics,
) )
@defer.inlineCallbacks
def _get_power_levels_and_sender_level(self, event, context):
pl_event_id = context.prev_state_ids.get(POWER_KEY)
if pl_event_id:
# fastpath: if there's a power level event, that's all we need, and
# not having a power level event is an extreme edge case
pl_event = yield self.store.get_event(pl_event_id)
auth_events = {POWER_KEY: pl_event}
else:
auth_events_ids = yield self.auth.compute_auth_events(
event, context.prev_state_ids, for_verification=False,
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
(e.type, e.state_key): e for e in auth_events.itervalues()
}
sender_level = get_user_power_level(event.sender, auth_events)
pl_event = auth_events.get(POWER_KEY)
defer.returnValue((pl_event.content if pl_event else {}, sender_level))
@defer.inlineCallbacks @defer.inlineCallbacks
def action_for_event_by_user(self, event, context): def action_for_event_by_user(self, event, context):
"""Given an event and context, evaluate the push rules and return """Given an event and context, evaluate the push rules and return
@ -123,7 +150,13 @@ class BulkPushRuleEvaluator(object):
event, context event, context
) )
evaluator = PushRuleEvaluatorForEvent(event, len(room_members)) (power_levels, sender_power_level) = (
yield self._get_power_levels_and_sender_level(event, context)
)
evaluator = PushRuleEvaluatorForEvent(
event, len(room_members), sender_power_level, power_levels,
)
condition_cache = {} condition_cache = {}

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -29,6 +30,21 @@ INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
def _room_member_count(ev, condition, room_member_count): def _room_member_count(ev, condition, room_member_count):
return _test_ineq_condition(condition, room_member_count)
def _sender_notification_permission(ev, condition, sender_power_level, power_levels):
notif_level_key = condition.get('key')
if notif_level_key is None:
return False
notif_levels = power_levels.get('notifications', {})
room_notif_level = notif_levels.get(notif_level_key, 50)
return sender_power_level >= room_notif_level
def _test_ineq_condition(condition, number):
if 'is' not in condition: if 'is' not in condition:
return False return False
m = INEQUALITY_EXPR.match(condition['is']) m = INEQUALITY_EXPR.match(condition['is'])
@ -41,15 +57,15 @@ def _room_member_count(ev, condition, room_member_count):
rhs = int(rhs) rhs = int(rhs)
if ineq == '' or ineq == '==': if ineq == '' or ineq == '==':
return room_member_count == rhs return number == rhs
elif ineq == '<': elif ineq == '<':
return room_member_count < rhs return number < rhs
elif ineq == '>': elif ineq == '>':
return room_member_count > rhs return number > rhs
elif ineq == '>=': elif ineq == '>=':
return room_member_count >= rhs return number >= rhs
elif ineq == '<=': elif ineq == '<=':
return room_member_count <= rhs return number <= rhs
else: else:
return False return False
@ -65,9 +81,11 @@ def tweaks_for_actions(actions):
class PushRuleEvaluatorForEvent(object): class PushRuleEvaluatorForEvent(object):
def __init__(self, event, room_member_count): def __init__(self, event, room_member_count, sender_power_level, power_levels):
self._event = event self._event = event
self._room_member_count = room_member_count self._room_member_count = room_member_count
self._sender_power_level = sender_power_level
self._power_levels = power_levels
# Maps strings of e.g. 'content.body' -> event["content"]["body"] # Maps strings of e.g. 'content.body' -> event["content"]["body"]
self._value_cache = _flatten_dict(event) self._value_cache = _flatten_dict(event)
@ -81,6 +99,10 @@ class PushRuleEvaluatorForEvent(object):
return _room_member_count( return _room_member_count(
self._event, condition, self._room_member_count self._event, condition, self._room_member_count
) )
elif condition['kind'] == 'sender_notification_permission':
return _sender_notification_permission(
self._event, condition, self._sender_power_level, self._power_levels,
)
else: else:
return True return True
@ -183,7 +205,7 @@ def _glob_to_re(glob, word_boundary):
r, r,
) )
if word_boundary: if word_boundary:
r = r"\b%s\b" % (r,) r = _re_word_boundary(r)
return re.compile(r, flags=re.IGNORECASE) return re.compile(r, flags=re.IGNORECASE)
else: else:
@ -192,7 +214,7 @@ def _glob_to_re(glob, word_boundary):
return re.compile(r, flags=re.IGNORECASE) return re.compile(r, flags=re.IGNORECASE)
elif word_boundary: elif word_boundary:
r = re.escape(glob) r = re.escape(glob)
r = r"\b%s\b" % (r,) r = _re_word_boundary(r)
return re.compile(r, flags=re.IGNORECASE) return re.compile(r, flags=re.IGNORECASE)
else: else:
@ -200,6 +222,18 @@ def _glob_to_re(glob, word_boundary):
return re.compile(r, flags=re.IGNORECASE) return re.compile(r, flags=re.IGNORECASE)
def _re_word_boundary(r):
"""
Adds word boundary characters to the start and end of an
expression to require that the match occur as a whole word,
but do so respecting the fact that strings starting or ending
with non-word characters will change word boundaries.
"""
# we can't use \b as it chokes on unicode. however \W seems to be okay
# as shorthand for [^0-9A-Za-z_].
return r"(^|\W)%s(\W|$)" % (r,)
def _flatten_dict(d, prefix=[], result=None): def _flatten_dict(d, prefix=[], result=None):
if result is None: if result is None:
result = {} result = {}

View File

@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedGroupServerStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedGroupServerStore, self).__init__(db_conn, hs)
self.hs = hs
self._group_updates_id_gen = SlavedIdTracker(
db_conn, "local_group_updates", "stream_id",
)
self._group_updates_stream_cache = StreamChangeCache(
"_group_updates_stream_cache", self._group_updates_id_gen.get_current_token(),
)
get_groups_changes_for_user = DataStore.get_groups_changes_for_user.__func__
get_group_stream_token = DataStore.get_group_stream_token.__func__
get_all_groups_for_user = DataStore.get_all_groups_for_user.__func__
def stream_positions(self):
result = super(SlavedGroupServerStore, self).stream_positions()
result["groups"] = self._group_updates_id_gen.get_current_token()
return result
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "groups":
self._group_updates_id_gen.advance(token)
for row in rows:
self._group_updates_stream_cache.entity_has_changed(
row.user_id, token
)
return super(SlavedGroupServerStore, self).process_replication_rows(
stream_name, token, rows
)

View File

@ -160,7 +160,11 @@ class ReplicationStreamer(object):
"Getting stream: %s: %s -> %s", "Getting stream: %s: %s -> %s",
stream.NAME, stream.last_token, stream.upto_token stream.NAME, stream.last_token, stream.upto_token
) )
updates, current_token = yield stream.get_updates() try:
updates, current_token = yield stream.get_updates()
except:
logger.info("Failed to handle stream %s", stream.NAME)
raise
logger.debug( logger.debug(
"Sending %d updates to %d connections", "Sending %d updates to %d connections",

View File

@ -118,6 +118,12 @@ CurrentStateDeltaStreamRow = namedtuple("CurrentStateDeltaStream", (
"state_key", # str "state_key", # str
"event_id", # str, optional "event_id", # str, optional
)) ))
GroupsStreamRow = namedtuple("GroupsStreamRow", (
"group_id", # str
"user_id", # str
"type", # str
"content", # dict
))
class Stream(object): class Stream(object):
@ -464,6 +470,19 @@ class CurrentStateDeltaStream(Stream):
super(CurrentStateDeltaStream, self).__init__(hs) super(CurrentStateDeltaStream, self).__init__(hs)
class GroupServerStream(Stream):
NAME = "groups"
ROW_TYPE = GroupsStreamRow
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_group_stream_token
self.update_function = store.get_all_groups_changes
super(GroupServerStream, self).__init__(hs)
STREAMS_MAP = { STREAMS_MAP = {
stream.NAME: stream stream.NAME: stream
for stream in ( for stream in (
@ -482,5 +501,6 @@ STREAMS_MAP = {
TagAccountDataStream, TagAccountDataStream,
AccountDataStream, AccountDataStream,
CurrentStateDeltaStream, CurrentStateDeltaStream,
GroupServerStream,
) )
} }

View File

@ -52,6 +52,7 @@ from synapse.rest.client.v2_alpha import (
thirdparty, thirdparty,
sendtodevice, sendtodevice,
user_directory, user_directory,
groups,
) )
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
@ -102,3 +103,4 @@ class ClientRestResource(JsonResource):
thirdparty.register_servlets(hs, client_resource) thirdparty.register_servlets(hs, client_resource)
sendtodevice.register_servlets(hs, client_resource) sendtodevice.register_servlets(hs, client_resource)
user_directory.register_servlets(hs, client_resource) user_directory.register_servlets(hs, client_resource)
groups.register_servlets(hs, client_resource)

View File

@ -26,13 +26,13 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(ProfileDisplaynameRestServlet, self).__init__(hs) super(ProfileDisplaynameRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.profile_handler = hs.get_profile_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
displayname = yield self.handlers.profile_handler.get_displayname( displayname = yield self.profile_handler.get_displayname(
user, user,
) )
@ -55,7 +55,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
except: except:
defer.returnValue((400, "Unable to parse name")) defer.returnValue((400, "Unable to parse name"))
yield self.handlers.profile_handler.set_displayname( yield self.profile_handler.set_displayname(
user, requester, new_name, is_admin) user, requester, new_name, is_admin)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -69,13 +69,13 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(ProfileAvatarURLRestServlet, self).__init__(hs) super(ProfileAvatarURLRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.profile_handler = hs.get_profile_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
avatar_url = yield self.handlers.profile_handler.get_avatar_url( avatar_url = yield self.profile_handler.get_avatar_url(
user, user,
) )
@ -97,7 +97,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
except: except:
defer.returnValue((400, "Unable to parse name")) defer.returnValue((400, "Unable to parse name"))
yield self.handlers.profile_handler.set_avatar_url( yield self.profile_handler.set_avatar_url(
user, requester, new_name, is_admin) user, requester, new_name, is_admin)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -111,16 +111,16 @@ class ProfileRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(ProfileRestServlet, self).__init__(hs) super(ProfileRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.profile_handler = hs.get_profile_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
displayname = yield self.handlers.profile_handler.get_displayname( displayname = yield self.profile_handler.get_displayname(
user, user,
) )
avatar_url = yield self.handlers.profile_handler.get_avatar_url( avatar_url = yield self.profile_handler.get_avatar_url(
user, user,
) )

View File

@ -398,22 +398,18 @@ class JoinedRoomMemberListRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(JoinedRoomMemberListRestServlet, self).__init__(hs) super(JoinedRoomMemberListRestServlet, self).__init__(hs)
self.state = hs.get_state_handler() self.message_handler = hs.get_handlers().message_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
users_with_profile = yield self.state.get_current_user_in_room(room_id) users_with_profile = yield self.message_handler.get_joined_members(
requester, room_id,
)
defer.returnValue((200, { defer.returnValue((200, {
"joined": { "joined": users_with_profile,
user_id: {
"avatar_url": profile.avatar_url,
"display_name": profile.display_name,
}
for user_id, profile in users_with_profile.iteritems()
}
})) }))

View File

@ -0,0 +1,717 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import GroupID
from ._base import client_v2_patterns
import logging
logger = logging.getLogger(__name__)
class GroupServlet(RestServlet):
"""Get the group profile
"""
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/profile$")
def __init__(self, hs):
super(GroupServlet, self).__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks
def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
group_description = yield self.groups_handler.get_group_profile(group_id, user_id)
defer.returnValue((200, group_description))
@defer.inlineCallbacks
def on_POST(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
yield self.groups_handler.update_group_profile(
group_id, user_id, content,
)
defer.returnValue((200, {}))
class GroupSummaryServlet(RestServlet):
"""Get the full group summary
"""
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/summary$")
def __init__(self, hs):
super(GroupSummaryServlet, self).__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks
def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
get_group_summary = yield self.groups_handler.get_group_summary(group_id, user_id)
defer.returnValue((200, get_group_summary))
class GroupSummaryRoomsCatServlet(RestServlet):
"""Update/delete a rooms entry in the summary.
Matches both:
- /groups/:group/summary/rooms/:room_id
- /groups/:group/summary/categories/:category/rooms/:room_id
"""
PATTERNS = client_v2_patterns(
"/groups/(?P<group_id>[^/]*)/summary"
"(/categories/(?P<category_id>[^/]+))?"
"/rooms/(?P<room_id>[^/]*)$"
)
def __init__(self, hs):
super(GroupSummaryRoomsCatServlet, self).__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks
def on_PUT(self, request, group_id, category_id, room_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
resp = yield self.groups_handler.update_group_summary_room(
group_id, user_id,
room_id=room_id,
category_id=category_id,
content=content,
)
defer.returnValue((200, resp))
@defer.inlineCallbacks
def on_DELETE(self, request, group_id, category_id, room_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
resp = yield self.groups_handler.delete_group_summary_room(
group_id, user_id,
room_id=room_id,
category_id=category_id,
)
defer.returnValue((200, resp))
class GroupCategoryServlet(RestServlet):
"""Get/add/update/delete a group category
"""
PATTERNS = client_v2_patterns(
"/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
)
def __init__(self, hs):
super(GroupCategoryServlet, self).__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks
def on_GET(self, request, group_id, category_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_category(
group_id, user_id,
category_id=category_id,
)
defer.returnValue((200, category))
@defer.inlineCallbacks
def on_PUT(self, request, group_id, category_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
resp = yield self.groups_handler.update_group_category(
group_id, user_id,
category_id=category_id,
content=content,
)
defer.returnValue((200, resp))
@defer.inlineCallbacks
def on_DELETE(self, request, group_id, category_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
resp = yield self.groups_handler.delete_group_category(
group_id, user_id,
category_id=category_id,
)
defer.returnValue((200, resp))
class GroupCategoriesServlet(RestServlet):
"""Get all group categories
"""
PATTERNS = client_v2_patterns(
"/groups/(?P<group_id>[^/]*)/categories/$"
)
def __init__(self, hs):
super(GroupCategoriesServlet, self).__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks
def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_categories(
group_id, user_id,
)
defer.returnValue((200, category))
class GroupRoleServlet(RestServlet):
"""Get/add/update/delete a group role
"""
PATTERNS = client_v2_patterns(
"/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$"
)
def __init__(self, hs):
super(GroupRoleServlet, self).__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks
def on_GET(self, request, group_id, role_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_role(
group_id, user_id,
role_id=role_id,
)
defer.returnValue((200, category))
@defer.inlineCallbacks
def on_PUT(self, request, group_id, role_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
resp = yield self.groups_handler.update_group_role(
group_id, user_id,
role_id=role_id,
content=content,
)
defer.returnValue((200, resp))
@defer.inlineCallbacks
def on_DELETE(self, request, group_id, role_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
resp = yield self.groups_handler.delete_group_role(
group_id, user_id,
role_id=role_id,
)
defer.returnValue((200, resp))
class GroupRolesServlet(RestServlet):
"""Get all group roles
"""
PATTERNS = client_v2_patterns(
"/groups/(?P<group_id>[^/]*)/roles/$"
)
def __init__(self, hs):
super(GroupRolesServlet, self).__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks
def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_roles(
group_id, user_id,
)
defer.returnValue((200, category))
class GroupSummaryUsersRoleServlet(RestServlet):
"""Update/delete a user's entry in the summary.
Matches both:
- /groups/:group/summary/users/:room_id
- /groups/:group/summary/roles/:role/users/:user_id
"""
PATTERNS = client_v2_patterns(
"/groups/(?P<group_id>[^/]*)/summary"
"(/roles/(?P<role_id>[^/]+))?"
"/users/(?P<user_id>[^/]*)$"
)
def __init__(self, hs):
super(GroupSummaryUsersRoleServlet, self).__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks
def on_PUT(self, request, group_id, role_id, user_id):
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
resp = yield self.groups_handler.update_group_summary_user(
group_id, requester_user_id,
user_id=user_id,
role_id=role_id,
content=content,
)
defer.returnValue((200, resp))
@defer.inlineCallbacks
def on_DELETE(self, request, group_id, role_id, user_id):
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
resp = yield self.groups_handler.delete_group_summary_user(
group_id, requester_user_id,
user_id=user_id,
role_id=role_id,
)
defer.returnValue((200, resp))
class GroupRoomServlet(RestServlet):
"""Get all rooms in a group
"""
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
def __init__(self, hs):
super(GroupRoomServlet, self).__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks
def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
result = yield self.groups_handler.get_rooms_in_group(group_id, user_id)
defer.returnValue((200, result))
class GroupUsersServlet(RestServlet):
"""Get all users in a group
"""
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/users$")
def __init__(self, hs):
super(GroupUsersServlet, self).__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks
def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
result = yield self.groups_handler.get_users_in_group(group_id, user_id)
defer.returnValue((200, result))
class GroupInvitedUsersServlet(RestServlet):
"""Get users invited to a group
"""
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
def __init__(self, hs):
super(GroupInvitedUsersServlet, self).__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks
def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
result = yield self.groups_handler.get_invited_users_in_group(group_id, user_id)
defer.returnValue((200, result))
class GroupCreateServlet(RestServlet):
"""Create a group
"""
PATTERNS = client_v2_patterns("/create_group$")
def __init__(self, hs):
super(GroupCreateServlet, self).__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
self.server_name = hs.hostname
@defer.inlineCallbacks
def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
# TODO: Create group on remote server
content = parse_json_object_from_request(request)
localpart = content.pop("localpart")
group_id = GroupID.create(localpart, self.server_name).to_string()
result = yield self.groups_handler.create_group(group_id, user_id, content)
defer.returnValue((200, result))
class GroupAdminRoomsServlet(RestServlet):
"""Add a room to the group
"""
PATTERNS = client_v2_patterns(
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$"
)
def __init__(self, hs):
super(GroupAdminRoomsServlet, self).__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks
def on_PUT(self, request, group_id, room_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
result = yield self.groups_handler.add_room_to_group(
group_id, user_id, room_id, content,
)
defer.returnValue((200, result))
@defer.inlineCallbacks
def on_DELETE(self, request, group_id, room_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
result = yield self.groups_handler.remove_room_from_group(
group_id, user_id, room_id,
)
defer.returnValue((200, result))
class GroupAdminUsersInviteServlet(RestServlet):
"""Invite a user to the group
"""
PATTERNS = client_v2_patterns(
"/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$"
)
def __init__(self, hs):
super(GroupAdminUsersInviteServlet, self).__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
self.store = hs.get_datastore()
self.is_mine_id = hs.is_mine_id
@defer.inlineCallbacks
def on_PUT(self, request, group_id, user_id):
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
config = content.get("config", {})
result = yield self.groups_handler.invite(
group_id, user_id, requester_user_id, config,
)
defer.returnValue((200, result))
class GroupAdminUsersKickServlet(RestServlet):
"""Kick a user from the group
"""
PATTERNS = client_v2_patterns(
"/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$"
)
def __init__(self, hs):
super(GroupAdminUsersKickServlet, self).__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks
def on_PUT(self, request, group_id, user_id):
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
result = yield self.groups_handler.remove_user_from_group(
group_id, user_id, requester_user_id, content,
)
defer.returnValue((200, result))
class GroupSelfLeaveServlet(RestServlet):
"""Leave a joined group
"""
PATTERNS = client_v2_patterns(
"/groups/(?P<group_id>[^/]*)/self/leave$"
)
def __init__(self, hs):
super(GroupSelfLeaveServlet, self).__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks
def on_PUT(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
result = yield self.groups_handler.remove_user_from_group(
group_id, requester_user_id, requester_user_id, content,
)
defer.returnValue((200, result))
class GroupSelfJoinServlet(RestServlet):
"""Attempt to join a group, or knock
"""
PATTERNS = client_v2_patterns(
"/groups/(?P<group_id>[^/]*)/self/join$"
)
def __init__(self, hs):
super(GroupSelfJoinServlet, self).__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks
def on_PUT(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
result = yield self.groups_handler.join_group(
group_id, requester_user_id, content,
)
defer.returnValue((200, result))
class GroupSelfAcceptInviteServlet(RestServlet):
"""Accept a group invite
"""
PATTERNS = client_v2_patterns(
"/groups/(?P<group_id>[^/]*)/self/accept_invite$"
)
def __init__(self, hs):
super(GroupSelfAcceptInviteServlet, self).__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks
def on_PUT(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
result = yield self.groups_handler.accept_invite(
group_id, requester_user_id, content,
)
defer.returnValue((200, result))
class GroupSelfUpdatePublicityServlet(RestServlet):
"""Update whether we publicise a users membership of a group
"""
PATTERNS = client_v2_patterns(
"/groups/(?P<group_id>[^/]*)/self/update_publicity$"
)
def __init__(self, hs):
super(GroupSelfUpdatePublicityServlet, self).__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
@defer.inlineCallbacks
def on_PUT(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
publicise = content["publicise"]
yield self.store.update_group_publicity(
group_id, requester_user_id, publicise,
)
defer.returnValue((200, {}))
class PublicisedGroupsForUserServlet(RestServlet):
"""Get the list of groups a user is advertising
"""
PATTERNS = client_v2_patterns(
"/publicised_groups/(?P<user_id>[^/]*)$"
)
def __init__(self, hs):
super(PublicisedGroupsForUserServlet, self).__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
yield self.auth.get_user_by_req(request)
result = yield self.groups_handler.get_publicised_groups_for_user(
user_id
)
defer.returnValue((200, result))
class PublicisedGroupsForUsersServlet(RestServlet):
"""Get the list of groups a user is advertising
"""
PATTERNS = client_v2_patterns(
"/publicised_groups$"
)
def __init__(self, hs):
super(PublicisedGroupsForUsersServlet, self).__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks
def on_POST(self, request):
yield self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
user_ids = content["user_ids"]
result = yield self.groups_handler.bulk_get_publicised_groups(
user_ids
)
defer.returnValue((200, result))
class GroupsForUserServlet(RestServlet):
"""Get all groups the logged in user is joined to
"""
PATTERNS = client_v2_patterns(
"/joined_groups$"
)
def __init__(self, hs):
super(GroupsForUserServlet, self).__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
result = yield self.groups_handler.get_joined_groups(user_id)
defer.returnValue((200, result))
def register_servlets(hs, http_server):
GroupServlet(hs).register(http_server)
GroupSummaryServlet(hs).register(http_server)
GroupInvitedUsersServlet(hs).register(http_server)
GroupUsersServlet(hs).register(http_server)
GroupRoomServlet(hs).register(http_server)
GroupCreateServlet(hs).register(http_server)
GroupAdminRoomsServlet(hs).register(http_server)
GroupAdminUsersInviteServlet(hs).register(http_server)
GroupAdminUsersKickServlet(hs).register(http_server)
GroupSelfLeaveServlet(hs).register(http_server)
GroupSelfJoinServlet(hs).register(http_server)
GroupSelfAcceptInviteServlet(hs).register(http_server)
GroupsForUserServlet(hs).register(http_server)
GroupCategoryServlet(hs).register(http_server)
GroupCategoriesServlet(hs).register(http_server)
GroupSummaryRoomsCatServlet(hs).register(http_server)
GroupRoleServlet(hs).register(http_server)
GroupRolesServlet(hs).register(http_server)
GroupSelfUpdatePublicityServlet(hs).register(http_server)
GroupSummaryUsersRoleServlet(hs).register(http_server)
PublicisedGroupsForUserServlet(hs).register(http_server)
PublicisedGroupsForUsersServlet(hs).register(http_server)

View File

@ -17,8 +17,10 @@
from twisted.internet import defer from twisted.internet import defer
import synapse import synapse
import synapse.types
from synapse.api.auth import get_access_token_from_request, has_access_token from synapse.api.auth import get_access_token_from_request, has_access_token
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.types import RoomID, RoomAlias
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string
@ -170,6 +172,7 @@ class RegisterRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_handlers().registration_handler self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
self.room_member_handler = hs.get_handlers().room_member_handler
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator() self.macaroon_gen = hs.get_macaroon_generator()
@ -340,6 +343,14 @@ class RegisterRestServlet(RestServlet):
generate_token=False, generate_token=False,
) )
# auto-join the user to any rooms we're supposed to dump them into
fake_requester = synapse.types.create_requester(registered_user_id)
for r in self.hs.config.auto_join_rooms:
try:
yield self._join_user_to_room(fake_requester, r)
except Exception as e:
logger.error("Failed to join new user to %r: %r", r, e)
# remember that we've now registered that user account, and with # remember that we've now registered that user account, and with
# what user ID (since the user may not have specified) # what user ID (since the user may not have specified)
self.auth_handler.set_session_data( self.auth_handler.set_session_data(
@ -372,6 +383,29 @@ class RegisterRestServlet(RestServlet):
def on_OPTIONS(self, _): def on_OPTIONS(self, _):
return 200, {} return 200, {}
@defer.inlineCallbacks
def _join_user_to_room(self, requester, room_identifier):
room_id = None
if RoomID.is_valid(room_identifier):
room_id = room_identifier
elif RoomAlias.is_valid(room_identifier):
room_alias = RoomAlias.from_string(room_identifier)
room_id, remote_room_hosts = (
yield self.room_member_handler.lookup_room_alias(room_alias)
)
room_id = room_id.to_string()
else:
raise SynapseError(400, "%s was not legal room ID or room alias" % (
room_identifier,
))
yield self.room_member_handler.update_membership(
requester=requester,
target=requester.user,
room_id=room_id,
action="join",
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_appservice_registration(self, username, as_token, body): def _do_appservice_registration(self, username, as_token, body):
user_id = yield self.registration_handler.appservice_register( user_id = yield self.registration_handler.appservice_register(

View File

@ -200,6 +200,11 @@ class SyncRestServlet(RestServlet):
"invite": invited, "invite": invited,
"leave": archived, "leave": archived,
}, },
"groups": {
"join": sync_result.groups.join,
"invite": sync_result.groups.invite,
"leave": sync_result.groups.leave,
},
"device_one_time_keys_count": sync_result.device_one_time_keys_count, "device_one_time_keys_count": sync_result.device_one_time_keys_count,
"next_batch": sync_result.next_batch.to_string(), "next_batch": sync_result.next_batch.to_string(),
} }

View File

@ -14,78 +14,200 @@
# limitations under the License. # limitations under the License.
import os import os
import re
import functools
NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
def _wrap_in_base_path(func):
"""Takes a function that returns a relative path and turns it into an
absolute path based on the location of the primary media store
"""
@functools.wraps(func)
def _wrapped(self, *args, **kwargs):
path = func(self, *args, **kwargs)
return os.path.join(self.base_path, path)
return _wrapped
class MediaFilePaths(object): class MediaFilePaths(object):
"""Describes where files are stored on disk.
def __init__(self, base_path): Most of the functions have a `*_rel` variant which returns a file path that
self.base_path = base_path is relative to the base media store path. This is mainly used when we want
to write to the backup media store (when one is configured)
"""
def default_thumbnail(self, default_top_level, default_sub_type, width, def __init__(self, primary_base_path):
height, content_type, method): self.base_path = primary_base_path
def default_thumbnail_rel(self, default_top_level, default_sub_type, width,
height, content_type, method):
top_level_type, sub_type = content_type.split("/") top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % ( file_name = "%i-%i-%s-%s-%s" % (
width, height, top_level_type, sub_type, method width, height, top_level_type, sub_type, method
) )
return os.path.join( return os.path.join(
self.base_path, "default_thumbnails", default_top_level, "default_thumbnails", default_top_level,
default_sub_type, file_name default_sub_type, file_name
) )
def local_media_filepath(self, media_id): default_thumbnail = _wrap_in_base_path(default_thumbnail_rel)
def local_media_filepath_rel(self, media_id):
return os.path.join( return os.path.join(
self.base_path, "local_content", "local_content",
media_id[0:2], media_id[2:4], media_id[4:] media_id[0:2], media_id[2:4], media_id[4:]
) )
def local_media_thumbnail(self, media_id, width, height, content_type, local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
method):
def local_media_thumbnail_rel(self, media_id, width, height, content_type,
method):
top_level_type, sub_type = content_type.split("/") top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % ( file_name = "%i-%i-%s-%s-%s" % (
width, height, top_level_type, sub_type, method width, height, top_level_type, sub_type, method
) )
return os.path.join( return os.path.join(
self.base_path, "local_thumbnails", "local_thumbnails",
media_id[0:2], media_id[2:4], media_id[4:], media_id[0:2], media_id[2:4], media_id[4:],
file_name file_name
) )
def remote_media_filepath(self, server_name, file_id): local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel)
def remote_media_filepath_rel(self, server_name, file_id):
return os.path.join( return os.path.join(
self.base_path, "remote_content", server_name, "remote_content", server_name,
file_id[0:2], file_id[2:4], file_id[4:] file_id[0:2], file_id[2:4], file_id[4:]
) )
def remote_media_thumbnail(self, server_name, file_id, width, height, remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
content_type, method):
def remote_media_thumbnail_rel(self, server_name, file_id, width, height,
content_type, method):
top_level_type, sub_type = content_type.split("/") top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
return os.path.join( return os.path.join(
self.base_path, "remote_thumbnail", server_name, "remote_thumbnail", server_name,
file_id[0:2], file_id[2:4], file_id[4:], file_id[0:2], file_id[2:4], file_id[4:],
file_name file_name
) )
remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel)
def remote_media_thumbnail_dir(self, server_name, file_id): def remote_media_thumbnail_dir(self, server_name, file_id):
return os.path.join( return os.path.join(
self.base_path, "remote_thumbnail", server_name, self.base_path, "remote_thumbnail", server_name,
file_id[0:2], file_id[2:4], file_id[4:], file_id[0:2], file_id[2:4], file_id[4:],
) )
def url_cache_filepath(self, media_id): def url_cache_filepath_rel(self, media_id):
return os.path.join( if NEW_FORMAT_ID_RE.match(media_id):
self.base_path, "url_cache", # Media id is of the form <DATE><RANDOM_STRING>
media_id[0:2], media_id[2:4], media_id[4:] # E.g.: 2017-09-28-fsdRDt24DS234dsf
) return os.path.join(
"url_cache",
media_id[:10], media_id[11:]
)
else:
return os.path.join(
"url_cache",
media_id[0:2], media_id[2:4], media_id[4:],
)
url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
def url_cache_filepath_dirs_to_delete(self, media_id):
"The dirs to try and remove if we delete the media_id file"
if NEW_FORMAT_ID_RE.match(media_id):
return [
os.path.join(
self.base_path, "url_cache",
media_id[:10],
),
]
else:
return [
os.path.join(
self.base_path, "url_cache",
media_id[0:2], media_id[2:4],
),
os.path.join(
self.base_path, "url_cache",
media_id[0:2],
),
]
def url_cache_thumbnail_rel(self, media_id, width, height, content_type,
method):
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
def url_cache_thumbnail(self, media_id, width, height, content_type,
method):
top_level_type, sub_type = content_type.split("/") top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % ( file_name = "%i-%i-%s-%s-%s" % (
width, height, top_level_type, sub_type, method width, height, top_level_type, sub_type, method
) )
return os.path.join(
self.base_path, "url_cache_thumbnails", if NEW_FORMAT_ID_RE.match(media_id):
media_id[0:2], media_id[2:4], media_id[4:], return os.path.join(
file_name "url_cache_thumbnails",
) media_id[:10], media_id[11:],
file_name
)
else:
return os.path.join(
"url_cache_thumbnails",
media_id[0:2], media_id[2:4], media_id[4:],
file_name
)
url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
def url_cache_thumbnail_directory(self, media_id):
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
if NEW_FORMAT_ID_RE.match(media_id):
return os.path.join(
self.base_path, "url_cache_thumbnails",
media_id[:10], media_id[11:],
)
else:
return os.path.join(
self.base_path, "url_cache_thumbnails",
media_id[0:2], media_id[2:4], media_id[4:],
)
def url_cache_thumbnail_dirs_to_delete(self, media_id):
"The dirs to try and remove if we delete the media_id thumbnails"
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
if NEW_FORMAT_ID_RE.match(media_id):
return [
os.path.join(
self.base_path, "url_cache_thumbnails",
media_id[:10], media_id[11:],
),
os.path.join(
self.base_path, "url_cache_thumbnails",
media_id[:10],
),
]
else:
return [
os.path.join(
self.base_path, "url_cache_thumbnails",
media_id[0:2], media_id[2:4], media_id[4:],
),
os.path.join(
self.base_path, "url_cache_thumbnails",
media_id[0:2], media_id[2:4],
),
os.path.join(
self.base_path, "url_cache_thumbnails",
media_id[0:2],
),
]

View File

@ -33,7 +33,7 @@ from synapse.api.errors import SynapseError, HttpResponseException, \
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from synapse.util.stringutils import is_ascii from synapse.util.stringutils import is_ascii
from synapse.util.logcontext import preserve_context_over_fn from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
import os import os
@ -59,7 +59,14 @@ class MediaRepository(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.max_upload_size = hs.config.max_upload_size self.max_upload_size = hs.config.max_upload_size
self.max_image_pixels = hs.config.max_image_pixels self.max_image_pixels = hs.config.max_image_pixels
self.filepaths = MediaFilePaths(hs.config.media_store_path)
self.primary_base_path = hs.config.media_store_path
self.filepaths = MediaFilePaths(self.primary_base_path)
self.backup_base_path = hs.config.backup_media_store_path
self.synchronous_backup_media_store = hs.config.synchronous_backup_media_store
self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements self.thumbnail_requirements = hs.config.thumbnail_requirements
@ -87,18 +94,86 @@ class MediaRepository(object):
if not os.path.exists(dirname): if not os.path.exists(dirname):
os.makedirs(dirname) os.makedirs(dirname)
@staticmethod
def _write_file_synchronously(source, fname):
"""Write `source` to the path `fname` synchronously. Should be called
from a thread.
Args:
source: A file like object to be written
fname (str): Path to write to
"""
MediaRepository._makedirs(fname)
source.seek(0) # Ensure we read from the start of the file
with open(fname, "wb") as f:
shutil.copyfileobj(source, f)
@defer.inlineCallbacks
def write_to_file_and_backup(self, source, path):
"""Write `source` to the on disk media store, and also the backup store
if configured.
Args:
source: A file like object that should be written
path (str): Relative path to write file to
Returns:
Deferred[str]: the file path written to in the primary media store
"""
fname = os.path.join(self.primary_base_path, path)
# Write to the main repository
yield make_deferred_yieldable(threads.deferToThread(
self._write_file_synchronously, source, fname,
))
# Write to backup repository
yield self.copy_to_backup(path)
defer.returnValue(fname)
@defer.inlineCallbacks
def copy_to_backup(self, path):
"""Copy a file from the primary to backup media store, if configured.
Args:
path(str): Relative path to write file to
"""
if self.backup_base_path:
primary_fname = os.path.join(self.primary_base_path, path)
backup_fname = os.path.join(self.backup_base_path, path)
# We can either wait for successful writing to the backup repository
# or write in the background and immediately return
if self.synchronous_backup_media_store:
yield make_deferred_yieldable(threads.deferToThread(
shutil.copyfile, primary_fname, backup_fname,
))
else:
preserve_fn(threads.deferToThread)(
shutil.copyfile, primary_fname, backup_fname,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def create_content(self, media_type, upload_name, content, content_length, def create_content(self, media_type, upload_name, content, content_length,
auth_user): auth_user):
"""Store uploaded content for a local user and return the mxc URL
Args:
media_type(str): The content type of the file
upload_name(str): The name of the file
content: A file like object that is the content to store
content_length(int): The length of the content
auth_user(str): The user_id of the uploader
Returns:
Deferred[str]: The mxc url of the stored content
"""
media_id = random_string(24) media_id = random_string(24)
fname = self.filepaths.local_media_filepath(media_id) fname = yield self.write_to_file_and_backup(
self._makedirs(fname) content, self.filepaths.local_media_filepath_rel(media_id)
)
# This shouldn't block for very long because the content will have
# already been uploaded at this point.
with open(fname, "wb") as f:
f.write(content)
logger.info("Stored local media in file %r", fname) logger.info("Stored local media in file %r", fname)
@ -115,7 +190,7 @@ class MediaRepository(object):
"media_length": content_length, "media_length": content_length,
} }
yield self._generate_local_thumbnails(media_id, media_info) yield self._generate_thumbnails(None, media_id, media_info)
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
@ -148,9 +223,10 @@ class MediaRepository(object):
def _download_remote_file(self, server_name, media_id): def _download_remote_file(self, server_name, media_id):
file_id = random_string(24) file_id = random_string(24)
fname = self.filepaths.remote_media_filepath( fpath = self.filepaths.remote_media_filepath_rel(
server_name, file_id server_name, file_id
) )
fname = os.path.join(self.primary_base_path, fpath)
self._makedirs(fname) self._makedirs(fname)
try: try:
@ -192,6 +268,8 @@ class MediaRepository(object):
server_name, media_id) server_name, media_id)
raise SynapseError(502, "Failed to fetch remote media") raise SynapseError(502, "Failed to fetch remote media")
yield self.copy_to_backup(fpath)
media_type = headers["Content-Type"][0] media_type = headers["Content-Type"][0]
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
@ -244,7 +322,7 @@ class MediaRepository(object):
"filesystem_id": file_id, "filesystem_id": file_id,
} }
yield self._generate_remote_thumbnails( yield self._generate_thumbnails(
server_name, media_id, media_info server_name, media_id, media_info
) )
@ -253,9 +331,8 @@ class MediaRepository(object):
def _get_thumbnail_requirements(self, media_type): def _get_thumbnail_requirements(self, media_type):
return self.thumbnail_requirements.get(media_type, ()) return self.thumbnail_requirements.get(media_type, ())
def _generate_thumbnail(self, input_path, t_path, t_width, t_height, def _generate_thumbnail(self, thumbnailer, t_width, t_height,
t_method, t_type): t_method, t_type):
thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width m_width = thumbnailer.width
m_height = thumbnailer.height m_height = thumbnailer.height
@ -267,72 +344,105 @@ class MediaRepository(object):
return return
if t_method == "crop": if t_method == "crop":
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) t_byte_source = thumbnailer.crop(t_width, t_height, t_type)
elif t_method == "scale": elif t_method == "scale":
t_width, t_height = thumbnailer.aspect(t_width, t_height) t_width, t_height = thumbnailer.aspect(t_width, t_height)
t_width = min(m_width, t_width) t_width = min(m_width, t_width)
t_height = min(m_height, t_height) t_height = min(m_height, t_height)
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) t_byte_source = thumbnailer.scale(t_width, t_height, t_type)
else: else:
t_len = None t_byte_source = None
return t_len return t_byte_source
@defer.inlineCallbacks @defer.inlineCallbacks
def generate_local_exact_thumbnail(self, media_id, t_width, t_height, def generate_local_exact_thumbnail(self, media_id, t_width, t_height,
t_method, t_type): t_method, t_type):
input_path = self.filepaths.local_media_filepath(media_id) input_path = self.filepaths.local_media_filepath(media_id)
t_path = self.filepaths.local_media_thumbnail( thumbnailer = Thumbnailer(input_path)
media_id, t_width, t_height, t_type, t_method t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
)
self._makedirs(t_path)
t_len = yield preserve_context_over_fn(
threads.deferToThread,
self._generate_thumbnail, self._generate_thumbnail,
input_path, t_path, t_width, t_height, t_method, t_type thumbnailer, t_width, t_height, t_method, t_type
) ))
if t_byte_source:
try:
output_path = yield self.write_to_file_and_backup(
t_byte_source,
self.filepaths.local_media_thumbnail_rel(
media_id, t_width, t_height, t_type, t_method
)
)
finally:
t_byte_source.close()
logger.info("Stored thumbnail in file %r", output_path)
t_len = os.path.getsize(output_path)
if t_len:
yield self.store.store_local_thumbnail( yield self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len media_id, t_width, t_height, t_type, t_method, t_len
) )
defer.returnValue(t_path) defer.returnValue(output_path)
@defer.inlineCallbacks @defer.inlineCallbacks
def generate_remote_exact_thumbnail(self, server_name, file_id, media_id, def generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
t_width, t_height, t_method, t_type): t_width, t_height, t_method, t_type):
input_path = self.filepaths.remote_media_filepath(server_name, file_id) input_path = self.filepaths.remote_media_filepath(server_name, file_id)
t_path = self.filepaths.remote_media_thumbnail( thumbnailer = Thumbnailer(input_path)
server_name, file_id, t_width, t_height, t_type, t_method t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
)
self._makedirs(t_path)
t_len = yield preserve_context_over_fn(
threads.deferToThread,
self._generate_thumbnail, self._generate_thumbnail,
input_path, t_path, t_width, t_height, t_method, t_type thumbnailer, t_width, t_height, t_method, t_type
) ))
if t_byte_source:
try:
output_path = yield self.write_to_file_and_backup(
t_byte_source,
self.filepaths.remote_media_thumbnail_rel(
server_name, file_id, t_width, t_height, t_type, t_method
)
)
finally:
t_byte_source.close()
logger.info("Stored thumbnail in file %r", output_path)
t_len = os.path.getsize(output_path)
if t_len:
yield self.store.store_remote_media_thumbnail( yield self.store.store_remote_media_thumbnail(
server_name, media_id, file_id, server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len t_width, t_height, t_type, t_method, t_len
) )
defer.returnValue(t_path) defer.returnValue(output_path)
@defer.inlineCallbacks @defer.inlineCallbacks
def _generate_local_thumbnails(self, media_id, media_info, url_cache=False): def _generate_thumbnails(self, server_name, media_id, media_info, url_cache=False):
"""Generate and store thumbnails for an image.
Args:
server_name(str|None): The server name if remote media, else None if local
media_id(str)
media_info(dict)
url_cache(bool): If we are thumbnailing images downloaded for the URL cache,
used exclusively by the url previewer
Returns:
Deferred[dict]: Dict with "width" and "height" keys of original image
"""
media_type = media_info["media_type"] media_type = media_info["media_type"]
file_id = media_info.get("filesystem_id")
requirements = self._get_thumbnail_requirements(media_type) requirements = self._get_thumbnail_requirements(media_type)
if not requirements: if not requirements:
return return
if url_cache: if server_name:
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
elif url_cache:
input_path = self.filepaths.url_cache_filepath(media_id) input_path = self.filepaths.url_cache_filepath(media_id)
else: else:
input_path = self.filepaths.local_media_filepath(media_id) input_path = self.filepaths.local_media_filepath(media_id)
@ -348,135 +458,72 @@ class MediaRepository(object):
) )
return return
local_thumbnails = [] # We deduplicate the thumbnail sizes by ignoring the cropped versions if
# they have the same dimensions of a scaled one.
thumbnails = {}
for r_width, r_height, r_method, r_type in requirements:
if r_method == "crop":
thumbnails.setdefault((r_width, r_height, r_type), r_method)
elif r_method == "scale":
t_width, t_height = thumbnailer.aspect(r_width, r_height)
t_width = min(m_width, t_width)
t_height = min(m_height, t_height)
thumbnails[(t_width, t_height, r_type)] = r_method
def generate_thumbnails(): # Now we generate the thumbnails for each dimension, store it
scales = set() for (t_width, t_height, t_type), t_method in thumbnails.iteritems():
crops = set() # Work out the correct file name for thumbnail
for r_width, r_height, r_method, r_type in requirements: if server_name:
if r_method == "scale": file_path = self.filepaths.remote_media_thumbnail_rel(
t_width, t_height = thumbnailer.aspect(r_width, r_height)
scales.add((
min(m_width, t_width), min(m_height, t_height), r_type,
))
elif r_method == "crop":
crops.add((r_width, r_height, r_type))
for t_width, t_height, t_type in scales:
t_method = "scale"
if url_cache:
t_path = self.filepaths.url_cache_thumbnail(
media_id, t_width, t_height, t_type, t_method
)
else:
t_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
local_thumbnails.append((
media_id, t_width, t_height, t_type, t_method, t_len
))
for t_width, t_height, t_type in crops:
if (t_width, t_height, t_type) in scales:
# If the aspect ratio of the cropped thumbnail matches a purely
# scaled one then there is no point in calculating a separate
# thumbnail.
continue
t_method = "crop"
if url_cache:
t_path = self.filepaths.url_cache_thumbnail(
media_id, t_width, t_height, t_type, t_method
)
else:
t_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
local_thumbnails.append((
media_id, t_width, t_height, t_type, t_method, t_len
))
yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
for l in local_thumbnails:
yield self.store.store_local_thumbnail(*l)
defer.returnValue({
"width": m_width,
"height": m_height,
})
@defer.inlineCallbacks
def _generate_remote_thumbnails(self, server_name, media_id, media_info):
media_type = media_info["media_type"]
file_id = media_info["filesystem_id"]
requirements = self._get_thumbnail_requirements(media_type)
if not requirements:
return
remote_thumbnails = []
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
m_height = thumbnailer.height
def generate_thumbnails():
if m_width * m_height >= self.max_image_pixels:
logger.info(
"Image too large to thumbnail %r x %r > %r",
m_width, m_height, self.max_image_pixels
)
return
scales = set()
crops = set()
for r_width, r_height, r_method, r_type in requirements:
if r_method == "scale":
t_width, t_height = thumbnailer.aspect(r_width, r_height)
scales.add((
min(m_width, t_width), min(m_height, t_height), r_type,
))
elif r_method == "crop":
crops.add((r_width, r_height, r_type))
for t_width, t_height, t_type in scales:
t_method = "scale"
t_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, t_width, t_height, t_type, t_method server_name, file_id, t_width, t_height, t_type, t_method
) )
self._makedirs(t_path) elif url_cache:
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) file_path = self.filepaths.url_cache_thumbnail_rel(
remote_thumbnails.append([ media_id, t_width, t_height, t_type, t_method
server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len
])
for t_width, t_height, t_type in crops:
if (t_width, t_height, t_type) in scales:
# If the aspect ratio of the cropped thumbnail matches a purely
# scaled one then there is no point in calculating a separate
# thumbnail.
continue
t_method = "crop"
t_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, t_width, t_height, t_type, t_method
) )
self._makedirs(t_path) else:
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) file_path = self.filepaths.local_media_thumbnail_rel(
remote_thumbnails.append([ media_id, t_width, t_height, t_type, t_method
)
# Generate the thumbnail
if t_method == "crop":
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
thumbnailer.crop,
t_width, t_height, t_type,
))
elif t_method == "scale":
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
thumbnailer.scale,
t_width, t_height, t_type,
))
else:
logger.error("Unrecognized method: %r", t_method)
continue
if not t_byte_source:
continue
try:
# Write to disk
output_path = yield self.write_to_file_and_backup(
t_byte_source, file_path,
)
finally:
t_byte_source.close()
t_len = os.path.getsize(output_path)
# Write to database
if server_name:
yield self.store.store_remote_media_thumbnail(
server_name, media_id, file_id, server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len t_width, t_height, t_type, t_method, t_len
]) )
else:
yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails) yield self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
for r in remote_thumbnails: )
yield self.store.store_remote_media_thumbnail(*r)
defer.returnValue({ defer.returnValue({
"width": m_width, "width": m_width,
@ -497,6 +544,8 @@ class MediaRepository(object):
logger.info("Deleting: %r", key) logger.info("Deleting: %r", key)
# TODO: Should we delete from the backup store
with (yield self.remote_media_linearizer.queue(key)): with (yield self.remote_media_linearizer.queue(key)):
full_path = self.filepaths.remote_media_filepath(origin, file_id) full_path = self.filepaths.remote_media_filepath(origin, file_id)
try: try:

View File

@ -36,6 +36,9 @@ import cgi
import ujson as json import ujson as json
import urlparse import urlparse
import itertools import itertools
import datetime
import errno
import shutil
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,6 +59,7 @@ class PreviewUrlResource(Resource):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.client = SpiderHttpClient(hs) self.client = SpiderHttpClient(hs)
self.media_repo = media_repo self.media_repo = media_repo
self.primary_base_path = media_repo.primary_base_path
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
@ -70,6 +74,10 @@ class PreviewUrlResource(Resource):
self.downloads = {} self.downloads = {}
self._cleaner_loop = self.clock.looping_call(
self._expire_url_cache_data, 10 * 1000
)
def render_GET(self, request): def render_GET(self, request):
self._async_render_GET(request) self._async_render_GET(request)
return NOT_DONE_YET return NOT_DONE_YET
@ -130,7 +138,7 @@ class PreviewUrlResource(Resource):
cache_result = yield self.store.get_url_cache(url, ts) cache_result = yield self.store.get_url_cache(url, ts)
if ( if (
cache_result and cache_result and
cache_result["download_ts"] + cache_result["expires"] > ts and cache_result["expires_ts"] > ts and
cache_result["response_code"] / 100 == 2 cache_result["response_code"] / 100 == 2
): ):
respond_with_json_bytes( respond_with_json_bytes(
@ -163,8 +171,8 @@ class PreviewUrlResource(Resource):
logger.debug("got media_info of '%s'" % media_info) logger.debug("got media_info of '%s'" % media_info)
if _is_media(media_info['media_type']): if _is_media(media_info['media_type']):
dims = yield self.media_repo._generate_local_thumbnails( dims = yield self.media_repo._generate_thumbnails(
media_info['filesystem_id'], media_info, url_cache=True, None, media_info['filesystem_id'], media_info, url_cache=True,
) )
og = { og = {
@ -209,8 +217,8 @@ class PreviewUrlResource(Resource):
if _is_media(image_info['media_type']): if _is_media(image_info['media_type']):
# TODO: make sure we don't choke on white-on-transparent images # TODO: make sure we don't choke on white-on-transparent images
dims = yield self.media_repo._generate_local_thumbnails( dims = yield self.media_repo._generate_thumbnails(
image_info['filesystem_id'], image_info, url_cache=True, None, image_info['filesystem_id'], image_info, url_cache=True,
) )
if dims: if dims:
og["og:image:width"] = dims['width'] og["og:image:width"] = dims['width']
@ -239,7 +247,7 @@ class PreviewUrlResource(Resource):
url, url,
media_info["response_code"], media_info["response_code"],
media_info["etag"], media_info["etag"],
media_info["expires"], media_info["expires"] + media_info["created_ts"],
json.dumps(og), json.dumps(og),
media_info["filesystem_id"], media_info["filesystem_id"],
media_info["created_ts"], media_info["created_ts"],
@ -253,10 +261,10 @@ class PreviewUrlResource(Resource):
# we're most likely being explicitly triggered by a human rather than a # we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot? # bot, so are we really a robot?
# XXX: horrible duplication with base_resource's _download_remote_file() file_id = datetime.date.today().isoformat() + '_' + random_string(16)
file_id = random_string(24)
fname = self.filepaths.url_cache_filepath(file_id) fpath = self.filepaths.url_cache_filepath_rel(file_id)
fname = os.path.join(self.primary_base_path, fpath)
self.media_repo._makedirs(fname) self.media_repo._makedirs(fname)
try: try:
@ -267,6 +275,8 @@ class PreviewUrlResource(Resource):
) )
# FIXME: pass through 404s and other error messages nicely # FIXME: pass through 404s and other error messages nicely
yield self.media_repo.copy_to_backup(fpath)
media_type = headers["Content-Type"][0] media_type = headers["Content-Type"][0]
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
@ -328,6 +338,91 @@ class PreviewUrlResource(Resource):
"etag": headers["ETag"][0] if "ETag" in headers else None, "etag": headers["ETag"][0] if "ETag" in headers else None,
}) })
@defer.inlineCallbacks
def _expire_url_cache_data(self):
"""Clean up expired url cache content, media and thumbnails.
"""
# TODO: Delete from backup media store
now = self.clock.time_msec()
# First we delete expired url cache entries
media_ids = yield self.store.get_expired_url_cache(now)
removed_media = []
for media_id in media_ids:
fname = self.filepaths.url_cache_filepath(media_id)
try:
os.remove(fname)
except OSError as e:
# If the path doesn't exist, meh
if e.errno != errno.ENOENT:
logger.warn("Failed to remove media: %r: %s", media_id, e)
continue
removed_media.append(media_id)
try:
dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
for dir in dirs:
os.rmdir(dir)
except:
pass
yield self.store.delete_url_cache(removed_media)
if removed_media:
logger.info("Deleted %d entries from url cache", len(removed_media))
# Now we delete old images associated with the url cache.
# These may be cached for a bit on the client (i.e., they
# may have a room open with a preview url thing open).
# So we wait a couple of days before deleting, just in case.
expire_before = now - 2 * 24 * 60 * 60 * 1000
media_ids = yield self.store.get_url_cache_media_before(expire_before)
removed_media = []
for media_id in media_ids:
fname = self.filepaths.url_cache_filepath(media_id)
try:
os.remove(fname)
except OSError as e:
# If the path doesn't exist, meh
if e.errno != errno.ENOENT:
logger.warn("Failed to remove media: %r: %s", media_id, e)
continue
try:
dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
for dir in dirs:
os.rmdir(dir)
except:
pass
thumbnail_dir = self.filepaths.url_cache_thumbnail_directory(media_id)
try:
shutil.rmtree(thumbnail_dir)
except OSError as e:
# If the path doesn't exist, meh
if e.errno != errno.ENOENT:
logger.warn("Failed to remove media: %r: %s", media_id, e)
continue
removed_media.append(media_id)
try:
dirs = self.filepaths.url_cache_thumbnail_dirs_to_delete(media_id)
for dir in dirs:
os.rmdir(dir)
except:
pass
yield self.store.delete_url_cache_media(removed_media)
if removed_media:
logger.info("Deleted %d media from url cache", len(removed_media))
def decode_and_calc_og(body, media_uri, request_encoding=None): def decode_and_calc_og(body, media_uri, request_encoding=None):
from lxml import etree from lxml import etree

View File

@ -50,12 +50,16 @@ class Thumbnailer(object):
else: else:
return ((max_height * self.width) // self.height, max_height) return ((max_height * self.width) // self.height, max_height)
def scale(self, output_path, width, height, output_type): def scale(self, width, height, output_type):
"""Rescales the image to the given dimensions""" """Rescales the image to the given dimensions.
scaled = self.image.resize((width, height), Image.ANTIALIAS)
return self.save_image(scaled, output_type, output_path)
def crop(self, output_path, width, height, output_type): Returns:
BytesIO: the bytes of the encoded image ready to be written to disk
"""
scaled = self.image.resize((width, height), Image.ANTIALIAS)
return self._encode_image(scaled, output_type)
def crop(self, width, height, output_type):
"""Rescales and crops the image to the given dimensions preserving """Rescales and crops the image to the given dimensions preserving
aspect:: aspect::
(w_in / h_in) = (w_scaled / h_scaled) (w_in / h_in) = (w_scaled / h_scaled)
@ -65,6 +69,9 @@ class Thumbnailer(object):
Args: Args:
max_width: The largest possible width. max_width: The largest possible width.
max_height: The larget possible height. max_height: The larget possible height.
Returns:
BytesIO: the bytes of the encoded image ready to be written to disk
""" """
if width * self.height > height * self.width: if width * self.height > height * self.width:
scaled_height = (width * self.height) // self.width scaled_height = (width * self.height) // self.width
@ -82,13 +89,9 @@ class Thumbnailer(object):
crop_left = (scaled_width - width) // 2 crop_left = (scaled_width - width) // 2
crop_right = width + crop_left crop_right = width + crop_left
cropped = scaled_image.crop((crop_left, 0, crop_right, height)) cropped = scaled_image.crop((crop_left, 0, crop_right, height))
return self.save_image(cropped, output_type, output_path) return self._encode_image(cropped, output_type)
def save_image(self, output_image, output_type, output_path): def _encode_image(self, output_image, output_type):
output_bytes_io = BytesIO() output_bytes_io = BytesIO()
output_image.save(output_bytes_io, self.FORMATS[output_type], quality=80) output_image.save(output_bytes_io, self.FORMATS[output_type], quality=80)
output_bytes = output_bytes_io.getvalue() return output_bytes_io
with open(output_path, "wb") as output_file:
output_file.write(output_bytes)
logger.info("Stored thumbnail in file %r", output_path)
return len(output_bytes)

View File

@ -93,7 +93,7 @@ class UploadResource(Resource):
# TODO(markjh): parse content-dispostion # TODO(markjh): parse content-dispostion
content_uri = yield self.media_repo.create_content( content_uri = yield self.media_repo.create_content(
media_type, upload_name, request.content.read(), media_type, upload_name, request.content,
content_length, requester.user content_length, requester.user
) )

View File

@ -31,6 +31,7 @@ from synapse.appservice.api import ApplicationServiceApi
from synapse.appservice.scheduler import ApplicationServiceScheduler from synapse.appservice.scheduler import ApplicationServiceScheduler
from synapse.crypto.keyring import Keyring from synapse.crypto.keyring import Keyring
from synapse.events.builder import EventBuilderFactory from synapse.events.builder import EventBuilderFactory
from synapse.events.spamcheck import SpamChecker
from synapse.federation import initialize_http_replication from synapse.federation import initialize_http_replication
from synapse.federation.send_queue import FederationRemoteSendQueue from synapse.federation.send_queue import FederationRemoteSendQueue
from synapse.federation.transport.client import TransportLayerClient from synapse.federation.transport.client import TransportLayerClient
@ -50,6 +51,10 @@ from synapse.handlers.initial_sync import InitialSyncHandler
from synapse.handlers.receipts import ReceiptsHandler from synapse.handlers.receipts import ReceiptsHandler
from synapse.handlers.read_marker import ReadMarkerHandler from synapse.handlers.read_marker import ReadMarkerHandler
from synapse.handlers.user_directory import UserDirectoyHandler from synapse.handlers.user_directory import UserDirectoyHandler
from synapse.handlers.groups_local import GroupsLocalHandler
from synapse.handlers.profile import ProfileHandler
from synapse.groups.groups_server import GroupsServerHandler
from synapse.groups.attestations import GroupAttestionRenewer, GroupAttestationSigning
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.notifier import Notifier from synapse.notifier import Notifier
@ -111,6 +116,7 @@ class HomeServer(object):
'application_service_scheduler', 'application_service_scheduler',
'application_service_handler', 'application_service_handler',
'device_message_handler', 'device_message_handler',
'profile_handler',
'notifier', 'notifier',
'distributor', 'distributor',
'client_resource', 'client_resource',
@ -139,6 +145,11 @@ class HomeServer(object):
'read_marker_handler', 'read_marker_handler',
'action_generator', 'action_generator',
'user_directory_handler', 'user_directory_handler',
'groups_local_handler',
'groups_server_handler',
'groups_attestation_signing',
'groups_attestation_renewer',
'spam_checker',
] ]
def __init__(self, hostname, **kwargs): def __init__(self, hostname, **kwargs):
@ -251,6 +262,9 @@ class HomeServer(object):
def build_initial_sync_handler(self): def build_initial_sync_handler(self):
return InitialSyncHandler(self) return InitialSyncHandler(self)
def build_profile_handler(self):
return ProfileHandler(self)
def build_event_sources(self): def build_event_sources(self):
return EventSources(self) return EventSources(self)
@ -309,6 +323,21 @@ class HomeServer(object):
def build_user_directory_handler(self): def build_user_directory_handler(self):
return UserDirectoyHandler(self) return UserDirectoyHandler(self)
def build_groups_local_handler(self):
return GroupsLocalHandler(self)
def build_groups_server_handler(self):
return GroupsServerHandler(self)
def build_groups_attestation_signing(self):
return GroupAttestationSigning(self)
def build_groups_attestation_renewer(self):
return GroupAttestionRenewer(self)
def build_spam_checker(self):
return SpamChecker(self)
def remove_pusher(self, app_id, push_key, user_id): def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)

View File

@ -1,4 +1,6 @@
import synapse.api.auth import synapse.api.auth
import synapse.federation.transaction_queue
import synapse.federation.transport.client
import synapse.handlers import synapse.handlers
import synapse.handlers.auth import synapse.handlers.auth
import synapse.handlers.device import synapse.handlers.device
@ -27,3 +29,9 @@ class HomeServer(object):
def get_state_handler(self) -> synapse.state.StateHandler: def get_state_handler(self) -> synapse.state.StateHandler:
pass pass
def get_federation_sender(self) -> synapse.federation.transaction_queue.TransactionQueue:
pass
def get_federation_transport_client(self) -> synapse.federation.transport.client.TransportLayerClient:
pass

View File

@ -288,6 +288,9 @@ class StateHandler(object):
""" """
logger.debug("resolve_state_groups event_ids %s", event_ids) logger.debug("resolve_state_groups event_ids %s", event_ids)
# map from state group id to the state in that state group (where
# 'state' is a map from state key to event id)
# dict[int, dict[(str, str), str]]
state_groups_ids = yield self.store.get_state_groups_ids( state_groups_ids = yield self.store.get_state_groups_ids(
room_id, event_ids room_id, event_ids
) )
@ -320,11 +323,15 @@ class StateHandler(object):
"Resolving state for %s with %d groups", room_id, len(state_groups_ids) "Resolving state for %s with %d groups", room_id, len(state_groups_ids)
) )
# build a map from state key to the event_ids which set that state.
# dict[(str, str), set[str])
state = {} state = {}
for st in state_groups_ids.values(): for st in state_groups_ids.values():
for key, e_id in st.items(): for key, e_id in st.items():
state.setdefault(key, set()).add(e_id) state.setdefault(key, set()).add(e_id)
# build a map from state key to the event_ids which set that state,
# including only those where there are state keys in conflict.
conflicted_state = { conflicted_state = {
k: list(v) k: list(v)
for k, v in state.items() for k, v in state.items()
@ -494,8 +501,14 @@ def _resolve_with_state_fac(unconflicted_state, conflicted_state,
logger.info("Asking for %d conflicted events", len(needed_events)) logger.info("Asking for %d conflicted events", len(needed_events))
# dict[str, FrozenEvent]: a map from state event id to event. Only includes
# the state events which are in conflict.
state_map = yield state_map_factory(needed_events) state_map = yield state_map_factory(needed_events)
# get the ids of the auth events which allow us to authenticate the
# conflicted state, picking only from the unconflicting state.
#
# dict[(str, str), str]: a map from state key to event id
auth_events = _create_auth_events_from_maps( auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map unconflicted_state, conflicted_state, state_map
) )

View File

@ -37,7 +37,7 @@ from .media_repository import MediaRepositoryStore
from .rejections import RejectionsStore from .rejections import RejectionsStore
from .event_push_actions import EventPushActionsStore from .event_push_actions import EventPushActionsStore
from .deviceinbox import DeviceInboxStore from .deviceinbox import DeviceInboxStore
from .group_server import GroupServerStore
from .state import StateStore from .state import StateStore
from .signatures import SignatureStore from .signatures import SignatureStore
from .filtering import FilteringStore from .filtering import FilteringStore
@ -88,6 +88,7 @@ class DataStore(RoomMemberStore, RoomStore,
DeviceStore, DeviceStore,
DeviceInboxStore, DeviceInboxStore,
UserDirectoryStore, UserDirectoryStore,
GroupServerStore,
): ):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
@ -135,6 +136,9 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "pushers", "id", db_conn, "pushers", "id",
extra_tables=[("deleted_pushers", "stream_id")], extra_tables=[("deleted_pushers", "stream_id")],
) )
self._group_updates_id_gen = StreamIdGenerator(
db_conn, "local_group_updates", "stream_id",
)
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = StreamIdGenerator( self._cache_id_gen = StreamIdGenerator(
@ -235,6 +239,18 @@ class DataStore(RoomMemberStore, RoomStore,
prefilled_cache=curr_state_delta_prefill, prefilled_cache=curr_state_delta_prefill,
) )
_group_updates_prefill, min_group_updates_id = self._get_cache_dict(
db_conn, "local_group_updates",
entity_column="user_id",
stream_column="stream_id",
max_value=self._group_updates_id_gen.get_current_token(),
limit=1000,
)
self._group_updates_stream_cache = StreamChangeCache(
"_group_updates_stream_cache", min_group_updates_id,
prefilled_cache=_group_updates_prefill,
)
cur = LoggingTransaction( cur = LoggingTransaction(
db_conn.cursor(), db_conn.cursor(),
name="_find_stream_orderings_for_times_txn", name="_find_stream_orderings_for_times_txn",

View File

@ -743,6 +743,33 @@ class SQLBaseStore(object):
txn.execute(sql, values) txn.execute(sql, values)
return cls.cursor_to_dict(txn) return cls.cursor_to_dict(txn)
def _simple_update(self, table, keyvalues, updatevalues, desc):
return self.runInteraction(
desc,
self._simple_update_txn,
table, keyvalues, updatevalues,
)
@staticmethod
def _simple_update_txn(txn, table, keyvalues, updatevalues):
if keyvalues:
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
else:
where = ""
update_sql = "UPDATE %s SET %s %s" % (
table,
", ".join("%s = ?" % (k,) for k in updatevalues),
where,
)
txn.execute(
update_sql,
updatevalues.values() + keyvalues.values()
)
return txn.rowcount
def _simple_update_one(self, table, keyvalues, updatevalues, def _simple_update_one(self, table, keyvalues, updatevalues,
desc="_simple_update_one"): desc="_simple_update_one"):
"""Executes an UPDATE query on the named table, setting new values for """Executes an UPDATE query on the named table, setting new values for
@ -768,27 +795,13 @@ class SQLBaseStore(object):
table, keyvalues, updatevalues, table, keyvalues, updatevalues,
) )
@staticmethod @classmethod
def _simple_update_one_txn(txn, table, keyvalues, updatevalues): def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
if keyvalues: rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues)
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
else:
where = ""
update_sql = "UPDATE %s SET %s %s" % ( if rowcount == 0:
table,
", ".join("%s = ?" % (k,) for k in updatevalues),
where,
)
txn.execute(
update_sql,
updatevalues.values() + keyvalues.values()
)
if txn.rowcount == 0:
raise StoreError(404, "No row found") raise StoreError(404, "No row found")
if txn.rowcount > 1: if rowcount > 1:
raise StoreError(500, "More than one row matched") raise StoreError(500, "More than one row matched")
@staticmethod @staticmethod

View File

@ -21,7 +21,7 @@ from synapse.events.utils import prune_event
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import ( from synapse.util.logcontext import (
preserve_fn, PreserveLoggingContext, preserve_context_over_deferred preserve_fn, PreserveLoggingContext, make_deferred_yieldable
) )
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -88,13 +88,23 @@ class _EventPeristenceQueue(object):
def add_to_queue(self, room_id, events_and_contexts, backfilled): def add_to_queue(self, room_id, events_and_contexts, backfilled):
"""Add events to the queue, with the given persist_event options. """Add events to the queue, with the given persist_event options.
NB: due to the normal usage pattern of this method, it does *not*
follow the synapse logcontext rules, and leaves the logcontext in
place whether or not the returned deferred is ready.
Args: Args:
room_id (str): room_id (str):
events_and_contexts (list[(EventBase, EventContext)]): events_and_contexts (list[(EventBase, EventContext)]):
backfilled (bool): backfilled (bool):
Returns:
defer.Deferred: a deferred which will resolve once the events are
persisted. Runs its callbacks *without* a logcontext.
""" """
queue = self._event_persist_queues.setdefault(room_id, deque()) queue = self._event_persist_queues.setdefault(room_id, deque())
if queue: if queue:
# if the last item in the queue has the same `backfilled` setting,
# we can just add these new events to that item.
end_item = queue[-1] end_item = queue[-1]
if end_item.backfilled == backfilled: if end_item.backfilled == backfilled:
end_item.events_and_contexts.extend(events_and_contexts) end_item.events_and_contexts.extend(events_and_contexts)
@ -113,11 +123,11 @@ class _EventPeristenceQueue(object):
def handle_queue(self, room_id, per_item_callback): def handle_queue(self, room_id, per_item_callback):
"""Attempts to handle the queue for a room if not already being handled. """Attempts to handle the queue for a room if not already being handled.
The given callback will be invoked with for each item in the queue,1 The given callback will be invoked with for each item in the queue,
of type _EventPersistQueueItem. The per_item_callback will continuously of type _EventPersistQueueItem. The per_item_callback will continuously
be called with new items, unless the queue becomnes empty. The return be called with new items, unless the queue becomnes empty. The return
value of the function will be given to the deferreds waiting on the item, value of the function will be given to the deferreds waiting on the item,
exceptions will be passed to the deferres as well. exceptions will be passed to the deferreds as well.
This function should therefore be called whenever anything is added This function should therefore be called whenever anything is added
to the queue. to the queue.
@ -233,7 +243,7 @@ class EventsStore(SQLBaseStore):
deferreds = [] deferreds = []
for room_id, evs_ctxs in partitioned.iteritems(): for room_id, evs_ctxs in partitioned.iteritems():
d = preserve_fn(self._event_persist_queue.add_to_queue)( d = self._event_persist_queue.add_to_queue(
room_id, evs_ctxs, room_id, evs_ctxs,
backfilled=backfilled, backfilled=backfilled,
) )
@ -242,7 +252,7 @@ class EventsStore(SQLBaseStore):
for room_id in partitioned: for room_id in partitioned:
self._maybe_start_persisting(room_id) self._maybe_start_persisting(room_id)
return preserve_context_over_deferred( return make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True) defer.gatherResults(deferreds, consumeErrors=True)
) )
@ -267,7 +277,7 @@ class EventsStore(SQLBaseStore):
self._maybe_start_persisting(event.room_id) self._maybe_start_persisting(event.room_id)
yield preserve_context_over_deferred(deferred) yield make_deferred_yieldable(deferred)
max_persisted_id = yield self._stream_id_gen.get_current_token() max_persisted_id = yield self._stream_id_gen.get_current_token()
defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id)) defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id))
@ -784,6 +794,9 @@ class EventsStore(SQLBaseStore):
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(
txn, self.is_host_joined, (room_id, host) txn, self.is_host_joined, (room_id, host)
) )
self._invalidate_cache_and_stream(
txn, self.was_host_joined, (room_id, host)
)
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(
txn, self.get_users_in_room, (room_id,) txn, self.get_users_in_room, (room_id,)
@ -1523,7 +1536,7 @@ class EventsStore(SQLBaseStore):
if not allow_rejected: if not allow_rejected:
rows[:] = [r for r in rows if not r["rejects"]] rows[:] = [r for r in rows if not r["rejects"]]
res = yield preserve_context_over_deferred(defer.gatherResults( res = yield make_deferred_yieldable(defer.gatherResults(
[ [
preserve_fn(self._get_event_from_row)( preserve_fn(self._get_event_from_row)(
row["internal_metadata"], row["json"], row["redacts"], row["internal_metadata"], row["json"], row["redacts"],

File diff suppressed because it is too large Load Diff

View File

@ -62,7 +62,7 @@ class MediaRepositoryStore(SQLBaseStore):
def get_url_cache_txn(txn): def get_url_cache_txn(txn):
# get the most recently cached result (relative to the given ts) # get the most recently cached result (relative to the given ts)
sql = ( sql = (
"SELECT response_code, etag, expires, og, media_id, download_ts" "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
" FROM local_media_repository_url_cache" " FROM local_media_repository_url_cache"
" WHERE url = ? AND download_ts <= ?" " WHERE url = ? AND download_ts <= ?"
" ORDER BY download_ts DESC LIMIT 1" " ORDER BY download_ts DESC LIMIT 1"
@ -74,7 +74,7 @@ class MediaRepositoryStore(SQLBaseStore):
# ...or if we've requested a timestamp older than the oldest # ...or if we've requested a timestamp older than the oldest
# copy in the cache, return the oldest copy (if any) # copy in the cache, return the oldest copy (if any)
sql = ( sql = (
"SELECT response_code, etag, expires, og, media_id, download_ts" "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
" FROM local_media_repository_url_cache" " FROM local_media_repository_url_cache"
" WHERE url = ? AND download_ts > ?" " WHERE url = ? AND download_ts > ?"
" ORDER BY download_ts ASC LIMIT 1" " ORDER BY download_ts ASC LIMIT 1"
@ -86,14 +86,14 @@ class MediaRepositoryStore(SQLBaseStore):
return None return None
return dict(zip(( return dict(zip((
'response_code', 'etag', 'expires', 'og', 'media_id', 'download_ts' 'response_code', 'etag', 'expires_ts', 'og', 'media_id', 'download_ts'
), row)) ), row))
return self.runInteraction( return self.runInteraction(
"get_url_cache", get_url_cache_txn "get_url_cache", get_url_cache_txn
) )
def store_url_cache(self, url, response_code, etag, expires, og, media_id, def store_url_cache(self, url, response_code, etag, expires_ts, og, media_id,
download_ts): download_ts):
return self._simple_insert( return self._simple_insert(
"local_media_repository_url_cache", "local_media_repository_url_cache",
@ -101,7 +101,7 @@ class MediaRepositoryStore(SQLBaseStore):
"url": url, "url": url,
"response_code": response_code, "response_code": response_code,
"etag": etag, "etag": etag,
"expires": expires, "expires_ts": expires_ts,
"og": og, "og": og,
"media_id": media_id, "media_id": media_id,
"download_ts": download_ts, "download_ts": download_ts,
@ -238,3 +238,64 @@ class MediaRepositoryStore(SQLBaseStore):
}, },
) )
return self.runInteraction("delete_remote_media", delete_remote_media_txn) return self.runInteraction("delete_remote_media", delete_remote_media_txn)
def get_expired_url_cache(self, now_ts):
sql = (
"SELECT media_id FROM local_media_repository_url_cache"
" WHERE expires_ts < ?"
" ORDER BY expires_ts ASC"
" LIMIT 500"
)
def _get_expired_url_cache_txn(txn):
txn.execute(sql, (now_ts,))
return [row[0] for row in txn]
return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn)
def delete_url_cache(self, media_ids):
sql = (
"DELETE FROM local_media_repository_url_cache"
" WHERE media_id = ?"
)
def _delete_url_cache_txn(txn):
txn.executemany(sql, [(media_id,) for media_id in media_ids])
return self.runInteraction("delete_url_cache", _delete_url_cache_txn)
def get_url_cache_media_before(self, before_ts):
sql = (
"SELECT media_id FROM local_media_repository"
" WHERE created_ts < ? AND url_cache IS NOT NULL"
" ORDER BY created_ts ASC"
" LIMIT 500"
)
def _get_url_cache_media_before_txn(txn):
txn.execute(sql, (before_ts,))
return [row[0] for row in txn]
return self.runInteraction(
"get_url_cache_media_before", _get_url_cache_media_before_txn,
)
def delete_url_cache_media(self, media_ids):
def _delete_url_cache_media_txn(txn):
sql = (
"DELETE FROM local_media_repository"
" WHERE media_id = ?"
)
txn.executemany(sql, [(media_id,) for media_id in media_ids])
sql = (
"DELETE FROM local_media_repository_thumbnails"
" WHERE media_id = ?"
)
txn.executemany(sql, [(media_id,) for media_id in media_ids])
return self.runInteraction(
"delete_url_cache_media", _delete_url_cache_media_txn,
)

View File

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 43 SCHEMA_VERSION = 45
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from ._base import SQLBaseStore from ._base import SQLBaseStore
@ -55,3 +57,99 @@ class ProfileStore(SQLBaseStore):
updatevalues={"avatar_url": new_avatar_url}, updatevalues={"avatar_url": new_avatar_url},
desc="set_profile_avatar_url", desc="set_profile_avatar_url",
) )
def get_from_remote_profile_cache(self, user_id):
return self._simple_select_one(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
retcols=("displayname", "avatar_url",),
allow_none=True,
desc="get_from_remote_profile_cache",
)
def add_remote_profile_cache(self, user_id, displayname, avatar_url):
"""Ensure we are caching the remote user's profiles.
This should only be called when `is_subscribed_remote_profile_for_user`
would return true for the user.
"""
return self._simple_upsert(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
values={
"displayname": displayname,
"avatar_url": avatar_url,
"last_check": self._clock.time_msec(),
},
desc="add_remote_profile_cache",
)
def update_remote_profile_cache(self, user_id, displayname, avatar_url):
return self._simple_update(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
values={
"displayname": displayname,
"avatar_url": avatar_url,
"last_check": self._clock.time_msec(),
},
desc="update_remote_profile_cache",
)
@defer.inlineCallbacks
def maybe_delete_remote_profile_cache(self, user_id):
"""Check if we still care about the remote user's profile, and if we
don't then remove their profile from the cache
"""
subscribed = yield self.is_subscribed_remote_profile_for_user(user_id)
if not subscribed:
yield self._simple_delete(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
desc="delete_remote_profile_cache",
)
def get_remote_profile_cache_entries_that_expire(self, last_checked):
"""Get all users who haven't been checked since `last_checked`
"""
def _get_remote_profile_cache_entries_that_expire_txn(txn):
sql = """
SELECT user_id, displayname, avatar_url
FROM remote_profile_cache
WHERE last_check < ?
"""
txn.execute(sql, (last_checked,))
return self.cursor_to_dict(txn)
return self.runInteraction(
"get_remote_profile_cache_entries_that_expire",
_get_remote_profile_cache_entries_that_expire_txn,
)
@defer.inlineCallbacks
def is_subscribed_remote_profile_for_user(self, user_id):
"""Check whether we are interested in a remote user's profile.
"""
res = yield self._simple_select_one_onecol(
table="group_users",
keyvalues={"user_id": user_id},
retcol="user_id",
allow_none=True,
desc="should_update_remote_profile_cache_for_user",
)
if res:
defer.returnValue(True)
res = yield self._simple_select_one_onecol(
table="group_invites",
keyvalues={"user_id": user_id},
retcol="user_id",
allow_none=True,
desc="should_update_remote_profile_cache_for_user",
)
if res:
defer.returnValue(True)

View File

@ -533,6 +533,46 @@ class RoomMemberStore(SQLBaseStore):
defer.returnValue(True) defer.returnValue(True)
@cachedInlineCallbacks()
def was_host_joined(self, room_id, host):
"""Check whether the server is or ever was in the room.
Args:
room_id (str)
host (str)
Returns:
Deferred: Resolves to True if the host is/was in the room, otherwise
False.
"""
if '%' in host or '_' in host:
raise Exception("Invalid host name")
sql = """
SELECT user_id FROM room_memberships
WHERE room_id = ?
AND user_id LIKE ?
AND membership = 'join'
LIMIT 1
"""
# We do need to be careful to ensure that host doesn't have any wild cards
# in it, but we checked above for known ones and we'll check below that
# the returned user actually has the correct domain.
like_clause = "%:" + host
rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause)
if not rows:
defer.returnValue(False)
user_id = rows[0][0]
if get_domain_from_id(user_id) != host:
# This can only happen if the host name has something funky in it
raise Exception("Invalid host name")
defer.returnValue(True)
def get_joined_hosts(self, room_id, state_entry): def get_joined_hosts(self, room_id, state_entry):
state_group = state_entry.state_group state_group = state_entry.state_group
if not state_group: if not state_group:

View File

@ -0,0 +1,38 @@
/* Copyright 2017 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL;
-- we need to change `expires` to `expires_ts` so that we can index on it. SQLite doesn't support
-- indices on expressions until 3.9.
CREATE TABLE local_media_repository_url_cache_new(
url TEXT,
response_code INTEGER,
etag TEXT,
expires_ts BIGINT,
og TEXT,
media_id TEXT,
download_ts BIGINT
);
INSERT INTO local_media_repository_url_cache_new
SELECT url, response_code, etag, expires + download_ts, og, media_id, download_ts FROM local_media_repository_url_cache;
DROP TABLE local_media_repository_url_cache;
ALTER TABLE local_media_repository_url_cache_new RENAME TO local_media_repository_url_cache;
CREATE INDEX local_media_repository_url_cache_expires_idx ON local_media_repository_url_cache(expires_ts);
CREATE INDEX local_media_repository_url_cache_by_url_download_ts ON local_media_repository_url_cache(url, download_ts);
CREATE INDEX local_media_repository_url_cache_media_idx ON local_media_repository_url_cache(media_id);

View File

@ -0,0 +1,167 @@
/* Copyright 2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE groups (
group_id TEXT NOT NULL,
name TEXT, -- the display name of the room
avatar_url TEXT,
short_description TEXT,
long_description TEXT
);
CREATE UNIQUE INDEX groups_idx ON groups(group_id);
-- list of users the group server thinks are joined
CREATE TABLE group_users (
group_id TEXT NOT NULL,
user_id TEXT NOT NULL,
is_admin BOOLEAN NOT NULL,
is_public BOOLEAN NOT NULL -- whether the users membership can be seen by everyone
);
CREATE INDEX groups_users_g_idx ON group_users(group_id, user_id);
CREATE INDEX groups_users_u_idx ON group_users(user_id);
-- list of users the group server thinks are invited
CREATE TABLE group_invites (
group_id TEXT NOT NULL,
user_id TEXT NOT NULL
);
CREATE INDEX groups_invites_g_idx ON group_invites(group_id, user_id);
CREATE INDEX groups_invites_u_idx ON group_invites(user_id);
CREATE TABLE group_rooms (
group_id TEXT NOT NULL,
room_id TEXT NOT NULL,
is_public BOOLEAN NOT NULL -- whether the room can be seen by everyone
);
CREATE UNIQUE INDEX groups_rooms_g_idx ON group_rooms(group_id, room_id);
CREATE INDEX groups_rooms_r_idx ON group_rooms(room_id);
-- Rooms to include in the summary
CREATE TABLE group_summary_rooms (
group_id TEXT NOT NULL,
room_id TEXT NOT NULL,
category_id TEXT NOT NULL,
room_order BIGINT NOT NULL,
is_public BOOLEAN NOT NULL, -- whether the room should be show to everyone
UNIQUE (group_id, category_id, room_id, room_order),
CHECK (room_order > 0)
);
CREATE UNIQUE INDEX group_summary_rooms_g_idx ON group_summary_rooms(group_id, room_id, category_id);
-- Categories to include in the summary
CREATE TABLE group_summary_room_categories (
group_id TEXT NOT NULL,
category_id TEXT NOT NULL,
cat_order BIGINT NOT NULL,
UNIQUE (group_id, category_id, cat_order),
CHECK (cat_order > 0)
);
-- The categories in the group
CREATE TABLE group_room_categories (
group_id TEXT NOT NULL,
category_id TEXT NOT NULL,
profile TEXT NOT NULL,
is_public BOOLEAN NOT NULL, -- whether the category should be show to everyone
UNIQUE (group_id, category_id)
);
-- The users to include in the group summary
CREATE TABLE group_summary_users (
group_id TEXT NOT NULL,
user_id TEXT NOT NULL,
role_id TEXT NOT NULL,
user_order BIGINT NOT NULL,
is_public BOOLEAN NOT NULL -- whether the user should be show to everyone
);
CREATE INDEX group_summary_users_g_idx ON group_summary_users(group_id);
-- The roles to include in the group summary
CREATE TABLE group_summary_roles (
group_id TEXT NOT NULL,
role_id TEXT NOT NULL,
role_order BIGINT NOT NULL,
UNIQUE (group_id, role_id, role_order),
CHECK (role_order > 0)
);
-- The roles in a groups
CREATE TABLE group_roles (
group_id TEXT NOT NULL,
role_id TEXT NOT NULL,
profile TEXT NOT NULL,
is_public BOOLEAN NOT NULL, -- whether the role should be show to everyone
UNIQUE (group_id, role_id)
);
-- List of attestations we've given out and need to renew
CREATE TABLE group_attestations_renewals (
group_id TEXT NOT NULL,
user_id TEXT NOT NULL,
valid_until_ms BIGINT NOT NULL
);
CREATE INDEX group_attestations_renewals_g_idx ON group_attestations_renewals(group_id, user_id);
CREATE INDEX group_attestations_renewals_u_idx ON group_attestations_renewals(user_id);
CREATE INDEX group_attestations_renewals_v_idx ON group_attestations_renewals(valid_until_ms);
-- List of attestations we've received from remotes and are interested in.
CREATE TABLE group_attestations_remote (
group_id TEXT NOT NULL,
user_id TEXT NOT NULL,
valid_until_ms BIGINT NOT NULL,
attestation_json TEXT NOT NULL
);
CREATE INDEX group_attestations_remote_g_idx ON group_attestations_remote(group_id, user_id);
CREATE INDEX group_attestations_remote_u_idx ON group_attestations_remote(user_id);
CREATE INDEX group_attestations_remote_v_idx ON group_attestations_remote(valid_until_ms);
-- The group membership for the HS's users
CREATE TABLE local_group_membership (
group_id TEXT NOT NULL,
user_id TEXT NOT NULL,
is_admin BOOLEAN NOT NULL,
membership TEXT NOT NULL,
is_publicised BOOLEAN NOT NULL, -- if the user is publicising their membership
content TEXT NOT NULL
);
CREATE INDEX local_group_membership_u_idx ON local_group_membership(user_id, group_id);
CREATE INDEX local_group_membership_g_idx ON local_group_membership(group_id);
CREATE TABLE local_group_updates (
stream_id BIGINT NOT NULL,
group_id TEXT NOT NULL,
user_id TEXT NOT NULL,
type TEXT NOT NULL,
content TEXT NOT NULL
);

View File

@ -0,0 +1,28 @@
/* Copyright 2017 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- A subset of remote users whose profiles we have cached.
-- Whether a user is in this table or not is defined by the storage function
-- `is_subscribed_remote_profile_for_user`
CREATE TABLE remote_profile_cache (
user_id TEXT NOT NULL,
displayname TEXT,
avatar_url TEXT,
last_check BIGINT NOT NULL
);
CREATE UNIQUE INDEX remote_profile_cache_user_id ON remote_profile_cache(user_id);
CREATE INDEX remote_profile_cache_time ON remote_profile_cache(last_check);

View File

@ -45,6 +45,7 @@ class EventSources(object):
push_rules_key, _ = self.store.get_push_rules_stream_token() push_rules_key, _ = self.store.get_push_rules_stream_token()
to_device_key = self.store.get_to_device_stream_token() to_device_key = self.store.get_to_device_stream_token()
device_list_key = self.store.get_device_stream_token() device_list_key = self.store.get_device_stream_token()
groups_key = self.store.get_group_stream_token()
token = StreamToken( token = StreamToken(
room_key=( room_key=(
@ -65,6 +66,7 @@ class EventSources(object):
push_rules_key=push_rules_key, push_rules_key=push_rules_key,
to_device_key=to_device_key, to_device_key=to_device_key,
device_list_key=device_list_key, device_list_key=device_list_key,
groups_key=groups_key,
) )
defer.returnValue(token) defer.returnValue(token)
@ -73,6 +75,7 @@ class EventSources(object):
push_rules_key, _ = self.store.get_push_rules_stream_token() push_rules_key, _ = self.store.get_push_rules_stream_token()
to_device_key = self.store.get_to_device_stream_token() to_device_key = self.store.get_to_device_stream_token()
device_list_key = self.store.get_device_stream_token() device_list_key = self.store.get_device_stream_token()
groups_key = self.store.get_group_stream_token()
token = StreamToken( token = StreamToken(
room_key=( room_key=(
@ -93,5 +96,6 @@ class EventSources(object):
push_rules_key=push_rules_key, push_rules_key=push_rules_key,
to_device_key=to_device_key, to_device_key=to_device_key,
device_list_key=device_list_key, device_list_key=device_list_key,
groups_key=groups_key,
) )
defer.returnValue(token) defer.returnValue(token)

View File

@ -156,6 +156,11 @@ class EventID(DomainSpecificString):
SIGIL = "$" SIGIL = "$"
class GroupID(DomainSpecificString):
"""Structure representing a group ID."""
SIGIL = "+"
class StreamToken( class StreamToken(
namedtuple("Token", ( namedtuple("Token", (
"room_key", "room_key",
@ -166,6 +171,7 @@ class StreamToken(
"push_rules_key", "push_rules_key",
"to_device_key", "to_device_key",
"device_list_key", "device_list_key",
"groups_key",
)) ))
): ):
_SEPARATOR = "_" _SEPARATOR = "_"
@ -204,6 +210,7 @@ class StreamToken(
or (int(other.push_rules_key) < int(self.push_rules_key)) or (int(other.push_rules_key) < int(self.push_rules_key))
or (int(other.to_device_key) < int(self.to_device_key)) or (int(other.to_device_key) < int(self.to_device_key))
or (int(other.device_list_key) < int(self.device_list_key)) or (int(other.device_list_key) < int(self.device_list_key))
or (int(other.groups_key) < int(self.groups_key))
) )
def copy_and_advance(self, key, new_value): def copy_and_advance(self, key, new_value):

View File

@ -19,7 +19,7 @@ from twisted.internet import defer, reactor
from .logcontext import ( from .logcontext import (
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred, PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
) )
from synapse.util import unwrapFirstError from synapse.util import logcontext, unwrapFirstError
from contextlib import contextmanager from contextlib import contextmanager
@ -53,6 +53,11 @@ class ObservableDeferred(object):
Cancelling or otherwise resolving an observer will not affect the original Cancelling or otherwise resolving an observer will not affect the original
ObservableDeferred. ObservableDeferred.
NB that it does not attempt to do anything with logcontexts; in general
you should probably make_deferred_yieldable the deferreds
returned by `observe`, and ensure that the original deferred runs its
callbacks in the sentinel logcontext.
""" """
__slots__ = ["_deferred", "_observers", "_result"] __slots__ = ["_deferred", "_observers", "_result"]
@ -155,7 +160,7 @@ def concurrently_execute(func, args, limit):
except StopIteration: except StopIteration:
pass pass
return preserve_context_over_deferred(defer.gatherResults([ return logcontext.make_deferred_yieldable(defer.gatherResults([
preserve_fn(_concurrently_execute_inner)() preserve_fn(_concurrently_execute_inner)()
for _ in xrange(limit) for _ in xrange(limit)
], consumeErrors=True)).addErrback(unwrapFirstError) ], consumeErrors=True)).addErrback(unwrapFirstError)
@ -203,7 +208,26 @@ class Linearizer(object):
except: except:
logger.exception("Unexpected exception in Linearizer") logger.exception("Unexpected exception in Linearizer")
logger.info("Acquired linearizer lock %r for key %r", self.name, key) logger.info("Acquired linearizer lock %r for key %r", self.name,
key)
# if the code holding the lock completes synchronously, then it
# will recursively run the next claimant on the list. That can
# relatively rapidly lead to stack exhaustion. This is essentially
# the same problem as http://twistedmatrix.com/trac/ticket/9304.
#
# In order to break the cycle, we add a cheeky sleep(0) here to
# ensure that we fall back to the reactor between each iteration.
#
# (There's no particular need for it to happen before we return
# the context manager, but it needs to happen while we hold the
# lock, and the context manager's exit code must be synchronous,
# so actually this is the only sensible place.
yield run_on_reactor()
else:
logger.info("Acquired uncontended linearizer lock %r for key %r",
self.name, key)
@contextmanager @contextmanager
def _ctx_manager(): def _ctx_manager():
@ -211,7 +235,8 @@ class Linearizer(object):
yield yield
finally: finally:
logger.info("Releasing linearizer lock %r for key %r", self.name, key) logger.info("Releasing linearizer lock %r for key %r", self.name, key)
new_defer.callback(None) with PreserveLoggingContext():
new_defer.callback(None)
current_d = self.key_to_defer.get(key) current_d = self.key_to_defer.get(key)
if current_d is new_defer: if current_d is new_defer:
self.key_to_defer.pop(key, None) self.key_to_defer.pop(key, None)

View File

@ -0,0 +1,51 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import StringIO
import logging
import traceback
class LogFormatter(logging.Formatter):
"""Log formatter which gives more detail for exceptions
This is the same as the standard log formatter, except that when logging
exceptions [typically via log.foo("msg", exc_info=1)], it prints the
sequence that led up to the point at which the exception was caught.
(Normally only stack frames between the point the exception was raised and
where it was caught are logged).
"""
def __init__(self, *args, **kwargs):
super(LogFormatter, self).__init__(*args, **kwargs)
def formatException(self, ei):
sio = StringIO.StringIO()
(typ, val, tb) = ei
# log the stack above the exception capture point if possible, but
# check that we actually have an f_back attribute to work around
# https://twistedmatrix.com/trac/ticket/9305
if tb and hasattr(tb.tb_frame, 'f_back'):
sio.write("Capture point (most recent call last):\n")
traceback.print_stack(tb.tb_frame.f_back, None, sio)
traceback.print_exception(typ, val, tb, None, sio)
s = sio.getvalue()
sio.close()
if s[-1:] == "\n":
s = s[:-1]
return s

View File

@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from synapse.config._base import ConfigError
def load_module(provider):
""" Loads a module with its config
Take a dict with keys 'module' (the module name) and 'config'
(the config dict).
Returns
Tuple of (provider class, parsed config object)
"""
# We need to import the module, and then pick the class out of
# that, so we split based on the last dot.
module, clz = provider['module'].rsplit(".", 1)
module = importlib.import_module(module)
provider_class = getattr(module, clz)
try:
provider_config = provider_class.parse_config(provider["config"])
except Exception as e:
raise ConfigError(
"Failed to parse config for %r: %r" % (provider['module'], e)
)
return provider_class, provider_config

View File

@ -62,8 +62,6 @@ class ProfileTestCase(unittest.TestCase):
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0) self.ratelimiter.send_message.return_value = (True, 0)
hs.handlers = ProfileHandlers(hs)
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.frank = UserID.from_string("@1234ABCD:test") self.frank = UserID.from_string("@1234ABCD:test")
@ -72,7 +70,7 @@ class ProfileTestCase(unittest.TestCase):
yield self.store.create_profile(self.frank.localpart) yield self.store.create_profile(self.frank.localpart)
self.handler = hs.get_handlers().profile_handler self.handler = hs.get_profile_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_my_name(self): def test_get_my_name(self):

View File

@ -40,13 +40,14 @@ class RegistrationTestCase(unittest.TestCase):
self.hs = yield setup_test_homeserver( self.hs = yield setup_test_homeserver(
handlers=None, handlers=None,
http_client=None, http_client=None,
expire_access_token=True) expire_access_token=True,
profile_handler=Mock(),
)
self.macaroon_generator = Mock( self.macaroon_generator = Mock(
generate_access_token=Mock(return_value='secret')) generate_access_token=Mock(return_value='secret'))
self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator) self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
self.hs.handlers = RegistrationHandlers(self.hs) self.hs.handlers = RegistrationHandlers(self.hs)
self.handler = self.hs.get_handlers().registration_handler self.handler = self.hs.get_handlers().registration_handler
self.hs.get_handlers().profile_handler = Mock()
@defer.inlineCallbacks @defer.inlineCallbacks
def test_user_is_created_and_logged_in_if_doesnt_exist(self): def test_user_is_created_and_logged_in_if_doesnt_exist(self):

View File

@ -46,6 +46,7 @@ class ProfileTestCase(unittest.TestCase):
resource_for_client=self.mock_resource, resource_for_client=self.mock_resource,
federation=Mock(), federation=Mock(),
replication_layer=Mock(), replication_layer=Mock(),
profile_handler=self.mock_handler
) )
def _get_user_by_req(request=None, allow_guest=False): def _get_user_by_req(request=None, allow_guest=False):
@ -53,8 +54,6 @@ class ProfileTestCase(unittest.TestCase):
hs.get_v1auth().get_user_by_req = _get_user_by_req hs.get_v1auth().get_user_by_req = _get_user_by_req
hs.get_handlers().profile_handler = self.mock_handler
profile.register_servlets(hs, self.mock_resource) profile.register_servlets(hs, self.mock_resource)
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -1032,7 +1032,7 @@ class RoomMessageListTestCase(RestTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_topo_token_is_accepted(self): def test_topo_token_is_accepted(self):
token = "t1-0_0_0_0_0_0_0_0" token = "t1-0_0_0_0_0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get( (code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=%s" % "/rooms/%s/messages?access_token=x&from=%s" %
(self.room_id, token)) (self.room_id, token))
@ -1044,7 +1044,7 @@ class RoomMessageListTestCase(RestTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_stream_token_is_accepted_for_fwd_pagianation(self): def test_stream_token_is_accepted_for_fwd_pagianation(self):
token = "s0_0_0_0_0_0_0_0" token = "s0_0_0_0_0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get( (code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=%s" % "/rooms/%s/messages?access_token=x&from=%s" %
(self.room_id, token)) (self.room_id, token))

View File

@ -47,6 +47,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.hs.get_auth_handler = Mock(return_value=self.auth_handler) self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
self.hs.get_device_handler = Mock(return_value=self.device_handler) self.hs.get_device_handler = Mock(return_value=self.device_handler)
self.hs.config.enable_registration = True self.hs.config.enable_registration = True
self.hs.config.auto_join_rooms = []
# init the thing we're testing # init the thing we're testing
self.servlet = RegisterRestServlet(self.hs) self.servlet = RegisterRestServlet(self.hs)

View File

@ -1,76 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.constants import EventTypes
class EventInjector:
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
self.message_handler = hs.get_handlers().message_handler
self.event_builder_factory = hs.get_event_builder_factory()
@defer.inlineCallbacks
def create_room(self, room, user):
builder = self.event_builder_factory.new({
"type": EventTypes.Create,
"sender": user.to_string(),
"room_id": room.to_string(),
"content": {},
})
event, context = yield self.message_handler._create_new_client_event(
builder
)
yield self.store.persist_event(event, context)
@defer.inlineCallbacks
def inject_room_member(self, room, user, membership):
builder = self.event_builder_factory.new({
"type": EventTypes.Member,
"sender": user.to_string(),
"state_key": user.to_string(),
"room_id": room.to_string(),
"content": {"membership": membership},
})
event, context = yield self.message_handler._create_new_client_event(
builder
)
yield self.store.persist_event(event, context)
defer.returnValue(event)
@defer.inlineCallbacks
def inject_message(self, room, user, body):
builder = self.event_builder_factory.new({
"type": EventTypes.Message,
"sender": user.to_string(),
"state_key": user.to_string(),
"room_id": room.to_string(),
"content": {"body": body, "msgtype": u"message"},
})
event, context = yield self.message_handler._create_new_client_event(
builder
)
yield self.store.persist_event(event, context)

View File

@ -12,8 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.util import async, logcontext
from tests import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
@ -38,7 +37,28 @@ class LinearizerTestCase(unittest.TestCase):
with cm1: with cm1:
self.assertFalse(d2.called) self.assertFalse(d2.called)
self.assertTrue(d2.called)
with (yield d2): with (yield d2):
pass pass
def test_lots_of_queued_things(self):
# we have one slow thing, and lots of fast things queued up behind it.
# it should *not* explode the stack.
linearizer = Linearizer()
@defer.inlineCallbacks
def func(i, sleep=False):
with logcontext.LoggingContext("func(%s)" % i) as lc:
with (yield linearizer.queue("")):
self.assertEqual(
logcontext.LoggingContext.current_context(), lc)
if sleep:
yield async.sleep(0)
self.assertEqual(
logcontext.LoggingContext.current_context(), lc)
func(0, sleep=True)
for i in xrange(1, 100):
func(i)
return func(1000)

View File

@ -94,3 +94,41 @@ class LoggingContextTestCase(unittest.TestCase):
yield defer.succeed(None) yield defer.succeed(None)
return self._test_preserve_fn(nonblocking_function) return self._test_preserve_fn(nonblocking_function)
@defer.inlineCallbacks
def test_make_deferred_yieldable(self):
# a function which retuns an incomplete deferred, but doesn't follow
# the synapse rules.
def blocking_function():
d = defer.Deferred()
reactor.callLater(0, d.callback, None)
return d
sentinel_context = LoggingContext.current_context()
with LoggingContext() as context_one:
context_one.test_key = "one"
d1 = logcontext.make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable
self.assertIs(LoggingContext.current_context(), sentinel_context)
yield d1
# now it should be restored
self._check_test_key("one")
@defer.inlineCallbacks
def test_make_deferred_yieldable_on_non_deferred(self):
"""Check that make_deferred_yieldable does the right thing when its
argument isn't actually a deferred"""
with LoggingContext() as context_one:
context_one.test_key = "one"
d1 = logcontext.make_deferred_yieldable("bum")
self._check_test_key("one")
r = yield d1
self.assertEqual(r, "bum")
self._check_test_key("one")