mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-01-26 13:05:54 -05:00
Merge branch 'release-v0.24.0' of github.com:matrix-org/synapse
This commit is contained in:
commit
ffba978077
42
CHANGES.rst
42
CHANGES.rst
@ -1,3 +1,45 @@
|
||||
Changes in synapse v0.24.0 (2017-10-23)
|
||||
=======================================
|
||||
|
||||
No changes since v0.24.0-rc1
|
||||
|
||||
|
||||
Changes in synapse v0.24.0-rc1 (2017-10-19)
|
||||
===========================================
|
||||
|
||||
Features:
|
||||
|
||||
* Add Group Server (PR #2352, #2363, #2374, #2377, #2378, #2382, #2410, #2426,
|
||||
#2430, #2454, #2471, #2472, #2544)
|
||||
* Add support for channel notifications (PR #2501)
|
||||
* Add basic implementation of backup media store (PR #2538)
|
||||
* Add config option to auto-join new users to rooms (PR #2545)
|
||||
|
||||
|
||||
Changes:
|
||||
|
||||
* Make the spam checker a module (PR #2474)
|
||||
* Delete expired url cache data (PR #2478)
|
||||
* Ignore incoming events for rooms that we have left (PR #2490)
|
||||
* Allow spam checker to reject invites too (PR #2492)
|
||||
* Add room creation checks to spam checker (PR #2495)
|
||||
* Spam checking: add the invitee to user_may_invite (PR #2502)
|
||||
* Process events from federation for different rooms in parallel (PR #2520)
|
||||
* Allow error strings from spam checker (PR #2531)
|
||||
* Improve error handling for missing files in config (PR #2551)
|
||||
|
||||
|
||||
Bug fixes:
|
||||
|
||||
* Fix handling SERVFAILs when doing AAAA lookups for federation (PR #2477)
|
||||
* Fix incompatibility with newer versions of ujson (PR #2483) Thanks to
|
||||
@jeremycline!
|
||||
* Fix notification keywords that start/end with non-word chars (PR #2500)
|
||||
* Fix stack overflow and logcontexts from linearizer (PR #2532)
|
||||
* Fix 500 error when fields missing from power_levels event (PR #2552)
|
||||
* Fix 500 error when we get an error handling a PDU (PR #2553)
|
||||
|
||||
|
||||
Changes in synapse v0.23.1 (2017-10-02)
|
||||
=======================================
|
||||
|
||||
|
@ -50,7 +50,7 @@ master_doc = 'index'
|
||||
|
||||
# General information about the project.
|
||||
project = u'Synapse'
|
||||
copyright = u'2014, TNG'
|
||||
copyright = u'Copyright 2014-2017 OpenMarket Ltd, 2017 Vector Creations Ltd, 2017 New Vector Ltd'
|
||||
|
||||
# The version info for the project you're documenting, acts as replacement for
|
||||
# |version| and |release|, also used in various other places throughout the
|
||||
|
@ -376,10 +376,13 @@ class Porter(object):
|
||||
" VALUES (?,?,?,?,to_tsvector('english', ?),?,?)"
|
||||
)
|
||||
|
||||
rows_dict = [
|
||||
dict(zip(headers, row))
|
||||
for row in rows
|
||||
]
|
||||
rows_dict = []
|
||||
for row in rows:
|
||||
d = dict(zip(headers, row))
|
||||
if "\0" in d['value']:
|
||||
logger.warn('dropping search row %s', d)
|
||||
else:
|
||||
rows_dict.append(d)
|
||||
|
||||
txn.executemany(sql, [
|
||||
(
|
||||
|
@ -16,4 +16,4 @@
|
||||
""" This is a reference implementation of a Matrix home server.
|
||||
"""
|
||||
|
||||
__version__ = "0.23.1"
|
||||
__version__ = "0.24.0"
|
||||
|
@ -40,6 +40,7 @@ from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
|
||||
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||
from synapse.replication.slave.storage.room import RoomStore
|
||||
from synapse.replication.slave.storage.groups import SlavedGroupServerStore
|
||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||
from synapse.rest.client.v1 import events
|
||||
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
|
||||
@ -69,6 +70,7 @@ class SynchrotronSlavedStore(
|
||||
SlavedRegistrationStore,
|
||||
SlavedFilteringStore,
|
||||
SlavedPresenceStore,
|
||||
SlavedGroupServerStore,
|
||||
SlavedDeviceInboxStore,
|
||||
SlavedDeviceStore,
|
||||
SlavedClientIpStore,
|
||||
@ -403,6 +405,10 @@ class SyncReplicationHandler(ReplicationClientHandler):
|
||||
)
|
||||
elif stream_name == "presence":
|
||||
yield self.presence_handler.process_replication_rows(token, rows)
|
||||
elif stream_name == "receipts":
|
||||
self.notifier.on_new_event(
|
||||
"groups_key", token, users=[row.user_id for row in rows],
|
||||
)
|
||||
|
||||
|
||||
def start(config_options):
|
||||
|
@ -81,22 +81,38 @@ class Config(object):
|
||||
def abspath(file_path):
|
||||
return os.path.abspath(file_path) if file_path else file_path
|
||||
|
||||
@classmethod
|
||||
def path_exists(cls, file_path):
|
||||
"""Check if a file exists
|
||||
|
||||
Unlike os.path.exists, this throws an exception if there is an error
|
||||
checking if the file exists (for example, if there is a perms error on
|
||||
the parent dir).
|
||||
|
||||
Returns:
|
||||
bool: True if the file exists; False if not.
|
||||
"""
|
||||
try:
|
||||
os.stat(file_path)
|
||||
return True
|
||||
except OSError as e:
|
||||
if e.errno != errno.ENOENT:
|
||||
raise e
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def check_file(cls, file_path, config_name):
|
||||
if file_path is None:
|
||||
raise ConfigError(
|
||||
"Missing config for %s."
|
||||
" You must specify a path for the config file. You can "
|
||||
"do this with the -c or --config-path option. "
|
||||
"Adding --generate-config along with --server-name "
|
||||
"<server name> will generate a config file at the given path."
|
||||
% (config_name,)
|
||||
)
|
||||
if not os.path.exists(file_path):
|
||||
try:
|
||||
os.stat(file_path)
|
||||
except OSError as e:
|
||||
raise ConfigError(
|
||||
"File %s config for %s doesn't exist."
|
||||
" Try running again with --generate-config"
|
||||
% (file_path, config_name,)
|
||||
"Error accessing file '%s' (config for %s): %s"
|
||||
% (file_path, config_name, e.strerror)
|
||||
)
|
||||
return cls.abspath(file_path)
|
||||
|
||||
@ -248,7 +264,7 @@ class Config(object):
|
||||
" -c CONFIG-FILE\""
|
||||
)
|
||||
(config_path,) = config_files
|
||||
if not os.path.exists(config_path):
|
||||
if not cls.path_exists(config_path):
|
||||
if config_args.keys_directory:
|
||||
config_dir_path = config_args.keys_directory
|
||||
else:
|
||||
@ -261,7 +277,7 @@ class Config(object):
|
||||
"Must specify a server_name to a generate config for."
|
||||
" Pass -H server.name."
|
||||
)
|
||||
if not os.path.exists(config_dir_path):
|
||||
if not cls.path_exists(config_dir_path):
|
||||
os.makedirs(config_dir_path)
|
||||
with open(config_path, "wb") as config_file:
|
||||
config_bytes, config = obj.generate_config(
|
||||
|
32
synapse/config/groups.py
Normal file
32
synapse/config/groups.py
Normal file
@ -0,0 +1,32 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import Config
|
||||
|
||||
|
||||
class GroupsConfig(Config):
|
||||
def read_config(self, config):
|
||||
self.enable_group_creation = config.get("enable_group_creation", False)
|
||||
self.group_creation_prefix = config.get("group_creation_prefix", "")
|
||||
|
||||
def default_config(self, **kwargs):
|
||||
return """\
|
||||
# Whether to allow non server admins to create groups on this server
|
||||
enable_group_creation: false
|
||||
|
||||
# If enabled, non server admins can only create groups with local parts
|
||||
# starting with this prefix
|
||||
# group_creation_prefix: "unofficial/"
|
||||
"""
|
@ -34,6 +34,8 @@ from .password_auth_providers import PasswordAuthProviderConfig
|
||||
from .emailconfig import EmailConfig
|
||||
from .workers import WorkerConfig
|
||||
from .push import PushConfig
|
||||
from .spam_checker import SpamCheckerConfig
|
||||
from .groups import GroupsConfig
|
||||
|
||||
|
||||
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
|
||||
@ -41,7 +43,8 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
|
||||
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
|
||||
AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
|
||||
JWTConfig, PasswordConfig, EmailConfig,
|
||||
WorkerConfig, PasswordAuthProviderConfig, PushConfig,):
|
||||
WorkerConfig, PasswordAuthProviderConfig, PushConfig,
|
||||
SpamCheckerConfig, GroupsConfig,):
|
||||
pass
|
||||
|
||||
|
||||
|
@ -118,10 +118,9 @@ class KeyConfig(Config):
|
||||
signing_keys = self.read_file(signing_key_path, "signing_key")
|
||||
try:
|
||||
return read_signing_keys(signing_keys.splitlines(True))
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
raise ConfigError(
|
||||
"Error reading signing_key."
|
||||
" Try running again with --generate-config"
|
||||
"Error reading signing_key: %s" % (str(e))
|
||||
)
|
||||
|
||||
def read_old_signing_keys(self, old_signing_keys):
|
||||
@ -141,7 +140,8 @@ class KeyConfig(Config):
|
||||
|
||||
def generate_files(self, config):
|
||||
signing_key_path = config["signing_key_path"]
|
||||
if not os.path.exists(signing_key_path):
|
||||
|
||||
if not self.path_exists(signing_key_path):
|
||||
with open(signing_key_path, "w") as signing_key_file:
|
||||
key_id = "a_" + random_string(4)
|
||||
write_signing_keys(
|
||||
|
@ -15,13 +15,15 @@
|
||||
|
||||
from ._base import Config, ConfigError
|
||||
|
||||
import importlib
|
||||
from synapse.util.module_loader import load_module
|
||||
|
||||
|
||||
class PasswordAuthProviderConfig(Config):
|
||||
def read_config(self, config):
|
||||
self.password_providers = []
|
||||
|
||||
provider_config = None
|
||||
|
||||
# We want to be backwards compatible with the old `ldap_config`
|
||||
# param.
|
||||
ldap_config = config.get("ldap_config", {})
|
||||
@ -38,19 +40,15 @@ class PasswordAuthProviderConfig(Config):
|
||||
if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider":
|
||||
from ldap_auth_provider import LdapAuthProvider
|
||||
provider_class = LdapAuthProvider
|
||||
try:
|
||||
provider_config = provider_class.parse_config(provider["config"])
|
||||
except Exception as e:
|
||||
raise ConfigError(
|
||||
"Failed to parse config for %r: %r" % (provider['module'], e)
|
||||
)
|
||||
else:
|
||||
# We need to import the module, and then pick the class out of
|
||||
# that, so we split based on the last dot.
|
||||
module, clz = provider['module'].rsplit(".", 1)
|
||||
module = importlib.import_module(module)
|
||||
provider_class = getattr(module, clz)
|
||||
(provider_class, provider_config) = load_module(provider)
|
||||
|
||||
try:
|
||||
provider_config = provider_class.parse_config(provider["config"])
|
||||
except Exception as e:
|
||||
raise ConfigError(
|
||||
"Failed to parse config for %r: %r" % (provider['module'], e)
|
||||
)
|
||||
self.password_providers.append((provider_class, provider_config))
|
||||
|
||||
def default_config(self, **kwargs):
|
||||
|
@ -41,6 +41,8 @@ class RegistrationConfig(Config):
|
||||
self.allow_guest_access and config.get("invite_3pid_guest", False)
|
||||
)
|
||||
|
||||
self.auto_join_rooms = config.get("auto_join_rooms", [])
|
||||
|
||||
def default_config(self, **kwargs):
|
||||
registration_shared_secret = random_string_with_symbols(50)
|
||||
|
||||
@ -70,6 +72,11 @@ class RegistrationConfig(Config):
|
||||
- matrix.org
|
||||
- vector.im
|
||||
- riot.im
|
||||
|
||||
# Users who register on this homeserver will automatically be joined
|
||||
# to these rooms
|
||||
#auto_join_rooms:
|
||||
# - "#example:example.com"
|
||||
""" % locals()
|
||||
|
||||
def add_arguments(self, parser):
|
||||
|
@ -70,7 +70,19 @@ class ContentRepositoryConfig(Config):
|
||||
self.max_upload_size = self.parse_size(config["max_upload_size"])
|
||||
self.max_image_pixels = self.parse_size(config["max_image_pixels"])
|
||||
self.max_spider_size = self.parse_size(config["max_spider_size"])
|
||||
|
||||
self.media_store_path = self.ensure_directory(config["media_store_path"])
|
||||
|
||||
self.backup_media_store_path = config.get("backup_media_store_path")
|
||||
if self.backup_media_store_path:
|
||||
self.backup_media_store_path = self.ensure_directory(
|
||||
self.backup_media_store_path
|
||||
)
|
||||
|
||||
self.synchronous_backup_media_store = config.get(
|
||||
"synchronous_backup_media_store", False
|
||||
)
|
||||
|
||||
self.uploads_path = self.ensure_directory(config["uploads_path"])
|
||||
self.dynamic_thumbnails = config["dynamic_thumbnails"]
|
||||
self.thumbnail_requirements = parse_thumbnail_requirements(
|
||||
@ -115,6 +127,14 @@ class ContentRepositoryConfig(Config):
|
||||
# Directory where uploaded images and attachments are stored.
|
||||
media_store_path: "%(media_store)s"
|
||||
|
||||
# A secondary directory where uploaded images and attachments are
|
||||
# stored as a backup.
|
||||
# backup_media_store_path: "%(media_store)s"
|
||||
|
||||
# Whether to wait for successful write to backup media store before
|
||||
# returning successfully.
|
||||
# synchronous_backup_media_store: false
|
||||
|
||||
# Directory where in-progress uploads are stored.
|
||||
uploads_path: "%(uploads_path)s"
|
||||
|
||||
|
35
synapse/config/spam_checker.py
Normal file
35
synapse/config/spam_checker.py
Normal file
@ -0,0 +1,35 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.util.module_loader import load_module
|
||||
|
||||
from ._base import Config
|
||||
|
||||
|
||||
class SpamCheckerConfig(Config):
|
||||
def read_config(self, config):
|
||||
self.spam_checker = None
|
||||
|
||||
provider = config.get("spam_checker", None)
|
||||
if provider is not None:
|
||||
self.spam_checker = load_module(provider)
|
||||
|
||||
def default_config(self, **kwargs):
|
||||
return """\
|
||||
# spam_checker:
|
||||
# module: "my_custom_project.SuperSpamChecker"
|
||||
# config:
|
||||
# example_option: 'things'
|
||||
"""
|
@ -126,7 +126,7 @@ class TlsConfig(Config):
|
||||
tls_private_key_path = config["tls_private_key_path"]
|
||||
tls_dh_params_path = config["tls_dh_params_path"]
|
||||
|
||||
if not os.path.exists(tls_private_key_path):
|
||||
if not self.path_exists(tls_private_key_path):
|
||||
with open(tls_private_key_path, "w") as private_key_file:
|
||||
tls_private_key = crypto.PKey()
|
||||
tls_private_key.generate_key(crypto.TYPE_RSA, 2048)
|
||||
@ -141,7 +141,7 @@ class TlsConfig(Config):
|
||||
crypto.FILETYPE_PEM, private_key_pem
|
||||
)
|
||||
|
||||
if not os.path.exists(tls_certificate_path):
|
||||
if not self.path_exists(tls_certificate_path):
|
||||
with open(tls_certificate_path, "w") as certificate_file:
|
||||
cert = crypto.X509()
|
||||
subject = cert.get_subject()
|
||||
@ -159,7 +159,7 @@ class TlsConfig(Config):
|
||||
|
||||
certificate_file.write(cert_pem)
|
||||
|
||||
if not os.path.exists(tls_dh_params_path):
|
||||
if not self.path_exists(tls_dh_params_path):
|
||||
if GENERATE_DH_PARAMS:
|
||||
subprocess.check_call([
|
||||
"openssl", "dhparam",
|
||||
|
@ -470,14 +470,14 @@ def _check_power_levels(event, auth_events):
|
||||
("invite", None),
|
||||
]
|
||||
|
||||
old_list = current_state.content.get("users")
|
||||
old_list = current_state.content.get("users", {})
|
||||
for user in set(old_list.keys() + user_list.keys()):
|
||||
levels_to_check.append(
|
||||
(user, "users")
|
||||
)
|
||||
|
||||
old_list = current_state.content.get("events")
|
||||
new_list = event.content.get("events")
|
||||
old_list = current_state.content.get("events", {})
|
||||
new_list = event.content.get("events", {})
|
||||
for ev_id in set(old_list.keys() + new_list.keys()):
|
||||
levels_to_check.append(
|
||||
(ev_id, "events")
|
||||
|
@ -14,25 +14,100 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
def check_event_for_spam(event):
|
||||
"""Checks if a given event is considered "spammy" by this server.
|
||||
class SpamChecker(object):
|
||||
def __init__(self, hs):
|
||||
self.spam_checker = None
|
||||
|
||||
If the server considers an event spammy, then it will be rejected if
|
||||
sent by a local user. If it is sent by a user on another server, then
|
||||
users receive a blank event.
|
||||
module = None
|
||||
config = None
|
||||
try:
|
||||
module, config = hs.config.spam_checker
|
||||
except:
|
||||
pass
|
||||
|
||||
Args:
|
||||
event (synapse.events.EventBase): the event to be checked
|
||||
if module is not None:
|
||||
self.spam_checker = module(config=config)
|
||||
|
||||
Returns:
|
||||
bool: True if the event is spammy.
|
||||
"""
|
||||
if not hasattr(event, "content") or "body" not in event.content:
|
||||
return False
|
||||
def check_event_for_spam(self, event):
|
||||
"""Checks if a given event is considered "spammy" by this server.
|
||||
|
||||
# for example:
|
||||
#
|
||||
# if "the third flower is green" in event.content["body"]:
|
||||
# return True
|
||||
If the server considers an event spammy, then it will be rejected if
|
||||
sent by a local user. If it is sent by a user on another server, then
|
||||
users receive a blank event.
|
||||
|
||||
return False
|
||||
Args:
|
||||
event (synapse.events.EventBase): the event to be checked
|
||||
|
||||
Returns:
|
||||
bool: True if the event is spammy.
|
||||
"""
|
||||
if self.spam_checker is None:
|
||||
return False
|
||||
|
||||
return self.spam_checker.check_event_for_spam(event)
|
||||
|
||||
def user_may_invite(self, inviter_userid, invitee_userid, room_id):
|
||||
"""Checks if a given user may send an invite
|
||||
|
||||
If this method returns false, the invite will be rejected.
|
||||
|
||||
Args:
|
||||
userid (string): The sender's user ID
|
||||
|
||||
Returns:
|
||||
bool: True if the user may send an invite, otherwise False
|
||||
"""
|
||||
if self.spam_checker is None:
|
||||
return True
|
||||
|
||||
return self.spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
|
||||
|
||||
def user_may_create_room(self, userid):
|
||||
"""Checks if a given user may create a room
|
||||
|
||||
If this method returns false, the creation request will be rejected.
|
||||
|
||||
Args:
|
||||
userid (string): The sender's user ID
|
||||
|
||||
Returns:
|
||||
bool: True if the user may create a room, otherwise False
|
||||
"""
|
||||
if self.spam_checker is None:
|
||||
return True
|
||||
|
||||
return self.spam_checker.user_may_create_room(userid)
|
||||
|
||||
def user_may_create_room_alias(self, userid, room_alias):
|
||||
"""Checks if a given user may create a room alias
|
||||
|
||||
If this method returns false, the association request will be rejected.
|
||||
|
||||
Args:
|
||||
userid (string): The sender's user ID
|
||||
room_alias (string): The alias to be created
|
||||
|
||||
Returns:
|
||||
bool: True if the user may create a room alias, otherwise False
|
||||
"""
|
||||
if self.spam_checker is None:
|
||||
return True
|
||||
|
||||
return self.spam_checker.user_may_create_room_alias(userid, room_alias)
|
||||
|
||||
def user_may_publish_room(self, userid, room_id):
|
||||
"""Checks if a given user may publish a room to the directory
|
||||
|
||||
If this method returns false, the publish request will be rejected.
|
||||
|
||||
Args:
|
||||
userid (string): The sender's user ID
|
||||
room_id (string): The ID of the room that would be published
|
||||
|
||||
Returns:
|
||||
bool: True if the user may publish the room, otherwise False
|
||||
"""
|
||||
if self.spam_checker is None:
|
||||
return True
|
||||
|
||||
return self.spam_checker.user_may_publish_room(userid, room_id)
|
||||
|
@ -16,7 +16,6 @@ import logging
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.crypto.event_signing import check_event_content_hash
|
||||
from synapse.events import spamcheck
|
||||
from synapse.events.utils import prune_event
|
||||
from synapse.util import unwrapFirstError, logcontext
|
||||
from twisted.internet import defer
|
||||
@ -26,7 +25,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class FederationBase(object):
|
||||
def __init__(self, hs):
|
||||
pass
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
|
||||
@ -144,7 +143,7 @@ class FederationBase(object):
|
||||
)
|
||||
return redacted
|
||||
|
||||
if spamcheck.check_event_for_spam(pdu):
|
||||
if self.spam_checker.check_event_for_spam(pdu):
|
||||
logger.warn(
|
||||
"Event contains spam, redacting %s: %s",
|
||||
pdu.event_id, pdu.get_pdu_json()
|
||||
|
@ -12,14 +12,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from .federation_base import FederationBase
|
||||
from .units import Transaction, Edu
|
||||
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util import async
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.events import FrozenEvent
|
||||
@ -33,6 +31,9 @@ from synapse.crypto.event_signing import compute_event_signature
|
||||
import simplejson as json
|
||||
import logging
|
||||
|
||||
# when processing incoming transactions, we try to handle multiple rooms in
|
||||
# parallel, up to this limit.
|
||||
TRANSACTION_CONCURRENCY_LIMIT = 10
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -52,7 +53,8 @@ class FederationServer(FederationBase):
|
||||
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
self._server_linearizer = Linearizer("fed_server")
|
||||
self._server_linearizer = async.Linearizer("fed_server")
|
||||
self._transaction_linearizer = async.Linearizer("fed_txn_handler")
|
||||
|
||||
# We cache responses to state queries, as they take a while and often
|
||||
# come in waves.
|
||||
@ -109,25 +111,41 @@ class FederationServer(FederationBase):
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_incoming_transaction(self, transaction_data):
|
||||
# keep this as early as possible to make the calculated origin ts as
|
||||
# accurate as possible.
|
||||
request_time = self._clock.time_msec()
|
||||
|
||||
transaction = Transaction(**transaction_data)
|
||||
|
||||
received_pdus_counter.inc_by(len(transaction.pdus))
|
||||
|
||||
for p in transaction.pdus:
|
||||
if "unsigned" in p:
|
||||
unsigned = p["unsigned"]
|
||||
if "age" in unsigned:
|
||||
p["age"] = unsigned["age"]
|
||||
if "age" in p:
|
||||
p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
|
||||
del p["age"]
|
||||
|
||||
pdu_list = [
|
||||
self.event_from_pdu_json(p) for p in transaction.pdus
|
||||
]
|
||||
if not transaction.transaction_id:
|
||||
raise Exception("Transaction missing transaction_id")
|
||||
if not transaction.origin:
|
||||
raise Exception("Transaction missing origin")
|
||||
|
||||
logger.debug("[%s] Got transaction", transaction.transaction_id)
|
||||
|
||||
# use a linearizer to ensure that we don't process the same transaction
|
||||
# multiple times in parallel.
|
||||
with (yield self._transaction_linearizer.queue(
|
||||
(transaction.origin, transaction.transaction_id),
|
||||
)):
|
||||
result = yield self._handle_incoming_transaction(
|
||||
transaction, request_time,
|
||||
)
|
||||
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _handle_incoming_transaction(self, transaction, request_time):
|
||||
""" Process an incoming transaction and return the HTTP response
|
||||
|
||||
Args:
|
||||
transaction (Transaction): incoming transaction
|
||||
request_time (int): timestamp that the HTTP request arrived at
|
||||
|
||||
Returns:
|
||||
Deferred[(int, object)]: http response code and body
|
||||
"""
|
||||
response = yield self.transaction_actions.have_responded(transaction)
|
||||
|
||||
if response:
|
||||
@ -140,42 +158,49 @@ class FederationServer(FederationBase):
|
||||
|
||||
logger.debug("[%s] Transaction is new", transaction.transaction_id)
|
||||
|
||||
results = []
|
||||
received_pdus_counter.inc_by(len(transaction.pdus))
|
||||
|
||||
for pdu in pdu_list:
|
||||
# check that it's actually being sent from a valid destination to
|
||||
# workaround bug #1753 in 0.18.5 and 0.18.6
|
||||
if transaction.origin != get_domain_from_id(pdu.event_id):
|
||||
# We continue to accept join events from any server; this is
|
||||
# necessary for the federation join dance to work correctly.
|
||||
# (When we join over federation, the "helper" server is
|
||||
# responsible for sending out the join event, rather than the
|
||||
# origin. See bug #1893).
|
||||
if not (
|
||||
pdu.type == 'm.room.member' and
|
||||
pdu.content and
|
||||
pdu.content.get("membership", None) == 'join'
|
||||
):
|
||||
logger.info(
|
||||
"Discarding PDU %s from invalid origin %s",
|
||||
pdu.event_id, transaction.origin
|
||||
)
|
||||
continue
|
||||
else:
|
||||
logger.info(
|
||||
"Accepting join PDU %s from %s",
|
||||
pdu.event_id, transaction.origin
|
||||
)
|
||||
pdus_by_room = {}
|
||||
|
||||
try:
|
||||
yield self._handle_received_pdu(transaction.origin, pdu)
|
||||
results.append({})
|
||||
except FederationError as e:
|
||||
self.send_failure(e, transaction.origin)
|
||||
results.append({"error": str(e)})
|
||||
except Exception as e:
|
||||
results.append({"error": str(e)})
|
||||
logger.exception("Failed to handle PDU")
|
||||
for p in transaction.pdus:
|
||||
if "unsigned" in p:
|
||||
unsigned = p["unsigned"]
|
||||
if "age" in unsigned:
|
||||
p["age"] = unsigned["age"]
|
||||
if "age" in p:
|
||||
p["age_ts"] = request_time - int(p["age"])
|
||||
del p["age"]
|
||||
|
||||
event = self.event_from_pdu_json(p)
|
||||
room_id = event.room_id
|
||||
pdus_by_room.setdefault(room_id, []).append(event)
|
||||
|
||||
pdu_results = {}
|
||||
|
||||
# we can process different rooms in parallel (which is useful if they
|
||||
# require callouts to other servers to fetch missing events), but
|
||||
# impose a limit to avoid going too crazy with ram/cpu.
|
||||
@defer.inlineCallbacks
|
||||
def process_pdus_for_room(room_id):
|
||||
logger.debug("Processing PDUs for %s", room_id)
|
||||
for pdu in pdus_by_room[room_id]:
|
||||
event_id = pdu.event_id
|
||||
try:
|
||||
yield self._handle_received_pdu(
|
||||
transaction.origin, pdu
|
||||
)
|
||||
pdu_results[event_id] = {}
|
||||
except FederationError as e:
|
||||
logger.warn("Error handling PDU %s: %s", event_id, e)
|
||||
pdu_results[event_id] = {"error": str(e)}
|
||||
except Exception as e:
|
||||
pdu_results[event_id] = {"error": str(e)}
|
||||
logger.exception("Failed to handle PDU %s", event_id)
|
||||
|
||||
yield async.concurrently_execute(
|
||||
process_pdus_for_room, pdus_by_room.keys(),
|
||||
TRANSACTION_CONCURRENCY_LIMIT,
|
||||
)
|
||||
|
||||
if hasattr(transaction, "edus"):
|
||||
for edu in (Edu(**x) for x in transaction.edus):
|
||||
@ -185,17 +210,16 @@ class FederationServer(FederationBase):
|
||||
edu.content
|
||||
)
|
||||
|
||||
for failure in getattr(transaction, "pdu_failures", []):
|
||||
logger.info("Got failure %r", failure)
|
||||
|
||||
logger.debug("Returning: %s", str(results))
|
||||
pdu_failures = getattr(transaction, "pdu_failures", [])
|
||||
for failure in pdu_failures:
|
||||
logger.info("Got failure %r", failure)
|
||||
|
||||
response = {
|
||||
"pdus": dict(zip(
|
||||
(p.event_id for p in pdu_list), results
|
||||
)),
|
||||
"pdus": pdu_results,
|
||||
}
|
||||
|
||||
logger.debug("Returning: %s", str(response))
|
||||
|
||||
yield self.transaction_actions.set_response(
|
||||
transaction,
|
||||
200, response
|
||||
@ -520,6 +544,30 @@ class FederationServer(FederationBase):
|
||||
Returns (Deferred): completes with None
|
||||
Raises: FederationError if the signatures / hash do not match
|
||||
"""
|
||||
# check that it's actually being sent from a valid destination to
|
||||
# workaround bug #1753 in 0.18.5 and 0.18.6
|
||||
if origin != get_domain_from_id(pdu.event_id):
|
||||
# We continue to accept join events from any server; this is
|
||||
# necessary for the federation join dance to work correctly.
|
||||
# (When we join over federation, the "helper" server is
|
||||
# responsible for sending out the join event, rather than the
|
||||
# origin. See bug #1893).
|
||||
if not (
|
||||
pdu.type == 'm.room.member' and
|
||||
pdu.content and
|
||||
pdu.content.get("membership", None) == 'join'
|
||||
):
|
||||
logger.info(
|
||||
"Discarding PDU %s from invalid origin %s",
|
||||
pdu.event_id, origin
|
||||
)
|
||||
return
|
||||
else:
|
||||
logger.info(
|
||||
"Accepting join PDU %s from %s",
|
||||
pdu.event_id, origin
|
||||
)
|
||||
|
||||
# Check signature.
|
||||
try:
|
||||
pdu = yield self._check_sigs_and_hash(pdu)
|
||||
|
@ -20,8 +20,8 @@ from .persistence import TransactionActions
|
||||
from .units import Transaction, Edu
|
||||
|
||||
from synapse.api.errors import HttpResponseException
|
||||
from synapse.util import logcontext
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.util.logcontext import preserve_context_over_fn, preserve_fn
|
||||
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
|
||||
from synapse.util.metrics import measure_func
|
||||
from synapse.handlers.presence import format_user_presence_state, get_interested_remotes
|
||||
@ -231,11 +231,9 @@ class TransactionQueue(object):
|
||||
(pdu, order)
|
||||
)
|
||||
|
||||
preserve_context_over_fn(
|
||||
self._attempt_new_transaction, destination
|
||||
)
|
||||
self._attempt_new_transaction(destination)
|
||||
|
||||
@preserve_fn # the caller should not yield on this
|
||||
@logcontext.preserve_fn # the caller should not yield on this
|
||||
@defer.inlineCallbacks
|
||||
def send_presence(self, states):
|
||||
"""Send the new presence states to the appropriate destinations.
|
||||
@ -299,7 +297,7 @@ class TransactionQueue(object):
|
||||
state.user_id: state for state in states
|
||||
})
|
||||
|
||||
preserve_fn(self._attempt_new_transaction)(destination)
|
||||
self._attempt_new_transaction(destination)
|
||||
|
||||
def send_edu(self, destination, edu_type, content, key=None):
|
||||
edu = Edu(
|
||||
@ -321,9 +319,7 @@ class TransactionQueue(object):
|
||||
else:
|
||||
self.pending_edus_by_dest.setdefault(destination, []).append(edu)
|
||||
|
||||
preserve_context_over_fn(
|
||||
self._attempt_new_transaction, destination
|
||||
)
|
||||
self._attempt_new_transaction(destination)
|
||||
|
||||
def send_failure(self, failure, destination):
|
||||
if destination == self.server_name or destination == "localhost":
|
||||
@ -336,9 +332,7 @@ class TransactionQueue(object):
|
||||
destination, []
|
||||
).append(failure)
|
||||
|
||||
preserve_context_over_fn(
|
||||
self._attempt_new_transaction, destination
|
||||
)
|
||||
self._attempt_new_transaction(destination)
|
||||
|
||||
def send_device_messages(self, destination):
|
||||
if destination == self.server_name or destination == "localhost":
|
||||
@ -347,15 +341,24 @@ class TransactionQueue(object):
|
||||
if not self.can_send_to(destination):
|
||||
return
|
||||
|
||||
preserve_context_over_fn(
|
||||
self._attempt_new_transaction, destination
|
||||
)
|
||||
self._attempt_new_transaction(destination)
|
||||
|
||||
def get_current_token(self):
|
||||
return 0
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _attempt_new_transaction(self, destination):
|
||||
"""Try to start a new transaction to this destination
|
||||
|
||||
If there is already a transaction in progress to this destination,
|
||||
returns immediately. Otherwise kicks off the process of sending a
|
||||
transaction in the background.
|
||||
|
||||
Args:
|
||||
destination (str):
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# list of (pending_pdu, deferred, order)
|
||||
if destination in self.pending_transactions:
|
||||
# XXX: pending_transactions can get stuck on by a never-ending
|
||||
@ -368,6 +371,19 @@ class TransactionQueue(object):
|
||||
)
|
||||
return
|
||||
|
||||
logger.debug("TX [%s] Starting transaction loop", destination)
|
||||
|
||||
# Drop the logcontext before starting the transaction. It doesn't
|
||||
# really make sense to log all the outbound transactions against
|
||||
# whatever path led us to this point: that's pretty arbitrary really.
|
||||
#
|
||||
# (this also means we can fire off _perform_transaction without
|
||||
# yielding)
|
||||
with logcontext.PreserveLoggingContext():
|
||||
self._transaction_transmission_loop(destination)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _transaction_transmission_loop(self, destination):
|
||||
pending_pdus = []
|
||||
try:
|
||||
self.pending_transactions[destination] = 1
|
||||
|
@ -471,3 +471,384 @@ class TransportLayerClient(object):
|
||||
)
|
||||
|
||||
defer.returnValue(content)
|
||||
|
||||
@log_function
|
||||
def get_group_profile(self, destination, group_id, requester_user_id):
|
||||
"""Get a group profile
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/profile" % (group_id,)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def get_group_summary(self, destination, group_id, requester_user_id):
|
||||
"""Get a group summary
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/summary" % (group_id,)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def get_rooms_in_group(self, destination, group_id, requester_user_id):
|
||||
"""Get all rooms in a group
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/rooms" % (group_id,)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
def add_room_to_group(self, destination, group_id, requester_user_id, room_id,
|
||||
content):
|
||||
"""Add a room to a group
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/room/%s" % (group_id, room_id,)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
|
||||
"""Remove a room from a group
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/room/%s" % (group_id, room_id,)
|
||||
|
||||
return self.client.delete_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def get_users_in_group(self, destination, group_id, requester_user_id):
|
||||
"""Get users in a group
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/users" % (group_id,)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def get_invited_users_in_group(self, destination, group_id, requester_user_id):
|
||||
"""Get users that have been invited to a group
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/invited_users" % (group_id,)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def accept_group_invite(self, destination, group_id, user_id, content):
|
||||
"""Accept a group invite
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/users/%s/accept_invite" % (group_id, user_id)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def invite_to_group(self, destination, group_id, user_id, requester_user_id, content):
|
||||
"""Invite a user to a group
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/users/%s/invite" % (group_id, user_id)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def invite_to_group_notification(self, destination, group_id, user_id, content):
|
||||
"""Sent by group server to inform a user's server that they have been
|
||||
invited.
|
||||
"""
|
||||
|
||||
path = PREFIX + "/groups/local/%s/users/%s/invite" % (group_id, user_id)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def remove_user_from_group(self, destination, group_id, requester_user_id,
|
||||
user_id, content):
|
||||
"""Remove a user fron a group
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/users/%s/remove" % (group_id, user_id)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def remove_user_from_group_notification(self, destination, group_id, user_id,
|
||||
content):
|
||||
"""Sent by group server to inform a user's server that they have been
|
||||
kicked from the group.
|
||||
"""
|
||||
|
||||
path = PREFIX + "/groups/local/%s/users/%s/remove" % (group_id, user_id)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def renew_group_attestation(self, destination, group_id, user_id, content):
|
||||
"""Sent by either a group server or a user's server to periodically update
|
||||
the attestations
|
||||
"""
|
||||
|
||||
path = PREFIX + "/groups/%s/renew_attestation/%s" % (group_id, user_id)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def update_group_summary_room(self, destination, group_id, user_id, room_id,
|
||||
category_id, content):
|
||||
"""Update a room entry in a group summary
|
||||
"""
|
||||
if category_id:
|
||||
path = PREFIX + "/groups/%s/summary/categories/%s/rooms/%s" % (
|
||||
group_id, category_id, room_id,
|
||||
)
|
||||
else:
|
||||
path = PREFIX + "/groups/%s/summary/rooms/%s" % (group_id, room_id,)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": user_id},
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def delete_group_summary_room(self, destination, group_id, user_id, room_id,
|
||||
category_id):
|
||||
"""Delete a room entry in a group summary
|
||||
"""
|
||||
if category_id:
|
||||
path = PREFIX + "/groups/%s/summary/categories/%s/rooms/%s" % (
|
||||
group_id, category_id, room_id,
|
||||
)
|
||||
else:
|
||||
path = PREFIX + "/groups/%s/summary/rooms/%s" % (group_id, room_id,)
|
||||
|
||||
return self.client.delete_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": user_id},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def get_group_categories(self, destination, group_id, requester_user_id):
|
||||
"""Get all categories in a group
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/categories" % (group_id,)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def get_group_category(self, destination, group_id, requester_user_id, category_id):
|
||||
"""Get category info in a group
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/categories/%s" % (group_id, category_id,)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def update_group_category(self, destination, group_id, requester_user_id, category_id,
|
||||
content):
|
||||
"""Update a category in a group
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/categories/%s" % (group_id, category_id,)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def delete_group_category(self, destination, group_id, requester_user_id,
|
||||
category_id):
|
||||
"""Delete a category in a group
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/categories/%s" % (group_id, category_id,)
|
||||
|
||||
return self.client.delete_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def get_group_roles(self, destination, group_id, requester_user_id):
|
||||
"""Get all roles in a group
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/roles" % (group_id,)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def get_group_role(self, destination, group_id, requester_user_id, role_id):
|
||||
"""Get a roles info
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/roles/%s" % (group_id, role_id,)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def update_group_role(self, destination, group_id, requester_user_id, role_id,
|
||||
content):
|
||||
"""Update a role in a group
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/roles/%s" % (group_id, role_id,)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def delete_group_role(self, destination, group_id, requester_user_id, role_id):
|
||||
"""Delete a role in a group
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/roles/%s" % (group_id, role_id,)
|
||||
|
||||
return self.client.delete_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def update_group_summary_user(self, destination, group_id, requester_user_id,
|
||||
user_id, role_id, content):
|
||||
"""Update a users entry in a group
|
||||
"""
|
||||
if role_id:
|
||||
path = PREFIX + "/groups/%s/summary/roles/%s/users/%s" % (
|
||||
group_id, role_id, user_id,
|
||||
)
|
||||
else:
|
||||
path = PREFIX + "/groups/%s/summary/users/%s" % (group_id, user_id,)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def delete_group_summary_user(self, destination, group_id, requester_user_id,
|
||||
user_id, role_id):
|
||||
"""Delete a users entry in a group
|
||||
"""
|
||||
if role_id:
|
||||
path = PREFIX + "/groups/%s/summary/roles/%s/users/%s" % (
|
||||
group_id, role_id, user_id,
|
||||
)
|
||||
else:
|
||||
path = PREFIX + "/groups/%s/summary/users/%s" % (group_id, user_id,)
|
||||
|
||||
return self.client.delete_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
def bulk_get_publicised_groups(self, destination, user_ids):
|
||||
"""Get the groups a list of users are publicising
|
||||
"""
|
||||
|
||||
path = PREFIX + "/get_groups_publicised"
|
||||
|
||||
content = {"user_ids": user_ids}
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
@ -25,7 +25,7 @@ from synapse.http.servlet import (
|
||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||
from synapse.util.versionstring import get_version_string
|
||||
from synapse.util.logcontext import preserve_fn
|
||||
from synapse.types import ThirdPartyInstanceID
|
||||
from synapse.types import ThirdPartyInstanceID, get_domain_from_id
|
||||
|
||||
import functools
|
||||
import logging
|
||||
@ -609,6 +609,493 @@ class FederationVersionServlet(BaseFederationServlet):
|
||||
}))
|
||||
|
||||
|
||||
class FederationGroupsProfileServlet(BaseFederationServlet):
|
||||
"""Get the basic profile of a group on behalf of a user
|
||||
"""
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/profile$"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, origin, content, query, group_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.get_group_profile(
|
||||
group_id, requester_user_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
|
||||
class FederationGroupsSummaryServlet(BaseFederationServlet):
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/summary$"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, origin, content, query, group_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.get_group_summary(
|
||||
group_id, requester_user_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query, group_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.update_group_profile(
|
||||
group_id, requester_user_id, content
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
|
||||
class FederationGroupsRoomsServlet(BaseFederationServlet):
|
||||
"""Get the rooms in a group on behalf of a user
|
||||
"""
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/rooms$"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, origin, content, query, group_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.get_rooms_in_group(
|
||||
group_id, requester_user_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
|
||||
class FederationGroupsAddRoomsServlet(BaseFederationServlet):
|
||||
"""Add/remove room from group
|
||||
"""
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/room/(?<room_id>)$"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query, group_id, room_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.add_room_to_group(
|
||||
group_id, requester_user_id, room_id, content
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, origin, content, query, group_id, room_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.remove_room_from_group(
|
||||
group_id, requester_user_id, room_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
|
||||
class FederationGroupsUsersServlet(BaseFederationServlet):
|
||||
"""Get the users in a group on behalf of a user
|
||||
"""
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/users$"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, origin, content, query, group_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.get_users_in_group(
|
||||
group_id, requester_user_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
|
||||
class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
|
||||
"""Get the users that have been invited to a group
|
||||
"""
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/invited_users$"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, origin, content, query, group_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.get_invited_users_in_group(
|
||||
group_id, requester_user_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
|
||||
class FederationGroupsInviteServlet(BaseFederationServlet):
|
||||
"""Ask a group server to invite someone to the group
|
||||
"""
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite$"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query, group_id, user_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.invite_to_group(
|
||||
group_id, user_id, requester_user_id, content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
|
||||
class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
|
||||
"""Accept an invitation from the group server
|
||||
"""
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite$"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query, group_id, user_id):
|
||||
if get_domain_from_id(user_id) != origin:
|
||||
raise SynapseError(403, "user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.accept_invite(
|
||||
group_id, user_id, content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
|
||||
class FederationGroupsRemoveUserServlet(BaseFederationServlet):
|
||||
"""Leave or kick a user from the group
|
||||
"""
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove$"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query, group_id, user_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.remove_user_from_group(
|
||||
group_id, user_id, requester_user_id, content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
|
||||
class FederationGroupsLocalInviteServlet(BaseFederationServlet):
|
||||
"""A group server has invited a local user
|
||||
"""
|
||||
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite$"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query, group_id, user_id):
|
||||
if get_domain_from_id(group_id) != origin:
|
||||
raise SynapseError(403, "group_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.on_invite(
|
||||
group_id, user_id, content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
|
||||
class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
|
||||
"""A group server has removed a local user
|
||||
"""
|
||||
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove$"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query, group_id, user_id):
|
||||
if get_domain_from_id(group_id) != origin:
|
||||
raise SynapseError(403, "user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.user_removed_from_group(
|
||||
group_id, user_id, content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
|
||||
class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
|
||||
"""A group or user's server renews their attestation
|
||||
"""
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)$"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query, group_id, user_id):
|
||||
# We don't need to check auth here as we check the attestation signatures
|
||||
|
||||
new_content = yield self.handler.on_renew_attestation(
|
||||
group_id, user_id, content
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
|
||||
class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
|
||||
"""Add/remove a room from the group summary, with optional category.
|
||||
|
||||
Matches both:
|
||||
- /groups/:group/summary/rooms/:room_id
|
||||
- /groups/:group/summary/categories/:category/rooms/:room_id
|
||||
"""
|
||||
PATH = (
|
||||
"/groups/(?P<group_id>[^/]*)/summary"
|
||||
"(/categories/(?P<category_id>[^/]+))?"
|
||||
"/rooms/(?P<room_id>[^/]*)$"
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query, group_id, category_id, room_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
if category_id == "":
|
||||
raise SynapseError(400, "category_id cannot be empty string")
|
||||
|
||||
resp = yield self.handler.update_group_summary_room(
|
||||
group_id, requester_user_id,
|
||||
room_id=room_id,
|
||||
category_id=category_id,
|
||||
content=content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, origin, content, query, group_id, category_id, room_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
if category_id == "":
|
||||
raise SynapseError(400, "category_id cannot be empty string")
|
||||
|
||||
resp = yield self.handler.delete_group_summary_room(
|
||||
group_id, requester_user_id,
|
||||
room_id=room_id,
|
||||
category_id=category_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
|
||||
class FederationGroupsCategoriesServlet(BaseFederationServlet):
|
||||
"""Get all categories for a group
|
||||
"""
|
||||
PATH = (
|
||||
"/groups/(?P<group_id>[^/]*)/categories/$"
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, origin, content, query, group_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
resp = yield self.handler.get_group_categories(
|
||||
group_id, requester_user_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
|
||||
class FederationGroupsCategoryServlet(BaseFederationServlet):
|
||||
"""Add/remove/get a category in a group
|
||||
"""
|
||||
PATH = (
|
||||
"/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, origin, content, query, group_id, category_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
resp = yield self.handler.get_group_category(
|
||||
group_id, requester_user_id, category_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query, group_id, category_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
if category_id == "":
|
||||
raise SynapseError(400, "category_id cannot be empty string")
|
||||
|
||||
resp = yield self.handler.upsert_group_category(
|
||||
group_id, requester_user_id, category_id, content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, origin, content, query, group_id, category_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
if category_id == "":
|
||||
raise SynapseError(400, "category_id cannot be empty string")
|
||||
|
||||
resp = yield self.handler.delete_group_category(
|
||||
group_id, requester_user_id, category_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
|
||||
class FederationGroupsRolesServlet(BaseFederationServlet):
|
||||
"""Get roles in a group
|
||||
"""
|
||||
PATH = (
|
||||
"/groups/(?P<group_id>[^/]*)/roles/$"
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, origin, content, query, group_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
resp = yield self.handler.get_group_roles(
|
||||
group_id, requester_user_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
|
||||
class FederationGroupsRoleServlet(BaseFederationServlet):
|
||||
"""Add/remove/get a role in a group
|
||||
"""
|
||||
PATH = (
|
||||
"/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$"
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, origin, content, query, group_id, role_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
resp = yield self.handler.get_group_role(
|
||||
group_id, requester_user_id, role_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query, group_id, role_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
if role_id == "":
|
||||
raise SynapseError(400, "role_id cannot be empty string")
|
||||
|
||||
resp = yield self.handler.update_group_role(
|
||||
group_id, requester_user_id, role_id, content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, origin, content, query, group_id, role_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
if role_id == "":
|
||||
raise SynapseError(400, "role_id cannot be empty string")
|
||||
|
||||
resp = yield self.handler.delete_group_role(
|
||||
group_id, requester_user_id, role_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
|
||||
class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
|
||||
"""Add/remove a user from the group summary, with optional role.
|
||||
|
||||
Matches both:
|
||||
- /groups/:group/summary/users/:user_id
|
||||
- /groups/:group/summary/roles/:role/users/:user_id
|
||||
"""
|
||||
PATH = (
|
||||
"/groups/(?P<group_id>[^/]*)/summary"
|
||||
"(/roles/(?P<role_id>[^/]+))?"
|
||||
"/users/(?P<user_id>[^/]*)$"
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query, group_id, role_id, user_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
if role_id == "":
|
||||
raise SynapseError(400, "role_id cannot be empty string")
|
||||
|
||||
resp = yield self.handler.update_group_summary_user(
|
||||
group_id, requester_user_id,
|
||||
user_id=user_id,
|
||||
role_id=role_id,
|
||||
content=content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, origin, content, query, group_id, role_id, user_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
if role_id == "":
|
||||
raise SynapseError(400, "role_id cannot be empty string")
|
||||
|
||||
resp = yield self.handler.delete_group_summary_user(
|
||||
group_id, requester_user_id,
|
||||
user_id=user_id,
|
||||
role_id=role_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
|
||||
class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
|
||||
"""Get roles in a group
|
||||
"""
|
||||
PATH = (
|
||||
"/get_groups_publicised$"
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query):
|
||||
resp = yield self.handler.bulk_get_publicised_groups(
|
||||
content["user_ids"], proxy=False,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
|
||||
FEDERATION_SERVLET_CLASSES = (
|
||||
FederationSendServlet,
|
||||
FederationPullServlet,
|
||||
@ -635,10 +1122,40 @@ FEDERATION_SERVLET_CLASSES = (
|
||||
FederationVersionServlet,
|
||||
)
|
||||
|
||||
|
||||
ROOM_LIST_CLASSES = (
|
||||
PublicRoomList,
|
||||
)
|
||||
|
||||
GROUP_SERVER_SERVLET_CLASSES = (
|
||||
FederationGroupsProfileServlet,
|
||||
FederationGroupsSummaryServlet,
|
||||
FederationGroupsRoomsServlet,
|
||||
FederationGroupsUsersServlet,
|
||||
FederationGroupsInvitedUsersServlet,
|
||||
FederationGroupsInviteServlet,
|
||||
FederationGroupsAcceptInviteServlet,
|
||||
FederationGroupsRemoveUserServlet,
|
||||
FederationGroupsSummaryRoomsServlet,
|
||||
FederationGroupsCategoriesServlet,
|
||||
FederationGroupsCategoryServlet,
|
||||
FederationGroupsRolesServlet,
|
||||
FederationGroupsRoleServlet,
|
||||
FederationGroupsSummaryUsersServlet,
|
||||
)
|
||||
|
||||
|
||||
GROUP_LOCAL_SERVLET_CLASSES = (
|
||||
FederationGroupsLocalInviteServlet,
|
||||
FederationGroupsRemoveLocalUserServlet,
|
||||
FederationGroupsBulkPublicisedServlet,
|
||||
)
|
||||
|
||||
|
||||
GROUP_ATTESTATION_SERVLET_CLASSES = (
|
||||
FederationGroupsRenewAttestaionServlet,
|
||||
)
|
||||
|
||||
|
||||
def register_servlets(hs, resource, authenticator, ratelimiter):
|
||||
for servletclass in FEDERATION_SERVLET_CLASSES:
|
||||
@ -656,3 +1173,27 @@ def register_servlets(hs, resource, authenticator, ratelimiter):
|
||||
ratelimiter=ratelimiter,
|
||||
server_name=hs.hostname,
|
||||
).register(resource)
|
||||
|
||||
for servletclass in GROUP_SERVER_SERVLET_CLASSES:
|
||||
servletclass(
|
||||
handler=hs.get_groups_server_handler(),
|
||||
authenticator=authenticator,
|
||||
ratelimiter=ratelimiter,
|
||||
server_name=hs.hostname,
|
||||
).register(resource)
|
||||
|
||||
for servletclass in GROUP_LOCAL_SERVLET_CLASSES:
|
||||
servletclass(
|
||||
handler=hs.get_groups_local_handler(),
|
||||
authenticator=authenticator,
|
||||
ratelimiter=ratelimiter,
|
||||
server_name=hs.hostname,
|
||||
).register(resource)
|
||||
|
||||
for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES:
|
||||
servletclass(
|
||||
handler=hs.get_groups_attestation_renewer(),
|
||||
authenticator=authenticator,
|
||||
ratelimiter=ratelimiter,
|
||||
server_name=hs.hostname,
|
||||
).register(resource)
|
||||
|
0
synapse/groups/__init__.py
Normal file
0
synapse/groups/__init__.py
Normal file
151
synapse/groups/attestations.py
Normal file
151
synapse/groups/attestations.py
Normal file
@ -0,0 +1,151 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.util.logcontext import preserve_fn
|
||||
|
||||
from signedjson.sign import sign_json
|
||||
|
||||
|
||||
# Default validity duration for new attestations we create
|
||||
DEFAULT_ATTESTATION_LENGTH_MS = 3 * 24 * 60 * 60 * 1000
|
||||
|
||||
# Start trying to update our attestations when they come this close to expiring
|
||||
UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000
|
||||
|
||||
|
||||
class GroupAttestationSigning(object):
|
||||
"""Creates and verifies group attestations.
|
||||
"""
|
||||
def __init__(self, hs):
|
||||
self.keyring = hs.get_keyring()
|
||||
self.clock = hs.get_clock()
|
||||
self.server_name = hs.hostname
|
||||
self.signing_key = hs.config.signing_key[0]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def verify_attestation(self, attestation, group_id, user_id, server_name=None):
|
||||
"""Verifies that the given attestation matches the given parameters.
|
||||
|
||||
An optional server_name can be supplied to explicitly set which server's
|
||||
signature is expected. Otherwise assumes that either the group_id or user_id
|
||||
is local and uses the other's server as the one to check.
|
||||
"""
|
||||
|
||||
if not server_name:
|
||||
if get_domain_from_id(group_id) == self.server_name:
|
||||
server_name = get_domain_from_id(user_id)
|
||||
elif get_domain_from_id(user_id) == self.server_name:
|
||||
server_name = get_domain_from_id(group_id)
|
||||
else:
|
||||
raise Exception("Expected either group_id or user_id to be local")
|
||||
|
||||
if user_id != attestation["user_id"]:
|
||||
raise SynapseError(400, "Attestation has incorrect user_id")
|
||||
|
||||
if group_id != attestation["group_id"]:
|
||||
raise SynapseError(400, "Attestation has incorrect group_id")
|
||||
valid_until_ms = attestation["valid_until_ms"]
|
||||
|
||||
# TODO: We also want to check that *new* attestations that people give
|
||||
# us to store are valid for at least a little while.
|
||||
if valid_until_ms < self.clock.time_msec():
|
||||
raise SynapseError(400, "Attestation expired")
|
||||
|
||||
yield self.keyring.verify_json_for_server(server_name, attestation)
|
||||
|
||||
def create_attestation(self, group_id, user_id):
|
||||
"""Create an attestation for the group_id and user_id with default
|
||||
validity length.
|
||||
"""
|
||||
return sign_json({
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
"valid_until_ms": self.clock.time_msec() + DEFAULT_ATTESTATION_LENGTH_MS,
|
||||
}, self.server_name, self.signing_key)
|
||||
|
||||
|
||||
class GroupAttestionRenewer(object):
|
||||
"""Responsible for sending and receiving attestation updates.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
self.assestations = hs.get_groups_attestation_signing()
|
||||
self.transport_client = hs.get_federation_transport_client()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.attestations = hs.get_groups_attestation_signing()
|
||||
|
||||
self._renew_attestations_loop = self.clock.looping_call(
|
||||
self._renew_attestations, 30 * 60 * 1000,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_renew_attestation(self, group_id, user_id, content):
|
||||
"""When a remote updates an attestation
|
||||
"""
|
||||
attestation = content["attestation"]
|
||||
|
||||
if not self.is_mine_id(group_id) and not self.is_mine_id(user_id):
|
||||
raise SynapseError(400, "Neither user not group are on this server")
|
||||
|
||||
yield self.attestations.verify_attestation(
|
||||
attestation,
|
||||
user_id=user_id,
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
yield self.store.update_remote_attestion(group_id, user_id, attestation)
|
||||
|
||||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _renew_attestations(self):
|
||||
"""Called periodically to check if we need to update any of our attestations
|
||||
"""
|
||||
|
||||
now = self.clock.time_msec()
|
||||
|
||||
rows = yield self.store.get_attestations_need_renewals(
|
||||
now + UPDATE_ATTESTATION_TIME_MS
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _renew_attestation(group_id, user_id):
|
||||
attestation = self.attestations.create_attestation(group_id, user_id)
|
||||
|
||||
if self.is_mine_id(group_id):
|
||||
destination = get_domain_from_id(user_id)
|
||||
else:
|
||||
destination = get_domain_from_id(group_id)
|
||||
|
||||
yield self.transport_client.renew_group_attestation(
|
||||
destination, group_id, user_id,
|
||||
content={"attestation": attestation},
|
||||
)
|
||||
|
||||
yield self.store.update_attestation_renewal(
|
||||
group_id, user_id, attestation
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
group_id = row["group_id"]
|
||||
user_id = row["user_id"]
|
||||
|
||||
preserve_fn(_renew_attestation)(group_id, user_id)
|
803
synapse/groups/groups_server.py
Normal file
803
synapse/groups/groups_server.py
Normal file
@ -0,0 +1,803 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.types import UserID, get_domain_from_id, RoomID, GroupID
|
||||
|
||||
|
||||
import logging
|
||||
import urllib
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# TODO: Allow users to "knock" or simpkly join depending on rules
|
||||
# TODO: Federation admin APIs
|
||||
# TODO: is_priveged flag to users and is_public to users and rooms
|
||||
# TODO: Audit log for admins (profile updates, membership changes, users who tried
|
||||
# to join but were rejected, etc)
|
||||
# TODO: Flairs
|
||||
|
||||
|
||||
class GroupsServerHandler(object):
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.room_list_handler = hs.get_room_list_handler()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.keyring = hs.get_keyring()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.signing_key = hs.config.signing_key[0]
|
||||
self.server_name = hs.hostname
|
||||
self.attestations = hs.get_groups_attestation_signing()
|
||||
self.transport_client = hs.get_federation_transport_client()
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
|
||||
# Ensure attestations get renewed
|
||||
hs.get_groups_attestation_renewer()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_group_is_ours(self, group_id, and_exists=False, and_is_admin=None):
|
||||
"""Check that the group is ours, and optionally if it exists.
|
||||
|
||||
If group does exist then return group.
|
||||
|
||||
Args:
|
||||
group_id (str)
|
||||
and_exists (bool): whether to also check if group exists
|
||||
and_is_admin (str): whether to also check if given str is a user_id
|
||||
that is an admin
|
||||
"""
|
||||
if not self.is_mine_id(group_id):
|
||||
raise SynapseError(400, "Group not on this server")
|
||||
|
||||
group = yield self.store.get_group(group_id)
|
||||
if and_exists and not group:
|
||||
raise SynapseError(404, "Unknown group")
|
||||
|
||||
if and_is_admin:
|
||||
is_admin = yield self.store.is_user_admin_in_group(group_id, and_is_admin)
|
||||
if not is_admin:
|
||||
raise SynapseError(403, "User is not admin in group")
|
||||
|
||||
defer.returnValue(group)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_group_summary(self, group_id, requester_user_id):
|
||||
"""Get the summary for a group as seen by requester_user_id.
|
||||
|
||||
The group summary consists of the profile of the room, and a curated
|
||||
list of users and rooms. These list *may* be organised by role/category.
|
||||
The roles/categories are ordered, and so are the users/rooms within them.
|
||||
|
||||
A user/room may appear in multiple roles/categories.
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
|
||||
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
|
||||
|
||||
profile = yield self.get_group_profile(group_id, requester_user_id)
|
||||
|
||||
users, roles = yield self.store.get_users_for_summary_by_role(
|
||||
group_id, include_private=is_user_in_group,
|
||||
)
|
||||
|
||||
# TODO: Add profiles to users
|
||||
|
||||
rooms, categories = yield self.store.get_rooms_for_summary_by_category(
|
||||
group_id, include_private=is_user_in_group,
|
||||
)
|
||||
|
||||
for room_entry in rooms:
|
||||
room_id = room_entry["room_id"]
|
||||
joined_users = yield self.store.get_users_in_room(room_id)
|
||||
entry = yield self.room_list_handler.generate_room_entry(
|
||||
room_id, len(joined_users),
|
||||
with_alias=False, allow_private=True,
|
||||
)
|
||||
entry = dict(entry) # so we don't change whats cached
|
||||
entry.pop("room_id", None)
|
||||
|
||||
room_entry["profile"] = entry
|
||||
|
||||
rooms.sort(key=lambda e: e.get("order", 0))
|
||||
|
||||
for entry in users:
|
||||
user_id = entry["user_id"]
|
||||
|
||||
if not self.is_mine_id(requester_user_id):
|
||||
attestation = yield self.store.get_remote_attestation(group_id, user_id)
|
||||
if not attestation:
|
||||
continue
|
||||
|
||||
entry["attestation"] = attestation
|
||||
else:
|
||||
entry["attestation"] = self.attestations.create_attestation(
|
||||
group_id, user_id,
|
||||
)
|
||||
|
||||
user_profile = yield self.profile_handler.get_profile_from_cache(user_id)
|
||||
entry.update(user_profile)
|
||||
|
||||
users.sort(key=lambda e: e.get("order", 0))
|
||||
|
||||
membership_info = yield self.store.get_users_membership_info_in_group(
|
||||
group_id, requester_user_id,
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
"profile": profile,
|
||||
"users_section": {
|
||||
"users": users,
|
||||
"roles": roles,
|
||||
"total_user_count_estimate": 0, # TODO
|
||||
},
|
||||
"rooms_section": {
|
||||
"rooms": rooms,
|
||||
"categories": categories,
|
||||
"total_room_count_estimate": 0, # TODO
|
||||
},
|
||||
"user": membership_info,
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_group_summary_room(self, group_id, user_id, room_id, category_id, content):
|
||||
"""Add/update a room to the group summary
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
|
||||
|
||||
RoomID.from_string(room_id) # Ensure valid room id
|
||||
|
||||
order = content.get("order", None)
|
||||
|
||||
is_public = _parse_visibility_from_contents(content)
|
||||
|
||||
yield self.store.add_room_to_summary(
|
||||
group_id=group_id,
|
||||
room_id=room_id,
|
||||
category_id=category_id,
|
||||
order=order,
|
||||
is_public=is_public,
|
||||
)
|
||||
|
||||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_group_summary_room(self, group_id, user_id, room_id, category_id):
|
||||
"""Remove a room from the summary
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
|
||||
|
||||
yield self.store.remove_room_from_summary(
|
||||
group_id=group_id,
|
||||
room_id=room_id,
|
||||
category_id=category_id,
|
||||
)
|
||||
|
||||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_group_categories(self, group_id, user_id):
|
||||
"""Get all categories in a group (as seen by user)
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
|
||||
categories = yield self.store.get_group_categories(
|
||||
group_id=group_id,
|
||||
)
|
||||
defer.returnValue({"categories": categories})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_group_category(self, group_id, user_id, category_id):
|
||||
"""Get a specific category in a group (as seen by user)
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
|
||||
res = yield self.store.get_group_category(
|
||||
group_id=group_id,
|
||||
category_id=category_id,
|
||||
)
|
||||
|
||||
defer.returnValue(res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_group_category(self, group_id, user_id, category_id, content):
|
||||
"""Add/Update a group category
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
|
||||
|
||||
is_public = _parse_visibility_from_contents(content)
|
||||
profile = content.get("profile")
|
||||
|
||||
yield self.store.upsert_group_category(
|
||||
group_id=group_id,
|
||||
category_id=category_id,
|
||||
is_public=is_public,
|
||||
profile=profile,
|
||||
)
|
||||
|
||||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_group_category(self, group_id, user_id, category_id):
|
||||
"""Delete a group category
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
|
||||
|
||||
yield self.store.remove_group_category(
|
||||
group_id=group_id,
|
||||
category_id=category_id,
|
||||
)
|
||||
|
||||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_group_roles(self, group_id, user_id):
|
||||
"""Get all roles in a group (as seen by user)
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
|
||||
roles = yield self.store.get_group_roles(
|
||||
group_id=group_id,
|
||||
)
|
||||
defer.returnValue({"roles": roles})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_group_role(self, group_id, user_id, role_id):
|
||||
"""Get a specific role in a group (as seen by user)
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
|
||||
res = yield self.store.get_group_role(
|
||||
group_id=group_id,
|
||||
role_id=role_id,
|
||||
)
|
||||
defer.returnValue(res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_group_role(self, group_id, user_id, role_id, content):
|
||||
"""Add/update a role in a group
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
|
||||
|
||||
is_public = _parse_visibility_from_contents(content)
|
||||
|
||||
profile = content.get("profile")
|
||||
|
||||
yield self.store.upsert_group_role(
|
||||
group_id=group_id,
|
||||
role_id=role_id,
|
||||
is_public=is_public,
|
||||
profile=profile,
|
||||
)
|
||||
|
||||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_group_role(self, group_id, user_id, role_id):
|
||||
"""Remove role from group
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
|
||||
|
||||
yield self.store.remove_group_role(
|
||||
group_id=group_id,
|
||||
role_id=role_id,
|
||||
)
|
||||
|
||||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_group_summary_user(self, group_id, requester_user_id, user_id, role_id,
|
||||
content):
|
||||
"""Add/update a users entry in the group summary
|
||||
"""
|
||||
yield self.check_group_is_ours(
|
||||
group_id, and_exists=True, and_is_admin=requester_user_id,
|
||||
)
|
||||
|
||||
order = content.get("order", None)
|
||||
|
||||
is_public = _parse_visibility_from_contents(content)
|
||||
|
||||
yield self.store.add_user_to_summary(
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
role_id=role_id,
|
||||
order=order,
|
||||
is_public=is_public,
|
||||
)
|
||||
|
||||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_group_summary_user(self, group_id, requester_user_id, user_id, role_id):
|
||||
"""Remove a user from the group summary
|
||||
"""
|
||||
yield self.check_group_is_ours(
|
||||
group_id, and_exists=True, and_is_admin=requester_user_id,
|
||||
)
|
||||
|
||||
yield self.store.remove_user_from_summary(
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
role_id=role_id,
|
||||
)
|
||||
|
||||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_group_profile(self, group_id, requester_user_id):
|
||||
"""Get the group profile as seen by requester_user_id
|
||||
"""
|
||||
|
||||
yield self.check_group_is_ours(group_id)
|
||||
|
||||
group_description = yield self.store.get_group(group_id)
|
||||
|
||||
if group_description:
|
||||
defer.returnValue(group_description)
|
||||
else:
|
||||
raise SynapseError(404, "Unknown group")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_group_profile(self, group_id, requester_user_id, content):
|
||||
"""Update the group profile
|
||||
"""
|
||||
yield self.check_group_is_ours(
|
||||
group_id, and_exists=True, and_is_admin=requester_user_id,
|
||||
)
|
||||
|
||||
profile = {}
|
||||
for keyname in ("name", "avatar_url", "short_description",
|
||||
"long_description"):
|
||||
if keyname in content:
|
||||
value = content[keyname]
|
||||
if not isinstance(value, basestring):
|
||||
raise SynapseError(400, "%r value is not a string" % (keyname,))
|
||||
profile[keyname] = value
|
||||
|
||||
yield self.store.update_group_profile(group_id, profile)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_users_in_group(self, group_id, requester_user_id):
|
||||
"""Get the users in group as seen by requester_user_id.
|
||||
|
||||
The ordering is arbitrary at the moment
|
||||
"""
|
||||
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
|
||||
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
|
||||
|
||||
user_results = yield self.store.get_users_in_group(
|
||||
group_id, include_private=is_user_in_group,
|
||||
)
|
||||
|
||||
chunk = []
|
||||
for user_result in user_results:
|
||||
g_user_id = user_result["user_id"]
|
||||
is_public = user_result["is_public"]
|
||||
|
||||
entry = {"user_id": g_user_id}
|
||||
|
||||
profile = yield self.profile_handler.get_profile_from_cache(g_user_id)
|
||||
entry.update(profile)
|
||||
|
||||
if not is_public:
|
||||
entry["is_public"] = False
|
||||
|
||||
if not self.is_mine_id(g_user_id):
|
||||
attestation = yield self.store.get_remote_attestation(group_id, g_user_id)
|
||||
if not attestation:
|
||||
continue
|
||||
|
||||
entry["attestation"] = attestation
|
||||
else:
|
||||
entry["attestation"] = self.attestations.create_attestation(
|
||||
group_id, g_user_id,
|
||||
)
|
||||
|
||||
chunk.append(entry)
|
||||
|
||||
# TODO: If admin add lists of users whose attestations have timed out
|
||||
|
||||
defer.returnValue({
|
||||
"chunk": chunk,
|
||||
"total_user_count_estimate": len(user_results),
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_invited_users_in_group(self, group_id, requester_user_id):
|
||||
"""Get the users that have been invited to a group as seen by requester_user_id.
|
||||
|
||||
The ordering is arbitrary at the moment
|
||||
"""
|
||||
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
|
||||
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
|
||||
|
||||
if not is_user_in_group:
|
||||
raise SynapseError(403, "User not in group")
|
||||
|
||||
invited_users = yield self.store.get_invited_users_in_group(group_id)
|
||||
|
||||
user_profiles = []
|
||||
|
||||
for user_id in invited_users:
|
||||
user_profile = {
|
||||
"user_id": user_id
|
||||
}
|
||||
try:
|
||||
profile = yield self.profile_handler.get_profile_from_cache(user_id)
|
||||
user_profile.update(profile)
|
||||
except Exception as e:
|
||||
logger.warn("Error getting profile for %s: %s", user_id, e)
|
||||
user_profiles.append(user_profile)
|
||||
|
||||
defer.returnValue({
|
||||
"chunk": user_profiles,
|
||||
"total_user_count_estimate": len(invited_users),
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_rooms_in_group(self, group_id, requester_user_id):
|
||||
"""Get the rooms in group as seen by requester_user_id
|
||||
|
||||
This returns rooms in order of decreasing number of joined users
|
||||
"""
|
||||
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
|
||||
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
|
||||
|
||||
room_results = yield self.store.get_rooms_in_group(
|
||||
group_id, include_private=is_user_in_group,
|
||||
)
|
||||
|
||||
chunk = []
|
||||
for room_result in room_results:
|
||||
room_id = room_result["room_id"]
|
||||
is_public = room_result["is_public"]
|
||||
|
||||
joined_users = yield self.store.get_users_in_room(room_id)
|
||||
entry = yield self.room_list_handler.generate_room_entry(
|
||||
room_id, len(joined_users),
|
||||
with_alias=False, allow_private=True,
|
||||
)
|
||||
|
||||
if not entry:
|
||||
continue
|
||||
|
||||
if not is_public:
|
||||
entry["is_public"] = False
|
||||
|
||||
chunk.append(entry)
|
||||
|
||||
chunk.sort(key=lambda e: -e["num_joined_members"])
|
||||
|
||||
defer.returnValue({
|
||||
"chunk": chunk,
|
||||
"total_room_count_estimate": len(room_results),
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_room_to_group(self, group_id, requester_user_id, room_id, content):
|
||||
"""Add room to group
|
||||
"""
|
||||
RoomID.from_string(room_id) # Ensure valid room id
|
||||
|
||||
yield self.check_group_is_ours(
|
||||
group_id, and_exists=True, and_is_admin=requester_user_id
|
||||
)
|
||||
|
||||
is_public = _parse_visibility_from_contents(content)
|
||||
|
||||
yield self.store.add_room_to_group(group_id, room_id, is_public=is_public)
|
||||
|
||||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def remove_room_from_group(self, group_id, requester_user_id, room_id):
|
||||
"""Remove room from group
|
||||
"""
|
||||
yield self.check_group_is_ours(
|
||||
group_id, and_exists=True, and_is_admin=requester_user_id
|
||||
)
|
||||
|
||||
yield self.store.remove_room_from_group(group_id, room_id)
|
||||
|
||||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def invite_to_group(self, group_id, user_id, requester_user_id, content):
|
||||
"""Invite user to group
|
||||
"""
|
||||
|
||||
group = yield self.check_group_is_ours(
|
||||
group_id, and_exists=True, and_is_admin=requester_user_id
|
||||
)
|
||||
|
||||
# TODO: Check if user knocked
|
||||
# TODO: Check if user is already invited
|
||||
|
||||
content = {
|
||||
"profile": {
|
||||
"name": group["name"],
|
||||
"avatar_url": group["avatar_url"],
|
||||
},
|
||||
"inviter": requester_user_id,
|
||||
}
|
||||
|
||||
if self.hs.is_mine_id(user_id):
|
||||
groups_local = self.hs.get_groups_local_handler()
|
||||
res = yield groups_local.on_invite(group_id, user_id, content)
|
||||
local_attestation = None
|
||||
else:
|
||||
local_attestation = self.attestations.create_attestation(group_id, user_id)
|
||||
content.update({
|
||||
"attestation": local_attestation,
|
||||
})
|
||||
|
||||
res = yield self.transport_client.invite_to_group_notification(
|
||||
get_domain_from_id(user_id), group_id, user_id, content
|
||||
)
|
||||
|
||||
user_profile = res.get("user_profile", {})
|
||||
yield self.store.add_remote_profile_cache(
|
||||
user_id,
|
||||
displayname=user_profile.get("displayname"),
|
||||
avatar_url=user_profile.get("avatar_url"),
|
||||
)
|
||||
|
||||
if res["state"] == "join":
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
remote_attestation = res["attestation"]
|
||||
|
||||
yield self.attestations.verify_attestation(
|
||||
remote_attestation,
|
||||
user_id=user_id,
|
||||
group_id=group_id,
|
||||
)
|
||||
else:
|
||||
remote_attestation = None
|
||||
|
||||
yield self.store.add_user_to_group(
|
||||
group_id, user_id,
|
||||
is_admin=False,
|
||||
is_public=False, # TODO
|
||||
local_attestation=local_attestation,
|
||||
remote_attestation=remote_attestation,
|
||||
)
|
||||
elif res["state"] == "invite":
|
||||
yield self.store.add_group_invite(
|
||||
group_id, user_id,
|
||||
)
|
||||
defer.returnValue({
|
||||
"state": "invite"
|
||||
})
|
||||
elif res["state"] == "reject":
|
||||
defer.returnValue({
|
||||
"state": "reject"
|
||||
})
|
||||
else:
|
||||
raise SynapseError(502, "Unknown state returned by HS")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def accept_invite(self, group_id, user_id, content):
|
||||
"""User tries to accept an invite to the group.
|
||||
|
||||
This is different from them asking to join, and so should error if no
|
||||
invite exists (and they're not a member of the group)
|
||||
"""
|
||||
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
|
||||
if not self.store.is_user_invited_to_local_group(group_id, user_id):
|
||||
raise SynapseError(403, "User not invited to group")
|
||||
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
remote_attestation = content["attestation"]
|
||||
|
||||
yield self.attestations.verify_attestation(
|
||||
remote_attestation,
|
||||
user_id=user_id,
|
||||
group_id=group_id,
|
||||
)
|
||||
else:
|
||||
remote_attestation = None
|
||||
|
||||
local_attestation = self.attestations.create_attestation(group_id, user_id)
|
||||
|
||||
is_public = _parse_visibility_from_contents(content)
|
||||
|
||||
yield self.store.add_user_to_group(
|
||||
group_id, user_id,
|
||||
is_admin=False,
|
||||
is_public=is_public,
|
||||
local_attestation=local_attestation,
|
||||
remote_attestation=remote_attestation,
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
"state": "join",
|
||||
"attestation": local_attestation,
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def knock(self, group_id, user_id, content):
|
||||
"""A user requests becoming a member of the group
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def accept_knock(self, group_id, user_id, content):
|
||||
"""Accept a users knock to the room.
|
||||
|
||||
Errors if the user hasn't knocked, rather than inviting them.
|
||||
"""
|
||||
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
|
||||
"""Remove a user from the group; either a user is leaving or and admin
|
||||
kicked htem.
|
||||
"""
|
||||
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
|
||||
is_kick = False
|
||||
if requester_user_id != user_id:
|
||||
is_admin = yield self.store.is_user_admin_in_group(
|
||||
group_id, requester_user_id
|
||||
)
|
||||
if not is_admin:
|
||||
raise SynapseError(403, "User is not admin in group")
|
||||
|
||||
is_kick = True
|
||||
|
||||
yield self.store.remove_user_from_group(
|
||||
group_id, user_id,
|
||||
)
|
||||
|
||||
if is_kick:
|
||||
if self.hs.is_mine_id(user_id):
|
||||
groups_local = self.hs.get_groups_local_handler()
|
||||
yield groups_local.user_removed_from_group(group_id, user_id, {})
|
||||
else:
|
||||
yield self.transport_client.remove_user_from_group_notification(
|
||||
get_domain_from_id(user_id), group_id, user_id, {}
|
||||
)
|
||||
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
yield self.store.maybe_delete_remote_profile_cache(user_id)
|
||||
|
||||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_group(self, group_id, user_id, content):
|
||||
group = yield self.check_group_is_ours(group_id)
|
||||
|
||||
_validate_group_id(group_id)
|
||||
|
||||
logger.info("Attempting to create group with ID: %r", group_id)
|
||||
if group:
|
||||
raise SynapseError(400, "Group already exists")
|
||||
|
||||
is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id))
|
||||
if not is_admin:
|
||||
if not self.hs.config.enable_group_creation:
|
||||
raise SynapseError(
|
||||
403, "Only server admin can create group on this server",
|
||||
)
|
||||
localpart = GroupID.from_string(group_id).localpart
|
||||
if not localpart.startswith(self.hs.config.group_creation_prefix):
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Can only create groups with prefix %r on this server" % (
|
||||
self.hs.config.group_creation_prefix,
|
||||
),
|
||||
)
|
||||
|
||||
profile = content.get("profile", {})
|
||||
name = profile.get("name")
|
||||
avatar_url = profile.get("avatar_url")
|
||||
short_description = profile.get("short_description")
|
||||
long_description = profile.get("long_description")
|
||||
user_profile = content.get("user_profile", {})
|
||||
|
||||
yield self.store.create_group(
|
||||
group_id,
|
||||
user_id,
|
||||
name=name,
|
||||
avatar_url=avatar_url,
|
||||
short_description=short_description,
|
||||
long_description=long_description,
|
||||
)
|
||||
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
remote_attestation = content["attestation"]
|
||||
|
||||
yield self.attestations.verify_attestation(
|
||||
remote_attestation,
|
||||
user_id=user_id,
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
local_attestation = self.attestations.create_attestation(group_id, user_id)
|
||||
else:
|
||||
local_attestation = None
|
||||
remote_attestation = None
|
||||
|
||||
yield self.store.add_user_to_group(
|
||||
group_id, user_id,
|
||||
is_admin=True,
|
||||
is_public=True, # TODO
|
||||
local_attestation=local_attestation,
|
||||
remote_attestation=remote_attestation,
|
||||
)
|
||||
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
yield self.store.add_remote_profile_cache(
|
||||
user_id,
|
||||
displayname=user_profile.get("displayname"),
|
||||
avatar_url=user_profile.get("avatar_url"),
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
"group_id": group_id,
|
||||
})
|
||||
|
||||
|
||||
def _parse_visibility_from_contents(content):
|
||||
"""Given a content for a request parse out whether the entity should be
|
||||
public or not
|
||||
"""
|
||||
|
||||
visibility = content.get("visibility")
|
||||
if visibility:
|
||||
vis_type = visibility["type"]
|
||||
if vis_type not in ("public", "private"):
|
||||
raise SynapseError(
|
||||
400, "Synapse only supports 'public'/'private' visibility"
|
||||
)
|
||||
is_public = vis_type == "public"
|
||||
else:
|
||||
is_public = True
|
||||
|
||||
return is_public
|
||||
|
||||
|
||||
def _validate_group_id(group_id):
|
||||
"""Validates the group ID is valid for creation on this home server
|
||||
"""
|
||||
localpart = GroupID.from_string(group_id).localpart
|
||||
|
||||
if localpart.lower() != localpart:
|
||||
raise SynapseError(400, "Group ID must be lower case")
|
||||
|
||||
if urllib.quote(localpart.encode('utf-8')) != localpart:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Group ID can only contain characters a-z, 0-9, or '_-./'",
|
||||
)
|
@ -20,7 +20,6 @@ from .room import (
|
||||
from .room_member import RoomMemberHandler
|
||||
from .message import MessageHandler
|
||||
from .federation import FederationHandler
|
||||
from .profile import ProfileHandler
|
||||
from .directory import DirectoryHandler
|
||||
from .admin import AdminHandler
|
||||
from .identity import IdentityHandler
|
||||
@ -52,7 +51,6 @@ class Handlers(object):
|
||||
self.room_creation_handler = RoomCreationHandler(hs)
|
||||
self.room_member_handler = RoomMemberHandler(hs)
|
||||
self.federation_handler = FederationHandler(hs)
|
||||
self.profile_handler = ProfileHandler(hs)
|
||||
self.directory_handler = DirectoryHandler(hs)
|
||||
self.admin_handler = AdminHandler(hs)
|
||||
self.identity_handler = IdentityHandler(hs)
|
||||
|
@ -40,6 +40,8 @@ class DirectoryHandler(BaseHandler):
|
||||
"directory", self.on_directory_query
|
||||
)
|
||||
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _create_association(self, room_alias, room_id, servers=None, creator=None):
|
||||
# general association creation for both human users and app services
|
||||
@ -73,6 +75,11 @@ class DirectoryHandler(BaseHandler):
|
||||
# association creation for human users
|
||||
# TODO(erikj): Do user auth.
|
||||
|
||||
if not self.spam_checker.user_may_create_room_alias(user_id, room_alias):
|
||||
raise SynapseError(
|
||||
403, "This user is not permitted to create this alias",
|
||||
)
|
||||
|
||||
can_create = yield self.can_modify_alias(
|
||||
room_alias,
|
||||
user_id=user_id
|
||||
@ -327,6 +334,14 @@ class DirectoryHandler(BaseHandler):
|
||||
room_id (str)
|
||||
visibility (str): "public" or "private"
|
||||
"""
|
||||
if not self.spam_checker.user_may_publish_room(
|
||||
requester.user.to_string(), room_id
|
||||
):
|
||||
raise AuthError(
|
||||
403,
|
||||
"This user is not permitted to publish rooms to the room list"
|
||||
)
|
||||
|
||||
if requester.is_guest:
|
||||
raise AuthError(403, "Guests cannot edit the published room list")
|
||||
|
||||
|
@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""Contains handlers for federation events."""
|
||||
import synapse.util.logcontext
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
from signedjson.sign import verify_signed_json
|
||||
from unpaddedbase64 import decode_base64
|
||||
@ -26,10 +25,7 @@ from synapse.api.errors import (
|
||||
)
|
||||
from synapse.api.constants import EventTypes, Membership, RejectedReason
|
||||
from synapse.events.validator import EventValidator
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.logcontext import (
|
||||
preserve_fn, preserve_context_over_deferred
|
||||
)
|
||||
from synapse.util import unwrapFirstError, logcontext
|
||||
from synapse.util.metrics import measure_func
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.async import run_on_reactor, Linearizer
|
||||
@ -77,6 +73,7 @@ class FederationHandler(BaseHandler):
|
||||
self.action_generator = hs.get_action_generator()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.pusher_pool = hs.get_pusherpool()
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
|
||||
self.replication_layer.set_handler(self)
|
||||
|
||||
@ -125,6 +122,28 @@ class FederationHandler(BaseHandler):
|
||||
self.room_queues[pdu.room_id].append((pdu, origin))
|
||||
return
|
||||
|
||||
# If we're no longer in the room just ditch the event entirely. This
|
||||
# is probably an old server that has come back and thinks we're still
|
||||
# in the room (or we've been rejoined to the room by a state reset).
|
||||
#
|
||||
# If we were never in the room then maybe our database got vaped and
|
||||
# we should check if we *are* in fact in the room. If we are then we
|
||||
# can magically rejoin the room.
|
||||
is_in_room = yield self.auth.check_host_in_room(
|
||||
pdu.room_id,
|
||||
self.server_name
|
||||
)
|
||||
if not is_in_room:
|
||||
was_in_room = yield self.store.was_host_joined(
|
||||
pdu.room_id, self.server_name,
|
||||
)
|
||||
if was_in_room:
|
||||
logger.info(
|
||||
"Ignoring PDU %s for room %s from %s as we've left the room!",
|
||||
pdu.event_id, pdu.room_id, origin,
|
||||
)
|
||||
return
|
||||
|
||||
state = None
|
||||
|
||||
auth_chain = []
|
||||
@ -591,9 +610,9 @@ class FederationHandler(BaseHandler):
|
||||
missing_auth - failed_to_fetch
|
||||
)
|
||||
|
||||
results = yield preserve_context_over_deferred(defer.gatherResults(
|
||||
results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
|
||||
[
|
||||
preserve_fn(self.replication_layer.get_pdu)(
|
||||
logcontext.preserve_fn(self.replication_layer.get_pdu)(
|
||||
[dest],
|
||||
event_id,
|
||||
outlier=True,
|
||||
@ -785,10 +804,14 @@ class FederationHandler(BaseHandler):
|
||||
event_ids = list(extremities.keys())
|
||||
|
||||
logger.debug("calling resolve_state_groups in _maybe_backfill")
|
||||
states = yield preserve_context_over_deferred(defer.gatherResults([
|
||||
preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e])
|
||||
for e in event_ids
|
||||
]))
|
||||
states = yield logcontext.make_deferred_yieldable(defer.gatherResults(
|
||||
[
|
||||
logcontext.preserve_fn(self.state_handler.resolve_state_groups)(
|
||||
room_id, [e]
|
||||
)
|
||||
for e in event_ids
|
||||
], consumeErrors=True,
|
||||
))
|
||||
states = dict(zip(event_ids, [s.state for s in states]))
|
||||
|
||||
state_map = yield self.store.get_events(
|
||||
@ -941,9 +964,7 @@ class FederationHandler(BaseHandler):
|
||||
# lots of requests for missing prev_events which we do actually
|
||||
# have. Hence we fire off the deferred, but don't wait for it.
|
||||
|
||||
synapse.util.logcontext.preserve_fn(self._handle_queued_pdus)(
|
||||
room_queue
|
||||
)
|
||||
logcontext.preserve_fn(self._handle_queued_pdus)(room_queue)
|
||||
|
||||
defer.returnValue(True)
|
||||
|
||||
@ -1070,6 +1091,9 @@ class FederationHandler(BaseHandler):
|
||||
"""
|
||||
event = pdu
|
||||
|
||||
if event.state_key is None:
|
||||
raise SynapseError(400, "The invite event did not have a state key")
|
||||
|
||||
is_blocked = yield self.store.is_room_blocked(event.room_id)
|
||||
if is_blocked:
|
||||
raise SynapseError(403, "This room has been blocked on this server")
|
||||
@ -1077,6 +1101,13 @@ class FederationHandler(BaseHandler):
|
||||
if self.hs.config.block_non_admin_invites:
|
||||
raise SynapseError(403, "This server does not accept room invites")
|
||||
|
||||
if not self.spam_checker.user_may_invite(
|
||||
event.sender, event.state_key, event.room_id,
|
||||
):
|
||||
raise SynapseError(
|
||||
403, "This user is not permitted to send invites to this server/user"
|
||||
)
|
||||
|
||||
membership = event.content.get("membership")
|
||||
if event.type != EventTypes.Member or membership != Membership.INVITE:
|
||||
raise SynapseError(400, "The event was not an m.room.member invite event")
|
||||
@ -1085,9 +1116,6 @@ class FederationHandler(BaseHandler):
|
||||
if sender_domain != origin:
|
||||
raise SynapseError(400, "The invite event was not from the server sending it")
|
||||
|
||||
if event.state_key is None:
|
||||
raise SynapseError(400, "The invite event did not have a state key")
|
||||
|
||||
if not self.is_mine_id(event.state_key):
|
||||
raise SynapseError(400, "The invite event must be for this server")
|
||||
|
||||
@ -1430,7 +1458,7 @@ class FederationHandler(BaseHandler):
|
||||
if not backfilled:
|
||||
# this intentionally does not yield: we don't care about the result
|
||||
# and don't need to wait for it.
|
||||
preserve_fn(self.pusher_pool.on_new_notifications)(
|
||||
logcontext.preserve_fn(self.pusher_pool.on_new_notifications)(
|
||||
event_stream_id, max_stream_id
|
||||
)
|
||||
|
||||
@ -1443,16 +1471,16 @@ class FederationHandler(BaseHandler):
|
||||
a bunch of outliers, but not a chunk of individual events that depend
|
||||
on each other for state calculations.
|
||||
"""
|
||||
contexts = yield preserve_context_over_deferred(defer.gatherResults(
|
||||
contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults(
|
||||
[
|
||||
preserve_fn(self._prep_event)(
|
||||
logcontext.preserve_fn(self._prep_event)(
|
||||
origin,
|
||||
ev_info["event"],
|
||||
state=ev_info.get("state"),
|
||||
auth_events=ev_info.get("auth_events"),
|
||||
)
|
||||
for ev_info in event_infos
|
||||
]
|
||||
], consumeErrors=True,
|
||||
))
|
||||
|
||||
yield self.store.persist_events(
|
||||
@ -1760,18 +1788,17 @@ class FederationHandler(BaseHandler):
|
||||
# Do auth conflict res.
|
||||
logger.info("Different auth: %s", different_auth)
|
||||
|
||||
different_events = yield preserve_context_over_deferred(defer.gatherResults(
|
||||
[
|
||||
preserve_fn(self.store.get_event)(
|
||||
different_events = yield logcontext.make_deferred_yieldable(
|
||||
defer.gatherResults([
|
||||
logcontext.preserve_fn(self.store.get_event)(
|
||||
d,
|
||||
allow_none=True,
|
||||
allow_rejected=False,
|
||||
)
|
||||
for d in different_auth
|
||||
if d in have_events and not have_events[d]
|
||||
],
|
||||
consumeErrors=True
|
||||
)).addErrback(unwrapFirstError)
|
||||
], consumeErrors=True)
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
if different_events:
|
||||
local_view = dict(auth_events)
|
||||
|
417
synapse/handlers/groups_local.py
Normal file
417
synapse/handlers/groups_local.py
Normal file
@ -0,0 +1,417 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.types import get_domain_from_id
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_rerouter(func_name):
|
||||
"""Returns a function that looks at the group id and calls the function
|
||||
on federation or the local group server if the group is local
|
||||
"""
|
||||
def f(self, group_id, *args, **kwargs):
|
||||
if self.is_mine_id(group_id):
|
||||
return getattr(self.groups_server_handler, func_name)(
|
||||
group_id, *args, **kwargs
|
||||
)
|
||||
else:
|
||||
destination = get_domain_from_id(group_id)
|
||||
return getattr(self.transport_client, func_name)(
|
||||
destination, group_id, *args, **kwargs
|
||||
)
|
||||
return f
|
||||
|
||||
|
||||
class GroupsLocalHandler(object):
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.room_list_handler = hs.get_room_list_handler()
|
||||
self.groups_server_handler = hs.get_groups_server_handler()
|
||||
self.transport_client = hs.get_federation_transport_client()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.keyring = hs.get_keyring()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.signing_key = hs.config.signing_key[0]
|
||||
self.server_name = hs.hostname
|
||||
self.notifier = hs.get_notifier()
|
||||
self.attestations = hs.get_groups_attestation_signing()
|
||||
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
|
||||
# Ensure attestations get renewed
|
||||
hs.get_groups_attestation_renewer()
|
||||
|
||||
# The following functions merely route the query to the local groups server
|
||||
# or federation depending on if the group is local or remote
|
||||
|
||||
get_group_profile = _create_rerouter("get_group_profile")
|
||||
update_group_profile = _create_rerouter("update_group_profile")
|
||||
get_rooms_in_group = _create_rerouter("get_rooms_in_group")
|
||||
|
||||
get_invited_users_in_group = _create_rerouter("get_invited_users_in_group")
|
||||
|
||||
add_room_to_group = _create_rerouter("add_room_to_group")
|
||||
remove_room_from_group = _create_rerouter("remove_room_from_group")
|
||||
|
||||
update_group_summary_room = _create_rerouter("update_group_summary_room")
|
||||
delete_group_summary_room = _create_rerouter("delete_group_summary_room")
|
||||
|
||||
update_group_category = _create_rerouter("update_group_category")
|
||||
delete_group_category = _create_rerouter("delete_group_category")
|
||||
get_group_category = _create_rerouter("get_group_category")
|
||||
get_group_categories = _create_rerouter("get_group_categories")
|
||||
|
||||
update_group_summary_user = _create_rerouter("update_group_summary_user")
|
||||
delete_group_summary_user = _create_rerouter("delete_group_summary_user")
|
||||
|
||||
update_group_role = _create_rerouter("update_group_role")
|
||||
delete_group_role = _create_rerouter("delete_group_role")
|
||||
get_group_role = _create_rerouter("get_group_role")
|
||||
get_group_roles = _create_rerouter("get_group_roles")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_group_summary(self, group_id, requester_user_id):
|
||||
"""Get the group summary for a group.
|
||||
|
||||
If the group is remote we check that the users have valid attestations.
|
||||
"""
|
||||
if self.is_mine_id(group_id):
|
||||
res = yield self.groups_server_handler.get_group_summary(
|
||||
group_id, requester_user_id
|
||||
)
|
||||
else:
|
||||
res = yield self.transport_client.get_group_summary(
|
||||
get_domain_from_id(group_id), group_id, requester_user_id,
|
||||
)
|
||||
|
||||
group_server_name = get_domain_from_id(group_id)
|
||||
|
||||
# Loop through the users and validate the attestations.
|
||||
chunk = res["users_section"]["users"]
|
||||
valid_users = []
|
||||
for entry in chunk:
|
||||
g_user_id = entry["user_id"]
|
||||
attestation = entry.pop("attestation", {})
|
||||
try:
|
||||
if get_domain_from_id(g_user_id) != group_server_name:
|
||||
yield self.attestations.verify_attestation(
|
||||
attestation,
|
||||
group_id=group_id,
|
||||
user_id=g_user_id,
|
||||
server_name=get_domain_from_id(g_user_id),
|
||||
)
|
||||
valid_users.append(entry)
|
||||
except Exception as e:
|
||||
logger.info("Failed to verify user is in group: %s", e)
|
||||
|
||||
res["users_section"]["users"] = valid_users
|
||||
|
||||
res["users_section"]["users"].sort(key=lambda e: e.get("order", 0))
|
||||
res["rooms_section"]["rooms"].sort(key=lambda e: e.get("order", 0))
|
||||
|
||||
# Add `is_publicised` flag to indicate whether the user has publicised their
|
||||
# membership of the group on their profile
|
||||
result = yield self.store.get_publicised_groups_for_user(requester_user_id)
|
||||
is_publicised = group_id in result
|
||||
|
||||
res.setdefault("user", {})["is_publicised"] = is_publicised
|
||||
|
||||
defer.returnValue(res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_group(self, group_id, user_id, content):
|
||||
"""Create a group
|
||||
"""
|
||||
|
||||
logger.info("Asking to create group with ID: %r", group_id)
|
||||
|
||||
if self.is_mine_id(group_id):
|
||||
res = yield self.groups_server_handler.create_group(
|
||||
group_id, user_id, content
|
||||
)
|
||||
local_attestation = None
|
||||
remote_attestation = None
|
||||
else:
|
||||
local_attestation = self.attestations.create_attestation(group_id, user_id)
|
||||
content["attestation"] = local_attestation
|
||||
|
||||
content["user_profile"] = yield self.profile_handler.get_profile(user_id)
|
||||
|
||||
res = yield self.transport_client.create_group(
|
||||
get_domain_from_id(group_id), group_id, user_id, content,
|
||||
)
|
||||
|
||||
remote_attestation = res["attestation"]
|
||||
yield self.attestations.verify_attestation(
|
||||
remote_attestation,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
server_name=get_domain_from_id(group_id),
|
||||
)
|
||||
|
||||
is_publicised = content.get("publicise", False)
|
||||
token = yield self.store.register_user_group_membership(
|
||||
group_id, user_id,
|
||||
membership="join",
|
||||
is_admin=True,
|
||||
local_attestation=local_attestation,
|
||||
remote_attestation=remote_attestation,
|
||||
is_publicised=is_publicised,
|
||||
)
|
||||
self.notifier.on_new_event(
|
||||
"groups_key", token, users=[user_id],
|
||||
)
|
||||
|
||||
defer.returnValue(res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_users_in_group(self, group_id, requester_user_id):
|
||||
"""Get users in a group
|
||||
"""
|
||||
if self.is_mine_id(group_id):
|
||||
res = yield self.groups_server_handler.get_users_in_group(
|
||||
group_id, requester_user_id
|
||||
)
|
||||
defer.returnValue(res)
|
||||
|
||||
group_server_name = get_domain_from_id(group_id)
|
||||
|
||||
res = yield self.transport_client.get_users_in_group(
|
||||
get_domain_from_id(group_id), group_id, requester_user_id,
|
||||
)
|
||||
|
||||
chunk = res["chunk"]
|
||||
valid_entries = []
|
||||
for entry in chunk:
|
||||
g_user_id = entry["user_id"]
|
||||
attestation = entry.pop("attestation", {})
|
||||
try:
|
||||
if get_domain_from_id(g_user_id) != group_server_name:
|
||||
yield self.attestations.verify_attestation(
|
||||
attestation,
|
||||
group_id=group_id,
|
||||
user_id=g_user_id,
|
||||
server_name=get_domain_from_id(g_user_id),
|
||||
)
|
||||
valid_entries.append(entry)
|
||||
except Exception as e:
|
||||
logger.info("Failed to verify user is in group: %s", e)
|
||||
|
||||
res["chunk"] = valid_entries
|
||||
|
||||
defer.returnValue(res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def join_group(self, group_id, user_id, content):
|
||||
"""Request to join a group
|
||||
"""
|
||||
raise NotImplementedError() # TODO
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def accept_invite(self, group_id, user_id, content):
|
||||
"""Accept an invite to a group
|
||||
"""
|
||||
if self.is_mine_id(group_id):
|
||||
yield self.groups_server_handler.accept_invite(
|
||||
group_id, user_id, content
|
||||
)
|
||||
local_attestation = None
|
||||
remote_attestation = None
|
||||
else:
|
||||
local_attestation = self.attestations.create_attestation(group_id, user_id)
|
||||
content["attestation"] = local_attestation
|
||||
|
||||
res = yield self.transport_client.accept_group_invite(
|
||||
get_domain_from_id(group_id), group_id, user_id, content,
|
||||
)
|
||||
|
||||
remote_attestation = res["attestation"]
|
||||
|
||||
yield self.attestations.verify_attestation(
|
||||
remote_attestation,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
server_name=get_domain_from_id(group_id),
|
||||
)
|
||||
|
||||
# TODO: Check that the group is public and we're being added publically
|
||||
is_publicised = content.get("publicise", False)
|
||||
|
||||
token = yield self.store.register_user_group_membership(
|
||||
group_id, user_id,
|
||||
membership="join",
|
||||
is_admin=False,
|
||||
local_attestation=local_attestation,
|
||||
remote_attestation=remote_attestation,
|
||||
is_publicised=is_publicised,
|
||||
)
|
||||
self.notifier.on_new_event(
|
||||
"groups_key", token, users=[user_id],
|
||||
)
|
||||
|
||||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def invite(self, group_id, user_id, requester_user_id, config):
|
||||
"""Invite a user to a group
|
||||
"""
|
||||
content = {
|
||||
"requester_user_id": requester_user_id,
|
||||
"config": config,
|
||||
}
|
||||
if self.is_mine_id(group_id):
|
||||
res = yield self.groups_server_handler.invite_to_group(
|
||||
group_id, user_id, requester_user_id, content,
|
||||
)
|
||||
else:
|
||||
res = yield self.transport_client.invite_to_group(
|
||||
get_domain_from_id(group_id), group_id, user_id, requester_user_id,
|
||||
content,
|
||||
)
|
||||
|
||||
defer.returnValue(res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_invite(self, group_id, user_id, content):
|
||||
"""One of our users were invited to a group
|
||||
"""
|
||||
# TODO: Support auto join and rejection
|
||||
|
||||
if not self.is_mine_id(user_id):
|
||||
raise SynapseError(400, "User not on this server")
|
||||
|
||||
local_profile = {}
|
||||
if "profile" in content:
|
||||
if "name" in content["profile"]:
|
||||
local_profile["name"] = content["profile"]["name"]
|
||||
if "avatar_url" in content["profile"]:
|
||||
local_profile["avatar_url"] = content["profile"]["avatar_url"]
|
||||
|
||||
token = yield self.store.register_user_group_membership(
|
||||
group_id, user_id,
|
||||
membership="invite",
|
||||
content={"profile": local_profile, "inviter": content["inviter"]},
|
||||
)
|
||||
self.notifier.on_new_event(
|
||||
"groups_key", token, users=[user_id],
|
||||
)
|
||||
try:
|
||||
user_profile = yield self.profile_handler.get_profile(user_id)
|
||||
except Exception as e:
|
||||
logger.warn("No profile for user %s: %s", user_id, e)
|
||||
user_profile = {}
|
||||
|
||||
defer.returnValue({"state": "invite", "user_profile": user_profile})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
|
||||
"""Remove a user from a group
|
||||
"""
|
||||
if user_id == requester_user_id:
|
||||
token = yield self.store.register_user_group_membership(
|
||||
group_id, user_id,
|
||||
membership="leave",
|
||||
)
|
||||
self.notifier.on_new_event(
|
||||
"groups_key", token, users=[user_id],
|
||||
)
|
||||
|
||||
# TODO: Should probably remember that we tried to leave so that we can
|
||||
# retry if the group server is currently down.
|
||||
|
||||
if self.is_mine_id(group_id):
|
||||
res = yield self.groups_server_handler.remove_user_from_group(
|
||||
group_id, user_id, requester_user_id, content,
|
||||
)
|
||||
else:
|
||||
content["requester_user_id"] = requester_user_id
|
||||
res = yield self.transport_client.remove_user_from_group(
|
||||
get_domain_from_id(group_id), group_id, requester_user_id,
|
||||
user_id, content,
|
||||
)
|
||||
|
||||
defer.returnValue(res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def user_removed_from_group(self, group_id, user_id, content):
|
||||
"""One of our users was removed/kicked from a group
|
||||
"""
|
||||
# TODO: Check if user in group
|
||||
token = yield self.store.register_user_group_membership(
|
||||
group_id, user_id,
|
||||
membership="leave",
|
||||
)
|
||||
self.notifier.on_new_event(
|
||||
"groups_key", token, users=[user_id],
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_joined_groups(self, user_id):
|
||||
group_ids = yield self.store.get_joined_groups(user_id)
|
||||
defer.returnValue({"groups": group_ids})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_publicised_groups_for_user(self, user_id):
|
||||
if self.hs.is_mine_id(user_id):
|
||||
result = yield self.store.get_publicised_groups_for_user(user_id)
|
||||
defer.returnValue({"groups": result})
|
||||
else:
|
||||
result = yield self.transport_client.get_publicised_groups_for_user(
|
||||
get_domain_from_id(user_id), user_id
|
||||
)
|
||||
# TODO: Verify attestations
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def bulk_get_publicised_groups(self, user_ids, proxy=True):
|
||||
destinations = {}
|
||||
local_users = set()
|
||||
|
||||
for user_id in user_ids:
|
||||
if self.hs.is_mine_id(user_id):
|
||||
local_users.add(user_id)
|
||||
else:
|
||||
destinations.setdefault(
|
||||
get_domain_from_id(user_id), set()
|
||||
).add(user_id)
|
||||
|
||||
if not proxy and destinations:
|
||||
raise SynapseError(400, "Some user_ids are not local")
|
||||
|
||||
results = {}
|
||||
failed_results = []
|
||||
for destination, dest_user_ids in destinations.iteritems():
|
||||
try:
|
||||
r = yield self.transport_client.bulk_get_publicised_groups(
|
||||
destination, list(dest_user_ids),
|
||||
)
|
||||
results.update(r["users"])
|
||||
except Exception:
|
||||
failed_results.extend(dest_user_ids)
|
||||
|
||||
for uid in local_users:
|
||||
results[uid] = yield self.store.get_publicised_groups_for_user(
|
||||
uid
|
||||
)
|
||||
|
||||
defer.returnValue({"users": results})
|
@ -1,5 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 - 2016 OpenMarket Ltd
|
||||
# Copyright 2017 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,7 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from synapse.events import spamcheck
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
@ -26,6 +26,7 @@ from synapse.types import (
|
||||
from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter
|
||||
from synapse.util.logcontext import preserve_fn
|
||||
from synapse.util.metrics import measure_func
|
||||
from synapse.util.frozenutils import unfreeze
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
from ._base import BaseHandler
|
||||
@ -47,6 +48,7 @@ class MessageHandler(BaseHandler):
|
||||
self.state = hs.get_state_handler()
|
||||
self.clock = hs.get_clock()
|
||||
self.validator = EventValidator()
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
|
||||
self.pagination_lock = ReadWriteLock()
|
||||
|
||||
@ -58,6 +60,8 @@ class MessageHandler(BaseHandler):
|
||||
|
||||
self.action_generator = hs.get_action_generator()
|
||||
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def purge_history(self, room_id, event_id):
|
||||
event = yield self.store.get_event(event_id)
|
||||
@ -210,7 +214,7 @@ class MessageHandler(BaseHandler):
|
||||
|
||||
if membership in {Membership.JOIN, Membership.INVITE}:
|
||||
# If event doesn't include a display name, add one.
|
||||
profile = self.hs.get_handlers().profile_handler
|
||||
profile = self.profile_handler
|
||||
content = builder.content
|
||||
|
||||
try:
|
||||
@ -322,9 +326,12 @@ class MessageHandler(BaseHandler):
|
||||
txn_id=txn_id
|
||||
)
|
||||
|
||||
if spamcheck.check_event_for_spam(event):
|
||||
spam_error = self.spam_checker.check_event_for_spam(event)
|
||||
if spam_error:
|
||||
if not isinstance(spam_error, basestring):
|
||||
spam_error = "Spam is not permitted here"
|
||||
raise SynapseError(
|
||||
403, "Spam is not permitted here", Codes.FORBIDDEN
|
||||
403, spam_error, Codes.FORBIDDEN
|
||||
)
|
||||
|
||||
yield self.send_nonmember_event(
|
||||
@ -418,6 +425,51 @@ class MessageHandler(BaseHandler):
|
||||
[serialize_event(c, now) for c in room_state.values()]
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_joined_members(self, requester, room_id):
|
||||
"""Get all the joined members in the room and their profile information.
|
||||
|
||||
If the user has left the room return the state events from when they left.
|
||||
|
||||
Args:
|
||||
requester(Requester): The user requesting state events.
|
||||
room_id(str): The room ID to get all state events from.
|
||||
Returns:
|
||||
A dict of user_id to profile info
|
||||
"""
|
||||
user_id = requester.user.to_string()
|
||||
if not requester.app_service:
|
||||
# We check AS auth after fetching the room membership, as it
|
||||
# requires us to pull out all joined members anyway.
|
||||
membership, _ = yield self._check_in_room_or_world_readable(
|
||||
room_id, user_id
|
||||
)
|
||||
if membership != Membership.JOIN:
|
||||
raise NotImplementedError(
|
||||
"Getting joined members after leaving is not implemented"
|
||||
)
|
||||
|
||||
users_with_profile = yield self.state.get_current_user_in_room(room_id)
|
||||
|
||||
# If this is an AS, double check that they are allowed to see the members.
|
||||
# This can either be because the AS user is in the room or becuase there
|
||||
# is a user in the room that the AS is "interested in"
|
||||
if requester.app_service and user_id not in users_with_profile:
|
||||
for uid in users_with_profile:
|
||||
if requester.app_service.is_interested_in_user(uid):
|
||||
break
|
||||
else:
|
||||
# Loop fell through, AS has no interested users in room
|
||||
raise AuthError(403, "Appservice not in room")
|
||||
|
||||
defer.returnValue({
|
||||
user_id: {
|
||||
"avatar_url": profile.avatar_url,
|
||||
"display_name": profile.display_name,
|
||||
}
|
||||
for user_id, profile in users_with_profile.iteritems()
|
||||
})
|
||||
|
||||
@measure_func("_create_new_client_event")
|
||||
@defer.inlineCallbacks
|
||||
def _create_new_client_event(self, builder, requester=None, prev_event_ids=None):
|
||||
@ -509,7 +561,7 @@ class MessageHandler(BaseHandler):
|
||||
|
||||
# Ensure that we can round trip before trying to persist in db
|
||||
try:
|
||||
dump = ujson.dumps(event.content)
|
||||
dump = ujson.dumps(unfreeze(event.content))
|
||||
ujson.loads(dump)
|
||||
except:
|
||||
logger.exception("Failed to encode content: %r", event.content)
|
||||
|
@ -19,14 +19,15 @@ from twisted.internet import defer
|
||||
|
||||
import synapse.types
|
||||
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
|
||||
from synapse.types import UserID
|
||||
from synapse.types import UserID, get_domain_from_id
|
||||
from ._base import BaseHandler
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProfileHandler(BaseHandler):
|
||||
PROFILE_UPDATE_MS = 60 * 1000
|
||||
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ProfileHandler, self).__init__(hs)
|
||||
@ -36,6 +37,63 @@ class ProfileHandler(BaseHandler):
|
||||
"profile", self.on_profile_query
|
||||
)
|
||||
|
||||
self.clock.looping_call(self._update_remote_profile_cache, self.PROFILE_UPDATE_MS)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_profile(self, user_id):
|
||||
target_user = UserID.from_string(user_id)
|
||||
if self.hs.is_mine(target_user):
|
||||
displayname = yield self.store.get_profile_displayname(
|
||||
target_user.localpart
|
||||
)
|
||||
avatar_url = yield self.store.get_profile_avatar_url(
|
||||
target_user.localpart
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
"displayname": displayname,
|
||||
"avatar_url": avatar_url,
|
||||
})
|
||||
else:
|
||||
try:
|
||||
result = yield self.federation.make_query(
|
||||
destination=target_user.domain,
|
||||
query_type="profile",
|
||||
args={
|
||||
"user_id": user_id,
|
||||
},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
defer.returnValue(result)
|
||||
except CodeMessageException as e:
|
||||
if e.code != 404:
|
||||
logger.exception("Failed to get displayname")
|
||||
|
||||
raise
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_profile_from_cache(self, user_id):
|
||||
"""Get the profile information from our local cache. If the user is
|
||||
ours then the profile information will always be corect. Otherwise,
|
||||
it may be out of date/missing.
|
||||
"""
|
||||
target_user = UserID.from_string(user_id)
|
||||
if self.hs.is_mine(target_user):
|
||||
displayname = yield self.store.get_profile_displayname(
|
||||
target_user.localpart
|
||||
)
|
||||
avatar_url = yield self.store.get_profile_avatar_url(
|
||||
target_user.localpart
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
"displayname": displayname,
|
||||
"avatar_url": avatar_url,
|
||||
})
|
||||
else:
|
||||
profile = yield self.store.get_from_remote_profile_cache(user_id)
|
||||
defer.returnValue(profile or {})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_displayname(self, target_user):
|
||||
if self.hs.is_mine(target_user):
|
||||
@ -182,3 +240,44 @@ class ProfileHandler(BaseHandler):
|
||||
"Failed to update join event for room %s - %s",
|
||||
room_id, str(e.message)
|
||||
)
|
||||
|
||||
def _update_remote_profile_cache(self):
|
||||
"""Called periodically to check profiles of remote users we haven't
|
||||
checked in a while.
|
||||
"""
|
||||
entries = yield self.store.get_remote_profile_cache_entries_that_expire(
|
||||
last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS
|
||||
)
|
||||
|
||||
for user_id, displayname, avatar_url in entries:
|
||||
is_subscribed = yield self.store.is_subscribed_remote_profile_for_user(
|
||||
user_id,
|
||||
)
|
||||
if not is_subscribed:
|
||||
yield self.store.maybe_delete_remote_profile_cache(user_id)
|
||||
continue
|
||||
|
||||
try:
|
||||
profile = yield self.federation.make_query(
|
||||
destination=get_domain_from_id(user_id),
|
||||
query_type="profile",
|
||||
args={
|
||||
"user_id": user_id,
|
||||
},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
except:
|
||||
logger.exception("Failed to get avatar_url")
|
||||
|
||||
yield self.store.update_remote_profile_cache(
|
||||
user_id, displayname, avatar_url
|
||||
)
|
||||
continue
|
||||
|
||||
new_name = profile.get("displayname")
|
||||
new_avatar = profile.get("avatar_url")
|
||||
|
||||
# We always hit update to update the last_check timestamp
|
||||
yield self.store.update_remote_profile_cache(
|
||||
user_id, new_name, new_avatar
|
||||
)
|
||||
|
@ -12,6 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from synapse.util import logcontext
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
@ -59,6 +60,8 @@ class ReceiptsHandler(BaseHandler):
|
||||
is_new = yield self._handle_new_receipts([receipt])
|
||||
|
||||
if is_new:
|
||||
# fire off a process in the background to send the receipt to
|
||||
# remote servers
|
||||
self._push_remotes([receipt])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -126,6 +129,7 @@ class ReceiptsHandler(BaseHandler):
|
||||
|
||||
defer.returnValue(True)
|
||||
|
||||
@logcontext.preserve_fn # caller should not yield on this
|
||||
@defer.inlineCallbacks
|
||||
def _push_remotes(self, receipts):
|
||||
"""Given a list of receipts, works out which remote servers should be
|
||||
|
@ -36,6 +36,7 @@ class RegistrationHandler(BaseHandler):
|
||||
super(RegistrationHandler, self).__init__(hs)
|
||||
|
||||
self.auth = hs.get_auth()
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
self.captcha_client = CaptchaServerHttpClient(hs)
|
||||
|
||||
self._next_generated_user_id = None
|
||||
@ -423,8 +424,7 @@ class RegistrationHandler(BaseHandler):
|
||||
|
||||
if displayname is not None:
|
||||
logger.info("setting user display name: %s -> %s", user_id, displayname)
|
||||
profile_handler = self.hs.get_handlers().profile_handler
|
||||
yield profile_handler.set_displayname(
|
||||
yield self.profile_handler.set_displayname(
|
||||
user, requester, displayname, by_admin=True,
|
||||
)
|
||||
|
||||
|
@ -60,6 +60,11 @@ class RoomCreationHandler(BaseHandler):
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RoomCreationHandler, self).__init__(hs)
|
||||
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_room(self, requester, config, ratelimit=True):
|
||||
""" Creates a new room.
|
||||
@ -75,6 +80,9 @@ class RoomCreationHandler(BaseHandler):
|
||||
"""
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
if not self.spam_checker.user_may_create_room(user_id):
|
||||
raise SynapseError(403, "You are not permitted to create rooms")
|
||||
|
||||
if ratelimit:
|
||||
yield self.ratelimit(requester)
|
||||
|
||||
|
@ -276,13 +276,14 @@ class RoomListHandler(BaseHandler):
|
||||
# We've already got enough, so lets just drop it.
|
||||
return
|
||||
|
||||
result = yield self._generate_room_entry(room_id, num_joined_users)
|
||||
result = yield self.generate_room_entry(room_id, num_joined_users)
|
||||
|
||||
if result and _matches_room_entry(result, search_filter):
|
||||
chunk.append(result)
|
||||
|
||||
@cachedInlineCallbacks(num_args=1, cache_context=True)
|
||||
def _generate_room_entry(self, room_id, num_joined_users, cache_context):
|
||||
def generate_room_entry(self, room_id, num_joined_users, cache_context,
|
||||
with_alias=True, allow_private=False):
|
||||
"""Returns the entry for a room
|
||||
"""
|
||||
result = {
|
||||
@ -316,14 +317,15 @@ class RoomListHandler(BaseHandler):
|
||||
join_rules_event = current_state.get((EventTypes.JoinRules, ""))
|
||||
if join_rules_event:
|
||||
join_rule = join_rules_event.content.get("join_rule", None)
|
||||
if join_rule and join_rule != JoinRules.PUBLIC:
|
||||
if not allow_private and join_rule and join_rule != JoinRules.PUBLIC:
|
||||
defer.returnValue(None)
|
||||
|
||||
aliases = yield self.store.get_aliases_for_room(
|
||||
room_id, on_invalidate=cache_context.invalidate
|
||||
)
|
||||
if aliases:
|
||||
result["aliases"] = aliases
|
||||
if with_alias:
|
||||
aliases = yield self.store.get_aliases_for_room(
|
||||
room_id, on_invalidate=cache_context.invalidate
|
||||
)
|
||||
if aliases:
|
||||
result["aliases"] = aliases
|
||||
|
||||
name_event = yield current_state.get((EventTypes.Name, ""))
|
||||
if name_event:
|
||||
|
@ -45,9 +45,12 @@ class RoomMemberHandler(BaseHandler):
|
||||
def __init__(self, hs):
|
||||
super(RoomMemberHandler, self).__init__(hs)
|
||||
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
|
||||
self.member_linearizer = Linearizer(name="member")
|
||||
|
||||
self.clock = hs.get_clock()
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
|
||||
self.distributor = hs.get_distributor()
|
||||
self.distributor.declare("user_joined_room")
|
||||
@ -210,12 +213,26 @@ class RoomMemberHandler(BaseHandler):
|
||||
if is_blocked:
|
||||
raise SynapseError(403, "This room has been blocked on this server")
|
||||
|
||||
if (effective_membership_state == "invite" and
|
||||
self.hs.config.block_non_admin_invites):
|
||||
if effective_membership_state == "invite":
|
||||
block_invite = False
|
||||
is_requester_admin = yield self.auth.is_server_admin(
|
||||
requester.user,
|
||||
)
|
||||
if not is_requester_admin:
|
||||
if self.hs.config.block_non_admin_invites:
|
||||
logger.info(
|
||||
"Blocking invite: user is not admin and non-admin "
|
||||
"invites disabled"
|
||||
)
|
||||
block_invite = True
|
||||
|
||||
if not self.spam_checker.user_may_invite(
|
||||
requester.user.to_string(), target.to_string(), room_id,
|
||||
):
|
||||
logger.info("Blocking invite due to spam checker")
|
||||
block_invite = True
|
||||
|
||||
if block_invite:
|
||||
raise SynapseError(
|
||||
403, "Invites have been disabled on this server",
|
||||
)
|
||||
@ -267,7 +284,7 @@ class RoomMemberHandler(BaseHandler):
|
||||
|
||||
content["membership"] = Membership.JOIN
|
||||
|
||||
profile = self.hs.get_handlers().profile_handler
|
||||
profile = self.profile_handler
|
||||
if not content_specified:
|
||||
content["displayname"] = yield profile.get_displayname(target)
|
||||
content["avatar_url"] = yield profile.get_avatar_url(target)
|
||||
|
@ -108,6 +108,17 @@ class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [
|
||||
return True
|
||||
|
||||
|
||||
class GroupsSyncResult(collections.namedtuple("GroupsSyncResult", [
|
||||
"join",
|
||||
"invite",
|
||||
"leave",
|
||||
])):
|
||||
__slots__ = []
|
||||
|
||||
def __nonzero__(self):
|
||||
return bool(self.join or self.invite or self.leave)
|
||||
|
||||
|
||||
class DeviceLists(collections.namedtuple("DeviceLists", [
|
||||
"changed", # list of user_ids whose devices may have changed
|
||||
"left", # list of user_ids whose devices we no longer track
|
||||
@ -129,6 +140,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
|
||||
"device_lists", # List of user_ids whose devices have chanegd
|
||||
"device_one_time_keys_count", # Dict of algorithm to count for one time keys
|
||||
# for this device
|
||||
"groups",
|
||||
])):
|
||||
__slots__ = []
|
||||
|
||||
@ -144,7 +156,8 @@ class SyncResult(collections.namedtuple("SyncResult", [
|
||||
self.archived or
|
||||
self.account_data or
|
||||
self.to_device or
|
||||
self.device_lists
|
||||
self.device_lists or
|
||||
self.groups
|
||||
)
|
||||
|
||||
|
||||
@ -595,6 +608,8 @@ class SyncHandler(object):
|
||||
user_id, device_id
|
||||
)
|
||||
|
||||
yield self._generate_sync_entry_for_groups(sync_result_builder)
|
||||
|
||||
defer.returnValue(SyncResult(
|
||||
presence=sync_result_builder.presence,
|
||||
account_data=sync_result_builder.account_data,
|
||||
@ -603,10 +618,57 @@ class SyncHandler(object):
|
||||
archived=sync_result_builder.archived,
|
||||
to_device=sync_result_builder.to_device,
|
||||
device_lists=device_lists,
|
||||
groups=sync_result_builder.groups,
|
||||
device_one_time_keys_count=one_time_key_counts,
|
||||
next_batch=sync_result_builder.now_token,
|
||||
))
|
||||
|
||||
@measure_func("_generate_sync_entry_for_groups")
|
||||
@defer.inlineCallbacks
|
||||
def _generate_sync_entry_for_groups(self, sync_result_builder):
|
||||
user_id = sync_result_builder.sync_config.user.to_string()
|
||||
since_token = sync_result_builder.since_token
|
||||
now_token = sync_result_builder.now_token
|
||||
|
||||
if since_token and since_token.groups_key:
|
||||
results = yield self.store.get_groups_changes_for_user(
|
||||
user_id, since_token.groups_key, now_token.groups_key,
|
||||
)
|
||||
else:
|
||||
results = yield self.store.get_all_groups_for_user(
|
||||
user_id, now_token.groups_key,
|
||||
)
|
||||
|
||||
invited = {}
|
||||
joined = {}
|
||||
left = {}
|
||||
for result in results:
|
||||
membership = result["membership"]
|
||||
group_id = result["group_id"]
|
||||
gtype = result["type"]
|
||||
content = result["content"]
|
||||
|
||||
if membership == "join":
|
||||
if gtype == "membership":
|
||||
# TODO: Add profile
|
||||
content.pop("membership", None)
|
||||
joined[group_id] = content["content"]
|
||||
else:
|
||||
joined.setdefault(group_id, {})[gtype] = content
|
||||
elif membership == "invite":
|
||||
if gtype == "membership":
|
||||
content.pop("membership", None)
|
||||
invited[group_id] = content["content"]
|
||||
else:
|
||||
if gtype == "membership":
|
||||
left[group_id] = content["content"]
|
||||
|
||||
sync_result_builder.groups = GroupsSyncResult(
|
||||
join=joined,
|
||||
invite=invited,
|
||||
leave=left,
|
||||
)
|
||||
|
||||
@measure_func("_generate_sync_entry_for_device_list")
|
||||
@defer.inlineCallbacks
|
||||
def _generate_sync_entry_for_device_list(self, sync_result_builder,
|
||||
@ -1368,6 +1430,7 @@ class SyncResultBuilder(object):
|
||||
self.invited = []
|
||||
self.archived = []
|
||||
self.device = []
|
||||
self.groups = None
|
||||
self.to_device = []
|
||||
|
||||
|
||||
|
@ -354,16 +354,28 @@ def _get_hosts_for_srv_record(dns_client, host):
|
||||
|
||||
return res[0]
|
||||
|
||||
def eb(res):
|
||||
res.trap(DNSNameError)
|
||||
return []
|
||||
def eb(res, record_type):
|
||||
if res.check(DNSNameError):
|
||||
return []
|
||||
logger.warn("Error looking up %s for %s: %s",
|
||||
record_type, host, res, res.value)
|
||||
return res
|
||||
|
||||
# no logcontexts here, so we can safely fire these off and gatherResults
|
||||
d1 = dns_client.lookupAddress(host).addCallbacks(cb, eb)
|
||||
d2 = dns_client.lookupIPV6Address(host).addCallbacks(cb, eb)
|
||||
results = yield defer.gatherResults([d1, d2], consumeErrors=True)
|
||||
results = yield defer.DeferredList(
|
||||
[d1, d2], consumeErrors=True)
|
||||
|
||||
# if all of the lookups failed, raise an exception rather than blowing out
|
||||
# the cache with an empty result.
|
||||
if results and all(s == defer.FAILURE for (s, _) in results):
|
||||
defer.returnValue(results[0][1])
|
||||
|
||||
for (success, result) in results:
|
||||
if success == defer.FAILURE:
|
||||
continue
|
||||
|
||||
for result in results:
|
||||
for answer in result:
|
||||
if not answer.payload:
|
||||
continue
|
||||
|
@ -204,18 +204,15 @@ class MatrixFederationHttpClient(object):
|
||||
raise
|
||||
|
||||
logger.warn(
|
||||
"{%s} Sending request failed to %s: %s %s: %s - %s",
|
||||
"{%s} Sending request failed to %s: %s %s: %s",
|
||||
txn_id,
|
||||
destination,
|
||||
method,
|
||||
url_bytes,
|
||||
type(e).__name__,
|
||||
_flatten_response_never_received(e),
|
||||
)
|
||||
|
||||
log_result = "%s - %s" % (
|
||||
type(e).__name__, _flatten_response_never_received(e),
|
||||
)
|
||||
log_result = _flatten_response_never_received(e)
|
||||
|
||||
if retries_left and not timeout:
|
||||
if long_retries:
|
||||
@ -347,7 +344,7 @@ class MatrixFederationHttpClient(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def post_json(self, destination, path, data={}, long_retries=False,
|
||||
timeout=None, ignore_backoff=False):
|
||||
timeout=None, ignore_backoff=False, args={}):
|
||||
""" Sends the specifed json data using POST
|
||||
|
||||
Args:
|
||||
@ -383,6 +380,7 @@ class MatrixFederationHttpClient(object):
|
||||
destination,
|
||||
"POST",
|
||||
path,
|
||||
query_bytes=encode_query_args(args),
|
||||
body_callback=body_callback,
|
||||
headers_dict={"Content-Type": ["application/json"]},
|
||||
long_retries=long_retries,
|
||||
@ -427,13 +425,6 @@ class MatrixFederationHttpClient(object):
|
||||
"""
|
||||
logger.debug("get_json args: %s", args)
|
||||
|
||||
encoded_args = {}
|
||||
for k, vs in args.items():
|
||||
if isinstance(vs, basestring):
|
||||
vs = [vs]
|
||||
encoded_args[k] = [v.encode("UTF-8") for v in vs]
|
||||
|
||||
query_bytes = urllib.urlencode(encoded_args, True)
|
||||
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
|
||||
|
||||
def body_callback(method, url_bytes, headers_dict):
|
||||
@ -444,7 +435,7 @@ class MatrixFederationHttpClient(object):
|
||||
destination,
|
||||
"GET",
|
||||
path,
|
||||
query_bytes=query_bytes,
|
||||
query_bytes=encode_query_args(args),
|
||||
body_callback=body_callback,
|
||||
retry_on_dns_fail=retry_on_dns_fail,
|
||||
timeout=timeout,
|
||||
@ -460,6 +451,52 @@ class MatrixFederationHttpClient(object):
|
||||
|
||||
defer.returnValue(json.loads(body))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_json(self, destination, path, long_retries=False,
|
||||
timeout=None, ignore_backoff=False, args={}):
|
||||
"""Send a DELETE request to the remote expecting some json response
|
||||
|
||||
Args:
|
||||
destination (str): The remote server to send the HTTP request
|
||||
to.
|
||||
path (str): The HTTP path.
|
||||
long_retries (bool): A boolean that indicates whether we should
|
||||
retry for a short or long time.
|
||||
timeout(int): How long to try (in ms) the destination for before
|
||||
giving up. None indicates no timeout.
|
||||
ignore_backoff (bool): true to ignore the historical backoff data and
|
||||
try the request anyway.
|
||||
Returns:
|
||||
Deferred: Succeeds when we get a 2xx HTTP response. The result
|
||||
will be the decoded JSON body.
|
||||
|
||||
Fails with ``HTTPRequestException`` if we get an HTTP response
|
||||
code >= 300.
|
||||
|
||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||
to retry this server.
|
||||
"""
|
||||
|
||||
response = yield self._request(
|
||||
destination,
|
||||
"DELETE",
|
||||
path,
|
||||
query_bytes=encode_query_args(args),
|
||||
headers_dict={"Content-Type": ["application/json"]},
|
||||
long_retries=long_retries,
|
||||
timeout=timeout,
|
||||
ignore_backoff=ignore_backoff,
|
||||
)
|
||||
|
||||
if 200 <= response.code < 300:
|
||||
# We need to update the transactions table to say it was sent?
|
||||
check_content_type_is_json(response.headers)
|
||||
|
||||
with logcontext.PreserveLoggingContext():
|
||||
body = yield readBody(response)
|
||||
|
||||
defer.returnValue(json.loads(body))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_file(self, destination, path, output_stream, args={},
|
||||
retry_on_dns_fail=True, max_size=None,
|
||||
@ -578,12 +615,14 @@ class _JsonProducer(object):
|
||||
|
||||
def _flatten_response_never_received(e):
|
||||
if hasattr(e, "reasons"):
|
||||
return ", ".join(
|
||||
reasons = ", ".join(
|
||||
_flatten_response_never_received(f.value)
|
||||
for f in e.reasons
|
||||
)
|
||||
|
||||
return "%s:[%s]" % (type(e).__name__, reasons)
|
||||
else:
|
||||
return "%s: %s" % (type(e).__name__, e.message,)
|
||||
return repr(e)
|
||||
|
||||
|
||||
def check_content_type_is_json(headers):
|
||||
@ -610,3 +649,15 @@ def check_content_type_is_json(headers):
|
||||
raise RuntimeError(
|
||||
"Content-Type not application/json: was '%s'" % c_type
|
||||
)
|
||||
|
||||
|
||||
def encode_query_args(args):
|
||||
encoded_args = {}
|
||||
for k, vs in args.items():
|
||||
if isinstance(vs, basestring):
|
||||
vs = [vs]
|
||||
encoded_args[k] = [v.encode("UTF-8") for v in vs]
|
||||
|
||||
query_bytes = urllib.urlencode(encoded_args, True)
|
||||
|
||||
return query_bytes
|
||||
|
@ -145,7 +145,9 @@ def wrap_request_handler(request_handler, include_metrics=False):
|
||||
"error": "Internal server error",
|
||||
"errcode": Codes.UNKNOWN,
|
||||
},
|
||||
send_cors=True
|
||||
send_cors=True,
|
||||
pretty_print=_request_user_agent_is_curl(request),
|
||||
version_string=self.version_string,
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
|
@ -1,4 +1,5 @@
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
# Copyright 2017 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -238,6 +239,28 @@ BASE_APPEND_OVERRIDE_RULES = [
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
'rule_id': 'global/override/.m.rule.roomnotif',
|
||||
'conditions': [
|
||||
{
|
||||
'kind': 'event_match',
|
||||
'key': 'content.body',
|
||||
'pattern': '@room',
|
||||
'_id': '_roomnotif_content',
|
||||
},
|
||||
{
|
||||
'kind': 'sender_notification_permission',
|
||||
'key': 'room',
|
||||
'_id': '_roomnotif_pl',
|
||||
},
|
||||
],
|
||||
'actions': [
|
||||
'notify', {
|
||||
'set_tweak': 'highlight',
|
||||
'value': True,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
# Copyright 2017 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -19,11 +20,13 @@ from twisted.internet import defer
|
||||
|
||||
from .push_rule_evaluator import PushRuleEvaluatorForEvent
|
||||
|
||||
from synapse.event_auth import get_user_power_level
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.metrics import get_metrics_for
|
||||
from synapse.util.caches import metrics as cache_metrics
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.state import POWER_KEY
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
@ -59,6 +62,7 @@ class BulkPushRuleEvaluator(object):
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
self.room_push_rule_cache_metrics = cache_metrics.register_cache(
|
||||
"cache",
|
||||
@ -108,6 +112,29 @@ class BulkPushRuleEvaluator(object):
|
||||
self.room_push_rule_cache_metrics,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_power_levels_and_sender_level(self, event, context):
|
||||
pl_event_id = context.prev_state_ids.get(POWER_KEY)
|
||||
if pl_event_id:
|
||||
# fastpath: if there's a power level event, that's all we need, and
|
||||
# not having a power level event is an extreme edge case
|
||||
pl_event = yield self.store.get_event(pl_event_id)
|
||||
auth_events = {POWER_KEY: pl_event}
|
||||
else:
|
||||
auth_events_ids = yield self.auth.compute_auth_events(
|
||||
event, context.prev_state_ids, for_verification=False,
|
||||
)
|
||||
auth_events = yield self.store.get_events(auth_events_ids)
|
||||
auth_events = {
|
||||
(e.type, e.state_key): e for e in auth_events.itervalues()
|
||||
}
|
||||
|
||||
sender_level = get_user_power_level(event.sender, auth_events)
|
||||
|
||||
pl_event = auth_events.get(POWER_KEY)
|
||||
|
||||
defer.returnValue((pl_event.content if pl_event else {}, sender_level))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def action_for_event_by_user(self, event, context):
|
||||
"""Given an event and context, evaluate the push rules and return
|
||||
@ -123,7 +150,13 @@ class BulkPushRuleEvaluator(object):
|
||||
event, context
|
||||
)
|
||||
|
||||
evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
|
||||
(power_levels, sender_power_level) = (
|
||||
yield self._get_power_levels_and_sender_level(event, context)
|
||||
)
|
||||
|
||||
evaluator = PushRuleEvaluatorForEvent(
|
||||
event, len(room_members), sender_power_level, power_levels,
|
||||
)
|
||||
|
||||
condition_cache = {}
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
# Copyright 2017 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -29,6 +30,21 @@ INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
|
||||
|
||||
|
||||
def _room_member_count(ev, condition, room_member_count):
|
||||
return _test_ineq_condition(condition, room_member_count)
|
||||
|
||||
|
||||
def _sender_notification_permission(ev, condition, sender_power_level, power_levels):
|
||||
notif_level_key = condition.get('key')
|
||||
if notif_level_key is None:
|
||||
return False
|
||||
|
||||
notif_levels = power_levels.get('notifications', {})
|
||||
room_notif_level = notif_levels.get(notif_level_key, 50)
|
||||
|
||||
return sender_power_level >= room_notif_level
|
||||
|
||||
|
||||
def _test_ineq_condition(condition, number):
|
||||
if 'is' not in condition:
|
||||
return False
|
||||
m = INEQUALITY_EXPR.match(condition['is'])
|
||||
@ -41,15 +57,15 @@ def _room_member_count(ev, condition, room_member_count):
|
||||
rhs = int(rhs)
|
||||
|
||||
if ineq == '' or ineq == '==':
|
||||
return room_member_count == rhs
|
||||
return number == rhs
|
||||
elif ineq == '<':
|
||||
return room_member_count < rhs
|
||||
return number < rhs
|
||||
elif ineq == '>':
|
||||
return room_member_count > rhs
|
||||
return number > rhs
|
||||
elif ineq == '>=':
|
||||
return room_member_count >= rhs
|
||||
return number >= rhs
|
||||
elif ineq == '<=':
|
||||
return room_member_count <= rhs
|
||||
return number <= rhs
|
||||
else:
|
||||
return False
|
||||
|
||||
@ -65,9 +81,11 @@ def tweaks_for_actions(actions):
|
||||
|
||||
|
||||
class PushRuleEvaluatorForEvent(object):
|
||||
def __init__(self, event, room_member_count):
|
||||
def __init__(self, event, room_member_count, sender_power_level, power_levels):
|
||||
self._event = event
|
||||
self._room_member_count = room_member_count
|
||||
self._sender_power_level = sender_power_level
|
||||
self._power_levels = power_levels
|
||||
|
||||
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
|
||||
self._value_cache = _flatten_dict(event)
|
||||
@ -81,6 +99,10 @@ class PushRuleEvaluatorForEvent(object):
|
||||
return _room_member_count(
|
||||
self._event, condition, self._room_member_count
|
||||
)
|
||||
elif condition['kind'] == 'sender_notification_permission':
|
||||
return _sender_notification_permission(
|
||||
self._event, condition, self._sender_power_level, self._power_levels,
|
||||
)
|
||||
else:
|
||||
return True
|
||||
|
||||
@ -183,7 +205,7 @@ def _glob_to_re(glob, word_boundary):
|
||||
r,
|
||||
)
|
||||
if word_boundary:
|
||||
r = r"\b%s\b" % (r,)
|
||||
r = _re_word_boundary(r)
|
||||
|
||||
return re.compile(r, flags=re.IGNORECASE)
|
||||
else:
|
||||
@ -192,7 +214,7 @@ def _glob_to_re(glob, word_boundary):
|
||||
return re.compile(r, flags=re.IGNORECASE)
|
||||
elif word_boundary:
|
||||
r = re.escape(glob)
|
||||
r = r"\b%s\b" % (r,)
|
||||
r = _re_word_boundary(r)
|
||||
|
||||
return re.compile(r, flags=re.IGNORECASE)
|
||||
else:
|
||||
@ -200,6 +222,18 @@ def _glob_to_re(glob, word_boundary):
|
||||
return re.compile(r, flags=re.IGNORECASE)
|
||||
|
||||
|
||||
def _re_word_boundary(r):
|
||||
"""
|
||||
Adds word boundary characters to the start and end of an
|
||||
expression to require that the match occur as a whole word,
|
||||
but do so respecting the fact that strings starting or ending
|
||||
with non-word characters will change word boundaries.
|
||||
"""
|
||||
# we can't use \b as it chokes on unicode. however \W seems to be okay
|
||||
# as shorthand for [^0-9A-Za-z_].
|
||||
return r"(^|\W)%s(\W|$)" % (r,)
|
||||
|
||||
|
||||
def _flatten_dict(d, prefix=[], result=None):
|
||||
if result is None:
|
||||
result = {}
|
||||
|
54
synapse/replication/slave/storage/groups.py
Normal file
54
synapse/replication/slave/storage/groups.py
Normal file
@ -0,0 +1,54 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
from ._slaved_id_tracker import SlavedIdTracker
|
||||
from synapse.storage import DataStore
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
|
||||
class SlavedGroupServerStore(BaseSlavedStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(SlavedGroupServerStore, self).__init__(db_conn, hs)
|
||||
|
||||
self.hs = hs
|
||||
|
||||
self._group_updates_id_gen = SlavedIdTracker(
|
||||
db_conn, "local_group_updates", "stream_id",
|
||||
)
|
||||
self._group_updates_stream_cache = StreamChangeCache(
|
||||
"_group_updates_stream_cache", self._group_updates_id_gen.get_current_token(),
|
||||
)
|
||||
|
||||
get_groups_changes_for_user = DataStore.get_groups_changes_for_user.__func__
|
||||
get_group_stream_token = DataStore.get_group_stream_token.__func__
|
||||
get_all_groups_for_user = DataStore.get_all_groups_for_user.__func__
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedGroupServerStore, self).stream_positions()
|
||||
result["groups"] = self._group_updates_id_gen.get_current_token()
|
||||
return result
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
if stream_name == "groups":
|
||||
self._group_updates_id_gen.advance(token)
|
||||
for row in rows:
|
||||
self._group_updates_stream_cache.entity_has_changed(
|
||||
row.user_id, token
|
||||
)
|
||||
|
||||
return super(SlavedGroupServerStore, self).process_replication_rows(
|
||||
stream_name, token, rows
|
||||
)
|
@ -160,7 +160,11 @@ class ReplicationStreamer(object):
|
||||
"Getting stream: %s: %s -> %s",
|
||||
stream.NAME, stream.last_token, stream.upto_token
|
||||
)
|
||||
updates, current_token = yield stream.get_updates()
|
||||
try:
|
||||
updates, current_token = yield stream.get_updates()
|
||||
except:
|
||||
logger.info("Failed to handle stream %s", stream.NAME)
|
||||
raise
|
||||
|
||||
logger.debug(
|
||||
"Sending %d updates to %d connections",
|
||||
|
@ -118,6 +118,12 @@ CurrentStateDeltaStreamRow = namedtuple("CurrentStateDeltaStream", (
|
||||
"state_key", # str
|
||||
"event_id", # str, optional
|
||||
))
|
||||
GroupsStreamRow = namedtuple("GroupsStreamRow", (
|
||||
"group_id", # str
|
||||
"user_id", # str
|
||||
"type", # str
|
||||
"content", # dict
|
||||
))
|
||||
|
||||
|
||||
class Stream(object):
|
||||
@ -464,6 +470,19 @@ class CurrentStateDeltaStream(Stream):
|
||||
super(CurrentStateDeltaStream, self).__init__(hs)
|
||||
|
||||
|
||||
class GroupServerStream(Stream):
|
||||
NAME = "groups"
|
||||
ROW_TYPE = GroupsStreamRow
|
||||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_group_stream_token
|
||||
self.update_function = store.get_all_groups_changes
|
||||
|
||||
super(GroupServerStream, self).__init__(hs)
|
||||
|
||||
|
||||
STREAMS_MAP = {
|
||||
stream.NAME: stream
|
||||
for stream in (
|
||||
@ -482,5 +501,6 @@ STREAMS_MAP = {
|
||||
TagAccountDataStream,
|
||||
AccountDataStream,
|
||||
CurrentStateDeltaStream,
|
||||
GroupServerStream,
|
||||
)
|
||||
}
|
||||
|
@ -52,6 +52,7 @@ from synapse.rest.client.v2_alpha import (
|
||||
thirdparty,
|
||||
sendtodevice,
|
||||
user_directory,
|
||||
groups,
|
||||
)
|
||||
|
||||
from synapse.http.server import JsonResource
|
||||
@ -102,3 +103,4 @@ class ClientRestResource(JsonResource):
|
||||
thirdparty.register_servlets(hs, client_resource)
|
||||
sendtodevice.register_servlets(hs, client_resource)
|
||||
user_directory.register_servlets(hs, client_resource)
|
||||
groups.register_servlets(hs, client_resource)
|
||||
|
@ -26,13 +26,13 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ProfileDisplaynameRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
displayname = yield self.handlers.profile_handler.get_displayname(
|
||||
displayname = yield self.profile_handler.get_displayname(
|
||||
user,
|
||||
)
|
||||
|
||||
@ -55,7 +55,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
|
||||
except:
|
||||
defer.returnValue((400, "Unable to parse name"))
|
||||
|
||||
yield self.handlers.profile_handler.set_displayname(
|
||||
yield self.profile_handler.set_displayname(
|
||||
user, requester, new_name, is_admin)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
@ -69,13 +69,13 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ProfileAvatarURLRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
avatar_url = yield self.handlers.profile_handler.get_avatar_url(
|
||||
avatar_url = yield self.profile_handler.get_avatar_url(
|
||||
user,
|
||||
)
|
||||
|
||||
@ -97,7 +97,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
|
||||
except:
|
||||
defer.returnValue((400, "Unable to parse name"))
|
||||
|
||||
yield self.handlers.profile_handler.set_avatar_url(
|
||||
yield self.profile_handler.set_avatar_url(
|
||||
user, requester, new_name, is_admin)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
@ -111,16 +111,16 @@ class ProfileRestServlet(ClientV1RestServlet):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ProfileRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
displayname = yield self.handlers.profile_handler.get_displayname(
|
||||
displayname = yield self.profile_handler.get_displayname(
|
||||
user,
|
||||
)
|
||||
avatar_url = yield self.handlers.profile_handler.get_avatar_url(
|
||||
avatar_url = yield self.profile_handler.get_avatar_url(
|
||||
user,
|
||||
)
|
||||
|
||||
|
@ -398,22 +398,18 @@ class JoinedRoomMemberListRestServlet(ClientV1RestServlet):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(JoinedRoomMemberListRestServlet, self).__init__(hs)
|
||||
self.state = hs.get_state_handler()
|
||||
self.message_handler = hs.get_handlers().message_handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
users_with_profile = yield self.state.get_current_user_in_room(room_id)
|
||||
users_with_profile = yield self.message_handler.get_joined_members(
|
||||
requester, room_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, {
|
||||
"joined": {
|
||||
user_id: {
|
||||
"avatar_url": profile.avatar_url,
|
||||
"display_name": profile.display_name,
|
||||
}
|
||||
for user_id, profile in users_with_profile.iteritems()
|
||||
}
|
||||
"joined": users_with_profile,
|
||||
}))
|
||||
|
||||
|
||||
|
717
synapse/rest/client/v2_alpha/groups.py
Normal file
717
synapse/rest/client/v2_alpha/groups.py
Normal file
@ -0,0 +1,717 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from synapse.types import GroupID
|
||||
|
||||
from ._base import client_v2_patterns
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GroupServlet(RestServlet):
|
||||
"""Get the group profile
|
||||
"""
|
||||
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/profile$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
group_description = yield self.groups_handler.get_group_profile(group_id, user_id)
|
||||
|
||||
defer.returnValue((200, group_description))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
yield self.groups_handler.update_group_profile(
|
||||
group_id, user_id, content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
class GroupSummaryServlet(RestServlet):
|
||||
"""Get the full group summary
|
||||
"""
|
||||
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/summary$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupSummaryServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
get_group_summary = yield self.groups_handler.get_group_summary(group_id, user_id)
|
||||
|
||||
defer.returnValue((200, get_group_summary))
|
||||
|
||||
|
||||
class GroupSummaryRoomsCatServlet(RestServlet):
|
||||
"""Update/delete a rooms entry in the summary.
|
||||
|
||||
Matches both:
|
||||
- /groups/:group/summary/rooms/:room_id
|
||||
- /groups/:group/summary/categories/:category/rooms/:room_id
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/summary"
|
||||
"(/categories/(?P<category_id>[^/]+))?"
|
||||
"/rooms/(?P<room_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupSummaryRoomsCatServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, group_id, category_id, room_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
resp = yield self.groups_handler.update_group_summary_room(
|
||||
group_id, user_id,
|
||||
room_id=room_id,
|
||||
category_id=category_id,
|
||||
content=content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, request, group_id, category_id, room_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
resp = yield self.groups_handler.delete_group_summary_room(
|
||||
group_id, user_id,
|
||||
room_id=room_id,
|
||||
category_id=category_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
|
||||
class GroupCategoryServlet(RestServlet):
|
||||
"""Get/add/update/delete a group category
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupCategoryServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id, category_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
category = yield self.groups_handler.get_group_category(
|
||||
group_id, user_id,
|
||||
category_id=category_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, category))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, group_id, category_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
resp = yield self.groups_handler.update_group_category(
|
||||
group_id, user_id,
|
||||
category_id=category_id,
|
||||
content=content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, request, group_id, category_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
resp = yield self.groups_handler.delete_group_category(
|
||||
group_id, user_id,
|
||||
category_id=category_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
|
||||
class GroupCategoriesServlet(RestServlet):
|
||||
"""Get all group categories
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/categories/$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupCategoriesServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
category = yield self.groups_handler.get_group_categories(
|
||||
group_id, user_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, category))
|
||||
|
||||
|
||||
class GroupRoleServlet(RestServlet):
|
||||
"""Get/add/update/delete a group role
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupRoleServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id, role_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
category = yield self.groups_handler.get_group_role(
|
||||
group_id, user_id,
|
||||
role_id=role_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, category))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, group_id, role_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
resp = yield self.groups_handler.update_group_role(
|
||||
group_id, user_id,
|
||||
role_id=role_id,
|
||||
content=content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, request, group_id, role_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
resp = yield self.groups_handler.delete_group_role(
|
||||
group_id, user_id,
|
||||
role_id=role_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
|
||||
class GroupRolesServlet(RestServlet):
|
||||
"""Get all group roles
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/roles/$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupRolesServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
category = yield self.groups_handler.get_group_roles(
|
||||
group_id, user_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, category))
|
||||
|
||||
|
||||
class GroupSummaryUsersRoleServlet(RestServlet):
|
||||
"""Update/delete a user's entry in the summary.
|
||||
|
||||
Matches both:
|
||||
- /groups/:group/summary/users/:room_id
|
||||
- /groups/:group/summary/roles/:role/users/:user_id
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/summary"
|
||||
"(/roles/(?P<role_id>[^/]+))?"
|
||||
"/users/(?P<user_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupSummaryUsersRoleServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, group_id, role_id, user_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
resp = yield self.groups_handler.update_group_summary_user(
|
||||
group_id, requester_user_id,
|
||||
user_id=user_id,
|
||||
role_id=role_id,
|
||||
content=content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, request, group_id, role_id, user_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
resp = yield self.groups_handler.delete_group_summary_user(
|
||||
group_id, requester_user_id,
|
||||
user_id=user_id,
|
||||
role_id=role_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
|
||||
class GroupRoomServlet(RestServlet):
|
||||
"""Get all rooms in a group
|
||||
"""
|
||||
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupRoomServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
result = yield self.groups_handler.get_rooms_in_group(group_id, user_id)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class GroupUsersServlet(RestServlet):
|
||||
"""Get all users in a group
|
||||
"""
|
||||
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/users$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupUsersServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
result = yield self.groups_handler.get_users_in_group(group_id, user_id)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class GroupInvitedUsersServlet(RestServlet):
|
||||
"""Get users invited to a group
|
||||
"""
|
||||
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupInvitedUsersServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
result = yield self.groups_handler.get_invited_users_in_group(group_id, user_id)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class GroupCreateServlet(RestServlet):
|
||||
"""Create a group
|
||||
"""
|
||||
PATTERNS = client_v2_patterns("/create_group$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupCreateServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
self.server_name = hs.hostname
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
# TODO: Create group on remote server
|
||||
content = parse_json_object_from_request(request)
|
||||
localpart = content.pop("localpart")
|
||||
group_id = GroupID.create(localpart, self.server_name).to_string()
|
||||
|
||||
result = yield self.groups_handler.create_group(group_id, user_id, content)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class GroupAdminRoomsServlet(RestServlet):
|
||||
"""Add a room to the group
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupAdminRoomsServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, group_id, room_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
result = yield self.groups_handler.add_room_to_group(
|
||||
group_id, user_id, room_id, content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, request, group_id, room_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
result = yield self.groups_handler.remove_room_from_group(
|
||||
group_id, user_id, room_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class GroupAdminUsersInviteServlet(RestServlet):
|
||||
"""Invite a user to the group
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupAdminUsersInviteServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
self.store = hs.get_datastore()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, group_id, user_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
config = content.get("config", {})
|
||||
result = yield self.groups_handler.invite(
|
||||
group_id, user_id, requester_user_id, config,
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class GroupAdminUsersKickServlet(RestServlet):
|
||||
"""Kick a user from the group
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupAdminUsersKickServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, group_id, user_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
result = yield self.groups_handler.remove_user_from_group(
|
||||
group_id, user_id, requester_user_id, content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class GroupSelfLeaveServlet(RestServlet):
|
||||
"""Leave a joined group
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/self/leave$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupSelfLeaveServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
result = yield self.groups_handler.remove_user_from_group(
|
||||
group_id, requester_user_id, requester_user_id, content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class GroupSelfJoinServlet(RestServlet):
|
||||
"""Attempt to join a group, or knock
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/self/join$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupSelfJoinServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
result = yield self.groups_handler.join_group(
|
||||
group_id, requester_user_id, content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class GroupSelfAcceptInviteServlet(RestServlet):
|
||||
"""Accept a group invite
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/self/accept_invite$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupSelfAcceptInviteServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
result = yield self.groups_handler.accept_invite(
|
||||
group_id, requester_user_id, content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class GroupSelfUpdatePublicityServlet(RestServlet):
|
||||
"""Update whether we publicise a users membership of a group
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/self/update_publicity$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupSelfUpdatePublicityServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
publicise = content["publicise"]
|
||||
yield self.store.update_group_publicity(
|
||||
group_id, requester_user_id, publicise,
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
class PublicisedGroupsForUserServlet(RestServlet):
|
||||
"""Get the list of groups a user is advertising
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/publicised_groups/(?P<user_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(PublicisedGroupsForUserServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
yield self.auth.get_user_by_req(request)
|
||||
|
||||
result = yield self.groups_handler.get_publicised_groups_for_user(
|
||||
user_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class PublicisedGroupsForUsersServlet(RestServlet):
|
||||
"""Get the list of groups a user is advertising
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/publicised_groups$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(PublicisedGroupsForUsersServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
yield self.auth.get_user_by_req(request)
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
user_ids = content["user_ids"]
|
||||
|
||||
result = yield self.groups_handler.bulk_get_publicised_groups(
|
||||
user_ids
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class GroupsForUserServlet(RestServlet):
|
||||
"""Get all groups the logged in user is joined to
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/joined_groups$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupsForUserServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
result = yield self.groups_handler.get_joined_groups(user_id)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
GroupServlet(hs).register(http_server)
|
||||
GroupSummaryServlet(hs).register(http_server)
|
||||
GroupInvitedUsersServlet(hs).register(http_server)
|
||||
GroupUsersServlet(hs).register(http_server)
|
||||
GroupRoomServlet(hs).register(http_server)
|
||||
GroupCreateServlet(hs).register(http_server)
|
||||
GroupAdminRoomsServlet(hs).register(http_server)
|
||||
GroupAdminUsersInviteServlet(hs).register(http_server)
|
||||
GroupAdminUsersKickServlet(hs).register(http_server)
|
||||
GroupSelfLeaveServlet(hs).register(http_server)
|
||||
GroupSelfJoinServlet(hs).register(http_server)
|
||||
GroupSelfAcceptInviteServlet(hs).register(http_server)
|
||||
GroupsForUserServlet(hs).register(http_server)
|
||||
GroupCategoryServlet(hs).register(http_server)
|
||||
GroupCategoriesServlet(hs).register(http_server)
|
||||
GroupSummaryRoomsCatServlet(hs).register(http_server)
|
||||
GroupRoleServlet(hs).register(http_server)
|
||||
GroupRolesServlet(hs).register(http_server)
|
||||
GroupSelfUpdatePublicityServlet(hs).register(http_server)
|
||||
GroupSummaryUsersRoleServlet(hs).register(http_server)
|
||||
PublicisedGroupsForUserServlet(hs).register(http_server)
|
||||
PublicisedGroupsForUsersServlet(hs).register(http_server)
|
@ -17,8 +17,10 @@
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse
|
||||
import synapse.types
|
||||
from synapse.api.auth import get_access_token_from_request, has_access_token
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.types import RoomID, RoomAlias
|
||||
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
|
||||
from synapse.http.servlet import (
|
||||
RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string
|
||||
@ -170,6 +172,7 @@ class RegisterRestServlet(RestServlet):
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.registration_handler = hs.get_handlers().registration_handler
|
||||
self.identity_handler = hs.get_handlers().identity_handler
|
||||
self.room_member_handler = hs.get_handlers().room_member_handler
|
||||
self.device_handler = hs.get_device_handler()
|
||||
self.macaroon_gen = hs.get_macaroon_generator()
|
||||
|
||||
@ -340,6 +343,14 @@ class RegisterRestServlet(RestServlet):
|
||||
generate_token=False,
|
||||
)
|
||||
|
||||
# auto-join the user to any rooms we're supposed to dump them into
|
||||
fake_requester = synapse.types.create_requester(registered_user_id)
|
||||
for r in self.hs.config.auto_join_rooms:
|
||||
try:
|
||||
yield self._join_user_to_room(fake_requester, r)
|
||||
except Exception as e:
|
||||
logger.error("Failed to join new user to %r: %r", r, e)
|
||||
|
||||
# remember that we've now registered that user account, and with
|
||||
# what user ID (since the user may not have specified)
|
||||
self.auth_handler.set_session_data(
|
||||
@ -372,6 +383,29 @@ class RegisterRestServlet(RestServlet):
|
||||
def on_OPTIONS(self, _):
|
||||
return 200, {}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _join_user_to_room(self, requester, room_identifier):
|
||||
room_id = None
|
||||
if RoomID.is_valid(room_identifier):
|
||||
room_id = room_identifier
|
||||
elif RoomAlias.is_valid(room_identifier):
|
||||
room_alias = RoomAlias.from_string(room_identifier)
|
||||
room_id, remote_room_hosts = (
|
||||
yield self.room_member_handler.lookup_room_alias(room_alias)
|
||||
)
|
||||
room_id = room_id.to_string()
|
||||
else:
|
||||
raise SynapseError(400, "%s was not legal room ID or room alias" % (
|
||||
room_identifier,
|
||||
))
|
||||
|
||||
yield self.room_member_handler.update_membership(
|
||||
requester=requester,
|
||||
target=requester.user,
|
||||
room_id=room_id,
|
||||
action="join",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_appservice_registration(self, username, as_token, body):
|
||||
user_id = yield self.registration_handler.appservice_register(
|
||||
|
@ -200,6 +200,11 @@ class SyncRestServlet(RestServlet):
|
||||
"invite": invited,
|
||||
"leave": archived,
|
||||
},
|
||||
"groups": {
|
||||
"join": sync_result.groups.join,
|
||||
"invite": sync_result.groups.invite,
|
||||
"leave": sync_result.groups.leave,
|
||||
},
|
||||
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
|
||||
"next_batch": sync_result.next_batch.to_string(),
|
||||
}
|
||||
|
@ -14,78 +14,200 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import re
|
||||
import functools
|
||||
|
||||
NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
|
||||
|
||||
|
||||
def _wrap_in_base_path(func):
|
||||
"""Takes a function that returns a relative path and turns it into an
|
||||
absolute path based on the location of the primary media store
|
||||
"""
|
||||
@functools.wraps(func)
|
||||
def _wrapped(self, *args, **kwargs):
|
||||
path = func(self, *args, **kwargs)
|
||||
return os.path.join(self.base_path, path)
|
||||
|
||||
return _wrapped
|
||||
|
||||
|
||||
class MediaFilePaths(object):
|
||||
"""Describes where files are stored on disk.
|
||||
|
||||
def __init__(self, base_path):
|
||||
self.base_path = base_path
|
||||
Most of the functions have a `*_rel` variant which returns a file path that
|
||||
is relative to the base media store path. This is mainly used when we want
|
||||
to write to the backup media store (when one is configured)
|
||||
"""
|
||||
|
||||
def default_thumbnail(self, default_top_level, default_sub_type, width,
|
||||
height, content_type, method):
|
||||
def __init__(self, primary_base_path):
|
||||
self.base_path = primary_base_path
|
||||
|
||||
def default_thumbnail_rel(self, default_top_level, default_sub_type, width,
|
||||
height, content_type, method):
|
||||
top_level_type, sub_type = content_type.split("/")
|
||||
file_name = "%i-%i-%s-%s-%s" % (
|
||||
width, height, top_level_type, sub_type, method
|
||||
)
|
||||
return os.path.join(
|
||||
self.base_path, "default_thumbnails", default_top_level,
|
||||
"default_thumbnails", default_top_level,
|
||||
default_sub_type, file_name
|
||||
)
|
||||
|
||||
def local_media_filepath(self, media_id):
|
||||
default_thumbnail = _wrap_in_base_path(default_thumbnail_rel)
|
||||
|
||||
def local_media_filepath_rel(self, media_id):
|
||||
return os.path.join(
|
||||
self.base_path, "local_content",
|
||||
"local_content",
|
||||
media_id[0:2], media_id[2:4], media_id[4:]
|
||||
)
|
||||
|
||||
def local_media_thumbnail(self, media_id, width, height, content_type,
|
||||
method):
|
||||
local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
|
||||
|
||||
def local_media_thumbnail_rel(self, media_id, width, height, content_type,
|
||||
method):
|
||||
top_level_type, sub_type = content_type.split("/")
|
||||
file_name = "%i-%i-%s-%s-%s" % (
|
||||
width, height, top_level_type, sub_type, method
|
||||
)
|
||||
return os.path.join(
|
||||
self.base_path, "local_thumbnails",
|
||||
"local_thumbnails",
|
||||
media_id[0:2], media_id[2:4], media_id[4:],
|
||||
file_name
|
||||
)
|
||||
|
||||
def remote_media_filepath(self, server_name, file_id):
|
||||
local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel)
|
||||
|
||||
def remote_media_filepath_rel(self, server_name, file_id):
|
||||
return os.path.join(
|
||||
self.base_path, "remote_content", server_name,
|
||||
"remote_content", server_name,
|
||||
file_id[0:2], file_id[2:4], file_id[4:]
|
||||
)
|
||||
|
||||
def remote_media_thumbnail(self, server_name, file_id, width, height,
|
||||
content_type, method):
|
||||
remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
|
||||
|
||||
def remote_media_thumbnail_rel(self, server_name, file_id, width, height,
|
||||
content_type, method):
|
||||
top_level_type, sub_type = content_type.split("/")
|
||||
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
|
||||
return os.path.join(
|
||||
self.base_path, "remote_thumbnail", server_name,
|
||||
"remote_thumbnail", server_name,
|
||||
file_id[0:2], file_id[2:4], file_id[4:],
|
||||
file_name
|
||||
)
|
||||
|
||||
remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel)
|
||||
|
||||
def remote_media_thumbnail_dir(self, server_name, file_id):
|
||||
return os.path.join(
|
||||
self.base_path, "remote_thumbnail", server_name,
|
||||
file_id[0:2], file_id[2:4], file_id[4:],
|
||||
)
|
||||
|
||||
def url_cache_filepath(self, media_id):
|
||||
return os.path.join(
|
||||
self.base_path, "url_cache",
|
||||
media_id[0:2], media_id[2:4], media_id[4:]
|
||||
)
|
||||
def url_cache_filepath_rel(self, media_id):
|
||||
if NEW_FORMAT_ID_RE.match(media_id):
|
||||
# Media id is of the form <DATE><RANDOM_STRING>
|
||||
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
||||
return os.path.join(
|
||||
"url_cache",
|
||||
media_id[:10], media_id[11:]
|
||||
)
|
||||
else:
|
||||
return os.path.join(
|
||||
"url_cache",
|
||||
media_id[0:2], media_id[2:4], media_id[4:],
|
||||
)
|
||||
|
||||
url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
|
||||
|
||||
def url_cache_filepath_dirs_to_delete(self, media_id):
|
||||
"The dirs to try and remove if we delete the media_id file"
|
||||
if NEW_FORMAT_ID_RE.match(media_id):
|
||||
return [
|
||||
os.path.join(
|
||||
self.base_path, "url_cache",
|
||||
media_id[:10],
|
||||
),
|
||||
]
|
||||
else:
|
||||
return [
|
||||
os.path.join(
|
||||
self.base_path, "url_cache",
|
||||
media_id[0:2], media_id[2:4],
|
||||
),
|
||||
os.path.join(
|
||||
self.base_path, "url_cache",
|
||||
media_id[0:2],
|
||||
),
|
||||
]
|
||||
|
||||
def url_cache_thumbnail_rel(self, media_id, width, height, content_type,
|
||||
method):
|
||||
# Media id is of the form <DATE><RANDOM_STRING>
|
||||
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
||||
|
||||
def url_cache_thumbnail(self, media_id, width, height, content_type,
|
||||
method):
|
||||
top_level_type, sub_type = content_type.split("/")
|
||||
file_name = "%i-%i-%s-%s-%s" % (
|
||||
width, height, top_level_type, sub_type, method
|
||||
)
|
||||
return os.path.join(
|
||||
self.base_path, "url_cache_thumbnails",
|
||||
media_id[0:2], media_id[2:4], media_id[4:],
|
||||
file_name
|
||||
)
|
||||
|
||||
if NEW_FORMAT_ID_RE.match(media_id):
|
||||
return os.path.join(
|
||||
"url_cache_thumbnails",
|
||||
media_id[:10], media_id[11:],
|
||||
file_name
|
||||
)
|
||||
else:
|
||||
return os.path.join(
|
||||
"url_cache_thumbnails",
|
||||
media_id[0:2], media_id[2:4], media_id[4:],
|
||||
file_name
|
||||
)
|
||||
|
||||
url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
|
||||
|
||||
def url_cache_thumbnail_directory(self, media_id):
|
||||
# Media id is of the form <DATE><RANDOM_STRING>
|
||||
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
||||
|
||||
if NEW_FORMAT_ID_RE.match(media_id):
|
||||
return os.path.join(
|
||||
self.base_path, "url_cache_thumbnails",
|
||||
media_id[:10], media_id[11:],
|
||||
)
|
||||
else:
|
||||
return os.path.join(
|
||||
self.base_path, "url_cache_thumbnails",
|
||||
media_id[0:2], media_id[2:4], media_id[4:],
|
||||
)
|
||||
|
||||
def url_cache_thumbnail_dirs_to_delete(self, media_id):
|
||||
"The dirs to try and remove if we delete the media_id thumbnails"
|
||||
# Media id is of the form <DATE><RANDOM_STRING>
|
||||
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
||||
if NEW_FORMAT_ID_RE.match(media_id):
|
||||
return [
|
||||
os.path.join(
|
||||
self.base_path, "url_cache_thumbnails",
|
||||
media_id[:10], media_id[11:],
|
||||
),
|
||||
os.path.join(
|
||||
self.base_path, "url_cache_thumbnails",
|
||||
media_id[:10],
|
||||
),
|
||||
]
|
||||
else:
|
||||
return [
|
||||
os.path.join(
|
||||
self.base_path, "url_cache_thumbnails",
|
||||
media_id[0:2], media_id[2:4], media_id[4:],
|
||||
),
|
||||
os.path.join(
|
||||
self.base_path, "url_cache_thumbnails",
|
||||
media_id[0:2], media_id[2:4],
|
||||
),
|
||||
os.path.join(
|
||||
self.base_path, "url_cache_thumbnails",
|
||||
media_id[0:2],
|
||||
),
|
||||
]
|
||||
|
@ -33,7 +33,7 @@ from synapse.api.errors import SynapseError, HttpResponseException, \
|
||||
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util.stringutils import is_ascii
|
||||
from synapse.util.logcontext import preserve_context_over_fn
|
||||
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
||||
import os
|
||||
@ -59,7 +59,14 @@ class MediaRepository(object):
|
||||
self.store = hs.get_datastore()
|
||||
self.max_upload_size = hs.config.max_upload_size
|
||||
self.max_image_pixels = hs.config.max_image_pixels
|
||||
self.filepaths = MediaFilePaths(hs.config.media_store_path)
|
||||
|
||||
self.primary_base_path = hs.config.media_store_path
|
||||
self.filepaths = MediaFilePaths(self.primary_base_path)
|
||||
|
||||
self.backup_base_path = hs.config.backup_media_store_path
|
||||
|
||||
self.synchronous_backup_media_store = hs.config.synchronous_backup_media_store
|
||||
|
||||
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
||||
self.thumbnail_requirements = hs.config.thumbnail_requirements
|
||||
|
||||
@ -87,18 +94,86 @@ class MediaRepository(object):
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
@staticmethod
|
||||
def _write_file_synchronously(source, fname):
|
||||
"""Write `source` to the path `fname` synchronously. Should be called
|
||||
from a thread.
|
||||
|
||||
Args:
|
||||
source: A file like object to be written
|
||||
fname (str): Path to write to
|
||||
"""
|
||||
MediaRepository._makedirs(fname)
|
||||
source.seek(0) # Ensure we read from the start of the file
|
||||
with open(fname, "wb") as f:
|
||||
shutil.copyfileobj(source, f)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def write_to_file_and_backup(self, source, path):
|
||||
"""Write `source` to the on disk media store, and also the backup store
|
||||
if configured.
|
||||
|
||||
Args:
|
||||
source: A file like object that should be written
|
||||
path (str): Relative path to write file to
|
||||
|
||||
Returns:
|
||||
Deferred[str]: the file path written to in the primary media store
|
||||
"""
|
||||
fname = os.path.join(self.primary_base_path, path)
|
||||
|
||||
# Write to the main repository
|
||||
yield make_deferred_yieldable(threads.deferToThread(
|
||||
self._write_file_synchronously, source, fname,
|
||||
))
|
||||
|
||||
# Write to backup repository
|
||||
yield self.copy_to_backup(path)
|
||||
|
||||
defer.returnValue(fname)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def copy_to_backup(self, path):
|
||||
"""Copy a file from the primary to backup media store, if configured.
|
||||
|
||||
Args:
|
||||
path(str): Relative path to write file to
|
||||
"""
|
||||
if self.backup_base_path:
|
||||
primary_fname = os.path.join(self.primary_base_path, path)
|
||||
backup_fname = os.path.join(self.backup_base_path, path)
|
||||
|
||||
# We can either wait for successful writing to the backup repository
|
||||
# or write in the background and immediately return
|
||||
if self.synchronous_backup_media_store:
|
||||
yield make_deferred_yieldable(threads.deferToThread(
|
||||
shutil.copyfile, primary_fname, backup_fname,
|
||||
))
|
||||
else:
|
||||
preserve_fn(threads.deferToThread)(
|
||||
shutil.copyfile, primary_fname, backup_fname,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_content(self, media_type, upload_name, content, content_length,
|
||||
auth_user):
|
||||
"""Store uploaded content for a local user and return the mxc URL
|
||||
|
||||
Args:
|
||||
media_type(str): The content type of the file
|
||||
upload_name(str): The name of the file
|
||||
content: A file like object that is the content to store
|
||||
content_length(int): The length of the content
|
||||
auth_user(str): The user_id of the uploader
|
||||
|
||||
Returns:
|
||||
Deferred[str]: The mxc url of the stored content
|
||||
"""
|
||||
media_id = random_string(24)
|
||||
|
||||
fname = self.filepaths.local_media_filepath(media_id)
|
||||
self._makedirs(fname)
|
||||
|
||||
# This shouldn't block for very long because the content will have
|
||||
# already been uploaded at this point.
|
||||
with open(fname, "wb") as f:
|
||||
f.write(content)
|
||||
fname = yield self.write_to_file_and_backup(
|
||||
content, self.filepaths.local_media_filepath_rel(media_id)
|
||||
)
|
||||
|
||||
logger.info("Stored local media in file %r", fname)
|
||||
|
||||
@ -115,7 +190,7 @@ class MediaRepository(object):
|
||||
"media_length": content_length,
|
||||
}
|
||||
|
||||
yield self._generate_local_thumbnails(media_id, media_info)
|
||||
yield self._generate_thumbnails(None, media_id, media_info)
|
||||
|
||||
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
|
||||
|
||||
@ -148,9 +223,10 @@ class MediaRepository(object):
|
||||
def _download_remote_file(self, server_name, media_id):
|
||||
file_id = random_string(24)
|
||||
|
||||
fname = self.filepaths.remote_media_filepath(
|
||||
fpath = self.filepaths.remote_media_filepath_rel(
|
||||
server_name, file_id
|
||||
)
|
||||
fname = os.path.join(self.primary_base_path, fpath)
|
||||
self._makedirs(fname)
|
||||
|
||||
try:
|
||||
@ -192,6 +268,8 @@ class MediaRepository(object):
|
||||
server_name, media_id)
|
||||
raise SynapseError(502, "Failed to fetch remote media")
|
||||
|
||||
yield self.copy_to_backup(fpath)
|
||||
|
||||
media_type = headers["Content-Type"][0]
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
@ -244,7 +322,7 @@ class MediaRepository(object):
|
||||
"filesystem_id": file_id,
|
||||
}
|
||||
|
||||
yield self._generate_remote_thumbnails(
|
||||
yield self._generate_thumbnails(
|
||||
server_name, media_id, media_info
|
||||
)
|
||||
|
||||
@ -253,9 +331,8 @@ class MediaRepository(object):
|
||||
def _get_thumbnail_requirements(self, media_type):
|
||||
return self.thumbnail_requirements.get(media_type, ())
|
||||
|
||||
def _generate_thumbnail(self, input_path, t_path, t_width, t_height,
|
||||
def _generate_thumbnail(self, thumbnailer, t_width, t_height,
|
||||
t_method, t_type):
|
||||
thumbnailer = Thumbnailer(input_path)
|
||||
m_width = thumbnailer.width
|
||||
m_height = thumbnailer.height
|
||||
|
||||
@ -267,72 +344,105 @@ class MediaRepository(object):
|
||||
return
|
||||
|
||||
if t_method == "crop":
|
||||
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
|
||||
t_byte_source = thumbnailer.crop(t_width, t_height, t_type)
|
||||
elif t_method == "scale":
|
||||
t_width, t_height = thumbnailer.aspect(t_width, t_height)
|
||||
t_width = min(m_width, t_width)
|
||||
t_height = min(m_height, t_height)
|
||||
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
|
||||
t_byte_source = thumbnailer.scale(t_width, t_height, t_type)
|
||||
else:
|
||||
t_len = None
|
||||
t_byte_source = None
|
||||
|
||||
return t_len
|
||||
return t_byte_source
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def generate_local_exact_thumbnail(self, media_id, t_width, t_height,
|
||||
t_method, t_type):
|
||||
input_path = self.filepaths.local_media_filepath(media_id)
|
||||
|
||||
t_path = self.filepaths.local_media_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
self._makedirs(t_path)
|
||||
|
||||
t_len = yield preserve_context_over_fn(
|
||||
threads.deferToThread,
|
||||
thumbnailer = Thumbnailer(input_path)
|
||||
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
|
||||
self._generate_thumbnail,
|
||||
input_path, t_path, t_width, t_height, t_method, t_type
|
||||
)
|
||||
thumbnailer, t_width, t_height, t_method, t_type
|
||||
))
|
||||
|
||||
if t_byte_source:
|
||||
try:
|
||||
output_path = yield self.write_to_file_and_backup(
|
||||
t_byte_source,
|
||||
self.filepaths.local_media_thumbnail_rel(
|
||||
media_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
)
|
||||
finally:
|
||||
t_byte_source.close()
|
||||
|
||||
logger.info("Stored thumbnail in file %r", output_path)
|
||||
|
||||
t_len = os.path.getsize(output_path)
|
||||
|
||||
if t_len:
|
||||
yield self.store.store_local_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method, t_len
|
||||
)
|
||||
|
||||
defer.returnValue(t_path)
|
||||
defer.returnValue(output_path)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
|
||||
t_width, t_height, t_method, t_type):
|
||||
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
|
||||
|
||||
t_path = self.filepaths.remote_media_thumbnail(
|
||||
server_name, file_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
self._makedirs(t_path)
|
||||
|
||||
t_len = yield preserve_context_over_fn(
|
||||
threads.deferToThread,
|
||||
thumbnailer = Thumbnailer(input_path)
|
||||
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
|
||||
self._generate_thumbnail,
|
||||
input_path, t_path, t_width, t_height, t_method, t_type
|
||||
)
|
||||
thumbnailer, t_width, t_height, t_method, t_type
|
||||
))
|
||||
|
||||
if t_byte_source:
|
||||
try:
|
||||
output_path = yield self.write_to_file_and_backup(
|
||||
t_byte_source,
|
||||
self.filepaths.remote_media_thumbnail_rel(
|
||||
server_name, file_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
)
|
||||
finally:
|
||||
t_byte_source.close()
|
||||
|
||||
logger.info("Stored thumbnail in file %r", output_path)
|
||||
|
||||
t_len = os.path.getsize(output_path)
|
||||
|
||||
if t_len:
|
||||
yield self.store.store_remote_media_thumbnail(
|
||||
server_name, media_id, file_id,
|
||||
t_width, t_height, t_type, t_method, t_len
|
||||
)
|
||||
|
||||
defer.returnValue(t_path)
|
||||
defer.returnValue(output_path)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _generate_local_thumbnails(self, media_id, media_info, url_cache=False):
|
||||
def _generate_thumbnails(self, server_name, media_id, media_info, url_cache=False):
|
||||
"""Generate and store thumbnails for an image.
|
||||
|
||||
Args:
|
||||
server_name(str|None): The server name if remote media, else None if local
|
||||
media_id(str)
|
||||
media_info(dict)
|
||||
url_cache(bool): If we are thumbnailing images downloaded for the URL cache,
|
||||
used exclusively by the url previewer
|
||||
|
||||
Returns:
|
||||
Deferred[dict]: Dict with "width" and "height" keys of original image
|
||||
"""
|
||||
media_type = media_info["media_type"]
|
||||
file_id = media_info.get("filesystem_id")
|
||||
requirements = self._get_thumbnail_requirements(media_type)
|
||||
if not requirements:
|
||||
return
|
||||
|
||||
if url_cache:
|
||||
if server_name:
|
||||
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
|
||||
elif url_cache:
|
||||
input_path = self.filepaths.url_cache_filepath(media_id)
|
||||
else:
|
||||
input_path = self.filepaths.local_media_filepath(media_id)
|
||||
@ -348,135 +458,72 @@ class MediaRepository(object):
|
||||
)
|
||||
return
|
||||
|
||||
local_thumbnails = []
|
||||
# We deduplicate the thumbnail sizes by ignoring the cropped versions if
|
||||
# they have the same dimensions of a scaled one.
|
||||
thumbnails = {}
|
||||
for r_width, r_height, r_method, r_type in requirements:
|
||||
if r_method == "crop":
|
||||
thumbnails.setdefault((r_width, r_height, r_type), r_method)
|
||||
elif r_method == "scale":
|
||||
t_width, t_height = thumbnailer.aspect(r_width, r_height)
|
||||
t_width = min(m_width, t_width)
|
||||
t_height = min(m_height, t_height)
|
||||
thumbnails[(t_width, t_height, r_type)] = r_method
|
||||
|
||||
def generate_thumbnails():
|
||||
scales = set()
|
||||
crops = set()
|
||||
for r_width, r_height, r_method, r_type in requirements:
|
||||
if r_method == "scale":
|
||||
t_width, t_height = thumbnailer.aspect(r_width, r_height)
|
||||
scales.add((
|
||||
min(m_width, t_width), min(m_height, t_height), r_type,
|
||||
))
|
||||
elif r_method == "crop":
|
||||
crops.add((r_width, r_height, r_type))
|
||||
|
||||
for t_width, t_height, t_type in scales:
|
||||
t_method = "scale"
|
||||
if url_cache:
|
||||
t_path = self.filepaths.url_cache_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
else:
|
||||
t_path = self.filepaths.local_media_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
self._makedirs(t_path)
|
||||
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
|
||||
|
||||
local_thumbnails.append((
|
||||
media_id, t_width, t_height, t_type, t_method, t_len
|
||||
))
|
||||
|
||||
for t_width, t_height, t_type in crops:
|
||||
if (t_width, t_height, t_type) in scales:
|
||||
# If the aspect ratio of the cropped thumbnail matches a purely
|
||||
# scaled one then there is no point in calculating a separate
|
||||
# thumbnail.
|
||||
continue
|
||||
t_method = "crop"
|
||||
if url_cache:
|
||||
t_path = self.filepaths.url_cache_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
else:
|
||||
t_path = self.filepaths.local_media_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
self._makedirs(t_path)
|
||||
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
|
||||
local_thumbnails.append((
|
||||
media_id, t_width, t_height, t_type, t_method, t_len
|
||||
))
|
||||
|
||||
yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
|
||||
|
||||
for l in local_thumbnails:
|
||||
yield self.store.store_local_thumbnail(*l)
|
||||
|
||||
defer.returnValue({
|
||||
"width": m_width,
|
||||
"height": m_height,
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _generate_remote_thumbnails(self, server_name, media_id, media_info):
|
||||
media_type = media_info["media_type"]
|
||||
file_id = media_info["filesystem_id"]
|
||||
requirements = self._get_thumbnail_requirements(media_type)
|
||||
if not requirements:
|
||||
return
|
||||
|
||||
remote_thumbnails = []
|
||||
|
||||
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
|
||||
thumbnailer = Thumbnailer(input_path)
|
||||
m_width = thumbnailer.width
|
||||
m_height = thumbnailer.height
|
||||
|
||||
def generate_thumbnails():
|
||||
if m_width * m_height >= self.max_image_pixels:
|
||||
logger.info(
|
||||
"Image too large to thumbnail %r x %r > %r",
|
||||
m_width, m_height, self.max_image_pixels
|
||||
)
|
||||
return
|
||||
|
||||
scales = set()
|
||||
crops = set()
|
||||
for r_width, r_height, r_method, r_type in requirements:
|
||||
if r_method == "scale":
|
||||
t_width, t_height = thumbnailer.aspect(r_width, r_height)
|
||||
scales.add((
|
||||
min(m_width, t_width), min(m_height, t_height), r_type,
|
||||
))
|
||||
elif r_method == "crop":
|
||||
crops.add((r_width, r_height, r_type))
|
||||
|
||||
for t_width, t_height, t_type in scales:
|
||||
t_method = "scale"
|
||||
t_path = self.filepaths.remote_media_thumbnail(
|
||||
# Now we generate the thumbnails for each dimension, store it
|
||||
for (t_width, t_height, t_type), t_method in thumbnails.iteritems():
|
||||
# Work out the correct file name for thumbnail
|
||||
if server_name:
|
||||
file_path = self.filepaths.remote_media_thumbnail_rel(
|
||||
server_name, file_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
self._makedirs(t_path)
|
||||
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
|
||||
remote_thumbnails.append([
|
||||
server_name, media_id, file_id,
|
||||
t_width, t_height, t_type, t_method, t_len
|
||||
])
|
||||
|
||||
for t_width, t_height, t_type in crops:
|
||||
if (t_width, t_height, t_type) in scales:
|
||||
# If the aspect ratio of the cropped thumbnail matches a purely
|
||||
# scaled one then there is no point in calculating a separate
|
||||
# thumbnail.
|
||||
continue
|
||||
t_method = "crop"
|
||||
t_path = self.filepaths.remote_media_thumbnail(
|
||||
server_name, file_id, t_width, t_height, t_type, t_method
|
||||
elif url_cache:
|
||||
file_path = self.filepaths.url_cache_thumbnail_rel(
|
||||
media_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
self._makedirs(t_path)
|
||||
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
|
||||
remote_thumbnails.append([
|
||||
else:
|
||||
file_path = self.filepaths.local_media_thumbnail_rel(
|
||||
media_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
|
||||
# Generate the thumbnail
|
||||
if t_method == "crop":
|
||||
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
|
||||
thumbnailer.crop,
|
||||
t_width, t_height, t_type,
|
||||
))
|
||||
elif t_method == "scale":
|
||||
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
|
||||
thumbnailer.scale,
|
||||
t_width, t_height, t_type,
|
||||
))
|
||||
else:
|
||||
logger.error("Unrecognized method: %r", t_method)
|
||||
continue
|
||||
|
||||
if not t_byte_source:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Write to disk
|
||||
output_path = yield self.write_to_file_and_backup(
|
||||
t_byte_source, file_path,
|
||||
)
|
||||
finally:
|
||||
t_byte_source.close()
|
||||
|
||||
t_len = os.path.getsize(output_path)
|
||||
|
||||
# Write to database
|
||||
if server_name:
|
||||
yield self.store.store_remote_media_thumbnail(
|
||||
server_name, media_id, file_id,
|
||||
t_width, t_height, t_type, t_method, t_len
|
||||
])
|
||||
|
||||
yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
|
||||
|
||||
for r in remote_thumbnails:
|
||||
yield self.store.store_remote_media_thumbnail(*r)
|
||||
)
|
||||
else:
|
||||
yield self.store.store_local_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method, t_len
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
"width": m_width,
|
||||
@ -497,6 +544,8 @@ class MediaRepository(object):
|
||||
|
||||
logger.info("Deleting: %r", key)
|
||||
|
||||
# TODO: Should we delete from the backup store
|
||||
|
||||
with (yield self.remote_media_linearizer.queue(key)):
|
||||
full_path = self.filepaths.remote_media_filepath(origin, file_id)
|
||||
try:
|
||||
|
@ -36,6 +36,9 @@ import cgi
|
||||
import ujson as json
|
||||
import urlparse
|
||||
import itertools
|
||||
import datetime
|
||||
import errno
|
||||
import shutil
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -56,6 +59,7 @@ class PreviewUrlResource(Resource):
|
||||
self.store = hs.get_datastore()
|
||||
self.client = SpiderHttpClient(hs)
|
||||
self.media_repo = media_repo
|
||||
self.primary_base_path = media_repo.primary_base_path
|
||||
|
||||
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
|
||||
|
||||
@ -70,6 +74,10 @@ class PreviewUrlResource(Resource):
|
||||
|
||||
self.downloads = {}
|
||||
|
||||
self._cleaner_loop = self.clock.looping_call(
|
||||
self._expire_url_cache_data, 10 * 1000
|
||||
)
|
||||
|
||||
def render_GET(self, request):
|
||||
self._async_render_GET(request)
|
||||
return NOT_DONE_YET
|
||||
@ -130,7 +138,7 @@ class PreviewUrlResource(Resource):
|
||||
cache_result = yield self.store.get_url_cache(url, ts)
|
||||
if (
|
||||
cache_result and
|
||||
cache_result["download_ts"] + cache_result["expires"] > ts and
|
||||
cache_result["expires_ts"] > ts and
|
||||
cache_result["response_code"] / 100 == 2
|
||||
):
|
||||
respond_with_json_bytes(
|
||||
@ -163,8 +171,8 @@ class PreviewUrlResource(Resource):
|
||||
logger.debug("got media_info of '%s'" % media_info)
|
||||
|
||||
if _is_media(media_info['media_type']):
|
||||
dims = yield self.media_repo._generate_local_thumbnails(
|
||||
media_info['filesystem_id'], media_info, url_cache=True,
|
||||
dims = yield self.media_repo._generate_thumbnails(
|
||||
None, media_info['filesystem_id'], media_info, url_cache=True,
|
||||
)
|
||||
|
||||
og = {
|
||||
@ -209,8 +217,8 @@ class PreviewUrlResource(Resource):
|
||||
|
||||
if _is_media(image_info['media_type']):
|
||||
# TODO: make sure we don't choke on white-on-transparent images
|
||||
dims = yield self.media_repo._generate_local_thumbnails(
|
||||
image_info['filesystem_id'], image_info, url_cache=True,
|
||||
dims = yield self.media_repo._generate_thumbnails(
|
||||
None, image_info['filesystem_id'], image_info, url_cache=True,
|
||||
)
|
||||
if dims:
|
||||
og["og:image:width"] = dims['width']
|
||||
@ -239,7 +247,7 @@ class PreviewUrlResource(Resource):
|
||||
url,
|
||||
media_info["response_code"],
|
||||
media_info["etag"],
|
||||
media_info["expires"],
|
||||
media_info["expires"] + media_info["created_ts"],
|
||||
json.dumps(og),
|
||||
media_info["filesystem_id"],
|
||||
media_info["created_ts"],
|
||||
@ -253,10 +261,10 @@ class PreviewUrlResource(Resource):
|
||||
# we're most likely being explicitly triggered by a human rather than a
|
||||
# bot, so are we really a robot?
|
||||
|
||||
# XXX: horrible duplication with base_resource's _download_remote_file()
|
||||
file_id = random_string(24)
|
||||
file_id = datetime.date.today().isoformat() + '_' + random_string(16)
|
||||
|
||||
fname = self.filepaths.url_cache_filepath(file_id)
|
||||
fpath = self.filepaths.url_cache_filepath_rel(file_id)
|
||||
fname = os.path.join(self.primary_base_path, fpath)
|
||||
self.media_repo._makedirs(fname)
|
||||
|
||||
try:
|
||||
@ -267,6 +275,8 @@ class PreviewUrlResource(Resource):
|
||||
)
|
||||
# FIXME: pass through 404s and other error messages nicely
|
||||
|
||||
yield self.media_repo.copy_to_backup(fpath)
|
||||
|
||||
media_type = headers["Content-Type"][0]
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
@ -328,6 +338,91 @@ class PreviewUrlResource(Resource):
|
||||
"etag": headers["ETag"][0] if "ETag" in headers else None,
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _expire_url_cache_data(self):
|
||||
"""Clean up expired url cache content, media and thumbnails.
|
||||
"""
|
||||
|
||||
# TODO: Delete from backup media store
|
||||
|
||||
now = self.clock.time_msec()
|
||||
|
||||
# First we delete expired url cache entries
|
||||
media_ids = yield self.store.get_expired_url_cache(now)
|
||||
|
||||
removed_media = []
|
||||
for media_id in media_ids:
|
||||
fname = self.filepaths.url_cache_filepath(media_id)
|
||||
try:
|
||||
os.remove(fname)
|
||||
except OSError as e:
|
||||
# If the path doesn't exist, meh
|
||||
if e.errno != errno.ENOENT:
|
||||
logger.warn("Failed to remove media: %r: %s", media_id, e)
|
||||
continue
|
||||
|
||||
removed_media.append(media_id)
|
||||
|
||||
try:
|
||||
dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
|
||||
for dir in dirs:
|
||||
os.rmdir(dir)
|
||||
except:
|
||||
pass
|
||||
|
||||
yield self.store.delete_url_cache(removed_media)
|
||||
|
||||
if removed_media:
|
||||
logger.info("Deleted %d entries from url cache", len(removed_media))
|
||||
|
||||
# Now we delete old images associated with the url cache.
|
||||
# These may be cached for a bit on the client (i.e., they
|
||||
# may have a room open with a preview url thing open).
|
||||
# So we wait a couple of days before deleting, just in case.
|
||||
expire_before = now - 2 * 24 * 60 * 60 * 1000
|
||||
media_ids = yield self.store.get_url_cache_media_before(expire_before)
|
||||
|
||||
removed_media = []
|
||||
for media_id in media_ids:
|
||||
fname = self.filepaths.url_cache_filepath(media_id)
|
||||
try:
|
||||
os.remove(fname)
|
||||
except OSError as e:
|
||||
# If the path doesn't exist, meh
|
||||
if e.errno != errno.ENOENT:
|
||||
logger.warn("Failed to remove media: %r: %s", media_id, e)
|
||||
continue
|
||||
|
||||
try:
|
||||
dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
|
||||
for dir in dirs:
|
||||
os.rmdir(dir)
|
||||
except:
|
||||
pass
|
||||
|
||||
thumbnail_dir = self.filepaths.url_cache_thumbnail_directory(media_id)
|
||||
try:
|
||||
shutil.rmtree(thumbnail_dir)
|
||||
except OSError as e:
|
||||
# If the path doesn't exist, meh
|
||||
if e.errno != errno.ENOENT:
|
||||
logger.warn("Failed to remove media: %r: %s", media_id, e)
|
||||
continue
|
||||
|
||||
removed_media.append(media_id)
|
||||
|
||||
try:
|
||||
dirs = self.filepaths.url_cache_thumbnail_dirs_to_delete(media_id)
|
||||
for dir in dirs:
|
||||
os.rmdir(dir)
|
||||
except:
|
||||
pass
|
||||
|
||||
yield self.store.delete_url_cache_media(removed_media)
|
||||
|
||||
if removed_media:
|
||||
logger.info("Deleted %d media from url cache", len(removed_media))
|
||||
|
||||
|
||||
def decode_and_calc_og(body, media_uri, request_encoding=None):
|
||||
from lxml import etree
|
||||
|
@ -50,12 +50,16 @@ class Thumbnailer(object):
|
||||
else:
|
||||
return ((max_height * self.width) // self.height, max_height)
|
||||
|
||||
def scale(self, output_path, width, height, output_type):
|
||||
"""Rescales the image to the given dimensions"""
|
||||
scaled = self.image.resize((width, height), Image.ANTIALIAS)
|
||||
return self.save_image(scaled, output_type, output_path)
|
||||
def scale(self, width, height, output_type):
|
||||
"""Rescales the image to the given dimensions.
|
||||
|
||||
def crop(self, output_path, width, height, output_type):
|
||||
Returns:
|
||||
BytesIO: the bytes of the encoded image ready to be written to disk
|
||||
"""
|
||||
scaled = self.image.resize((width, height), Image.ANTIALIAS)
|
||||
return self._encode_image(scaled, output_type)
|
||||
|
||||
def crop(self, width, height, output_type):
|
||||
"""Rescales and crops the image to the given dimensions preserving
|
||||
aspect::
|
||||
(w_in / h_in) = (w_scaled / h_scaled)
|
||||
@ -65,6 +69,9 @@ class Thumbnailer(object):
|
||||
Args:
|
||||
max_width: The largest possible width.
|
||||
max_height: The larget possible height.
|
||||
|
||||
Returns:
|
||||
BytesIO: the bytes of the encoded image ready to be written to disk
|
||||
"""
|
||||
if width * self.height > height * self.width:
|
||||
scaled_height = (width * self.height) // self.width
|
||||
@ -82,13 +89,9 @@ class Thumbnailer(object):
|
||||
crop_left = (scaled_width - width) // 2
|
||||
crop_right = width + crop_left
|
||||
cropped = scaled_image.crop((crop_left, 0, crop_right, height))
|
||||
return self.save_image(cropped, output_type, output_path)
|
||||
return self._encode_image(cropped, output_type)
|
||||
|
||||
def save_image(self, output_image, output_type, output_path):
|
||||
def _encode_image(self, output_image, output_type):
|
||||
output_bytes_io = BytesIO()
|
||||
output_image.save(output_bytes_io, self.FORMATS[output_type], quality=80)
|
||||
output_bytes = output_bytes_io.getvalue()
|
||||
with open(output_path, "wb") as output_file:
|
||||
output_file.write(output_bytes)
|
||||
logger.info("Stored thumbnail in file %r", output_path)
|
||||
return len(output_bytes)
|
||||
return output_bytes_io
|
||||
|
@ -93,7 +93,7 @@ class UploadResource(Resource):
|
||||
# TODO(markjh): parse content-dispostion
|
||||
|
||||
content_uri = yield self.media_repo.create_content(
|
||||
media_type, upload_name, request.content.read(),
|
||||
media_type, upload_name, request.content,
|
||||
content_length, requester.user
|
||||
)
|
||||
|
||||
|
@ -31,6 +31,7 @@ from synapse.appservice.api import ApplicationServiceApi
|
||||
from synapse.appservice.scheduler import ApplicationServiceScheduler
|
||||
from synapse.crypto.keyring import Keyring
|
||||
from synapse.events.builder import EventBuilderFactory
|
||||
from synapse.events.spamcheck import SpamChecker
|
||||
from synapse.federation import initialize_http_replication
|
||||
from synapse.federation.send_queue import FederationRemoteSendQueue
|
||||
from synapse.federation.transport.client import TransportLayerClient
|
||||
@ -50,6 +51,10 @@ from synapse.handlers.initial_sync import InitialSyncHandler
|
||||
from synapse.handlers.receipts import ReceiptsHandler
|
||||
from synapse.handlers.read_marker import ReadMarkerHandler
|
||||
from synapse.handlers.user_directory import UserDirectoyHandler
|
||||
from synapse.handlers.groups_local import GroupsLocalHandler
|
||||
from synapse.handlers.profile import ProfileHandler
|
||||
from synapse.groups.groups_server import GroupsServerHandler
|
||||
from synapse.groups.attestations import GroupAttestionRenewer, GroupAttestationSigning
|
||||
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
|
||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||
from synapse.notifier import Notifier
|
||||
@ -111,6 +116,7 @@ class HomeServer(object):
|
||||
'application_service_scheduler',
|
||||
'application_service_handler',
|
||||
'device_message_handler',
|
||||
'profile_handler',
|
||||
'notifier',
|
||||
'distributor',
|
||||
'client_resource',
|
||||
@ -139,6 +145,11 @@ class HomeServer(object):
|
||||
'read_marker_handler',
|
||||
'action_generator',
|
||||
'user_directory_handler',
|
||||
'groups_local_handler',
|
||||
'groups_server_handler',
|
||||
'groups_attestation_signing',
|
||||
'groups_attestation_renewer',
|
||||
'spam_checker',
|
||||
]
|
||||
|
||||
def __init__(self, hostname, **kwargs):
|
||||
@ -251,6 +262,9 @@ class HomeServer(object):
|
||||
def build_initial_sync_handler(self):
|
||||
return InitialSyncHandler(self)
|
||||
|
||||
def build_profile_handler(self):
|
||||
return ProfileHandler(self)
|
||||
|
||||
def build_event_sources(self):
|
||||
return EventSources(self)
|
||||
|
||||
@ -309,6 +323,21 @@ class HomeServer(object):
|
||||
def build_user_directory_handler(self):
|
||||
return UserDirectoyHandler(self)
|
||||
|
||||
def build_groups_local_handler(self):
|
||||
return GroupsLocalHandler(self)
|
||||
|
||||
def build_groups_server_handler(self):
|
||||
return GroupsServerHandler(self)
|
||||
|
||||
def build_groups_attestation_signing(self):
|
||||
return GroupAttestationSigning(self)
|
||||
|
||||
def build_groups_attestation_renewer(self):
|
||||
return GroupAttestionRenewer(self)
|
||||
|
||||
def build_spam_checker(self):
|
||||
return SpamChecker(self)
|
||||
|
||||
def remove_pusher(self, app_id, push_key, user_id):
|
||||
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
|
||||
|
||||
|
@ -1,4 +1,6 @@
|
||||
import synapse.api.auth
|
||||
import synapse.federation.transaction_queue
|
||||
import synapse.federation.transport.client
|
||||
import synapse.handlers
|
||||
import synapse.handlers.auth
|
||||
import synapse.handlers.device
|
||||
@ -27,3 +29,9 @@ class HomeServer(object):
|
||||
|
||||
def get_state_handler(self) -> synapse.state.StateHandler:
|
||||
pass
|
||||
|
||||
def get_federation_sender(self) -> synapse.federation.transaction_queue.TransactionQueue:
|
||||
pass
|
||||
|
||||
def get_federation_transport_client(self) -> synapse.federation.transport.client.TransportLayerClient:
|
||||
pass
|
||||
|
@ -288,6 +288,9 @@ class StateHandler(object):
|
||||
"""
|
||||
logger.debug("resolve_state_groups event_ids %s", event_ids)
|
||||
|
||||
# map from state group id to the state in that state group (where
|
||||
# 'state' is a map from state key to event id)
|
||||
# dict[int, dict[(str, str), str]]
|
||||
state_groups_ids = yield self.store.get_state_groups_ids(
|
||||
room_id, event_ids
|
||||
)
|
||||
@ -320,11 +323,15 @@ class StateHandler(object):
|
||||
"Resolving state for %s with %d groups", room_id, len(state_groups_ids)
|
||||
)
|
||||
|
||||
# build a map from state key to the event_ids which set that state.
|
||||
# dict[(str, str), set[str])
|
||||
state = {}
|
||||
for st in state_groups_ids.values():
|
||||
for key, e_id in st.items():
|
||||
state.setdefault(key, set()).add(e_id)
|
||||
|
||||
# build a map from state key to the event_ids which set that state,
|
||||
# including only those where there are state keys in conflict.
|
||||
conflicted_state = {
|
||||
k: list(v)
|
||||
for k, v in state.items()
|
||||
@ -494,8 +501,14 @@ def _resolve_with_state_fac(unconflicted_state, conflicted_state,
|
||||
|
||||
logger.info("Asking for %d conflicted events", len(needed_events))
|
||||
|
||||
# dict[str, FrozenEvent]: a map from state event id to event. Only includes
|
||||
# the state events which are in conflict.
|
||||
state_map = yield state_map_factory(needed_events)
|
||||
|
||||
# get the ids of the auth events which allow us to authenticate the
|
||||
# conflicted state, picking only from the unconflicting state.
|
||||
#
|
||||
# dict[(str, str), str]: a map from state key to event id
|
||||
auth_events = _create_auth_events_from_maps(
|
||||
unconflicted_state, conflicted_state, state_map
|
||||
)
|
||||
|
@ -37,7 +37,7 @@ from .media_repository import MediaRepositoryStore
|
||||
from .rejections import RejectionsStore
|
||||
from .event_push_actions import EventPushActionsStore
|
||||
from .deviceinbox import DeviceInboxStore
|
||||
|
||||
from .group_server import GroupServerStore
|
||||
from .state import StateStore
|
||||
from .signatures import SignatureStore
|
||||
from .filtering import FilteringStore
|
||||
@ -88,6 +88,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
DeviceStore,
|
||||
DeviceInboxStore,
|
||||
UserDirectoryStore,
|
||||
GroupServerStore,
|
||||
):
|
||||
|
||||
def __init__(self, db_conn, hs):
|
||||
@ -135,6 +136,9 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
db_conn, "pushers", "id",
|
||||
extra_tables=[("deleted_pushers", "stream_id")],
|
||||
)
|
||||
self._group_updates_id_gen = StreamIdGenerator(
|
||||
db_conn, "local_group_updates", "stream_id",
|
||||
)
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
self._cache_id_gen = StreamIdGenerator(
|
||||
@ -235,6 +239,18 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
prefilled_cache=curr_state_delta_prefill,
|
||||
)
|
||||
|
||||
_group_updates_prefill, min_group_updates_id = self._get_cache_dict(
|
||||
db_conn, "local_group_updates",
|
||||
entity_column="user_id",
|
||||
stream_column="stream_id",
|
||||
max_value=self._group_updates_id_gen.get_current_token(),
|
||||
limit=1000,
|
||||
)
|
||||
self._group_updates_stream_cache = StreamChangeCache(
|
||||
"_group_updates_stream_cache", min_group_updates_id,
|
||||
prefilled_cache=_group_updates_prefill,
|
||||
)
|
||||
|
||||
cur = LoggingTransaction(
|
||||
db_conn.cursor(),
|
||||
name="_find_stream_orderings_for_times_txn",
|
||||
|
@ -743,6 +743,33 @@ class SQLBaseStore(object):
|
||||
txn.execute(sql, values)
|
||||
return cls.cursor_to_dict(txn)
|
||||
|
||||
def _simple_update(self, table, keyvalues, updatevalues, desc):
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_update_txn,
|
||||
table, keyvalues, updatevalues,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _simple_update_txn(txn, table, keyvalues, updatevalues):
|
||||
if keyvalues:
|
||||
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
|
||||
else:
|
||||
where = ""
|
||||
|
||||
update_sql = "UPDATE %s SET %s %s" % (
|
||||
table,
|
||||
", ".join("%s = ?" % (k,) for k in updatevalues),
|
||||
where,
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
update_sql,
|
||||
updatevalues.values() + keyvalues.values()
|
||||
)
|
||||
|
||||
return txn.rowcount
|
||||
|
||||
def _simple_update_one(self, table, keyvalues, updatevalues,
|
||||
desc="_simple_update_one"):
|
||||
"""Executes an UPDATE query on the named table, setting new values for
|
||||
@ -768,27 +795,13 @@ class SQLBaseStore(object):
|
||||
table, keyvalues, updatevalues,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
|
||||
if keyvalues:
|
||||
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
|
||||
else:
|
||||
where = ""
|
||||
@classmethod
|
||||
def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
|
||||
rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues)
|
||||
|
||||
update_sql = "UPDATE %s SET %s %s" % (
|
||||
table,
|
||||
", ".join("%s = ?" % (k,) for k in updatevalues),
|
||||
where,
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
update_sql,
|
||||
updatevalues.values() + keyvalues.values()
|
||||
)
|
||||
|
||||
if txn.rowcount == 0:
|
||||
if rowcount == 0:
|
||||
raise StoreError(404, "No row found")
|
||||
if txn.rowcount > 1:
|
||||
if rowcount > 1:
|
||||
raise StoreError(500, "More than one row matched")
|
||||
|
||||
@staticmethod
|
||||
|
@ -21,7 +21,7 @@ from synapse.events.utils import prune_event
|
||||
|
||||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.logcontext import (
|
||||
preserve_fn, PreserveLoggingContext, preserve_context_over_deferred
|
||||
preserve_fn, PreserveLoggingContext, make_deferred_yieldable
|
||||
)
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.metrics import Measure
|
||||
@ -88,13 +88,23 @@ class _EventPeristenceQueue(object):
|
||||
def add_to_queue(self, room_id, events_and_contexts, backfilled):
|
||||
"""Add events to the queue, with the given persist_event options.
|
||||
|
||||
NB: due to the normal usage pattern of this method, it does *not*
|
||||
follow the synapse logcontext rules, and leaves the logcontext in
|
||||
place whether or not the returned deferred is ready.
|
||||
|
||||
Args:
|
||||
room_id (str):
|
||||
events_and_contexts (list[(EventBase, EventContext)]):
|
||||
backfilled (bool):
|
||||
|
||||
Returns:
|
||||
defer.Deferred: a deferred which will resolve once the events are
|
||||
persisted. Runs its callbacks *without* a logcontext.
|
||||
"""
|
||||
queue = self._event_persist_queues.setdefault(room_id, deque())
|
||||
if queue:
|
||||
# if the last item in the queue has the same `backfilled` setting,
|
||||
# we can just add these new events to that item.
|
||||
end_item = queue[-1]
|
||||
if end_item.backfilled == backfilled:
|
||||
end_item.events_and_contexts.extend(events_and_contexts)
|
||||
@ -113,11 +123,11 @@ class _EventPeristenceQueue(object):
|
||||
def handle_queue(self, room_id, per_item_callback):
|
||||
"""Attempts to handle the queue for a room if not already being handled.
|
||||
|
||||
The given callback will be invoked with for each item in the queue,1
|
||||
The given callback will be invoked with for each item in the queue,
|
||||
of type _EventPersistQueueItem. The per_item_callback will continuously
|
||||
be called with new items, unless the queue becomnes empty. The return
|
||||
value of the function will be given to the deferreds waiting on the item,
|
||||
exceptions will be passed to the deferres as well.
|
||||
exceptions will be passed to the deferreds as well.
|
||||
|
||||
This function should therefore be called whenever anything is added
|
||||
to the queue.
|
||||
@ -233,7 +243,7 @@ class EventsStore(SQLBaseStore):
|
||||
|
||||
deferreds = []
|
||||
for room_id, evs_ctxs in partitioned.iteritems():
|
||||
d = preserve_fn(self._event_persist_queue.add_to_queue)(
|
||||
d = self._event_persist_queue.add_to_queue(
|
||||
room_id, evs_ctxs,
|
||||
backfilled=backfilled,
|
||||
)
|
||||
@ -242,7 +252,7 @@ class EventsStore(SQLBaseStore):
|
||||
for room_id in partitioned:
|
||||
self._maybe_start_persisting(room_id)
|
||||
|
||||
return preserve_context_over_deferred(
|
||||
return make_deferred_yieldable(
|
||||
defer.gatherResults(deferreds, consumeErrors=True)
|
||||
)
|
||||
|
||||
@ -267,7 +277,7 @@ class EventsStore(SQLBaseStore):
|
||||
|
||||
self._maybe_start_persisting(event.room_id)
|
||||
|
||||
yield preserve_context_over_deferred(deferred)
|
||||
yield make_deferred_yieldable(deferred)
|
||||
|
||||
max_persisted_id = yield self._stream_id_gen.get_current_token()
|
||||
defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id))
|
||||
@ -784,6 +794,9 @@ class EventsStore(SQLBaseStore):
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.is_host_joined, (room_id, host)
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.was_host_joined, (room_id, host)
|
||||
)
|
||||
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_users_in_room, (room_id,)
|
||||
@ -1523,7 +1536,7 @@ class EventsStore(SQLBaseStore):
|
||||
if not allow_rejected:
|
||||
rows[:] = [r for r in rows if not r["rejects"]]
|
||||
|
||||
res = yield preserve_context_over_deferred(defer.gatherResults(
|
||||
res = yield make_deferred_yieldable(defer.gatherResults(
|
||||
[
|
||||
preserve_fn(self._get_event_from_row)(
|
||||
row["internal_metadata"], row["json"], row["redacts"],
|
||||
|
1199
synapse/storage/group_server.py
Normal file
1199
synapse/storage/group_server.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -62,7 +62,7 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||
def get_url_cache_txn(txn):
|
||||
# get the most recently cached result (relative to the given ts)
|
||||
sql = (
|
||||
"SELECT response_code, etag, expires, og, media_id, download_ts"
|
||||
"SELECT response_code, etag, expires_ts, og, media_id, download_ts"
|
||||
" FROM local_media_repository_url_cache"
|
||||
" WHERE url = ? AND download_ts <= ?"
|
||||
" ORDER BY download_ts DESC LIMIT 1"
|
||||
@ -74,7 +74,7 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||
# ...or if we've requested a timestamp older than the oldest
|
||||
# copy in the cache, return the oldest copy (if any)
|
||||
sql = (
|
||||
"SELECT response_code, etag, expires, og, media_id, download_ts"
|
||||
"SELECT response_code, etag, expires_ts, og, media_id, download_ts"
|
||||
" FROM local_media_repository_url_cache"
|
||||
" WHERE url = ? AND download_ts > ?"
|
||||
" ORDER BY download_ts ASC LIMIT 1"
|
||||
@ -86,14 +86,14 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||
return None
|
||||
|
||||
return dict(zip((
|
||||
'response_code', 'etag', 'expires', 'og', 'media_id', 'download_ts'
|
||||
'response_code', 'etag', 'expires_ts', 'og', 'media_id', 'download_ts'
|
||||
), row))
|
||||
|
||||
return self.runInteraction(
|
||||
"get_url_cache", get_url_cache_txn
|
||||
)
|
||||
|
||||
def store_url_cache(self, url, response_code, etag, expires, og, media_id,
|
||||
def store_url_cache(self, url, response_code, etag, expires_ts, og, media_id,
|
||||
download_ts):
|
||||
return self._simple_insert(
|
||||
"local_media_repository_url_cache",
|
||||
@ -101,7 +101,7 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||
"url": url,
|
||||
"response_code": response_code,
|
||||
"etag": etag,
|
||||
"expires": expires,
|
||||
"expires_ts": expires_ts,
|
||||
"og": og,
|
||||
"media_id": media_id,
|
||||
"download_ts": download_ts,
|
||||
@ -238,3 +238,64 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||
},
|
||||
)
|
||||
return self.runInteraction("delete_remote_media", delete_remote_media_txn)
|
||||
|
||||
def get_expired_url_cache(self, now_ts):
|
||||
sql = (
|
||||
"SELECT media_id FROM local_media_repository_url_cache"
|
||||
" WHERE expires_ts < ?"
|
||||
" ORDER BY expires_ts ASC"
|
||||
" LIMIT 500"
|
||||
)
|
||||
|
||||
def _get_expired_url_cache_txn(txn):
|
||||
txn.execute(sql, (now_ts,))
|
||||
return [row[0] for row in txn]
|
||||
|
||||
return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn)
|
||||
|
||||
def delete_url_cache(self, media_ids):
|
||||
sql = (
|
||||
"DELETE FROM local_media_repository_url_cache"
|
||||
" WHERE media_id = ?"
|
||||
)
|
||||
|
||||
def _delete_url_cache_txn(txn):
|
||||
txn.executemany(sql, [(media_id,) for media_id in media_ids])
|
||||
|
||||
return self.runInteraction("delete_url_cache", _delete_url_cache_txn)
|
||||
|
||||
def get_url_cache_media_before(self, before_ts):
|
||||
sql = (
|
||||
"SELECT media_id FROM local_media_repository"
|
||||
" WHERE created_ts < ? AND url_cache IS NOT NULL"
|
||||
" ORDER BY created_ts ASC"
|
||||
" LIMIT 500"
|
||||
)
|
||||
|
||||
def _get_url_cache_media_before_txn(txn):
|
||||
txn.execute(sql, (before_ts,))
|
||||
return [row[0] for row in txn]
|
||||
|
||||
return self.runInteraction(
|
||||
"get_url_cache_media_before", _get_url_cache_media_before_txn,
|
||||
)
|
||||
|
||||
def delete_url_cache_media(self, media_ids):
|
||||
def _delete_url_cache_media_txn(txn):
|
||||
sql = (
|
||||
"DELETE FROM local_media_repository"
|
||||
" WHERE media_id = ?"
|
||||
)
|
||||
|
||||
txn.executemany(sql, [(media_id,) for media_id in media_ids])
|
||||
|
||||
sql = (
|
||||
"DELETE FROM local_media_repository_thumbnails"
|
||||
" WHERE media_id = ?"
|
||||
)
|
||||
|
||||
txn.executemany(sql, [(media_id,) for media_id in media_ids])
|
||||
|
||||
return self.runInteraction(
|
||||
"delete_url_cache_media", _delete_url_cache_media_txn,
|
||||
)
|
||||
|
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Remember to update this number every time a change is made to database
|
||||
# schema files, so the users will be informed on server restarts.
|
||||
SCHEMA_VERSION = 43
|
||||
SCHEMA_VERSION = 45
|
||||
|
||||
dir_path = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
|
@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
|
||||
|
||||
@ -55,3 +57,99 @@ class ProfileStore(SQLBaseStore):
|
||||
updatevalues={"avatar_url": new_avatar_url},
|
||||
desc="set_profile_avatar_url",
|
||||
)
|
||||
|
||||
def get_from_remote_profile_cache(self, user_id):
|
||||
return self._simple_select_one(
|
||||
table="remote_profile_cache",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=("displayname", "avatar_url",),
|
||||
allow_none=True,
|
||||
desc="get_from_remote_profile_cache",
|
||||
)
|
||||
|
||||
def add_remote_profile_cache(self, user_id, displayname, avatar_url):
|
||||
"""Ensure we are caching the remote user's profiles.
|
||||
|
||||
This should only be called when `is_subscribed_remote_profile_for_user`
|
||||
would return true for the user.
|
||||
"""
|
||||
return self._simple_upsert(
|
||||
table="remote_profile_cache",
|
||||
keyvalues={"user_id": user_id},
|
||||
values={
|
||||
"displayname": displayname,
|
||||
"avatar_url": avatar_url,
|
||||
"last_check": self._clock.time_msec(),
|
||||
},
|
||||
desc="add_remote_profile_cache",
|
||||
)
|
||||
|
||||
def update_remote_profile_cache(self, user_id, displayname, avatar_url):
|
||||
return self._simple_update(
|
||||
table="remote_profile_cache",
|
||||
keyvalues={"user_id": user_id},
|
||||
values={
|
||||
"displayname": displayname,
|
||||
"avatar_url": avatar_url,
|
||||
"last_check": self._clock.time_msec(),
|
||||
},
|
||||
desc="update_remote_profile_cache",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def maybe_delete_remote_profile_cache(self, user_id):
|
||||
"""Check if we still care about the remote user's profile, and if we
|
||||
don't then remove their profile from the cache
|
||||
"""
|
||||
subscribed = yield self.is_subscribed_remote_profile_for_user(user_id)
|
||||
if not subscribed:
|
||||
yield self._simple_delete(
|
||||
table="remote_profile_cache",
|
||||
keyvalues={"user_id": user_id},
|
||||
desc="delete_remote_profile_cache",
|
||||
)
|
||||
|
||||
def get_remote_profile_cache_entries_that_expire(self, last_checked):
|
||||
"""Get all users who haven't been checked since `last_checked`
|
||||
"""
|
||||
def _get_remote_profile_cache_entries_that_expire_txn(txn):
|
||||
sql = """
|
||||
SELECT user_id, displayname, avatar_url
|
||||
FROM remote_profile_cache
|
||||
WHERE last_check < ?
|
||||
"""
|
||||
|
||||
txn.execute(sql, (last_checked,))
|
||||
|
||||
return self.cursor_to_dict(txn)
|
||||
|
||||
return self.runInteraction(
|
||||
"get_remote_profile_cache_entries_that_expire",
|
||||
_get_remote_profile_cache_entries_that_expire_txn,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def is_subscribed_remote_profile_for_user(self, user_id):
|
||||
"""Check whether we are interested in a remote user's profile.
|
||||
"""
|
||||
res = yield self._simple_select_one_onecol(
|
||||
table="group_users",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcol="user_id",
|
||||
allow_none=True,
|
||||
desc="should_update_remote_profile_cache_for_user",
|
||||
)
|
||||
|
||||
if res:
|
||||
defer.returnValue(True)
|
||||
|
||||
res = yield self._simple_select_one_onecol(
|
||||
table="group_invites",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcol="user_id",
|
||||
allow_none=True,
|
||||
desc="should_update_remote_profile_cache_for_user",
|
||||
)
|
||||
|
||||
if res:
|
||||
defer.returnValue(True)
|
||||
|
@ -533,6 +533,46 @@ class RoomMemberStore(SQLBaseStore):
|
||||
|
||||
defer.returnValue(True)
|
||||
|
||||
@cachedInlineCallbacks()
|
||||
def was_host_joined(self, room_id, host):
|
||||
"""Check whether the server is or ever was in the room.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
host (str)
|
||||
|
||||
Returns:
|
||||
Deferred: Resolves to True if the host is/was in the room, otherwise
|
||||
False.
|
||||
"""
|
||||
if '%' in host or '_' in host:
|
||||
raise Exception("Invalid host name")
|
||||
|
||||
sql = """
|
||||
SELECT user_id FROM room_memberships
|
||||
WHERE room_id = ?
|
||||
AND user_id LIKE ?
|
||||
AND membership = 'join'
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
# We do need to be careful to ensure that host doesn't have any wild cards
|
||||
# in it, but we checked above for known ones and we'll check below that
|
||||
# the returned user actually has the correct domain.
|
||||
like_clause = "%:" + host
|
||||
|
||||
rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause)
|
||||
|
||||
if not rows:
|
||||
defer.returnValue(False)
|
||||
|
||||
user_id = rows[0][0]
|
||||
if get_domain_from_id(user_id) != host:
|
||||
# This can only happen if the host name has something funky in it
|
||||
raise Exception("Invalid host name")
|
||||
|
||||
defer.returnValue(True)
|
||||
|
||||
def get_joined_hosts(self, room_id, state_entry):
|
||||
state_group = state_entry.state_group
|
||||
if not state_group:
|
||||
|
38
synapse/storage/schema/delta/44/expire_url_cache.sql
Normal file
38
synapse/storage/schema/delta/44/expire_url_cache.sql
Normal file
@ -0,0 +1,38 @@
|
||||
/* Copyright 2017 New Vector Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL;
|
||||
|
||||
-- we need to change `expires` to `expires_ts` so that we can index on it. SQLite doesn't support
|
||||
-- indices on expressions until 3.9.
|
||||
CREATE TABLE local_media_repository_url_cache_new(
|
||||
url TEXT,
|
||||
response_code INTEGER,
|
||||
etag TEXT,
|
||||
expires_ts BIGINT,
|
||||
og TEXT,
|
||||
media_id TEXT,
|
||||
download_ts BIGINT
|
||||
);
|
||||
|
||||
INSERT INTO local_media_repository_url_cache_new
|
||||
SELECT url, response_code, etag, expires + download_ts, og, media_id, download_ts FROM local_media_repository_url_cache;
|
||||
|
||||
DROP TABLE local_media_repository_url_cache;
|
||||
ALTER TABLE local_media_repository_url_cache_new RENAME TO local_media_repository_url_cache;
|
||||
|
||||
CREATE INDEX local_media_repository_url_cache_expires_idx ON local_media_repository_url_cache(expires_ts);
|
||||
CREATE INDEX local_media_repository_url_cache_by_url_download_ts ON local_media_repository_url_cache(url, download_ts);
|
||||
CREATE INDEX local_media_repository_url_cache_media_idx ON local_media_repository_url_cache(media_id);
|
167
synapse/storage/schema/delta/45/group_server.sql
Normal file
167
synapse/storage/schema/delta/45/group_server.sql
Normal file
@ -0,0 +1,167 @@
|
||||
/* Copyright 2017 Vector Creations Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE TABLE groups (
|
||||
group_id TEXT NOT NULL,
|
||||
name TEXT, -- the display name of the room
|
||||
avatar_url TEXT,
|
||||
short_description TEXT,
|
||||
long_description TEXT
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX groups_idx ON groups(group_id);
|
||||
|
||||
|
||||
-- list of users the group server thinks are joined
|
||||
CREATE TABLE group_users (
|
||||
group_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
is_admin BOOLEAN NOT NULL,
|
||||
is_public BOOLEAN NOT NULL -- whether the users membership can be seen by everyone
|
||||
);
|
||||
|
||||
|
||||
CREATE INDEX groups_users_g_idx ON group_users(group_id, user_id);
|
||||
CREATE INDEX groups_users_u_idx ON group_users(user_id);
|
||||
|
||||
-- list of users the group server thinks are invited
|
||||
CREATE TABLE group_invites (
|
||||
group_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX groups_invites_g_idx ON group_invites(group_id, user_id);
|
||||
CREATE INDEX groups_invites_u_idx ON group_invites(user_id);
|
||||
|
||||
|
||||
CREATE TABLE group_rooms (
|
||||
group_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
is_public BOOLEAN NOT NULL -- whether the room can be seen by everyone
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX groups_rooms_g_idx ON group_rooms(group_id, room_id);
|
||||
CREATE INDEX groups_rooms_r_idx ON group_rooms(room_id);
|
||||
|
||||
|
||||
-- Rooms to include in the summary
|
||||
CREATE TABLE group_summary_rooms (
|
||||
group_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
category_id TEXT NOT NULL,
|
||||
room_order BIGINT NOT NULL,
|
||||
is_public BOOLEAN NOT NULL, -- whether the room should be show to everyone
|
||||
UNIQUE (group_id, category_id, room_id, room_order),
|
||||
CHECK (room_order > 0)
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX group_summary_rooms_g_idx ON group_summary_rooms(group_id, room_id, category_id);
|
||||
|
||||
|
||||
-- Categories to include in the summary
|
||||
CREATE TABLE group_summary_room_categories (
|
||||
group_id TEXT NOT NULL,
|
||||
category_id TEXT NOT NULL,
|
||||
cat_order BIGINT NOT NULL,
|
||||
UNIQUE (group_id, category_id, cat_order),
|
||||
CHECK (cat_order > 0)
|
||||
);
|
||||
|
||||
-- The categories in the group
|
||||
CREATE TABLE group_room_categories (
|
||||
group_id TEXT NOT NULL,
|
||||
category_id TEXT NOT NULL,
|
||||
profile TEXT NOT NULL,
|
||||
is_public BOOLEAN NOT NULL, -- whether the category should be show to everyone
|
||||
UNIQUE (group_id, category_id)
|
||||
);
|
||||
|
||||
-- The users to include in the group summary
|
||||
CREATE TABLE group_summary_users (
|
||||
group_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
role_id TEXT NOT NULL,
|
||||
user_order BIGINT NOT NULL,
|
||||
is_public BOOLEAN NOT NULL -- whether the user should be show to everyone
|
||||
);
|
||||
|
||||
CREATE INDEX group_summary_users_g_idx ON group_summary_users(group_id);
|
||||
|
||||
-- The roles to include in the group summary
|
||||
CREATE TABLE group_summary_roles (
|
||||
group_id TEXT NOT NULL,
|
||||
role_id TEXT NOT NULL,
|
||||
role_order BIGINT NOT NULL,
|
||||
UNIQUE (group_id, role_id, role_order),
|
||||
CHECK (role_order > 0)
|
||||
);
|
||||
|
||||
|
||||
-- The roles in a groups
|
||||
CREATE TABLE group_roles (
|
||||
group_id TEXT NOT NULL,
|
||||
role_id TEXT NOT NULL,
|
||||
profile TEXT NOT NULL,
|
||||
is_public BOOLEAN NOT NULL, -- whether the role should be show to everyone
|
||||
UNIQUE (group_id, role_id)
|
||||
);
|
||||
|
||||
|
||||
-- List of attestations we've given out and need to renew
|
||||
CREATE TABLE group_attestations_renewals (
|
||||
group_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
valid_until_ms BIGINT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX group_attestations_renewals_g_idx ON group_attestations_renewals(group_id, user_id);
|
||||
CREATE INDEX group_attestations_renewals_u_idx ON group_attestations_renewals(user_id);
|
||||
CREATE INDEX group_attestations_renewals_v_idx ON group_attestations_renewals(valid_until_ms);
|
||||
|
||||
|
||||
-- List of attestations we've received from remotes and are interested in.
|
||||
CREATE TABLE group_attestations_remote (
|
||||
group_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
valid_until_ms BIGINT NOT NULL,
|
||||
attestation_json TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX group_attestations_remote_g_idx ON group_attestations_remote(group_id, user_id);
|
||||
CREATE INDEX group_attestations_remote_u_idx ON group_attestations_remote(user_id);
|
||||
CREATE INDEX group_attestations_remote_v_idx ON group_attestations_remote(valid_until_ms);
|
||||
|
||||
|
||||
-- The group membership for the HS's users
|
||||
CREATE TABLE local_group_membership (
|
||||
group_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
is_admin BOOLEAN NOT NULL,
|
||||
membership TEXT NOT NULL,
|
||||
is_publicised BOOLEAN NOT NULL, -- if the user is publicising their membership
|
||||
content TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX local_group_membership_u_idx ON local_group_membership(user_id, group_id);
|
||||
CREATE INDEX local_group_membership_g_idx ON local_group_membership(group_id);
|
||||
|
||||
|
||||
CREATE TABLE local_group_updates (
|
||||
stream_id BIGINT NOT NULL,
|
||||
group_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
content TEXT NOT NULL
|
||||
);
|
28
synapse/storage/schema/delta/45/profile_cache.sql
Normal file
28
synapse/storage/schema/delta/45/profile_cache.sql
Normal file
@ -0,0 +1,28 @@
|
||||
/* Copyright 2017 New Vector Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
-- A subset of remote users whose profiles we have cached.
|
||||
-- Whether a user is in this table or not is defined by the storage function
|
||||
-- `is_subscribed_remote_profile_for_user`
|
||||
CREATE TABLE remote_profile_cache (
|
||||
user_id TEXT NOT NULL,
|
||||
displayname TEXT,
|
||||
avatar_url TEXT,
|
||||
last_check BIGINT NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX remote_profile_cache_user_id ON remote_profile_cache(user_id);
|
||||
CREATE INDEX remote_profile_cache_time ON remote_profile_cache(last_check);
|
@ -45,6 +45,7 @@ class EventSources(object):
|
||||
push_rules_key, _ = self.store.get_push_rules_stream_token()
|
||||
to_device_key = self.store.get_to_device_stream_token()
|
||||
device_list_key = self.store.get_device_stream_token()
|
||||
groups_key = self.store.get_group_stream_token()
|
||||
|
||||
token = StreamToken(
|
||||
room_key=(
|
||||
@ -65,6 +66,7 @@ class EventSources(object):
|
||||
push_rules_key=push_rules_key,
|
||||
to_device_key=to_device_key,
|
||||
device_list_key=device_list_key,
|
||||
groups_key=groups_key,
|
||||
)
|
||||
defer.returnValue(token)
|
||||
|
||||
@ -73,6 +75,7 @@ class EventSources(object):
|
||||
push_rules_key, _ = self.store.get_push_rules_stream_token()
|
||||
to_device_key = self.store.get_to_device_stream_token()
|
||||
device_list_key = self.store.get_device_stream_token()
|
||||
groups_key = self.store.get_group_stream_token()
|
||||
|
||||
token = StreamToken(
|
||||
room_key=(
|
||||
@ -93,5 +96,6 @@ class EventSources(object):
|
||||
push_rules_key=push_rules_key,
|
||||
to_device_key=to_device_key,
|
||||
device_list_key=device_list_key,
|
||||
groups_key=groups_key,
|
||||
)
|
||||
defer.returnValue(token)
|
||||
|
@ -156,6 +156,11 @@ class EventID(DomainSpecificString):
|
||||
SIGIL = "$"
|
||||
|
||||
|
||||
class GroupID(DomainSpecificString):
|
||||
"""Structure representing a group ID."""
|
||||
SIGIL = "+"
|
||||
|
||||
|
||||
class StreamToken(
|
||||
namedtuple("Token", (
|
||||
"room_key",
|
||||
@ -166,6 +171,7 @@ class StreamToken(
|
||||
"push_rules_key",
|
||||
"to_device_key",
|
||||
"device_list_key",
|
||||
"groups_key",
|
||||
))
|
||||
):
|
||||
_SEPARATOR = "_"
|
||||
@ -204,6 +210,7 @@ class StreamToken(
|
||||
or (int(other.push_rules_key) < int(self.push_rules_key))
|
||||
or (int(other.to_device_key) < int(self.to_device_key))
|
||||
or (int(other.device_list_key) < int(self.device_list_key))
|
||||
or (int(other.groups_key) < int(self.groups_key))
|
||||
)
|
||||
|
||||
def copy_and_advance(self, key, new_value):
|
||||
|
@ -19,7 +19,7 @@ from twisted.internet import defer, reactor
|
||||
from .logcontext import (
|
||||
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
|
||||
)
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util import logcontext, unwrapFirstError
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
@ -53,6 +53,11 @@ class ObservableDeferred(object):
|
||||
|
||||
Cancelling or otherwise resolving an observer will not affect the original
|
||||
ObservableDeferred.
|
||||
|
||||
NB that it does not attempt to do anything with logcontexts; in general
|
||||
you should probably make_deferred_yieldable the deferreds
|
||||
returned by `observe`, and ensure that the original deferred runs its
|
||||
callbacks in the sentinel logcontext.
|
||||
"""
|
||||
|
||||
__slots__ = ["_deferred", "_observers", "_result"]
|
||||
@ -155,7 +160,7 @@ def concurrently_execute(func, args, limit):
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
return preserve_context_over_deferred(defer.gatherResults([
|
||||
return logcontext.make_deferred_yieldable(defer.gatherResults([
|
||||
preserve_fn(_concurrently_execute_inner)()
|
||||
for _ in xrange(limit)
|
||||
], consumeErrors=True)).addErrback(unwrapFirstError)
|
||||
@ -203,7 +208,26 @@ class Linearizer(object):
|
||||
except:
|
||||
logger.exception("Unexpected exception in Linearizer")
|
||||
|
||||
logger.info("Acquired linearizer lock %r for key %r", self.name, key)
|
||||
logger.info("Acquired linearizer lock %r for key %r", self.name,
|
||||
key)
|
||||
|
||||
# if the code holding the lock completes synchronously, then it
|
||||
# will recursively run the next claimant on the list. That can
|
||||
# relatively rapidly lead to stack exhaustion. This is essentially
|
||||
# the same problem as http://twistedmatrix.com/trac/ticket/9304.
|
||||
#
|
||||
# In order to break the cycle, we add a cheeky sleep(0) here to
|
||||
# ensure that we fall back to the reactor between each iteration.
|
||||
#
|
||||
# (There's no particular need for it to happen before we return
|
||||
# the context manager, but it needs to happen while we hold the
|
||||
# lock, and the context manager's exit code must be synchronous,
|
||||
# so actually this is the only sensible place.
|
||||
yield run_on_reactor()
|
||||
|
||||
else:
|
||||
logger.info("Acquired uncontended linearizer lock %r for key %r",
|
||||
self.name, key)
|
||||
|
||||
@contextmanager
|
||||
def _ctx_manager():
|
||||
@ -211,7 +235,8 @@ class Linearizer(object):
|
||||
yield
|
||||
finally:
|
||||
logger.info("Releasing linearizer lock %r for key %r", self.name, key)
|
||||
new_defer.callback(None)
|
||||
with PreserveLoggingContext():
|
||||
new_defer.callback(None)
|
||||
current_d = self.key_to_defer.get(key)
|
||||
if current_d is new_defer:
|
||||
self.key_to_defer.pop(key, None)
|
||||
|
51
synapse/util/logformatter.py
Normal file
51
synapse/util/logformatter.py
Normal file
@ -0,0 +1,51 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import StringIO
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
|
||||
class LogFormatter(logging.Formatter):
|
||||
"""Log formatter which gives more detail for exceptions
|
||||
|
||||
This is the same as the standard log formatter, except that when logging
|
||||
exceptions [typically via log.foo("msg", exc_info=1)], it prints the
|
||||
sequence that led up to the point at which the exception was caught.
|
||||
(Normally only stack frames between the point the exception was raised and
|
||||
where it was caught are logged).
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(LogFormatter, self).__init__(*args, **kwargs)
|
||||
|
||||
def formatException(self, ei):
|
||||
sio = StringIO.StringIO()
|
||||
(typ, val, tb) = ei
|
||||
|
||||
# log the stack above the exception capture point if possible, but
|
||||
# check that we actually have an f_back attribute to work around
|
||||
# https://twistedmatrix.com/trac/ticket/9305
|
||||
|
||||
if tb and hasattr(tb.tb_frame, 'f_back'):
|
||||
sio.write("Capture point (most recent call last):\n")
|
||||
traceback.print_stack(tb.tb_frame.f_back, None, sio)
|
||||
|
||||
traceback.print_exception(typ, val, tb, None, sio)
|
||||
s = sio.getvalue()
|
||||
sio.close()
|
||||
if s[-1:] == "\n":
|
||||
s = s[:-1]
|
||||
return s
|
42
synapse/util/module_loader.py
Normal file
42
synapse/util/module_loader.py
Normal file
@ -0,0 +1,42 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
|
||||
from synapse.config._base import ConfigError
|
||||
|
||||
|
||||
def load_module(provider):
|
||||
""" Loads a module with its config
|
||||
Take a dict with keys 'module' (the module name) and 'config'
|
||||
(the config dict).
|
||||
|
||||
Returns
|
||||
Tuple of (provider class, parsed config object)
|
||||
"""
|
||||
# We need to import the module, and then pick the class out of
|
||||
# that, so we split based on the last dot.
|
||||
module, clz = provider['module'].rsplit(".", 1)
|
||||
module = importlib.import_module(module)
|
||||
provider_class = getattr(module, clz)
|
||||
|
||||
try:
|
||||
provider_config = provider_class.parse_config(provider["config"])
|
||||
except Exception as e:
|
||||
raise ConfigError(
|
||||
"Failed to parse config for %r: %r" % (provider['module'], e)
|
||||
)
|
||||
|
||||
return provider_class, provider_config
|
@ -62,8 +62,6 @@ class ProfileTestCase(unittest.TestCase):
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
self.ratelimiter.send_message.return_value = (True, 0)
|
||||
|
||||
hs.handlers = ProfileHandlers(hs)
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
self.frank = UserID.from_string("@1234ABCD:test")
|
||||
@ -72,7 +70,7 @@ class ProfileTestCase(unittest.TestCase):
|
||||
|
||||
yield self.store.create_profile(self.frank.localpart)
|
||||
|
||||
self.handler = hs.get_handlers().profile_handler
|
||||
self.handler = hs.get_profile_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_my_name(self):
|
||||
|
@ -40,13 +40,14 @@ class RegistrationTestCase(unittest.TestCase):
|
||||
self.hs = yield setup_test_homeserver(
|
||||
handlers=None,
|
||||
http_client=None,
|
||||
expire_access_token=True)
|
||||
expire_access_token=True,
|
||||
profile_handler=Mock(),
|
||||
)
|
||||
self.macaroon_generator = Mock(
|
||||
generate_access_token=Mock(return_value='secret'))
|
||||
self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
|
||||
self.hs.handlers = RegistrationHandlers(self.hs)
|
||||
self.handler = self.hs.get_handlers().registration_handler
|
||||
self.hs.get_handlers().profile_handler = Mock()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
|
||||
|
@ -46,6 +46,7 @@ class ProfileTestCase(unittest.TestCase):
|
||||
resource_for_client=self.mock_resource,
|
||||
federation=Mock(),
|
||||
replication_layer=Mock(),
|
||||
profile_handler=self.mock_handler
|
||||
)
|
||||
|
||||
def _get_user_by_req(request=None, allow_guest=False):
|
||||
@ -53,8 +54,6 @@ class ProfileTestCase(unittest.TestCase):
|
||||
|
||||
hs.get_v1auth().get_user_by_req = _get_user_by_req
|
||||
|
||||
hs.get_handlers().profile_handler = self.mock_handler
|
||||
|
||||
profile.register_servlets(hs, self.mock_resource)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -1032,7 +1032,7 @@ class RoomMessageListTestCase(RestTestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_topo_token_is_accepted(self):
|
||||
token = "t1-0_0_0_0_0_0_0_0"
|
||||
token = "t1-0_0_0_0_0_0_0_0_0"
|
||||
(code, response) = yield self.mock_resource.trigger_get(
|
||||
"/rooms/%s/messages?access_token=x&from=%s" %
|
||||
(self.room_id, token))
|
||||
@ -1044,7 +1044,7 @@ class RoomMessageListTestCase(RestTestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_stream_token_is_accepted_for_fwd_pagianation(self):
|
||||
token = "s0_0_0_0_0_0_0_0"
|
||||
token = "s0_0_0_0_0_0_0_0_0"
|
||||
(code, response) = yield self.mock_resource.trigger_get(
|
||||
"/rooms/%s/messages?access_token=x&from=%s" %
|
||||
(self.room_id, token))
|
||||
|
@ -47,6 +47,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
|
||||
self.hs.get_device_handler = Mock(return_value=self.device_handler)
|
||||
self.hs.config.enable_registration = True
|
||||
self.hs.config.auto_join_rooms = []
|
||||
|
||||
# init the thing we're testing
|
||||
self.servlet = RegisterRestServlet(self.hs)
|
||||
|
@ -1,76 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
|
||||
|
||||
class EventInjector:
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.message_handler = hs.get_handlers().message_handler
|
||||
self.event_builder_factory = hs.get_event_builder_factory()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_room(self, room, user):
|
||||
builder = self.event_builder_factory.new({
|
||||
"type": EventTypes.Create,
|
||||
"sender": user.to_string(),
|
||||
"room_id": room.to_string(),
|
||||
"content": {},
|
||||
})
|
||||
|
||||
event, context = yield self.message_handler._create_new_client_event(
|
||||
builder
|
||||
)
|
||||
|
||||
yield self.store.persist_event(event, context)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def inject_room_member(self, room, user, membership):
|
||||
builder = self.event_builder_factory.new({
|
||||
"type": EventTypes.Member,
|
||||
"sender": user.to_string(),
|
||||
"state_key": user.to_string(),
|
||||
"room_id": room.to_string(),
|
||||
"content": {"membership": membership},
|
||||
})
|
||||
|
||||
event, context = yield self.message_handler._create_new_client_event(
|
||||
builder
|
||||
)
|
||||
|
||||
yield self.store.persist_event(event, context)
|
||||
|
||||
defer.returnValue(event)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def inject_message(self, room, user, body):
|
||||
builder = self.event_builder_factory.new({
|
||||
"type": EventTypes.Message,
|
||||
"sender": user.to_string(),
|
||||
"state_key": user.to_string(),
|
||||
"room_id": room.to_string(),
|
||||
"content": {"body": body, "msgtype": u"message"},
|
||||
})
|
||||
|
||||
event, context = yield self.message_handler._create_new_client_event(
|
||||
builder
|
||||
)
|
||||
|
||||
yield self.store.persist_event(event, context)
|
@ -12,8 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from synapse.util import async, logcontext
|
||||
from tests import unittest
|
||||
|
||||
from twisted.internet import defer
|
||||
@ -38,7 +37,28 @@ class LinearizerTestCase(unittest.TestCase):
|
||||
with cm1:
|
||||
self.assertFalse(d2.called)
|
||||
|
||||
self.assertTrue(d2.called)
|
||||
|
||||
with (yield d2):
|
||||
pass
|
||||
|
||||
def test_lots_of_queued_things(self):
|
||||
# we have one slow thing, and lots of fast things queued up behind it.
|
||||
# it should *not* explode the stack.
|
||||
linearizer = Linearizer()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def func(i, sleep=False):
|
||||
with logcontext.LoggingContext("func(%s)" % i) as lc:
|
||||
with (yield linearizer.queue("")):
|
||||
self.assertEqual(
|
||||
logcontext.LoggingContext.current_context(), lc)
|
||||
if sleep:
|
||||
yield async.sleep(0)
|
||||
|
||||
self.assertEqual(
|
||||
logcontext.LoggingContext.current_context(), lc)
|
||||
|
||||
func(0, sleep=True)
|
||||
for i in xrange(1, 100):
|
||||
func(i)
|
||||
|
||||
return func(1000)
|
||||
|
@ -94,3 +94,41 @@ class LoggingContextTestCase(unittest.TestCase):
|
||||
yield defer.succeed(None)
|
||||
|
||||
return self._test_preserve_fn(nonblocking_function)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_make_deferred_yieldable(self):
|
||||
# a function which retuns an incomplete deferred, but doesn't follow
|
||||
# the synapse rules.
|
||||
def blocking_function():
|
||||
d = defer.Deferred()
|
||||
reactor.callLater(0, d.callback, None)
|
||||
return d
|
||||
|
||||
sentinel_context = LoggingContext.current_context()
|
||||
|
||||
with LoggingContext() as context_one:
|
||||
context_one.test_key = "one"
|
||||
|
||||
d1 = logcontext.make_deferred_yieldable(blocking_function())
|
||||
# make sure that the context was reset by make_deferred_yieldable
|
||||
self.assertIs(LoggingContext.current_context(), sentinel_context)
|
||||
|
||||
yield d1
|
||||
|
||||
# now it should be restored
|
||||
self._check_test_key("one")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_make_deferred_yieldable_on_non_deferred(self):
|
||||
"""Check that make_deferred_yieldable does the right thing when its
|
||||
argument isn't actually a deferred"""
|
||||
|
||||
with LoggingContext() as context_one:
|
||||
context_one.test_key = "one"
|
||||
|
||||
d1 = logcontext.make_deferred_yieldable("bum")
|
||||
self._check_test_key("one")
|
||||
|
||||
r = yield d1
|
||||
self.assertEqual(r, "bum")
|
||||
self._check_test_key("one")
|
Loading…
x
Reference in New Issue
Block a user