mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-01-16 16:47:10 -05:00
Merge branch 'release-v0.24.0' of github.com:matrix-org/synapse
This commit is contained in:
commit
ffba978077
42
CHANGES.rst
42
CHANGES.rst
@ -1,3 +1,45 @@
|
|||||||
|
Changes in synapse v0.24.0 (2017-10-23)
|
||||||
|
=======================================
|
||||||
|
|
||||||
|
No changes since v0.24.0-rc1
|
||||||
|
|
||||||
|
|
||||||
|
Changes in synapse v0.24.0-rc1 (2017-10-19)
|
||||||
|
===========================================
|
||||||
|
|
||||||
|
Features:
|
||||||
|
|
||||||
|
* Add Group Server (PR #2352, #2363, #2374, #2377, #2378, #2382, #2410, #2426,
|
||||||
|
#2430, #2454, #2471, #2472, #2544)
|
||||||
|
* Add support for channel notifications (PR #2501)
|
||||||
|
* Add basic implementation of backup media store (PR #2538)
|
||||||
|
* Add config option to auto-join new users to rooms (PR #2545)
|
||||||
|
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
|
||||||
|
* Make the spam checker a module (PR #2474)
|
||||||
|
* Delete expired url cache data (PR #2478)
|
||||||
|
* Ignore incoming events for rooms that we have left (PR #2490)
|
||||||
|
* Allow spam checker to reject invites too (PR #2492)
|
||||||
|
* Add room creation checks to spam checker (PR #2495)
|
||||||
|
* Spam checking: add the invitee to user_may_invite (PR #2502)
|
||||||
|
* Process events from federation for different rooms in parallel (PR #2520)
|
||||||
|
* Allow error strings from spam checker (PR #2531)
|
||||||
|
* Improve error handling for missing files in config (PR #2551)
|
||||||
|
|
||||||
|
|
||||||
|
Bug fixes:
|
||||||
|
|
||||||
|
* Fix handling SERVFAILs when doing AAAA lookups for federation (PR #2477)
|
||||||
|
* Fix incompatibility with newer versions of ujson (PR #2483) Thanks to
|
||||||
|
@jeremycline!
|
||||||
|
* Fix notification keywords that start/end with non-word chars (PR #2500)
|
||||||
|
* Fix stack overflow and logcontexts from linearizer (PR #2532)
|
||||||
|
* Fix 500 error when fields missing from power_levels event (PR #2552)
|
||||||
|
* Fix 500 error when we get an error handling a PDU (PR #2553)
|
||||||
|
|
||||||
|
|
||||||
Changes in synapse v0.23.1 (2017-10-02)
|
Changes in synapse v0.23.1 (2017-10-02)
|
||||||
=======================================
|
=======================================
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ master_doc = 'index'
|
|||||||
|
|
||||||
# General information about the project.
|
# General information about the project.
|
||||||
project = u'Synapse'
|
project = u'Synapse'
|
||||||
copyright = u'2014, TNG'
|
copyright = u'Copyright 2014-2017 OpenMarket Ltd, 2017 Vector Creations Ltd, 2017 New Vector Ltd'
|
||||||
|
|
||||||
# The version info for the project you're documenting, acts as replacement for
|
# The version info for the project you're documenting, acts as replacement for
|
||||||
# |version| and |release|, also used in various other places throughout the
|
# |version| and |release|, also used in various other places throughout the
|
||||||
|
@ -376,10 +376,13 @@ class Porter(object):
|
|||||||
" VALUES (?,?,?,?,to_tsvector('english', ?),?,?)"
|
" VALUES (?,?,?,?,to_tsvector('english', ?),?,?)"
|
||||||
)
|
)
|
||||||
|
|
||||||
rows_dict = [
|
rows_dict = []
|
||||||
dict(zip(headers, row))
|
for row in rows:
|
||||||
for row in rows
|
d = dict(zip(headers, row))
|
||||||
]
|
if "\0" in d['value']:
|
||||||
|
logger.warn('dropping search row %s', d)
|
||||||
|
else:
|
||||||
|
rows_dict.append(d)
|
||||||
|
|
||||||
txn.executemany(sql, [
|
txn.executemany(sql, [
|
||||||
(
|
(
|
||||||
|
@ -16,4 +16,4 @@
|
|||||||
""" This is a reference implementation of a Matrix home server.
|
""" This is a reference implementation of a Matrix home server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.23.1"
|
__version__ = "0.24.0"
|
||||||
|
@ -40,6 +40,7 @@ from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
|
|||||||
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||||
from synapse.replication.slave.storage.room import RoomStore
|
from synapse.replication.slave.storage.room import RoomStore
|
||||||
|
from synapse.replication.slave.storage.groups import SlavedGroupServerStore
|
||||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||||
from synapse.rest.client.v1 import events
|
from synapse.rest.client.v1 import events
|
||||||
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
|
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
|
||||||
@ -69,6 +70,7 @@ class SynchrotronSlavedStore(
|
|||||||
SlavedRegistrationStore,
|
SlavedRegistrationStore,
|
||||||
SlavedFilteringStore,
|
SlavedFilteringStore,
|
||||||
SlavedPresenceStore,
|
SlavedPresenceStore,
|
||||||
|
SlavedGroupServerStore,
|
||||||
SlavedDeviceInboxStore,
|
SlavedDeviceInboxStore,
|
||||||
SlavedDeviceStore,
|
SlavedDeviceStore,
|
||||||
SlavedClientIpStore,
|
SlavedClientIpStore,
|
||||||
@ -403,6 +405,10 @@ class SyncReplicationHandler(ReplicationClientHandler):
|
|||||||
)
|
)
|
||||||
elif stream_name == "presence":
|
elif stream_name == "presence":
|
||||||
yield self.presence_handler.process_replication_rows(token, rows)
|
yield self.presence_handler.process_replication_rows(token, rows)
|
||||||
|
elif stream_name == "receipts":
|
||||||
|
self.notifier.on_new_event(
|
||||||
|
"groups_key", token, users=[row.user_id for row in rows],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def start(config_options):
|
def start(config_options):
|
||||||
|
@ -81,22 +81,38 @@ class Config(object):
|
|||||||
def abspath(file_path):
|
def abspath(file_path):
|
||||||
return os.path.abspath(file_path) if file_path else file_path
|
return os.path.abspath(file_path) if file_path else file_path
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def path_exists(cls, file_path):
|
||||||
|
"""Check if a file exists
|
||||||
|
|
||||||
|
Unlike os.path.exists, this throws an exception if there is an error
|
||||||
|
checking if the file exists (for example, if there is a perms error on
|
||||||
|
the parent dir).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the file exists; False if not.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
os.stat(file_path)
|
||||||
|
return True
|
||||||
|
except OSError as e:
|
||||||
|
if e.errno != errno.ENOENT:
|
||||||
|
raise e
|
||||||
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_file(cls, file_path, config_name):
|
def check_file(cls, file_path, config_name):
|
||||||
if file_path is None:
|
if file_path is None:
|
||||||
raise ConfigError(
|
raise ConfigError(
|
||||||
"Missing config for %s."
|
"Missing config for %s."
|
||||||
" You must specify a path for the config file. You can "
|
|
||||||
"do this with the -c or --config-path option. "
|
|
||||||
"Adding --generate-config along with --server-name "
|
|
||||||
"<server name> will generate a config file at the given path."
|
|
||||||
% (config_name,)
|
% (config_name,)
|
||||||
)
|
)
|
||||||
if not os.path.exists(file_path):
|
try:
|
||||||
|
os.stat(file_path)
|
||||||
|
except OSError as e:
|
||||||
raise ConfigError(
|
raise ConfigError(
|
||||||
"File %s config for %s doesn't exist."
|
"Error accessing file '%s' (config for %s): %s"
|
||||||
" Try running again with --generate-config"
|
% (file_path, config_name, e.strerror)
|
||||||
% (file_path, config_name,)
|
|
||||||
)
|
)
|
||||||
return cls.abspath(file_path)
|
return cls.abspath(file_path)
|
||||||
|
|
||||||
@ -248,7 +264,7 @@ class Config(object):
|
|||||||
" -c CONFIG-FILE\""
|
" -c CONFIG-FILE\""
|
||||||
)
|
)
|
||||||
(config_path,) = config_files
|
(config_path,) = config_files
|
||||||
if not os.path.exists(config_path):
|
if not cls.path_exists(config_path):
|
||||||
if config_args.keys_directory:
|
if config_args.keys_directory:
|
||||||
config_dir_path = config_args.keys_directory
|
config_dir_path = config_args.keys_directory
|
||||||
else:
|
else:
|
||||||
@ -261,7 +277,7 @@ class Config(object):
|
|||||||
"Must specify a server_name to a generate config for."
|
"Must specify a server_name to a generate config for."
|
||||||
" Pass -H server.name."
|
" Pass -H server.name."
|
||||||
)
|
)
|
||||||
if not os.path.exists(config_dir_path):
|
if not cls.path_exists(config_dir_path):
|
||||||
os.makedirs(config_dir_path)
|
os.makedirs(config_dir_path)
|
||||||
with open(config_path, "wb") as config_file:
|
with open(config_path, "wb") as config_file:
|
||||||
config_bytes, config = obj.generate_config(
|
config_bytes, config = obj.generate_config(
|
||||||
|
32
synapse/config/groups.py
Normal file
32
synapse/config/groups.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2017 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ._base import Config
|
||||||
|
|
||||||
|
|
||||||
|
class GroupsConfig(Config):
|
||||||
|
def read_config(self, config):
|
||||||
|
self.enable_group_creation = config.get("enable_group_creation", False)
|
||||||
|
self.group_creation_prefix = config.get("group_creation_prefix", "")
|
||||||
|
|
||||||
|
def default_config(self, **kwargs):
|
||||||
|
return """\
|
||||||
|
# Whether to allow non server admins to create groups on this server
|
||||||
|
enable_group_creation: false
|
||||||
|
|
||||||
|
# If enabled, non server admins can only create groups with local parts
|
||||||
|
# starting with this prefix
|
||||||
|
# group_creation_prefix: "unofficial/"
|
||||||
|
"""
|
@ -34,6 +34,8 @@ from .password_auth_providers import PasswordAuthProviderConfig
|
|||||||
from .emailconfig import EmailConfig
|
from .emailconfig import EmailConfig
|
||||||
from .workers import WorkerConfig
|
from .workers import WorkerConfig
|
||||||
from .push import PushConfig
|
from .push import PushConfig
|
||||||
|
from .spam_checker import SpamCheckerConfig
|
||||||
|
from .groups import GroupsConfig
|
||||||
|
|
||||||
|
|
||||||
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
|
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
|
||||||
@ -41,7 +43,8 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
|
|||||||
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
|
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
|
||||||
AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
|
AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
|
||||||
JWTConfig, PasswordConfig, EmailConfig,
|
JWTConfig, PasswordConfig, EmailConfig,
|
||||||
WorkerConfig, PasswordAuthProviderConfig, PushConfig,):
|
WorkerConfig, PasswordAuthProviderConfig, PushConfig,
|
||||||
|
SpamCheckerConfig, GroupsConfig,):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ -118,10 +118,9 @@ class KeyConfig(Config):
|
|||||||
signing_keys = self.read_file(signing_key_path, "signing_key")
|
signing_keys = self.read_file(signing_key_path, "signing_key")
|
||||||
try:
|
try:
|
||||||
return read_signing_keys(signing_keys.splitlines(True))
|
return read_signing_keys(signing_keys.splitlines(True))
|
||||||
except Exception:
|
except Exception as e:
|
||||||
raise ConfigError(
|
raise ConfigError(
|
||||||
"Error reading signing_key."
|
"Error reading signing_key: %s" % (str(e))
|
||||||
" Try running again with --generate-config"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def read_old_signing_keys(self, old_signing_keys):
|
def read_old_signing_keys(self, old_signing_keys):
|
||||||
@ -141,7 +140,8 @@ class KeyConfig(Config):
|
|||||||
|
|
||||||
def generate_files(self, config):
|
def generate_files(self, config):
|
||||||
signing_key_path = config["signing_key_path"]
|
signing_key_path = config["signing_key_path"]
|
||||||
if not os.path.exists(signing_key_path):
|
|
||||||
|
if not self.path_exists(signing_key_path):
|
||||||
with open(signing_key_path, "w") as signing_key_file:
|
with open(signing_key_path, "w") as signing_key_file:
|
||||||
key_id = "a_" + random_string(4)
|
key_id = "a_" + random_string(4)
|
||||||
write_signing_keys(
|
write_signing_keys(
|
||||||
|
@ -15,13 +15,15 @@
|
|||||||
|
|
||||||
from ._base import Config, ConfigError
|
from ._base import Config, ConfigError
|
||||||
|
|
||||||
import importlib
|
from synapse.util.module_loader import load_module
|
||||||
|
|
||||||
|
|
||||||
class PasswordAuthProviderConfig(Config):
|
class PasswordAuthProviderConfig(Config):
|
||||||
def read_config(self, config):
|
def read_config(self, config):
|
||||||
self.password_providers = []
|
self.password_providers = []
|
||||||
|
|
||||||
|
provider_config = None
|
||||||
|
|
||||||
# We want to be backwards compatible with the old `ldap_config`
|
# We want to be backwards compatible with the old `ldap_config`
|
||||||
# param.
|
# param.
|
||||||
ldap_config = config.get("ldap_config", {})
|
ldap_config = config.get("ldap_config", {})
|
||||||
@ -38,19 +40,15 @@ class PasswordAuthProviderConfig(Config):
|
|||||||
if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider":
|
if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider":
|
||||||
from ldap_auth_provider import LdapAuthProvider
|
from ldap_auth_provider import LdapAuthProvider
|
||||||
provider_class = LdapAuthProvider
|
provider_class = LdapAuthProvider
|
||||||
|
try:
|
||||||
|
provider_config = provider_class.parse_config(provider["config"])
|
||||||
|
except Exception as e:
|
||||||
|
raise ConfigError(
|
||||||
|
"Failed to parse config for %r: %r" % (provider['module'], e)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# We need to import the module, and then pick the class out of
|
(provider_class, provider_config) = load_module(provider)
|
||||||
# that, so we split based on the last dot.
|
|
||||||
module, clz = provider['module'].rsplit(".", 1)
|
|
||||||
module = importlib.import_module(module)
|
|
||||||
provider_class = getattr(module, clz)
|
|
||||||
|
|
||||||
try:
|
|
||||||
provider_config = provider_class.parse_config(provider["config"])
|
|
||||||
except Exception as e:
|
|
||||||
raise ConfigError(
|
|
||||||
"Failed to parse config for %r: %r" % (provider['module'], e)
|
|
||||||
)
|
|
||||||
self.password_providers.append((provider_class, provider_config))
|
self.password_providers.append((provider_class, provider_config))
|
||||||
|
|
||||||
def default_config(self, **kwargs):
|
def default_config(self, **kwargs):
|
||||||
|
@ -41,6 +41,8 @@ class RegistrationConfig(Config):
|
|||||||
self.allow_guest_access and config.get("invite_3pid_guest", False)
|
self.allow_guest_access and config.get("invite_3pid_guest", False)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.auto_join_rooms = config.get("auto_join_rooms", [])
|
||||||
|
|
||||||
def default_config(self, **kwargs):
|
def default_config(self, **kwargs):
|
||||||
registration_shared_secret = random_string_with_symbols(50)
|
registration_shared_secret = random_string_with_symbols(50)
|
||||||
|
|
||||||
@ -70,6 +72,11 @@ class RegistrationConfig(Config):
|
|||||||
- matrix.org
|
- matrix.org
|
||||||
- vector.im
|
- vector.im
|
||||||
- riot.im
|
- riot.im
|
||||||
|
|
||||||
|
# Users who register on this homeserver will automatically be joined
|
||||||
|
# to these rooms
|
||||||
|
#auto_join_rooms:
|
||||||
|
# - "#example:example.com"
|
||||||
""" % locals()
|
""" % locals()
|
||||||
|
|
||||||
def add_arguments(self, parser):
|
def add_arguments(self, parser):
|
||||||
|
@ -70,7 +70,19 @@ class ContentRepositoryConfig(Config):
|
|||||||
self.max_upload_size = self.parse_size(config["max_upload_size"])
|
self.max_upload_size = self.parse_size(config["max_upload_size"])
|
||||||
self.max_image_pixels = self.parse_size(config["max_image_pixels"])
|
self.max_image_pixels = self.parse_size(config["max_image_pixels"])
|
||||||
self.max_spider_size = self.parse_size(config["max_spider_size"])
|
self.max_spider_size = self.parse_size(config["max_spider_size"])
|
||||||
|
|
||||||
self.media_store_path = self.ensure_directory(config["media_store_path"])
|
self.media_store_path = self.ensure_directory(config["media_store_path"])
|
||||||
|
|
||||||
|
self.backup_media_store_path = config.get("backup_media_store_path")
|
||||||
|
if self.backup_media_store_path:
|
||||||
|
self.backup_media_store_path = self.ensure_directory(
|
||||||
|
self.backup_media_store_path
|
||||||
|
)
|
||||||
|
|
||||||
|
self.synchronous_backup_media_store = config.get(
|
||||||
|
"synchronous_backup_media_store", False
|
||||||
|
)
|
||||||
|
|
||||||
self.uploads_path = self.ensure_directory(config["uploads_path"])
|
self.uploads_path = self.ensure_directory(config["uploads_path"])
|
||||||
self.dynamic_thumbnails = config["dynamic_thumbnails"]
|
self.dynamic_thumbnails = config["dynamic_thumbnails"]
|
||||||
self.thumbnail_requirements = parse_thumbnail_requirements(
|
self.thumbnail_requirements = parse_thumbnail_requirements(
|
||||||
@ -115,6 +127,14 @@ class ContentRepositoryConfig(Config):
|
|||||||
# Directory where uploaded images and attachments are stored.
|
# Directory where uploaded images and attachments are stored.
|
||||||
media_store_path: "%(media_store)s"
|
media_store_path: "%(media_store)s"
|
||||||
|
|
||||||
|
# A secondary directory where uploaded images and attachments are
|
||||||
|
# stored as a backup.
|
||||||
|
# backup_media_store_path: "%(media_store)s"
|
||||||
|
|
||||||
|
# Whether to wait for successful write to backup media store before
|
||||||
|
# returning successfully.
|
||||||
|
# synchronous_backup_media_store: false
|
||||||
|
|
||||||
# Directory where in-progress uploads are stored.
|
# Directory where in-progress uploads are stored.
|
||||||
uploads_path: "%(uploads_path)s"
|
uploads_path: "%(uploads_path)s"
|
||||||
|
|
||||||
|
35
synapse/config/spam_checker.py
Normal file
35
synapse/config/spam_checker.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2017 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from synapse.util.module_loader import load_module
|
||||||
|
|
||||||
|
from ._base import Config
|
||||||
|
|
||||||
|
|
||||||
|
class SpamCheckerConfig(Config):
|
||||||
|
def read_config(self, config):
|
||||||
|
self.spam_checker = None
|
||||||
|
|
||||||
|
provider = config.get("spam_checker", None)
|
||||||
|
if provider is not None:
|
||||||
|
self.spam_checker = load_module(provider)
|
||||||
|
|
||||||
|
def default_config(self, **kwargs):
|
||||||
|
return """\
|
||||||
|
# spam_checker:
|
||||||
|
# module: "my_custom_project.SuperSpamChecker"
|
||||||
|
# config:
|
||||||
|
# example_option: 'things'
|
||||||
|
"""
|
@ -126,7 +126,7 @@ class TlsConfig(Config):
|
|||||||
tls_private_key_path = config["tls_private_key_path"]
|
tls_private_key_path = config["tls_private_key_path"]
|
||||||
tls_dh_params_path = config["tls_dh_params_path"]
|
tls_dh_params_path = config["tls_dh_params_path"]
|
||||||
|
|
||||||
if not os.path.exists(tls_private_key_path):
|
if not self.path_exists(tls_private_key_path):
|
||||||
with open(tls_private_key_path, "w") as private_key_file:
|
with open(tls_private_key_path, "w") as private_key_file:
|
||||||
tls_private_key = crypto.PKey()
|
tls_private_key = crypto.PKey()
|
||||||
tls_private_key.generate_key(crypto.TYPE_RSA, 2048)
|
tls_private_key.generate_key(crypto.TYPE_RSA, 2048)
|
||||||
@ -141,7 +141,7 @@ class TlsConfig(Config):
|
|||||||
crypto.FILETYPE_PEM, private_key_pem
|
crypto.FILETYPE_PEM, private_key_pem
|
||||||
)
|
)
|
||||||
|
|
||||||
if not os.path.exists(tls_certificate_path):
|
if not self.path_exists(tls_certificate_path):
|
||||||
with open(tls_certificate_path, "w") as certificate_file:
|
with open(tls_certificate_path, "w") as certificate_file:
|
||||||
cert = crypto.X509()
|
cert = crypto.X509()
|
||||||
subject = cert.get_subject()
|
subject = cert.get_subject()
|
||||||
@ -159,7 +159,7 @@ class TlsConfig(Config):
|
|||||||
|
|
||||||
certificate_file.write(cert_pem)
|
certificate_file.write(cert_pem)
|
||||||
|
|
||||||
if not os.path.exists(tls_dh_params_path):
|
if not self.path_exists(tls_dh_params_path):
|
||||||
if GENERATE_DH_PARAMS:
|
if GENERATE_DH_PARAMS:
|
||||||
subprocess.check_call([
|
subprocess.check_call([
|
||||||
"openssl", "dhparam",
|
"openssl", "dhparam",
|
||||||
|
@ -470,14 +470,14 @@ def _check_power_levels(event, auth_events):
|
|||||||
("invite", None),
|
("invite", None),
|
||||||
]
|
]
|
||||||
|
|
||||||
old_list = current_state.content.get("users")
|
old_list = current_state.content.get("users", {})
|
||||||
for user in set(old_list.keys() + user_list.keys()):
|
for user in set(old_list.keys() + user_list.keys()):
|
||||||
levels_to_check.append(
|
levels_to_check.append(
|
||||||
(user, "users")
|
(user, "users")
|
||||||
)
|
)
|
||||||
|
|
||||||
old_list = current_state.content.get("events")
|
old_list = current_state.content.get("events", {})
|
||||||
new_list = event.content.get("events")
|
new_list = event.content.get("events", {})
|
||||||
for ev_id in set(old_list.keys() + new_list.keys()):
|
for ev_id in set(old_list.keys() + new_list.keys()):
|
||||||
levels_to_check.append(
|
levels_to_check.append(
|
||||||
(ev_id, "events")
|
(ev_id, "events")
|
||||||
|
@ -14,25 +14,100 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
def check_event_for_spam(event):
|
class SpamChecker(object):
|
||||||
"""Checks if a given event is considered "spammy" by this server.
|
def __init__(self, hs):
|
||||||
|
self.spam_checker = None
|
||||||
|
|
||||||
If the server considers an event spammy, then it will be rejected if
|
module = None
|
||||||
sent by a local user. If it is sent by a user on another server, then
|
config = None
|
||||||
users receive a blank event.
|
try:
|
||||||
|
module, config = hs.config.spam_checker
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
Args:
|
if module is not None:
|
||||||
event (synapse.events.EventBase): the event to be checked
|
self.spam_checker = module(config=config)
|
||||||
|
|
||||||
Returns:
|
def check_event_for_spam(self, event):
|
||||||
bool: True if the event is spammy.
|
"""Checks if a given event is considered "spammy" by this server.
|
||||||
"""
|
|
||||||
if not hasattr(event, "content") or "body" not in event.content:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# for example:
|
If the server considers an event spammy, then it will be rejected if
|
||||||
#
|
sent by a local user. If it is sent by a user on another server, then
|
||||||
# if "the third flower is green" in event.content["body"]:
|
users receive a blank event.
|
||||||
# return True
|
|
||||||
|
|
||||||
return False
|
Args:
|
||||||
|
event (synapse.events.EventBase): the event to be checked
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the event is spammy.
|
||||||
|
"""
|
||||||
|
if self.spam_checker is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return self.spam_checker.check_event_for_spam(event)
|
||||||
|
|
||||||
|
def user_may_invite(self, inviter_userid, invitee_userid, room_id):
|
||||||
|
"""Checks if a given user may send an invite
|
||||||
|
|
||||||
|
If this method returns false, the invite will be rejected.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
userid (string): The sender's user ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the user may send an invite, otherwise False
|
||||||
|
"""
|
||||||
|
if self.spam_checker is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return self.spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
|
||||||
|
|
||||||
|
def user_may_create_room(self, userid):
|
||||||
|
"""Checks if a given user may create a room
|
||||||
|
|
||||||
|
If this method returns false, the creation request will be rejected.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
userid (string): The sender's user ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the user may create a room, otherwise False
|
||||||
|
"""
|
||||||
|
if self.spam_checker is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return self.spam_checker.user_may_create_room(userid)
|
||||||
|
|
||||||
|
def user_may_create_room_alias(self, userid, room_alias):
|
||||||
|
"""Checks if a given user may create a room alias
|
||||||
|
|
||||||
|
If this method returns false, the association request will be rejected.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
userid (string): The sender's user ID
|
||||||
|
room_alias (string): The alias to be created
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the user may create a room alias, otherwise False
|
||||||
|
"""
|
||||||
|
if self.spam_checker is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return self.spam_checker.user_may_create_room_alias(userid, room_alias)
|
||||||
|
|
||||||
|
def user_may_publish_room(self, userid, room_id):
|
||||||
|
"""Checks if a given user may publish a room to the directory
|
||||||
|
|
||||||
|
If this method returns false, the publish request will be rejected.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
userid (string): The sender's user ID
|
||||||
|
room_id (string): The ID of the room that would be published
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the user may publish the room, otherwise False
|
||||||
|
"""
|
||||||
|
if self.spam_checker is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return self.spam_checker.user_may_publish_room(userid, room_id)
|
||||||
|
@ -16,7 +16,6 @@ import logging
|
|||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.crypto.event_signing import check_event_content_hash
|
from synapse.crypto.event_signing import check_event_content_hash
|
||||||
from synapse.events import spamcheck
|
|
||||||
from synapse.events.utils import prune_event
|
from synapse.events.utils import prune_event
|
||||||
from synapse.util import unwrapFirstError, logcontext
|
from synapse.util import unwrapFirstError, logcontext
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
@ -26,7 +25,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class FederationBase(object):
|
class FederationBase(object):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
pass
|
self.spam_checker = hs.get_spam_checker()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
|
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
|
||||||
@ -144,7 +143,7 @@ class FederationBase(object):
|
|||||||
)
|
)
|
||||||
return redacted
|
return redacted
|
||||||
|
|
||||||
if spamcheck.check_event_for_spam(pdu):
|
if self.spam_checker.check_event_for_spam(pdu):
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Event contains spam, redacting %s: %s",
|
"Event contains spam, redacting %s: %s",
|
||||||
pdu.event_id, pdu.get_pdu_json()
|
pdu.event_id, pdu.get_pdu_json()
|
||||||
|
@ -12,14 +12,12 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from .federation_base import FederationBase
|
from .federation_base import FederationBase
|
||||||
from .units import Transaction, Edu
|
from .units import Transaction, Edu
|
||||||
|
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util import async
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.caches.response_cache import ResponseCache
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
from synapse.events import FrozenEvent
|
from synapse.events import FrozenEvent
|
||||||
@ -33,6 +31,9 @@ from synapse.crypto.event_signing import compute_event_signature
|
|||||||
import simplejson as json
|
import simplejson as json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
# when processing incoming transactions, we try to handle multiple rooms in
|
||||||
|
# parallel, up to this limit.
|
||||||
|
TRANSACTION_CONCURRENCY_LIMIT = 10
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -52,7 +53,8 @@ class FederationServer(FederationBase):
|
|||||||
|
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
self._server_linearizer = Linearizer("fed_server")
|
self._server_linearizer = async.Linearizer("fed_server")
|
||||||
|
self._transaction_linearizer = async.Linearizer("fed_txn_handler")
|
||||||
|
|
||||||
# We cache responses to state queries, as they take a while and often
|
# We cache responses to state queries, as they take a while and often
|
||||||
# come in waves.
|
# come in waves.
|
||||||
@ -109,25 +111,41 @@ class FederationServer(FederationBase):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def on_incoming_transaction(self, transaction_data):
|
def on_incoming_transaction(self, transaction_data):
|
||||||
|
# keep this as early as possible to make the calculated origin ts as
|
||||||
|
# accurate as possible.
|
||||||
|
request_time = self._clock.time_msec()
|
||||||
|
|
||||||
transaction = Transaction(**transaction_data)
|
transaction = Transaction(**transaction_data)
|
||||||
|
|
||||||
received_pdus_counter.inc_by(len(transaction.pdus))
|
if not transaction.transaction_id:
|
||||||
|
raise Exception("Transaction missing transaction_id")
|
||||||
for p in transaction.pdus:
|
if not transaction.origin:
|
||||||
if "unsigned" in p:
|
raise Exception("Transaction missing origin")
|
||||||
unsigned = p["unsigned"]
|
|
||||||
if "age" in unsigned:
|
|
||||||
p["age"] = unsigned["age"]
|
|
||||||
if "age" in p:
|
|
||||||
p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
|
|
||||||
del p["age"]
|
|
||||||
|
|
||||||
pdu_list = [
|
|
||||||
self.event_from_pdu_json(p) for p in transaction.pdus
|
|
||||||
]
|
|
||||||
|
|
||||||
logger.debug("[%s] Got transaction", transaction.transaction_id)
|
logger.debug("[%s] Got transaction", transaction.transaction_id)
|
||||||
|
|
||||||
|
# use a linearizer to ensure that we don't process the same transaction
|
||||||
|
# multiple times in parallel.
|
||||||
|
with (yield self._transaction_linearizer.queue(
|
||||||
|
(transaction.origin, transaction.transaction_id),
|
||||||
|
)):
|
||||||
|
result = yield self._handle_incoming_transaction(
|
||||||
|
transaction, request_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue(result)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _handle_incoming_transaction(self, transaction, request_time):
|
||||||
|
""" Process an incoming transaction and return the HTTP response
|
||||||
|
|
||||||
|
Args:
|
||||||
|
transaction (Transaction): incoming transaction
|
||||||
|
request_time (int): timestamp that the HTTP request arrived at
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[(int, object)]: http response code and body
|
||||||
|
"""
|
||||||
response = yield self.transaction_actions.have_responded(transaction)
|
response = yield self.transaction_actions.have_responded(transaction)
|
||||||
|
|
||||||
if response:
|
if response:
|
||||||
@ -140,42 +158,49 @@ class FederationServer(FederationBase):
|
|||||||
|
|
||||||
logger.debug("[%s] Transaction is new", transaction.transaction_id)
|
logger.debug("[%s] Transaction is new", transaction.transaction_id)
|
||||||
|
|
||||||
results = []
|
received_pdus_counter.inc_by(len(transaction.pdus))
|
||||||
|
|
||||||
for pdu in pdu_list:
|
pdus_by_room = {}
|
||||||
# check that it's actually being sent from a valid destination to
|
|
||||||
# workaround bug #1753 in 0.18.5 and 0.18.6
|
|
||||||
if transaction.origin != get_domain_from_id(pdu.event_id):
|
|
||||||
# We continue to accept join events from any server; this is
|
|
||||||
# necessary for the federation join dance to work correctly.
|
|
||||||
# (When we join over federation, the "helper" server is
|
|
||||||
# responsible for sending out the join event, rather than the
|
|
||||||
# origin. See bug #1893).
|
|
||||||
if not (
|
|
||||||
pdu.type == 'm.room.member' and
|
|
||||||
pdu.content and
|
|
||||||
pdu.content.get("membership", None) == 'join'
|
|
||||||
):
|
|
||||||
logger.info(
|
|
||||||
"Discarding PDU %s from invalid origin %s",
|
|
||||||
pdu.event_id, transaction.origin
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
"Accepting join PDU %s from %s",
|
|
||||||
pdu.event_id, transaction.origin
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
for p in transaction.pdus:
|
||||||
yield self._handle_received_pdu(transaction.origin, pdu)
|
if "unsigned" in p:
|
||||||
results.append({})
|
unsigned = p["unsigned"]
|
||||||
except FederationError as e:
|
if "age" in unsigned:
|
||||||
self.send_failure(e, transaction.origin)
|
p["age"] = unsigned["age"]
|
||||||
results.append({"error": str(e)})
|
if "age" in p:
|
||||||
except Exception as e:
|
p["age_ts"] = request_time - int(p["age"])
|
||||||
results.append({"error": str(e)})
|
del p["age"]
|
||||||
logger.exception("Failed to handle PDU")
|
|
||||||
|
event = self.event_from_pdu_json(p)
|
||||||
|
room_id = event.room_id
|
||||||
|
pdus_by_room.setdefault(room_id, []).append(event)
|
||||||
|
|
||||||
|
pdu_results = {}
|
||||||
|
|
||||||
|
# we can process different rooms in parallel (which is useful if they
|
||||||
|
# require callouts to other servers to fetch missing events), but
|
||||||
|
# impose a limit to avoid going too crazy with ram/cpu.
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def process_pdus_for_room(room_id):
|
||||||
|
logger.debug("Processing PDUs for %s", room_id)
|
||||||
|
for pdu in pdus_by_room[room_id]:
|
||||||
|
event_id = pdu.event_id
|
||||||
|
try:
|
||||||
|
yield self._handle_received_pdu(
|
||||||
|
transaction.origin, pdu
|
||||||
|
)
|
||||||
|
pdu_results[event_id] = {}
|
||||||
|
except FederationError as e:
|
||||||
|
logger.warn("Error handling PDU %s: %s", event_id, e)
|
||||||
|
pdu_results[event_id] = {"error": str(e)}
|
||||||
|
except Exception as e:
|
||||||
|
pdu_results[event_id] = {"error": str(e)}
|
||||||
|
logger.exception("Failed to handle PDU %s", event_id)
|
||||||
|
|
||||||
|
yield async.concurrently_execute(
|
||||||
|
process_pdus_for_room, pdus_by_room.keys(),
|
||||||
|
TRANSACTION_CONCURRENCY_LIMIT,
|
||||||
|
)
|
||||||
|
|
||||||
if hasattr(transaction, "edus"):
|
if hasattr(transaction, "edus"):
|
||||||
for edu in (Edu(**x) for x in transaction.edus):
|
for edu in (Edu(**x) for x in transaction.edus):
|
||||||
@ -185,17 +210,16 @@ class FederationServer(FederationBase):
|
|||||||
edu.content
|
edu.content
|
||||||
)
|
)
|
||||||
|
|
||||||
for failure in getattr(transaction, "pdu_failures", []):
|
pdu_failures = getattr(transaction, "pdu_failures", [])
|
||||||
logger.info("Got failure %r", failure)
|
for failure in pdu_failures:
|
||||||
|
logger.info("Got failure %r", failure)
|
||||||
logger.debug("Returning: %s", str(results))
|
|
||||||
|
|
||||||
response = {
|
response = {
|
||||||
"pdus": dict(zip(
|
"pdus": pdu_results,
|
||||||
(p.event_id for p in pdu_list), results
|
|
||||||
)),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.debug("Returning: %s", str(response))
|
||||||
|
|
||||||
yield self.transaction_actions.set_response(
|
yield self.transaction_actions.set_response(
|
||||||
transaction,
|
transaction,
|
||||||
200, response
|
200, response
|
||||||
@ -520,6 +544,30 @@ class FederationServer(FederationBase):
|
|||||||
Returns (Deferred): completes with None
|
Returns (Deferred): completes with None
|
||||||
Raises: FederationError if the signatures / hash do not match
|
Raises: FederationError if the signatures / hash do not match
|
||||||
"""
|
"""
|
||||||
|
# check that it's actually being sent from a valid destination to
|
||||||
|
# workaround bug #1753 in 0.18.5 and 0.18.6
|
||||||
|
if origin != get_domain_from_id(pdu.event_id):
|
||||||
|
# We continue to accept join events from any server; this is
|
||||||
|
# necessary for the federation join dance to work correctly.
|
||||||
|
# (When we join over federation, the "helper" server is
|
||||||
|
# responsible for sending out the join event, rather than the
|
||||||
|
# origin. See bug #1893).
|
||||||
|
if not (
|
||||||
|
pdu.type == 'm.room.member' and
|
||||||
|
pdu.content and
|
||||||
|
pdu.content.get("membership", None) == 'join'
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"Discarding PDU %s from invalid origin %s",
|
||||||
|
pdu.event_id, origin
|
||||||
|
)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Accepting join PDU %s from %s",
|
||||||
|
pdu.event_id, origin
|
||||||
|
)
|
||||||
|
|
||||||
# Check signature.
|
# Check signature.
|
||||||
try:
|
try:
|
||||||
pdu = yield self._check_sigs_and_hash(pdu)
|
pdu = yield self._check_sigs_and_hash(pdu)
|
||||||
|
@ -20,8 +20,8 @@ from .persistence import TransactionActions
|
|||||||
from .units import Transaction, Edu
|
from .units import Transaction, Edu
|
||||||
|
|
||||||
from synapse.api.errors import HttpResponseException
|
from synapse.api.errors import HttpResponseException
|
||||||
|
from synapse.util import logcontext
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.util.logcontext import preserve_context_over_fn, preserve_fn
|
|
||||||
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
|
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
|
||||||
from synapse.util.metrics import measure_func
|
from synapse.util.metrics import measure_func
|
||||||
from synapse.handlers.presence import format_user_presence_state, get_interested_remotes
|
from synapse.handlers.presence import format_user_presence_state, get_interested_remotes
|
||||||
@ -231,11 +231,9 @@ class TransactionQueue(object):
|
|||||||
(pdu, order)
|
(pdu, order)
|
||||||
)
|
)
|
||||||
|
|
||||||
preserve_context_over_fn(
|
self._attempt_new_transaction(destination)
|
||||||
self._attempt_new_transaction, destination
|
|
||||||
)
|
|
||||||
|
|
||||||
@preserve_fn # the caller should not yield on this
|
@logcontext.preserve_fn # the caller should not yield on this
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def send_presence(self, states):
|
def send_presence(self, states):
|
||||||
"""Send the new presence states to the appropriate destinations.
|
"""Send the new presence states to the appropriate destinations.
|
||||||
@ -299,7 +297,7 @@ class TransactionQueue(object):
|
|||||||
state.user_id: state for state in states
|
state.user_id: state for state in states
|
||||||
})
|
})
|
||||||
|
|
||||||
preserve_fn(self._attempt_new_transaction)(destination)
|
self._attempt_new_transaction(destination)
|
||||||
|
|
||||||
def send_edu(self, destination, edu_type, content, key=None):
|
def send_edu(self, destination, edu_type, content, key=None):
|
||||||
edu = Edu(
|
edu = Edu(
|
||||||
@ -321,9 +319,7 @@ class TransactionQueue(object):
|
|||||||
else:
|
else:
|
||||||
self.pending_edus_by_dest.setdefault(destination, []).append(edu)
|
self.pending_edus_by_dest.setdefault(destination, []).append(edu)
|
||||||
|
|
||||||
preserve_context_over_fn(
|
self._attempt_new_transaction(destination)
|
||||||
self._attempt_new_transaction, destination
|
|
||||||
)
|
|
||||||
|
|
||||||
def send_failure(self, failure, destination):
|
def send_failure(self, failure, destination):
|
||||||
if destination == self.server_name or destination == "localhost":
|
if destination == self.server_name or destination == "localhost":
|
||||||
@ -336,9 +332,7 @@ class TransactionQueue(object):
|
|||||||
destination, []
|
destination, []
|
||||||
).append(failure)
|
).append(failure)
|
||||||
|
|
||||||
preserve_context_over_fn(
|
self._attempt_new_transaction(destination)
|
||||||
self._attempt_new_transaction, destination
|
|
||||||
)
|
|
||||||
|
|
||||||
def send_device_messages(self, destination):
|
def send_device_messages(self, destination):
|
||||||
if destination == self.server_name or destination == "localhost":
|
if destination == self.server_name or destination == "localhost":
|
||||||
@ -347,15 +341,24 @@ class TransactionQueue(object):
|
|||||||
if not self.can_send_to(destination):
|
if not self.can_send_to(destination):
|
||||||
return
|
return
|
||||||
|
|
||||||
preserve_context_over_fn(
|
self._attempt_new_transaction(destination)
|
||||||
self._attempt_new_transaction, destination
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_current_token(self):
|
def get_current_token(self):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _attempt_new_transaction(self, destination):
|
def _attempt_new_transaction(self, destination):
|
||||||
|
"""Try to start a new transaction to this destination
|
||||||
|
|
||||||
|
If there is already a transaction in progress to this destination,
|
||||||
|
returns immediately. Otherwise kicks off the process of sending a
|
||||||
|
transaction in the background.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
destination (str):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
# list of (pending_pdu, deferred, order)
|
# list of (pending_pdu, deferred, order)
|
||||||
if destination in self.pending_transactions:
|
if destination in self.pending_transactions:
|
||||||
# XXX: pending_transactions can get stuck on by a never-ending
|
# XXX: pending_transactions can get stuck on by a never-ending
|
||||||
@ -368,6 +371,19 @@ class TransactionQueue(object):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
logger.debug("TX [%s] Starting transaction loop", destination)
|
||||||
|
|
||||||
|
# Drop the logcontext before starting the transaction. It doesn't
|
||||||
|
# really make sense to log all the outbound transactions against
|
||||||
|
# whatever path led us to this point: that's pretty arbitrary really.
|
||||||
|
#
|
||||||
|
# (this also means we can fire off _perform_transaction without
|
||||||
|
# yielding)
|
||||||
|
with logcontext.PreserveLoggingContext():
|
||||||
|
self._transaction_transmission_loop(destination)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _transaction_transmission_loop(self, destination):
|
||||||
pending_pdus = []
|
pending_pdus = []
|
||||||
try:
|
try:
|
||||||
self.pending_transactions[destination] = 1
|
self.pending_transactions[destination] = 1
|
||||||
|
@ -471,3 +471,384 @@ class TransportLayerClient(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(content)
|
defer.returnValue(content)
|
||||||
|
|
||||||
|
@log_function
|
||||||
|
def get_group_profile(self, destination, group_id, requester_user_id):
|
||||||
|
"""Get a group profile
|
||||||
|
"""
|
||||||
|
path = PREFIX + "/groups/%s/profile" % (group_id,)
|
||||||
|
|
||||||
|
return self.client.get_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
args={"requester_user_id": requester_user_id},
|
||||||
|
ignore_backoff=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@log_function
|
||||||
|
def get_group_summary(self, destination, group_id, requester_user_id):
|
||||||
|
"""Get a group summary
|
||||||
|
"""
|
||||||
|
path = PREFIX + "/groups/%s/summary" % (group_id,)
|
||||||
|
|
||||||
|
return self.client.get_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
args={"requester_user_id": requester_user_id},
|
||||||
|
ignore_backoff=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@log_function
|
||||||
|
def get_rooms_in_group(self, destination, group_id, requester_user_id):
|
||||||
|
"""Get all rooms in a group
|
||||||
|
"""
|
||||||
|
path = PREFIX + "/groups/%s/rooms" % (group_id,)
|
||||||
|
|
||||||
|
return self.client.get_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
args={"requester_user_id": requester_user_id},
|
||||||
|
ignore_backoff=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_room_to_group(self, destination, group_id, requester_user_id, room_id,
|
||||||
|
content):
|
||||||
|
"""Add a room to a group
|
||||||
|
"""
|
||||||
|
path = PREFIX + "/groups/%s/room/%s" % (group_id, room_id,)
|
||||||
|
|
||||||
|
return self.client.post_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
args={"requester_user_id": requester_user_id},
|
||||||
|
data=content,
|
||||||
|
ignore_backoff=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
|
||||||
|
"""Remove a room from a group
|
||||||
|
"""
|
||||||
|
path = PREFIX + "/groups/%s/room/%s" % (group_id, room_id,)
|
||||||
|
|
||||||
|
return self.client.delete_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
args={"requester_user_id": requester_user_id},
|
||||||
|
ignore_backoff=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@log_function
|
||||||
|
def get_users_in_group(self, destination, group_id, requester_user_id):
|
||||||
|
"""Get users in a group
|
||||||
|
"""
|
||||||
|
path = PREFIX + "/groups/%s/users" % (group_id,)
|
||||||
|
|
||||||
|
return self.client.get_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
args={"requester_user_id": requester_user_id},
|
||||||
|
ignore_backoff=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@log_function
|
||||||
|
def get_invited_users_in_group(self, destination, group_id, requester_user_id):
|
||||||
|
"""Get users that have been invited to a group
|
||||||
|
"""
|
||||||
|
path = PREFIX + "/groups/%s/invited_users" % (group_id,)
|
||||||
|
|
||||||
|
return self.client.get_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
args={"requester_user_id": requester_user_id},
|
||||||
|
ignore_backoff=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@log_function
|
||||||
|
def accept_group_invite(self, destination, group_id, user_id, content):
|
||||||
|
"""Accept a group invite
|
||||||
|
"""
|
||||||
|
path = PREFIX + "/groups/%s/users/%s/accept_invite" % (group_id, user_id)
|
||||||
|
|
||||||
|
return self.client.post_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
data=content,
|
||||||
|
ignore_backoff=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@log_function
|
||||||
|
def invite_to_group(self, destination, group_id, user_id, requester_user_id, content):
|
||||||
|
"""Invite a user to a group
|
||||||
|
"""
|
||||||
|
path = PREFIX + "/groups/%s/users/%s/invite" % (group_id, user_id)
|
||||||
|
|
||||||
|
return self.client.post_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
args={"requester_user_id": requester_user_id},
|
||||||
|
data=content,
|
||||||
|
ignore_backoff=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@log_function
|
||||||
|
def invite_to_group_notification(self, destination, group_id, user_id, content):
|
||||||
|
"""Sent by group server to inform a user's server that they have been
|
||||||
|
invited.
|
||||||
|
"""
|
||||||
|
|
||||||
|
path = PREFIX + "/groups/local/%s/users/%s/invite" % (group_id, user_id)
|
||||||
|
|
||||||
|
return self.client.post_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
data=content,
|
||||||
|
ignore_backoff=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@log_function
|
||||||
|
def remove_user_from_group(self, destination, group_id, requester_user_id,
|
||||||
|
user_id, content):
|
||||||
|
"""Remove a user fron a group
|
||||||
|
"""
|
||||||
|
path = PREFIX + "/groups/%s/users/%s/remove" % (group_id, user_id)
|
||||||
|
|
||||||
|
return self.client.post_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
args={"requester_user_id": requester_user_id},
|
||||||
|
data=content,
|
||||||
|
ignore_backoff=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@log_function
|
||||||
|
def remove_user_from_group_notification(self, destination, group_id, user_id,
|
||||||
|
content):
|
||||||
|
"""Sent by group server to inform a user's server that they have been
|
||||||
|
kicked from the group.
|
||||||
|
"""
|
||||||
|
|
||||||
|
path = PREFIX + "/groups/local/%s/users/%s/remove" % (group_id, user_id)
|
||||||
|
|
||||||
|
return self.client.post_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
data=content,
|
||||||
|
ignore_backoff=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@log_function
|
||||||
|
def renew_group_attestation(self, destination, group_id, user_id, content):
|
||||||
|
"""Sent by either a group server or a user's server to periodically update
|
||||||
|
the attestations
|
||||||
|
"""
|
||||||
|
|
||||||
|
path = PREFIX + "/groups/%s/renew_attestation/%s" % (group_id, user_id)
|
||||||
|
|
||||||
|
return self.client.post_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
data=content,
|
||||||
|
ignore_backoff=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@log_function
|
||||||
|
def update_group_summary_room(self, destination, group_id, user_id, room_id,
|
||||||
|
category_id, content):
|
||||||
|
"""Update a room entry in a group summary
|
||||||
|
"""
|
||||||
|
if category_id:
|
||||||
|
path = PREFIX + "/groups/%s/summary/categories/%s/rooms/%s" % (
|
||||||
|
group_id, category_id, room_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
path = PREFIX + "/groups/%s/summary/rooms/%s" % (group_id, room_id,)
|
||||||
|
|
||||||
|
return self.client.post_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
args={"requester_user_id": user_id},
|
||||||
|
data=content,
|
||||||
|
ignore_backoff=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@log_function
|
||||||
|
def delete_group_summary_room(self, destination, group_id, user_id, room_id,
|
||||||
|
category_id):
|
||||||
|
"""Delete a room entry in a group summary
|
||||||
|
"""
|
||||||
|
if category_id:
|
||||||
|
path = PREFIX + "/groups/%s/summary/categories/%s/rooms/%s" % (
|
||||||
|
group_id, category_id, room_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
path = PREFIX + "/groups/%s/summary/rooms/%s" % (group_id, room_id,)
|
||||||
|
|
||||||
|
return self.client.delete_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
args={"requester_user_id": user_id},
|
||||||
|
ignore_backoff=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@log_function
|
||||||
|
def get_group_categories(self, destination, group_id, requester_user_id):
|
||||||
|
"""Get all categories in a group
|
||||||
|
"""
|
||||||
|
path = PREFIX + "/groups/%s/categories" % (group_id,)
|
||||||
|
|
||||||
|
return self.client.get_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
args={"requester_user_id": requester_user_id},
|
||||||
|
ignore_backoff=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@log_function
|
||||||
|
def get_group_category(self, destination, group_id, requester_user_id, category_id):
|
||||||
|
"""Get category info in a group
|
||||||
|
"""
|
||||||
|
path = PREFIX + "/groups/%s/categories/%s" % (group_id, category_id,)
|
||||||
|
|
||||||
|
return self.client.get_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
args={"requester_user_id": requester_user_id},
|
||||||
|
ignore_backoff=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@log_function
|
||||||
|
def update_group_category(self, destination, group_id, requester_user_id, category_id,
|
||||||
|
content):
|
||||||
|
"""Update a category in a group
|
||||||
|
"""
|
||||||
|
path = PREFIX + "/groups/%s/categories/%s" % (group_id, category_id,)
|
||||||
|
|
||||||
|
return self.client.post_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
args={"requester_user_id": requester_user_id},
|
||||||
|
data=content,
|
||||||
|
ignore_backoff=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@log_function
|
||||||
|
def delete_group_category(self, destination, group_id, requester_user_id,
|
||||||
|
category_id):
|
||||||
|
"""Delete a category in a group
|
||||||
|
"""
|
||||||
|
path = PREFIX + "/groups/%s/categories/%s" % (group_id, category_id,)
|
||||||
|
|
||||||
|
return self.client.delete_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
args={"requester_user_id": requester_user_id},
|
||||||
|
ignore_backoff=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@log_function
|
||||||
|
def get_group_roles(self, destination, group_id, requester_user_id):
|
||||||
|
"""Get all roles in a group
|
||||||
|
"""
|
||||||
|
path = PREFIX + "/groups/%s/roles" % (group_id,)
|
||||||
|
|
||||||
|
return self.client.get_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
args={"requester_user_id": requester_user_id},
|
||||||
|
ignore_backoff=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@log_function
|
||||||
|
def get_group_role(self, destination, group_id, requester_user_id, role_id):
|
||||||
|
"""Get a roles info
|
||||||
|
"""
|
||||||
|
path = PREFIX + "/groups/%s/roles/%s" % (group_id, role_id,)
|
||||||
|
|
||||||
|
return self.client.get_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
args={"requester_user_id": requester_user_id},
|
||||||
|
ignore_backoff=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@log_function
|
||||||
|
def update_group_role(self, destination, group_id, requester_user_id, role_id,
|
||||||
|
content):
|
||||||
|
"""Update a role in a group
|
||||||
|
"""
|
||||||
|
path = PREFIX + "/groups/%s/roles/%s" % (group_id, role_id,)
|
||||||
|
|
||||||
|
return self.client.post_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
args={"requester_user_id": requester_user_id},
|
||||||
|
data=content,
|
||||||
|
ignore_backoff=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@log_function
|
||||||
|
def delete_group_role(self, destination, group_id, requester_user_id, role_id):
|
||||||
|
"""Delete a role in a group
|
||||||
|
"""
|
||||||
|
path = PREFIX + "/groups/%s/roles/%s" % (group_id, role_id,)
|
||||||
|
|
||||||
|
return self.client.delete_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
args={"requester_user_id": requester_user_id},
|
||||||
|
ignore_backoff=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@log_function
|
||||||
|
def update_group_summary_user(self, destination, group_id, requester_user_id,
|
||||||
|
user_id, role_id, content):
|
||||||
|
"""Update a users entry in a group
|
||||||
|
"""
|
||||||
|
if role_id:
|
||||||
|
path = PREFIX + "/groups/%s/summary/roles/%s/users/%s" % (
|
||||||
|
group_id, role_id, user_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
path = PREFIX + "/groups/%s/summary/users/%s" % (group_id, user_id,)
|
||||||
|
|
||||||
|
return self.client.post_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
args={"requester_user_id": requester_user_id},
|
||||||
|
data=content,
|
||||||
|
ignore_backoff=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@log_function
|
||||||
|
def delete_group_summary_user(self, destination, group_id, requester_user_id,
|
||||||
|
user_id, role_id):
|
||||||
|
"""Delete a users entry in a group
|
||||||
|
"""
|
||||||
|
if role_id:
|
||||||
|
path = PREFIX + "/groups/%s/summary/roles/%s/users/%s" % (
|
||||||
|
group_id, role_id, user_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
path = PREFIX + "/groups/%s/summary/users/%s" % (group_id, user_id,)
|
||||||
|
|
||||||
|
return self.client.delete_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
args={"requester_user_id": requester_user_id},
|
||||||
|
ignore_backoff=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def bulk_get_publicised_groups(self, destination, user_ids):
|
||||||
|
"""Get the groups a list of users are publicising
|
||||||
|
"""
|
||||||
|
|
||||||
|
path = PREFIX + "/get_groups_publicised"
|
||||||
|
|
||||||
|
content = {"user_ids": user_ids}
|
||||||
|
|
||||||
|
return self.client.post_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
data=content,
|
||||||
|
ignore_backoff=True,
|
||||||
|
)
|
||||||
|
@ -25,7 +25,7 @@ from synapse.http.servlet import (
|
|||||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||||
from synapse.util.versionstring import get_version_string
|
from synapse.util.versionstring import get_version_string
|
||||||
from synapse.util.logcontext import preserve_fn
|
from synapse.util.logcontext import preserve_fn
|
||||||
from synapse.types import ThirdPartyInstanceID
|
from synapse.types import ThirdPartyInstanceID, get_domain_from_id
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
@ -609,6 +609,493 @@ class FederationVersionServlet(BaseFederationServlet):
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
|
class FederationGroupsProfileServlet(BaseFederationServlet):
|
||||||
|
"""Get the basic profile of a group on behalf of a user
|
||||||
|
"""
|
||||||
|
PATH = "/groups/(?P<group_id>[^/]*)/profile$"
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, origin, content, query, group_id):
|
||||||
|
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||||
|
if get_domain_from_id(requester_user_id) != origin:
|
||||||
|
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||||
|
|
||||||
|
new_content = yield self.handler.get_group_profile(
|
||||||
|
group_id, requester_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, new_content))
|
||||||
|
|
||||||
|
|
||||||
|
class FederationGroupsSummaryServlet(BaseFederationServlet):
|
||||||
|
PATH = "/groups/(?P<group_id>[^/]*)/summary$"
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, origin, content, query, group_id):
|
||||||
|
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||||
|
if get_domain_from_id(requester_user_id) != origin:
|
||||||
|
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||||
|
|
||||||
|
new_content = yield self.handler.get_group_summary(
|
||||||
|
group_id, requester_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, new_content))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, origin, content, query, group_id):
|
||||||
|
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||||
|
if get_domain_from_id(requester_user_id) != origin:
|
||||||
|
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||||
|
|
||||||
|
new_content = yield self.handler.update_group_profile(
|
||||||
|
group_id, requester_user_id, content
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, new_content))
|
||||||
|
|
||||||
|
|
||||||
|
class FederationGroupsRoomsServlet(BaseFederationServlet):
|
||||||
|
"""Get the rooms in a group on behalf of a user
|
||||||
|
"""
|
||||||
|
PATH = "/groups/(?P<group_id>[^/]*)/rooms$"
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, origin, content, query, group_id):
|
||||||
|
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||||
|
if get_domain_from_id(requester_user_id) != origin:
|
||||||
|
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||||
|
|
||||||
|
new_content = yield self.handler.get_rooms_in_group(
|
||||||
|
group_id, requester_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, new_content))
|
||||||
|
|
||||||
|
|
||||||
|
class FederationGroupsAddRoomsServlet(BaseFederationServlet):
|
||||||
|
"""Add/remove room from group
|
||||||
|
"""
|
||||||
|
PATH = "/groups/(?P<group_id>[^/]*)/room/(?<room_id>)$"
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, origin, content, query, group_id, room_id):
|
||||||
|
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||||
|
if get_domain_from_id(requester_user_id) != origin:
|
||||||
|
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||||
|
|
||||||
|
new_content = yield self.handler.add_room_to_group(
|
||||||
|
group_id, requester_user_id, room_id, content
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, new_content))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_DELETE(self, origin, content, query, group_id, room_id):
|
||||||
|
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||||
|
if get_domain_from_id(requester_user_id) != origin:
|
||||||
|
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||||
|
|
||||||
|
new_content = yield self.handler.remove_room_from_group(
|
||||||
|
group_id, requester_user_id, room_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, new_content))
|
||||||
|
|
||||||
|
|
||||||
|
class FederationGroupsUsersServlet(BaseFederationServlet):
|
||||||
|
"""Get the users in a group on behalf of a user
|
||||||
|
"""
|
||||||
|
PATH = "/groups/(?P<group_id>[^/]*)/users$"
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, origin, content, query, group_id):
|
||||||
|
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||||
|
if get_domain_from_id(requester_user_id) != origin:
|
||||||
|
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||||
|
|
||||||
|
new_content = yield self.handler.get_users_in_group(
|
||||||
|
group_id, requester_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, new_content))
|
||||||
|
|
||||||
|
|
||||||
|
class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
|
||||||
|
"""Get the users that have been invited to a group
|
||||||
|
"""
|
||||||
|
PATH = "/groups/(?P<group_id>[^/]*)/invited_users$"
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, origin, content, query, group_id):
|
||||||
|
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||||
|
if get_domain_from_id(requester_user_id) != origin:
|
||||||
|
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||||
|
|
||||||
|
new_content = yield self.handler.get_invited_users_in_group(
|
||||||
|
group_id, requester_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, new_content))
|
||||||
|
|
||||||
|
|
||||||
|
class FederationGroupsInviteServlet(BaseFederationServlet):
|
||||||
|
"""Ask a group server to invite someone to the group
|
||||||
|
"""
|
||||||
|
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite$"
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, origin, content, query, group_id, user_id):
|
||||||
|
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||||
|
if get_domain_from_id(requester_user_id) != origin:
|
||||||
|
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||||
|
|
||||||
|
new_content = yield self.handler.invite_to_group(
|
||||||
|
group_id, user_id, requester_user_id, content,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, new_content))
|
||||||
|
|
||||||
|
|
||||||
|
class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
|
||||||
|
"""Accept an invitation from the group server
|
||||||
|
"""
|
||||||
|
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite$"
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, origin, content, query, group_id, user_id):
|
||||||
|
if get_domain_from_id(user_id) != origin:
|
||||||
|
raise SynapseError(403, "user_id doesn't match origin")
|
||||||
|
|
||||||
|
new_content = yield self.handler.accept_invite(
|
||||||
|
group_id, user_id, content,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, new_content))
|
||||||
|
|
||||||
|
|
||||||
|
class FederationGroupsRemoveUserServlet(BaseFederationServlet):
|
||||||
|
"""Leave or kick a user from the group
|
||||||
|
"""
|
||||||
|
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove$"
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, origin, content, query, group_id, user_id):
|
||||||
|
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||||
|
if get_domain_from_id(requester_user_id) != origin:
|
||||||
|
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||||
|
|
||||||
|
new_content = yield self.handler.remove_user_from_group(
|
||||||
|
group_id, user_id, requester_user_id, content,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, new_content))
|
||||||
|
|
||||||
|
|
||||||
|
class FederationGroupsLocalInviteServlet(BaseFederationServlet):
|
||||||
|
"""A group server has invited a local user
|
||||||
|
"""
|
||||||
|
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite$"
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, origin, content, query, group_id, user_id):
|
||||||
|
if get_domain_from_id(group_id) != origin:
|
||||||
|
raise SynapseError(403, "group_id doesn't match origin")
|
||||||
|
|
||||||
|
new_content = yield self.handler.on_invite(
|
||||||
|
group_id, user_id, content,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, new_content))
|
||||||
|
|
||||||
|
|
||||||
|
class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
|
||||||
|
"""A group server has removed a local user
|
||||||
|
"""
|
||||||
|
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove$"
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, origin, content, query, group_id, user_id):
|
||||||
|
if get_domain_from_id(group_id) != origin:
|
||||||
|
raise SynapseError(403, "user_id doesn't match origin")
|
||||||
|
|
||||||
|
new_content = yield self.handler.user_removed_from_group(
|
||||||
|
group_id, user_id, content,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, new_content))
|
||||||
|
|
||||||
|
|
||||||
|
class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
|
||||||
|
"""A group or user's server renews their attestation
|
||||||
|
"""
|
||||||
|
PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)$"
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, origin, content, query, group_id, user_id):
|
||||||
|
# We don't need to check auth here as we check the attestation signatures
|
||||||
|
|
||||||
|
new_content = yield self.handler.on_renew_attestation(
|
||||||
|
group_id, user_id, content
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, new_content))
|
||||||
|
|
||||||
|
|
||||||
|
class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
|
||||||
|
"""Add/remove a room from the group summary, with optional category.
|
||||||
|
|
||||||
|
Matches both:
|
||||||
|
- /groups/:group/summary/rooms/:room_id
|
||||||
|
- /groups/:group/summary/categories/:category/rooms/:room_id
|
||||||
|
"""
|
||||||
|
PATH = (
|
||||||
|
"/groups/(?P<group_id>[^/]*)/summary"
|
||||||
|
"(/categories/(?P<category_id>[^/]+))?"
|
||||||
|
"/rooms/(?P<room_id>[^/]*)$"
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, origin, content, query, group_id, category_id, room_id):
|
||||||
|
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||||
|
if get_domain_from_id(requester_user_id) != origin:
|
||||||
|
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||||
|
|
||||||
|
if category_id == "":
|
||||||
|
raise SynapseError(400, "category_id cannot be empty string")
|
||||||
|
|
||||||
|
resp = yield self.handler.update_group_summary_room(
|
||||||
|
group_id, requester_user_id,
|
||||||
|
room_id=room_id,
|
||||||
|
category_id=category_id,
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, resp))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_DELETE(self, origin, content, query, group_id, category_id, room_id):
|
||||||
|
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||||
|
if get_domain_from_id(requester_user_id) != origin:
|
||||||
|
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||||
|
|
||||||
|
if category_id == "":
|
||||||
|
raise SynapseError(400, "category_id cannot be empty string")
|
||||||
|
|
||||||
|
resp = yield self.handler.delete_group_summary_room(
|
||||||
|
group_id, requester_user_id,
|
||||||
|
room_id=room_id,
|
||||||
|
category_id=category_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, resp))
|
||||||
|
|
||||||
|
|
||||||
|
class FederationGroupsCategoriesServlet(BaseFederationServlet):
|
||||||
|
"""Get all categories for a group
|
||||||
|
"""
|
||||||
|
PATH = (
|
||||||
|
"/groups/(?P<group_id>[^/]*)/categories/$"
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, origin, content, query, group_id):
|
||||||
|
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||||
|
if get_domain_from_id(requester_user_id) != origin:
|
||||||
|
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||||
|
|
||||||
|
resp = yield self.handler.get_group_categories(
|
||||||
|
group_id, requester_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, resp))
|
||||||
|
|
||||||
|
|
||||||
|
class FederationGroupsCategoryServlet(BaseFederationServlet):
|
||||||
|
"""Add/remove/get a category in a group
|
||||||
|
"""
|
||||||
|
PATH = (
|
||||||
|
"/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, origin, content, query, group_id, category_id):
|
||||||
|
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||||
|
if get_domain_from_id(requester_user_id) != origin:
|
||||||
|
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||||
|
|
||||||
|
resp = yield self.handler.get_group_category(
|
||||||
|
group_id, requester_user_id, category_id
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, resp))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, origin, content, query, group_id, category_id):
|
||||||
|
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||||
|
if get_domain_from_id(requester_user_id) != origin:
|
||||||
|
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||||
|
|
||||||
|
if category_id == "":
|
||||||
|
raise SynapseError(400, "category_id cannot be empty string")
|
||||||
|
|
||||||
|
resp = yield self.handler.upsert_group_category(
|
||||||
|
group_id, requester_user_id, category_id, content,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, resp))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_DELETE(self, origin, content, query, group_id, category_id):
|
||||||
|
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||||
|
if get_domain_from_id(requester_user_id) != origin:
|
||||||
|
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||||
|
|
||||||
|
if category_id == "":
|
||||||
|
raise SynapseError(400, "category_id cannot be empty string")
|
||||||
|
|
||||||
|
resp = yield self.handler.delete_group_category(
|
||||||
|
group_id, requester_user_id, category_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, resp))
|
||||||
|
|
||||||
|
|
||||||
|
class FederationGroupsRolesServlet(BaseFederationServlet):
|
||||||
|
"""Get roles in a group
|
||||||
|
"""
|
||||||
|
PATH = (
|
||||||
|
"/groups/(?P<group_id>[^/]*)/roles/$"
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, origin, content, query, group_id):
|
||||||
|
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||||
|
if get_domain_from_id(requester_user_id) != origin:
|
||||||
|
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||||
|
|
||||||
|
resp = yield self.handler.get_group_roles(
|
||||||
|
group_id, requester_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, resp))
|
||||||
|
|
||||||
|
|
||||||
|
class FederationGroupsRoleServlet(BaseFederationServlet):
|
||||||
|
"""Add/remove/get a role in a group
|
||||||
|
"""
|
||||||
|
PATH = (
|
||||||
|
"/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$"
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, origin, content, query, group_id, role_id):
|
||||||
|
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||||
|
if get_domain_from_id(requester_user_id) != origin:
|
||||||
|
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||||
|
|
||||||
|
resp = yield self.handler.get_group_role(
|
||||||
|
group_id, requester_user_id, role_id
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, resp))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, origin, content, query, group_id, role_id):
|
||||||
|
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||||
|
if get_domain_from_id(requester_user_id) != origin:
|
||||||
|
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||||
|
|
||||||
|
if role_id == "":
|
||||||
|
raise SynapseError(400, "role_id cannot be empty string")
|
||||||
|
|
||||||
|
resp = yield self.handler.update_group_role(
|
||||||
|
group_id, requester_user_id, role_id, content,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, resp))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_DELETE(self, origin, content, query, group_id, role_id):
|
||||||
|
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||||
|
if get_domain_from_id(requester_user_id) != origin:
|
||||||
|
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||||
|
|
||||||
|
if role_id == "":
|
||||||
|
raise SynapseError(400, "role_id cannot be empty string")
|
||||||
|
|
||||||
|
resp = yield self.handler.delete_group_role(
|
||||||
|
group_id, requester_user_id, role_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, resp))
|
||||||
|
|
||||||
|
|
||||||
|
class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
|
||||||
|
"""Add/remove a user from the group summary, with optional role.
|
||||||
|
|
||||||
|
Matches both:
|
||||||
|
- /groups/:group/summary/users/:user_id
|
||||||
|
- /groups/:group/summary/roles/:role/users/:user_id
|
||||||
|
"""
|
||||||
|
PATH = (
|
||||||
|
"/groups/(?P<group_id>[^/]*)/summary"
|
||||||
|
"(/roles/(?P<role_id>[^/]+))?"
|
||||||
|
"/users/(?P<user_id>[^/]*)$"
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, origin, content, query, group_id, role_id, user_id):
|
||||||
|
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||||
|
if get_domain_from_id(requester_user_id) != origin:
|
||||||
|
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||||
|
|
||||||
|
if role_id == "":
|
||||||
|
raise SynapseError(400, "role_id cannot be empty string")
|
||||||
|
|
||||||
|
resp = yield self.handler.update_group_summary_user(
|
||||||
|
group_id, requester_user_id,
|
||||||
|
user_id=user_id,
|
||||||
|
role_id=role_id,
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, resp))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_DELETE(self, origin, content, query, group_id, role_id, user_id):
|
||||||
|
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||||
|
if get_domain_from_id(requester_user_id) != origin:
|
||||||
|
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||||
|
|
||||||
|
if role_id == "":
|
||||||
|
raise SynapseError(400, "role_id cannot be empty string")
|
||||||
|
|
||||||
|
resp = yield self.handler.delete_group_summary_user(
|
||||||
|
group_id, requester_user_id,
|
||||||
|
user_id=user_id,
|
||||||
|
role_id=role_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, resp))
|
||||||
|
|
||||||
|
|
||||||
|
class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
|
||||||
|
"""Get roles in a group
|
||||||
|
"""
|
||||||
|
PATH = (
|
||||||
|
"/get_groups_publicised$"
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, origin, content, query):
|
||||||
|
resp = yield self.handler.bulk_get_publicised_groups(
|
||||||
|
content["user_ids"], proxy=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, resp))
|
||||||
|
|
||||||
|
|
||||||
FEDERATION_SERVLET_CLASSES = (
|
FEDERATION_SERVLET_CLASSES = (
|
||||||
FederationSendServlet,
|
FederationSendServlet,
|
||||||
FederationPullServlet,
|
FederationPullServlet,
|
||||||
@ -635,10 +1122,40 @@ FEDERATION_SERVLET_CLASSES = (
|
|||||||
FederationVersionServlet,
|
FederationVersionServlet,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
ROOM_LIST_CLASSES = (
|
ROOM_LIST_CLASSES = (
|
||||||
PublicRoomList,
|
PublicRoomList,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
GROUP_SERVER_SERVLET_CLASSES = (
|
||||||
|
FederationGroupsProfileServlet,
|
||||||
|
FederationGroupsSummaryServlet,
|
||||||
|
FederationGroupsRoomsServlet,
|
||||||
|
FederationGroupsUsersServlet,
|
||||||
|
FederationGroupsInvitedUsersServlet,
|
||||||
|
FederationGroupsInviteServlet,
|
||||||
|
FederationGroupsAcceptInviteServlet,
|
||||||
|
FederationGroupsRemoveUserServlet,
|
||||||
|
FederationGroupsSummaryRoomsServlet,
|
||||||
|
FederationGroupsCategoriesServlet,
|
||||||
|
FederationGroupsCategoryServlet,
|
||||||
|
FederationGroupsRolesServlet,
|
||||||
|
FederationGroupsRoleServlet,
|
||||||
|
FederationGroupsSummaryUsersServlet,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
GROUP_LOCAL_SERVLET_CLASSES = (
|
||||||
|
FederationGroupsLocalInviteServlet,
|
||||||
|
FederationGroupsRemoveLocalUserServlet,
|
||||||
|
FederationGroupsBulkPublicisedServlet,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
GROUP_ATTESTATION_SERVLET_CLASSES = (
|
||||||
|
FederationGroupsRenewAttestaionServlet,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs, resource, authenticator, ratelimiter):
|
def register_servlets(hs, resource, authenticator, ratelimiter):
|
||||||
for servletclass in FEDERATION_SERVLET_CLASSES:
|
for servletclass in FEDERATION_SERVLET_CLASSES:
|
||||||
@ -656,3 +1173,27 @@ def register_servlets(hs, resource, authenticator, ratelimiter):
|
|||||||
ratelimiter=ratelimiter,
|
ratelimiter=ratelimiter,
|
||||||
server_name=hs.hostname,
|
server_name=hs.hostname,
|
||||||
).register(resource)
|
).register(resource)
|
||||||
|
|
||||||
|
for servletclass in GROUP_SERVER_SERVLET_CLASSES:
|
||||||
|
servletclass(
|
||||||
|
handler=hs.get_groups_server_handler(),
|
||||||
|
authenticator=authenticator,
|
||||||
|
ratelimiter=ratelimiter,
|
||||||
|
server_name=hs.hostname,
|
||||||
|
).register(resource)
|
||||||
|
|
||||||
|
for servletclass in GROUP_LOCAL_SERVLET_CLASSES:
|
||||||
|
servletclass(
|
||||||
|
handler=hs.get_groups_local_handler(),
|
||||||
|
authenticator=authenticator,
|
||||||
|
ratelimiter=ratelimiter,
|
||||||
|
server_name=hs.hostname,
|
||||||
|
).register(resource)
|
||||||
|
|
||||||
|
for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES:
|
||||||
|
servletclass(
|
||||||
|
handler=hs.get_groups_attestation_renewer(),
|
||||||
|
authenticator=authenticator,
|
||||||
|
ratelimiter=ratelimiter,
|
||||||
|
server_name=hs.hostname,
|
||||||
|
).register(resource)
|
||||||
|
0
synapse/groups/__init__.py
Normal file
0
synapse/groups/__init__.py
Normal file
151
synapse/groups/attestations.py
Normal file
151
synapse/groups/attestations.py
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2017 Vector Creations Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api.errors import SynapseError
|
||||||
|
from synapse.types import get_domain_from_id
|
||||||
|
from synapse.util.logcontext import preserve_fn
|
||||||
|
|
||||||
|
from signedjson.sign import sign_json
|
||||||
|
|
||||||
|
|
||||||
|
# Default validity duration for new attestations we create
|
||||||
|
DEFAULT_ATTESTATION_LENGTH_MS = 3 * 24 * 60 * 60 * 1000
|
||||||
|
|
||||||
|
# Start trying to update our attestations when they come this close to expiring
|
||||||
|
UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000
|
||||||
|
|
||||||
|
|
||||||
|
class GroupAttestationSigning(object):
|
||||||
|
"""Creates and verifies group attestations.
|
||||||
|
"""
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.keyring = hs.get_keyring()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.server_name = hs.hostname
|
||||||
|
self.signing_key = hs.config.signing_key[0]
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def verify_attestation(self, attestation, group_id, user_id, server_name=None):
|
||||||
|
"""Verifies that the given attestation matches the given parameters.
|
||||||
|
|
||||||
|
An optional server_name can be supplied to explicitly set which server's
|
||||||
|
signature is expected. Otherwise assumes that either the group_id or user_id
|
||||||
|
is local and uses the other's server as the one to check.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not server_name:
|
||||||
|
if get_domain_from_id(group_id) == self.server_name:
|
||||||
|
server_name = get_domain_from_id(user_id)
|
||||||
|
elif get_domain_from_id(user_id) == self.server_name:
|
||||||
|
server_name = get_domain_from_id(group_id)
|
||||||
|
else:
|
||||||
|
raise Exception("Expected either group_id or user_id to be local")
|
||||||
|
|
||||||
|
if user_id != attestation["user_id"]:
|
||||||
|
raise SynapseError(400, "Attestation has incorrect user_id")
|
||||||
|
|
||||||
|
if group_id != attestation["group_id"]:
|
||||||
|
raise SynapseError(400, "Attestation has incorrect group_id")
|
||||||
|
valid_until_ms = attestation["valid_until_ms"]
|
||||||
|
|
||||||
|
# TODO: We also want to check that *new* attestations that people give
|
||||||
|
# us to store are valid for at least a little while.
|
||||||
|
if valid_until_ms < self.clock.time_msec():
|
||||||
|
raise SynapseError(400, "Attestation expired")
|
||||||
|
|
||||||
|
yield self.keyring.verify_json_for_server(server_name, attestation)
|
||||||
|
|
||||||
|
def create_attestation(self, group_id, user_id):
|
||||||
|
"""Create an attestation for the group_id and user_id with default
|
||||||
|
validity length.
|
||||||
|
"""
|
||||||
|
return sign_json({
|
||||||
|
"group_id": group_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"valid_until_ms": self.clock.time_msec() + DEFAULT_ATTESTATION_LENGTH_MS,
|
||||||
|
}, self.server_name, self.signing_key)
|
||||||
|
|
||||||
|
|
||||||
|
class GroupAttestionRenewer(object):
|
||||||
|
"""Responsible for sending and receiving attestation updates.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.assestations = hs.get_groups_attestation_signing()
|
||||||
|
self.transport_client = hs.get_federation_transport_client()
|
||||||
|
self.is_mine_id = hs.is_mine_id
|
||||||
|
self.attestations = hs.get_groups_attestation_signing()
|
||||||
|
|
||||||
|
self._renew_attestations_loop = self.clock.looping_call(
|
||||||
|
self._renew_attestations, 30 * 60 * 1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_renew_attestation(self, group_id, user_id, content):
|
||||||
|
"""When a remote updates an attestation
|
||||||
|
"""
|
||||||
|
attestation = content["attestation"]
|
||||||
|
|
||||||
|
if not self.is_mine_id(group_id) and not self.is_mine_id(user_id):
|
||||||
|
raise SynapseError(400, "Neither user not group are on this server")
|
||||||
|
|
||||||
|
yield self.attestations.verify_attestation(
|
||||||
|
attestation,
|
||||||
|
user_id=user_id,
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self.store.update_remote_attestion(group_id, user_id, attestation)
|
||||||
|
|
||||||
|
defer.returnValue({})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _renew_attestations(self):
|
||||||
|
"""Called periodically to check if we need to update any of our attestations
|
||||||
|
"""
|
||||||
|
|
||||||
|
now = self.clock.time_msec()
|
||||||
|
|
||||||
|
rows = yield self.store.get_attestations_need_renewals(
|
||||||
|
now + UPDATE_ATTESTATION_TIME_MS
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _renew_attestation(group_id, user_id):
|
||||||
|
attestation = self.attestations.create_attestation(group_id, user_id)
|
||||||
|
|
||||||
|
if self.is_mine_id(group_id):
|
||||||
|
destination = get_domain_from_id(user_id)
|
||||||
|
else:
|
||||||
|
destination = get_domain_from_id(group_id)
|
||||||
|
|
||||||
|
yield self.transport_client.renew_group_attestation(
|
||||||
|
destination, group_id, user_id,
|
||||||
|
content={"attestation": attestation},
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self.store.update_attestation_renewal(
|
||||||
|
group_id, user_id, attestation
|
||||||
|
)
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
group_id = row["group_id"]
|
||||||
|
user_id = row["user_id"]
|
||||||
|
|
||||||
|
preserve_fn(_renew_attestation)(group_id, user_id)
|
803
synapse/groups/groups_server.py
Normal file
803
synapse/groups/groups_server.py
Normal file
@ -0,0 +1,803 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2017 Vector Creations Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api.errors import SynapseError
|
||||||
|
from synapse.types import UserID, get_domain_from_id, RoomID, GroupID
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import urllib
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Allow users to "knock" or simpkly join depending on rules
|
||||||
|
# TODO: Federation admin APIs
|
||||||
|
# TODO: is_priveged flag to users and is_public to users and rooms
|
||||||
|
# TODO: Audit log for admins (profile updates, membership changes, users who tried
|
||||||
|
# to join but were rejected, etc)
|
||||||
|
# TODO: Flairs
|
||||||
|
|
||||||
|
|
||||||
|
class GroupsServerHandler(object):
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.hs = hs
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.room_list_handler = hs.get_room_list_handler()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.keyring = hs.get_keyring()
|
||||||
|
self.is_mine_id = hs.is_mine_id
|
||||||
|
self.signing_key = hs.config.signing_key[0]
|
||||||
|
self.server_name = hs.hostname
|
||||||
|
self.attestations = hs.get_groups_attestation_signing()
|
||||||
|
self.transport_client = hs.get_federation_transport_client()
|
||||||
|
self.profile_handler = hs.get_profile_handler()
|
||||||
|
|
||||||
|
# Ensure attestations get renewed
|
||||||
|
hs.get_groups_attestation_renewer()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def check_group_is_ours(self, group_id, and_exists=False, and_is_admin=None):
|
||||||
|
"""Check that the group is ours, and optionally if it exists.
|
||||||
|
|
||||||
|
If group does exist then return group.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group_id (str)
|
||||||
|
and_exists (bool): whether to also check if group exists
|
||||||
|
and_is_admin (str): whether to also check if given str is a user_id
|
||||||
|
that is an admin
|
||||||
|
"""
|
||||||
|
if not self.is_mine_id(group_id):
|
||||||
|
raise SynapseError(400, "Group not on this server")
|
||||||
|
|
||||||
|
group = yield self.store.get_group(group_id)
|
||||||
|
if and_exists and not group:
|
||||||
|
raise SynapseError(404, "Unknown group")
|
||||||
|
|
||||||
|
if and_is_admin:
|
||||||
|
is_admin = yield self.store.is_user_admin_in_group(group_id, and_is_admin)
|
||||||
|
if not is_admin:
|
||||||
|
raise SynapseError(403, "User is not admin in group")
|
||||||
|
|
||||||
|
defer.returnValue(group)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_group_summary(self, group_id, requester_user_id):
|
||||||
|
"""Get the summary for a group as seen by requester_user_id.
|
||||||
|
|
||||||
|
The group summary consists of the profile of the room, and a curated
|
||||||
|
list of users and rooms. These list *may* be organised by role/category.
|
||||||
|
The roles/categories are ordered, and so are the users/rooms within them.
|
||||||
|
|
||||||
|
A user/room may appear in multiple roles/categories.
|
||||||
|
"""
|
||||||
|
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||||
|
|
||||||
|
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
|
||||||
|
|
||||||
|
profile = yield self.get_group_profile(group_id, requester_user_id)
|
||||||
|
|
||||||
|
users, roles = yield self.store.get_users_for_summary_by_role(
|
||||||
|
group_id, include_private=is_user_in_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Add profiles to users
|
||||||
|
|
||||||
|
rooms, categories = yield self.store.get_rooms_for_summary_by_category(
|
||||||
|
group_id, include_private=is_user_in_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
for room_entry in rooms:
|
||||||
|
room_id = room_entry["room_id"]
|
||||||
|
joined_users = yield self.store.get_users_in_room(room_id)
|
||||||
|
entry = yield self.room_list_handler.generate_room_entry(
|
||||||
|
room_id, len(joined_users),
|
||||||
|
with_alias=False, allow_private=True,
|
||||||
|
)
|
||||||
|
entry = dict(entry) # so we don't change whats cached
|
||||||
|
entry.pop("room_id", None)
|
||||||
|
|
||||||
|
room_entry["profile"] = entry
|
||||||
|
|
||||||
|
rooms.sort(key=lambda e: e.get("order", 0))
|
||||||
|
|
||||||
|
for entry in users:
|
||||||
|
user_id = entry["user_id"]
|
||||||
|
|
||||||
|
if not self.is_mine_id(requester_user_id):
|
||||||
|
attestation = yield self.store.get_remote_attestation(group_id, user_id)
|
||||||
|
if not attestation:
|
||||||
|
continue
|
||||||
|
|
||||||
|
entry["attestation"] = attestation
|
||||||
|
else:
|
||||||
|
entry["attestation"] = self.attestations.create_attestation(
|
||||||
|
group_id, user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
user_profile = yield self.profile_handler.get_profile_from_cache(user_id)
|
||||||
|
entry.update(user_profile)
|
||||||
|
|
||||||
|
users.sort(key=lambda e: e.get("order", 0))
|
||||||
|
|
||||||
|
membership_info = yield self.store.get_users_membership_info_in_group(
|
||||||
|
group_id, requester_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue({
|
||||||
|
"profile": profile,
|
||||||
|
"users_section": {
|
||||||
|
"users": users,
|
||||||
|
"roles": roles,
|
||||||
|
"total_user_count_estimate": 0, # TODO
|
||||||
|
},
|
||||||
|
"rooms_section": {
|
||||||
|
"rooms": rooms,
|
||||||
|
"categories": categories,
|
||||||
|
"total_room_count_estimate": 0, # TODO
|
||||||
|
},
|
||||||
|
"user": membership_info,
|
||||||
|
})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def update_group_summary_room(self, group_id, user_id, room_id, category_id, content):
|
||||||
|
"""Add/update a room to the group summary
|
||||||
|
"""
|
||||||
|
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
|
||||||
|
|
||||||
|
RoomID.from_string(room_id) # Ensure valid room id
|
||||||
|
|
||||||
|
order = content.get("order", None)
|
||||||
|
|
||||||
|
is_public = _parse_visibility_from_contents(content)
|
||||||
|
|
||||||
|
yield self.store.add_room_to_summary(
|
||||||
|
group_id=group_id,
|
||||||
|
room_id=room_id,
|
||||||
|
category_id=category_id,
|
||||||
|
order=order,
|
||||||
|
is_public=is_public,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue({})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def delete_group_summary_room(self, group_id, user_id, room_id, category_id):
|
||||||
|
"""Remove a room from the summary
|
||||||
|
"""
|
||||||
|
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
|
||||||
|
|
||||||
|
yield self.store.remove_room_from_summary(
|
||||||
|
group_id=group_id,
|
||||||
|
room_id=room_id,
|
||||||
|
category_id=category_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue({})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_group_categories(self, group_id, user_id):
|
||||||
|
"""Get all categories in a group (as seen by user)
|
||||||
|
"""
|
||||||
|
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||||
|
|
||||||
|
categories = yield self.store.get_group_categories(
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
defer.returnValue({"categories": categories})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_group_category(self, group_id, user_id, category_id):
|
||||||
|
"""Get a specific category in a group (as seen by user)
|
||||||
|
"""
|
||||||
|
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||||
|
|
||||||
|
res = yield self.store.get_group_category(
|
||||||
|
group_id=group_id,
|
||||||
|
category_id=category_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue(res)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def update_group_category(self, group_id, user_id, category_id, content):
|
||||||
|
"""Add/Update a group category
|
||||||
|
"""
|
||||||
|
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
|
||||||
|
|
||||||
|
is_public = _parse_visibility_from_contents(content)
|
||||||
|
profile = content.get("profile")
|
||||||
|
|
||||||
|
yield self.store.upsert_group_category(
|
||||||
|
group_id=group_id,
|
||||||
|
category_id=category_id,
|
||||||
|
is_public=is_public,
|
||||||
|
profile=profile,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue({})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def delete_group_category(self, group_id, user_id, category_id):
|
||||||
|
"""Delete a group category
|
||||||
|
"""
|
||||||
|
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
|
||||||
|
|
||||||
|
yield self.store.remove_group_category(
|
||||||
|
group_id=group_id,
|
||||||
|
category_id=category_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue({})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_group_roles(self, group_id, user_id):
|
||||||
|
"""Get all roles in a group (as seen by user)
|
||||||
|
"""
|
||||||
|
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||||
|
|
||||||
|
roles = yield self.store.get_group_roles(
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
defer.returnValue({"roles": roles})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_group_role(self, group_id, user_id, role_id):
|
||||||
|
"""Get a specific role in a group (as seen by user)
|
||||||
|
"""
|
||||||
|
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||||
|
|
||||||
|
res = yield self.store.get_group_role(
|
||||||
|
group_id=group_id,
|
||||||
|
role_id=role_id,
|
||||||
|
)
|
||||||
|
defer.returnValue(res)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def update_group_role(self, group_id, user_id, role_id, content):
|
||||||
|
"""Add/update a role in a group
|
||||||
|
"""
|
||||||
|
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
|
||||||
|
|
||||||
|
is_public = _parse_visibility_from_contents(content)
|
||||||
|
|
||||||
|
profile = content.get("profile")
|
||||||
|
|
||||||
|
yield self.store.upsert_group_role(
|
||||||
|
group_id=group_id,
|
||||||
|
role_id=role_id,
|
||||||
|
is_public=is_public,
|
||||||
|
profile=profile,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue({})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def delete_group_role(self, group_id, user_id, role_id):
|
||||||
|
"""Remove role from group
|
||||||
|
"""
|
||||||
|
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
|
||||||
|
|
||||||
|
yield self.store.remove_group_role(
|
||||||
|
group_id=group_id,
|
||||||
|
role_id=role_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue({})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def update_group_summary_user(self, group_id, requester_user_id, user_id, role_id,
|
||||||
|
content):
|
||||||
|
"""Add/update a users entry in the group summary
|
||||||
|
"""
|
||||||
|
yield self.check_group_is_ours(
|
||||||
|
group_id, and_exists=True, and_is_admin=requester_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
order = content.get("order", None)
|
||||||
|
|
||||||
|
is_public = _parse_visibility_from_contents(content)
|
||||||
|
|
||||||
|
yield self.store.add_user_to_summary(
|
||||||
|
group_id=group_id,
|
||||||
|
user_id=user_id,
|
||||||
|
role_id=role_id,
|
||||||
|
order=order,
|
||||||
|
is_public=is_public,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue({})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def delete_group_summary_user(self, group_id, requester_user_id, user_id, role_id):
|
||||||
|
"""Remove a user from the group summary
|
||||||
|
"""
|
||||||
|
yield self.check_group_is_ours(
|
||||||
|
group_id, and_exists=True, and_is_admin=requester_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self.store.remove_user_from_summary(
|
||||||
|
group_id=group_id,
|
||||||
|
user_id=user_id,
|
||||||
|
role_id=role_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue({})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_group_profile(self, group_id, requester_user_id):
|
||||||
|
"""Get the group profile as seen by requester_user_id
|
||||||
|
"""
|
||||||
|
|
||||||
|
yield self.check_group_is_ours(group_id)
|
||||||
|
|
||||||
|
group_description = yield self.store.get_group(group_id)
|
||||||
|
|
||||||
|
if group_description:
|
||||||
|
defer.returnValue(group_description)
|
||||||
|
else:
|
||||||
|
raise SynapseError(404, "Unknown group")
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def update_group_profile(self, group_id, requester_user_id, content):
|
||||||
|
"""Update the group profile
|
||||||
|
"""
|
||||||
|
yield self.check_group_is_ours(
|
||||||
|
group_id, and_exists=True, and_is_admin=requester_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
profile = {}
|
||||||
|
for keyname in ("name", "avatar_url", "short_description",
|
||||||
|
"long_description"):
|
||||||
|
if keyname in content:
|
||||||
|
value = content[keyname]
|
||||||
|
if not isinstance(value, basestring):
|
||||||
|
raise SynapseError(400, "%r value is not a string" % (keyname,))
|
||||||
|
profile[keyname] = value
|
||||||
|
|
||||||
|
yield self.store.update_group_profile(group_id, profile)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_users_in_group(self, group_id, requester_user_id):
|
||||||
|
"""Get the users in group as seen by requester_user_id.
|
||||||
|
|
||||||
|
The ordering is arbitrary at the moment
|
||||||
|
"""
|
||||||
|
|
||||||
|
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||||
|
|
||||||
|
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
|
||||||
|
|
||||||
|
user_results = yield self.store.get_users_in_group(
|
||||||
|
group_id, include_private=is_user_in_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk = []
|
||||||
|
for user_result in user_results:
|
||||||
|
g_user_id = user_result["user_id"]
|
||||||
|
is_public = user_result["is_public"]
|
||||||
|
|
||||||
|
entry = {"user_id": g_user_id}
|
||||||
|
|
||||||
|
profile = yield self.profile_handler.get_profile_from_cache(g_user_id)
|
||||||
|
entry.update(profile)
|
||||||
|
|
||||||
|
if not is_public:
|
||||||
|
entry["is_public"] = False
|
||||||
|
|
||||||
|
if not self.is_mine_id(g_user_id):
|
||||||
|
attestation = yield self.store.get_remote_attestation(group_id, g_user_id)
|
||||||
|
if not attestation:
|
||||||
|
continue
|
||||||
|
|
||||||
|
entry["attestation"] = attestation
|
||||||
|
else:
|
||||||
|
entry["attestation"] = self.attestations.create_attestation(
|
||||||
|
group_id, g_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk.append(entry)
|
||||||
|
|
||||||
|
# TODO: If admin add lists of users whose attestations have timed out
|
||||||
|
|
||||||
|
defer.returnValue({
|
||||||
|
"chunk": chunk,
|
||||||
|
"total_user_count_estimate": len(user_results),
|
||||||
|
})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_invited_users_in_group(self, group_id, requester_user_id):
|
||||||
|
"""Get the users that have been invited to a group as seen by requester_user_id.
|
||||||
|
|
||||||
|
The ordering is arbitrary at the moment
|
||||||
|
"""
|
||||||
|
|
||||||
|
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||||
|
|
||||||
|
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
|
||||||
|
|
||||||
|
if not is_user_in_group:
|
||||||
|
raise SynapseError(403, "User not in group")
|
||||||
|
|
||||||
|
invited_users = yield self.store.get_invited_users_in_group(group_id)
|
||||||
|
|
||||||
|
user_profiles = []
|
||||||
|
|
||||||
|
for user_id in invited_users:
|
||||||
|
user_profile = {
|
||||||
|
"user_id": user_id
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
profile = yield self.profile_handler.get_profile_from_cache(user_id)
|
||||||
|
user_profile.update(profile)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warn("Error getting profile for %s: %s", user_id, e)
|
||||||
|
user_profiles.append(user_profile)
|
||||||
|
|
||||||
|
defer.returnValue({
|
||||||
|
"chunk": user_profiles,
|
||||||
|
"total_user_count_estimate": len(invited_users),
|
||||||
|
})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_rooms_in_group(self, group_id, requester_user_id):
|
||||||
|
"""Get the rooms in group as seen by requester_user_id
|
||||||
|
|
||||||
|
This returns rooms in order of decreasing number of joined users
|
||||||
|
"""
|
||||||
|
|
||||||
|
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||||
|
|
||||||
|
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
|
||||||
|
|
||||||
|
room_results = yield self.store.get_rooms_in_group(
|
||||||
|
group_id, include_private=is_user_in_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk = []
|
||||||
|
for room_result in room_results:
|
||||||
|
room_id = room_result["room_id"]
|
||||||
|
is_public = room_result["is_public"]
|
||||||
|
|
||||||
|
joined_users = yield self.store.get_users_in_room(room_id)
|
||||||
|
entry = yield self.room_list_handler.generate_room_entry(
|
||||||
|
room_id, len(joined_users),
|
||||||
|
with_alias=False, allow_private=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not entry:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not is_public:
|
||||||
|
entry["is_public"] = False
|
||||||
|
|
||||||
|
chunk.append(entry)
|
||||||
|
|
||||||
|
chunk.sort(key=lambda e: -e["num_joined_members"])
|
||||||
|
|
||||||
|
defer.returnValue({
|
||||||
|
"chunk": chunk,
|
||||||
|
"total_room_count_estimate": len(room_results),
|
||||||
|
})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def add_room_to_group(self, group_id, requester_user_id, room_id, content):
|
||||||
|
"""Add room to group
|
||||||
|
"""
|
||||||
|
RoomID.from_string(room_id) # Ensure valid room id
|
||||||
|
|
||||||
|
yield self.check_group_is_ours(
|
||||||
|
group_id, and_exists=True, and_is_admin=requester_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
is_public = _parse_visibility_from_contents(content)
|
||||||
|
|
||||||
|
yield self.store.add_room_to_group(group_id, room_id, is_public=is_public)
|
||||||
|
|
||||||
|
defer.returnValue({})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def remove_room_from_group(self, group_id, requester_user_id, room_id):
|
||||||
|
"""Remove room from group
|
||||||
|
"""
|
||||||
|
yield self.check_group_is_ours(
|
||||||
|
group_id, and_exists=True, and_is_admin=requester_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self.store.remove_room_from_group(group_id, room_id)
|
||||||
|
|
||||||
|
defer.returnValue({})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def invite_to_group(self, group_id, user_id, requester_user_id, content):
|
||||||
|
"""Invite user to group
|
||||||
|
"""
|
||||||
|
|
||||||
|
group = yield self.check_group_is_ours(
|
||||||
|
group_id, and_exists=True, and_is_admin=requester_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Check if user knocked
|
||||||
|
# TODO: Check if user is already invited
|
||||||
|
|
||||||
|
content = {
|
||||||
|
"profile": {
|
||||||
|
"name": group["name"],
|
||||||
|
"avatar_url": group["avatar_url"],
|
||||||
|
},
|
||||||
|
"inviter": requester_user_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.hs.is_mine_id(user_id):
|
||||||
|
groups_local = self.hs.get_groups_local_handler()
|
||||||
|
res = yield groups_local.on_invite(group_id, user_id, content)
|
||||||
|
local_attestation = None
|
||||||
|
else:
|
||||||
|
local_attestation = self.attestations.create_attestation(group_id, user_id)
|
||||||
|
content.update({
|
||||||
|
"attestation": local_attestation,
|
||||||
|
})
|
||||||
|
|
||||||
|
res = yield self.transport_client.invite_to_group_notification(
|
||||||
|
get_domain_from_id(user_id), group_id, user_id, content
|
||||||
|
)
|
||||||
|
|
||||||
|
user_profile = res.get("user_profile", {})
|
||||||
|
yield self.store.add_remote_profile_cache(
|
||||||
|
user_id,
|
||||||
|
displayname=user_profile.get("displayname"),
|
||||||
|
avatar_url=user_profile.get("avatar_url"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if res["state"] == "join":
|
||||||
|
if not self.hs.is_mine_id(user_id):
|
||||||
|
remote_attestation = res["attestation"]
|
||||||
|
|
||||||
|
yield self.attestations.verify_attestation(
|
||||||
|
remote_attestation,
|
||||||
|
user_id=user_id,
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
remote_attestation = None
|
||||||
|
|
||||||
|
yield self.store.add_user_to_group(
|
||||||
|
group_id, user_id,
|
||||||
|
is_admin=False,
|
||||||
|
is_public=False, # TODO
|
||||||
|
local_attestation=local_attestation,
|
||||||
|
remote_attestation=remote_attestation,
|
||||||
|
)
|
||||||
|
elif res["state"] == "invite":
|
||||||
|
yield self.store.add_group_invite(
|
||||||
|
group_id, user_id,
|
||||||
|
)
|
||||||
|
defer.returnValue({
|
||||||
|
"state": "invite"
|
||||||
|
})
|
||||||
|
elif res["state"] == "reject":
|
||||||
|
defer.returnValue({
|
||||||
|
"state": "reject"
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
raise SynapseError(502, "Unknown state returned by HS")
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def accept_invite(self, group_id, user_id, content):
|
||||||
|
"""User tries to accept an invite to the group.
|
||||||
|
|
||||||
|
This is different from them asking to join, and so should error if no
|
||||||
|
invite exists (and they're not a member of the group)
|
||||||
|
"""
|
||||||
|
|
||||||
|
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||||
|
|
||||||
|
if not self.store.is_user_invited_to_local_group(group_id, user_id):
|
||||||
|
raise SynapseError(403, "User not invited to group")
|
||||||
|
|
||||||
|
if not self.hs.is_mine_id(user_id):
|
||||||
|
remote_attestation = content["attestation"]
|
||||||
|
|
||||||
|
yield self.attestations.verify_attestation(
|
||||||
|
remote_attestation,
|
||||||
|
user_id=user_id,
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
remote_attestation = None
|
||||||
|
|
||||||
|
local_attestation = self.attestations.create_attestation(group_id, user_id)
|
||||||
|
|
||||||
|
is_public = _parse_visibility_from_contents(content)
|
||||||
|
|
||||||
|
yield self.store.add_user_to_group(
|
||||||
|
group_id, user_id,
|
||||||
|
is_admin=False,
|
||||||
|
is_public=is_public,
|
||||||
|
local_attestation=local_attestation,
|
||||||
|
remote_attestation=remote_attestation,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue({
|
||||||
|
"state": "join",
|
||||||
|
"attestation": local_attestation,
|
||||||
|
})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def knock(self, group_id, user_id, content):
|
||||||
|
"""A user requests becoming a member of the group
|
||||||
|
"""
|
||||||
|
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||||
|
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def accept_knock(self, group_id, user_id, content):
|
||||||
|
"""Accept a users knock to the room.
|
||||||
|
|
||||||
|
Errors if the user hasn't knocked, rather than inviting them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||||
|
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
|
||||||
|
"""Remove a user from the group; either a user is leaving or and admin
|
||||||
|
kicked htem.
|
||||||
|
"""
|
||||||
|
|
||||||
|
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||||
|
|
||||||
|
is_kick = False
|
||||||
|
if requester_user_id != user_id:
|
||||||
|
is_admin = yield self.store.is_user_admin_in_group(
|
||||||
|
group_id, requester_user_id
|
||||||
|
)
|
||||||
|
if not is_admin:
|
||||||
|
raise SynapseError(403, "User is not admin in group")
|
||||||
|
|
||||||
|
is_kick = True
|
||||||
|
|
||||||
|
yield self.store.remove_user_from_group(
|
||||||
|
group_id, user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_kick:
|
||||||
|
if self.hs.is_mine_id(user_id):
|
||||||
|
groups_local = self.hs.get_groups_local_handler()
|
||||||
|
yield groups_local.user_removed_from_group(group_id, user_id, {})
|
||||||
|
else:
|
||||||
|
yield self.transport_client.remove_user_from_group_notification(
|
||||||
|
get_domain_from_id(user_id), group_id, user_id, {}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.hs.is_mine_id(user_id):
|
||||||
|
yield self.store.maybe_delete_remote_profile_cache(user_id)
|
||||||
|
|
||||||
|
defer.returnValue({})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def create_group(self, group_id, user_id, content):
|
||||||
|
group = yield self.check_group_is_ours(group_id)
|
||||||
|
|
||||||
|
_validate_group_id(group_id)
|
||||||
|
|
||||||
|
logger.info("Attempting to create group with ID: %r", group_id)
|
||||||
|
if group:
|
||||||
|
raise SynapseError(400, "Group already exists")
|
||||||
|
|
||||||
|
is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id))
|
||||||
|
if not is_admin:
|
||||||
|
if not self.hs.config.enable_group_creation:
|
||||||
|
raise SynapseError(
|
||||||
|
403, "Only server admin can create group on this server",
|
||||||
|
)
|
||||||
|
localpart = GroupID.from_string(group_id).localpart
|
||||||
|
if not localpart.startswith(self.hs.config.group_creation_prefix):
|
||||||
|
raise SynapseError(
|
||||||
|
400,
|
||||||
|
"Can only create groups with prefix %r on this server" % (
|
||||||
|
self.hs.config.group_creation_prefix,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
profile = content.get("profile", {})
|
||||||
|
name = profile.get("name")
|
||||||
|
avatar_url = profile.get("avatar_url")
|
||||||
|
short_description = profile.get("short_description")
|
||||||
|
long_description = profile.get("long_description")
|
||||||
|
user_profile = content.get("user_profile", {})
|
||||||
|
|
||||||
|
yield self.store.create_group(
|
||||||
|
group_id,
|
||||||
|
user_id,
|
||||||
|
name=name,
|
||||||
|
avatar_url=avatar_url,
|
||||||
|
short_description=short_description,
|
||||||
|
long_description=long_description,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.hs.is_mine_id(user_id):
|
||||||
|
remote_attestation = content["attestation"]
|
||||||
|
|
||||||
|
yield self.attestations.verify_attestation(
|
||||||
|
remote_attestation,
|
||||||
|
user_id=user_id,
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
local_attestation = self.attestations.create_attestation(group_id, user_id)
|
||||||
|
else:
|
||||||
|
local_attestation = None
|
||||||
|
remote_attestation = None
|
||||||
|
|
||||||
|
yield self.store.add_user_to_group(
|
||||||
|
group_id, user_id,
|
||||||
|
is_admin=True,
|
||||||
|
is_public=True, # TODO
|
||||||
|
local_attestation=local_attestation,
|
||||||
|
remote_attestation=remote_attestation,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.hs.is_mine_id(user_id):
|
||||||
|
yield self.store.add_remote_profile_cache(
|
||||||
|
user_id,
|
||||||
|
displayname=user_profile.get("displayname"),
|
||||||
|
avatar_url=user_profile.get("avatar_url"),
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue({
|
||||||
|
"group_id": group_id,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_visibility_from_contents(content):
|
||||||
|
"""Given a content for a request parse out whether the entity should be
|
||||||
|
public or not
|
||||||
|
"""
|
||||||
|
|
||||||
|
visibility = content.get("visibility")
|
||||||
|
if visibility:
|
||||||
|
vis_type = visibility["type"]
|
||||||
|
if vis_type not in ("public", "private"):
|
||||||
|
raise SynapseError(
|
||||||
|
400, "Synapse only supports 'public'/'private' visibility"
|
||||||
|
)
|
||||||
|
is_public = vis_type == "public"
|
||||||
|
else:
|
||||||
|
is_public = True
|
||||||
|
|
||||||
|
return is_public
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_group_id(group_id):
|
||||||
|
"""Validates the group ID is valid for creation on this home server
|
||||||
|
"""
|
||||||
|
localpart = GroupID.from_string(group_id).localpart
|
||||||
|
|
||||||
|
if localpart.lower() != localpart:
|
||||||
|
raise SynapseError(400, "Group ID must be lower case")
|
||||||
|
|
||||||
|
if urllib.quote(localpart.encode('utf-8')) != localpart:
|
||||||
|
raise SynapseError(
|
||||||
|
400,
|
||||||
|
"Group ID can only contain characters a-z, 0-9, or '_-./'",
|
||||||
|
)
|
@ -20,7 +20,6 @@ from .room import (
|
|||||||
from .room_member import RoomMemberHandler
|
from .room_member import RoomMemberHandler
|
||||||
from .message import MessageHandler
|
from .message import MessageHandler
|
||||||
from .federation import FederationHandler
|
from .federation import FederationHandler
|
||||||
from .profile import ProfileHandler
|
|
||||||
from .directory import DirectoryHandler
|
from .directory import DirectoryHandler
|
||||||
from .admin import AdminHandler
|
from .admin import AdminHandler
|
||||||
from .identity import IdentityHandler
|
from .identity import IdentityHandler
|
||||||
@ -52,7 +51,6 @@ class Handlers(object):
|
|||||||
self.room_creation_handler = RoomCreationHandler(hs)
|
self.room_creation_handler = RoomCreationHandler(hs)
|
||||||
self.room_member_handler = RoomMemberHandler(hs)
|
self.room_member_handler = RoomMemberHandler(hs)
|
||||||
self.federation_handler = FederationHandler(hs)
|
self.federation_handler = FederationHandler(hs)
|
||||||
self.profile_handler = ProfileHandler(hs)
|
|
||||||
self.directory_handler = DirectoryHandler(hs)
|
self.directory_handler = DirectoryHandler(hs)
|
||||||
self.admin_handler = AdminHandler(hs)
|
self.admin_handler = AdminHandler(hs)
|
||||||
self.identity_handler = IdentityHandler(hs)
|
self.identity_handler = IdentityHandler(hs)
|
||||||
|
@ -40,6 +40,8 @@ class DirectoryHandler(BaseHandler):
|
|||||||
"directory", self.on_directory_query
|
"directory", self.on_directory_query
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.spam_checker = hs.get_spam_checker()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _create_association(self, room_alias, room_id, servers=None, creator=None):
|
def _create_association(self, room_alias, room_id, servers=None, creator=None):
|
||||||
# general association creation for both human users and app services
|
# general association creation for both human users and app services
|
||||||
@ -73,6 +75,11 @@ class DirectoryHandler(BaseHandler):
|
|||||||
# association creation for human users
|
# association creation for human users
|
||||||
# TODO(erikj): Do user auth.
|
# TODO(erikj): Do user auth.
|
||||||
|
|
||||||
|
if not self.spam_checker.user_may_create_room_alias(user_id, room_alias):
|
||||||
|
raise SynapseError(
|
||||||
|
403, "This user is not permitted to create this alias",
|
||||||
|
)
|
||||||
|
|
||||||
can_create = yield self.can_modify_alias(
|
can_create = yield self.can_modify_alias(
|
||||||
room_alias,
|
room_alias,
|
||||||
user_id=user_id
|
user_id=user_id
|
||||||
@ -327,6 +334,14 @@ class DirectoryHandler(BaseHandler):
|
|||||||
room_id (str)
|
room_id (str)
|
||||||
visibility (str): "public" or "private"
|
visibility (str): "public" or "private"
|
||||||
"""
|
"""
|
||||||
|
if not self.spam_checker.user_may_publish_room(
|
||||||
|
requester.user.to_string(), room_id
|
||||||
|
):
|
||||||
|
raise AuthError(
|
||||||
|
403,
|
||||||
|
"This user is not permitted to publish rooms to the room list"
|
||||||
|
)
|
||||||
|
|
||||||
if requester.is_guest:
|
if requester.is_guest:
|
||||||
raise AuthError(403, "Guests cannot edit the published room list")
|
raise AuthError(403, "Guests cannot edit the published room list")
|
||||||
|
|
||||||
|
@ -14,7 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Contains handlers for federation events."""
|
"""Contains handlers for federation events."""
|
||||||
import synapse.util.logcontext
|
|
||||||
from signedjson.key import decode_verify_key_bytes
|
from signedjson.key import decode_verify_key_bytes
|
||||||
from signedjson.sign import verify_signed_json
|
from signedjson.sign import verify_signed_json
|
||||||
from unpaddedbase64 import decode_base64
|
from unpaddedbase64 import decode_base64
|
||||||
@ -26,10 +25,7 @@ from synapse.api.errors import (
|
|||||||
)
|
)
|
||||||
from synapse.api.constants import EventTypes, Membership, RejectedReason
|
from synapse.api.constants import EventTypes, Membership, RejectedReason
|
||||||
from synapse.events.validator import EventValidator
|
from synapse.events.validator import EventValidator
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError, logcontext
|
||||||
from synapse.util.logcontext import (
|
|
||||||
preserve_fn, preserve_context_over_deferred
|
|
||||||
)
|
|
||||||
from synapse.util.metrics import measure_func
|
from synapse.util.metrics import measure_func
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.async import run_on_reactor, Linearizer
|
from synapse.util.async import run_on_reactor, Linearizer
|
||||||
@ -77,6 +73,7 @@ class FederationHandler(BaseHandler):
|
|||||||
self.action_generator = hs.get_action_generator()
|
self.action_generator = hs.get_action_generator()
|
||||||
self.is_mine_id = hs.is_mine_id
|
self.is_mine_id = hs.is_mine_id
|
||||||
self.pusher_pool = hs.get_pusherpool()
|
self.pusher_pool = hs.get_pusherpool()
|
||||||
|
self.spam_checker = hs.get_spam_checker()
|
||||||
|
|
||||||
self.replication_layer.set_handler(self)
|
self.replication_layer.set_handler(self)
|
||||||
|
|
||||||
@ -125,6 +122,28 @@ class FederationHandler(BaseHandler):
|
|||||||
self.room_queues[pdu.room_id].append((pdu, origin))
|
self.room_queues[pdu.room_id].append((pdu, origin))
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# If we're no longer in the room just ditch the event entirely. This
|
||||||
|
# is probably an old server that has come back and thinks we're still
|
||||||
|
# in the room (or we've been rejoined to the room by a state reset).
|
||||||
|
#
|
||||||
|
# If we were never in the room then maybe our database got vaped and
|
||||||
|
# we should check if we *are* in fact in the room. If we are then we
|
||||||
|
# can magically rejoin the room.
|
||||||
|
is_in_room = yield self.auth.check_host_in_room(
|
||||||
|
pdu.room_id,
|
||||||
|
self.server_name
|
||||||
|
)
|
||||||
|
if not is_in_room:
|
||||||
|
was_in_room = yield self.store.was_host_joined(
|
||||||
|
pdu.room_id, self.server_name,
|
||||||
|
)
|
||||||
|
if was_in_room:
|
||||||
|
logger.info(
|
||||||
|
"Ignoring PDU %s for room %s from %s as we've left the room!",
|
||||||
|
pdu.event_id, pdu.room_id, origin,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
state = None
|
state = None
|
||||||
|
|
||||||
auth_chain = []
|
auth_chain = []
|
||||||
@ -591,9 +610,9 @@ class FederationHandler(BaseHandler):
|
|||||||
missing_auth - failed_to_fetch
|
missing_auth - failed_to_fetch
|
||||||
)
|
)
|
||||||
|
|
||||||
results = yield preserve_context_over_deferred(defer.gatherResults(
|
results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
|
||||||
[
|
[
|
||||||
preserve_fn(self.replication_layer.get_pdu)(
|
logcontext.preserve_fn(self.replication_layer.get_pdu)(
|
||||||
[dest],
|
[dest],
|
||||||
event_id,
|
event_id,
|
||||||
outlier=True,
|
outlier=True,
|
||||||
@ -785,10 +804,14 @@ class FederationHandler(BaseHandler):
|
|||||||
event_ids = list(extremities.keys())
|
event_ids = list(extremities.keys())
|
||||||
|
|
||||||
logger.debug("calling resolve_state_groups in _maybe_backfill")
|
logger.debug("calling resolve_state_groups in _maybe_backfill")
|
||||||
states = yield preserve_context_over_deferred(defer.gatherResults([
|
states = yield logcontext.make_deferred_yieldable(defer.gatherResults(
|
||||||
preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e])
|
[
|
||||||
for e in event_ids
|
logcontext.preserve_fn(self.state_handler.resolve_state_groups)(
|
||||||
]))
|
room_id, [e]
|
||||||
|
)
|
||||||
|
for e in event_ids
|
||||||
|
], consumeErrors=True,
|
||||||
|
))
|
||||||
states = dict(zip(event_ids, [s.state for s in states]))
|
states = dict(zip(event_ids, [s.state for s in states]))
|
||||||
|
|
||||||
state_map = yield self.store.get_events(
|
state_map = yield self.store.get_events(
|
||||||
@ -941,9 +964,7 @@ class FederationHandler(BaseHandler):
|
|||||||
# lots of requests for missing prev_events which we do actually
|
# lots of requests for missing prev_events which we do actually
|
||||||
# have. Hence we fire off the deferred, but don't wait for it.
|
# have. Hence we fire off the deferred, but don't wait for it.
|
||||||
|
|
||||||
synapse.util.logcontext.preserve_fn(self._handle_queued_pdus)(
|
logcontext.preserve_fn(self._handle_queued_pdus)(room_queue)
|
||||||
room_queue
|
|
||||||
)
|
|
||||||
|
|
||||||
defer.returnValue(True)
|
defer.returnValue(True)
|
||||||
|
|
||||||
@ -1070,6 +1091,9 @@ class FederationHandler(BaseHandler):
|
|||||||
"""
|
"""
|
||||||
event = pdu
|
event = pdu
|
||||||
|
|
||||||
|
if event.state_key is None:
|
||||||
|
raise SynapseError(400, "The invite event did not have a state key")
|
||||||
|
|
||||||
is_blocked = yield self.store.is_room_blocked(event.room_id)
|
is_blocked = yield self.store.is_room_blocked(event.room_id)
|
||||||
if is_blocked:
|
if is_blocked:
|
||||||
raise SynapseError(403, "This room has been blocked on this server")
|
raise SynapseError(403, "This room has been blocked on this server")
|
||||||
@ -1077,6 +1101,13 @@ class FederationHandler(BaseHandler):
|
|||||||
if self.hs.config.block_non_admin_invites:
|
if self.hs.config.block_non_admin_invites:
|
||||||
raise SynapseError(403, "This server does not accept room invites")
|
raise SynapseError(403, "This server does not accept room invites")
|
||||||
|
|
||||||
|
if not self.spam_checker.user_may_invite(
|
||||||
|
event.sender, event.state_key, event.room_id,
|
||||||
|
):
|
||||||
|
raise SynapseError(
|
||||||
|
403, "This user is not permitted to send invites to this server/user"
|
||||||
|
)
|
||||||
|
|
||||||
membership = event.content.get("membership")
|
membership = event.content.get("membership")
|
||||||
if event.type != EventTypes.Member or membership != Membership.INVITE:
|
if event.type != EventTypes.Member or membership != Membership.INVITE:
|
||||||
raise SynapseError(400, "The event was not an m.room.member invite event")
|
raise SynapseError(400, "The event was not an m.room.member invite event")
|
||||||
@ -1085,9 +1116,6 @@ class FederationHandler(BaseHandler):
|
|||||||
if sender_domain != origin:
|
if sender_domain != origin:
|
||||||
raise SynapseError(400, "The invite event was not from the server sending it")
|
raise SynapseError(400, "The invite event was not from the server sending it")
|
||||||
|
|
||||||
if event.state_key is None:
|
|
||||||
raise SynapseError(400, "The invite event did not have a state key")
|
|
||||||
|
|
||||||
if not self.is_mine_id(event.state_key):
|
if not self.is_mine_id(event.state_key):
|
||||||
raise SynapseError(400, "The invite event must be for this server")
|
raise SynapseError(400, "The invite event must be for this server")
|
||||||
|
|
||||||
@ -1430,7 +1458,7 @@ class FederationHandler(BaseHandler):
|
|||||||
if not backfilled:
|
if not backfilled:
|
||||||
# this intentionally does not yield: we don't care about the result
|
# this intentionally does not yield: we don't care about the result
|
||||||
# and don't need to wait for it.
|
# and don't need to wait for it.
|
||||||
preserve_fn(self.pusher_pool.on_new_notifications)(
|
logcontext.preserve_fn(self.pusher_pool.on_new_notifications)(
|
||||||
event_stream_id, max_stream_id
|
event_stream_id, max_stream_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1443,16 +1471,16 @@ class FederationHandler(BaseHandler):
|
|||||||
a bunch of outliers, but not a chunk of individual events that depend
|
a bunch of outliers, but not a chunk of individual events that depend
|
||||||
on each other for state calculations.
|
on each other for state calculations.
|
||||||
"""
|
"""
|
||||||
contexts = yield preserve_context_over_deferred(defer.gatherResults(
|
contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults(
|
||||||
[
|
[
|
||||||
preserve_fn(self._prep_event)(
|
logcontext.preserve_fn(self._prep_event)(
|
||||||
origin,
|
origin,
|
||||||
ev_info["event"],
|
ev_info["event"],
|
||||||
state=ev_info.get("state"),
|
state=ev_info.get("state"),
|
||||||
auth_events=ev_info.get("auth_events"),
|
auth_events=ev_info.get("auth_events"),
|
||||||
)
|
)
|
||||||
for ev_info in event_infos
|
for ev_info in event_infos
|
||||||
]
|
], consumeErrors=True,
|
||||||
))
|
))
|
||||||
|
|
||||||
yield self.store.persist_events(
|
yield self.store.persist_events(
|
||||||
@ -1760,18 +1788,17 @@ class FederationHandler(BaseHandler):
|
|||||||
# Do auth conflict res.
|
# Do auth conflict res.
|
||||||
logger.info("Different auth: %s", different_auth)
|
logger.info("Different auth: %s", different_auth)
|
||||||
|
|
||||||
different_events = yield preserve_context_over_deferred(defer.gatherResults(
|
different_events = yield logcontext.make_deferred_yieldable(
|
||||||
[
|
defer.gatherResults([
|
||||||
preserve_fn(self.store.get_event)(
|
logcontext.preserve_fn(self.store.get_event)(
|
||||||
d,
|
d,
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
allow_rejected=False,
|
allow_rejected=False,
|
||||||
)
|
)
|
||||||
for d in different_auth
|
for d in different_auth
|
||||||
if d in have_events and not have_events[d]
|
if d in have_events and not have_events[d]
|
||||||
],
|
], consumeErrors=True)
|
||||||
consumeErrors=True
|
).addErrback(unwrapFirstError)
|
||||||
)).addErrback(unwrapFirstError)
|
|
||||||
|
|
||||||
if different_events:
|
if different_events:
|
||||||
local_view = dict(auth_events)
|
local_view = dict(auth_events)
|
||||||
|
417
synapse/handlers/groups_local.py
Normal file
417
synapse/handlers/groups_local.py
Normal file
@ -0,0 +1,417 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2017 Vector Creations Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api.errors import SynapseError
|
||||||
|
from synapse.types import get_domain_from_id
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_rerouter(func_name):
|
||||||
|
"""Returns a function that looks at the group id and calls the function
|
||||||
|
on federation or the local group server if the group is local
|
||||||
|
"""
|
||||||
|
def f(self, group_id, *args, **kwargs):
|
||||||
|
if self.is_mine_id(group_id):
|
||||||
|
return getattr(self.groups_server_handler, func_name)(
|
||||||
|
group_id, *args, **kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
destination = get_domain_from_id(group_id)
|
||||||
|
return getattr(self.transport_client, func_name)(
|
||||||
|
destination, group_id, *args, **kwargs
|
||||||
|
)
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
class GroupsLocalHandler(object):
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.hs = hs
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.room_list_handler = hs.get_room_list_handler()
|
||||||
|
self.groups_server_handler = hs.get_groups_server_handler()
|
||||||
|
self.transport_client = hs.get_federation_transport_client()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.keyring = hs.get_keyring()
|
||||||
|
self.is_mine_id = hs.is_mine_id
|
||||||
|
self.signing_key = hs.config.signing_key[0]
|
||||||
|
self.server_name = hs.hostname
|
||||||
|
self.notifier = hs.get_notifier()
|
||||||
|
self.attestations = hs.get_groups_attestation_signing()
|
||||||
|
|
||||||
|
self.profile_handler = hs.get_profile_handler()
|
||||||
|
|
||||||
|
# Ensure attestations get renewed
|
||||||
|
hs.get_groups_attestation_renewer()
|
||||||
|
|
||||||
|
# The following functions merely route the query to the local groups server
|
||||||
|
# or federation depending on if the group is local or remote
|
||||||
|
|
||||||
|
get_group_profile = _create_rerouter("get_group_profile")
|
||||||
|
update_group_profile = _create_rerouter("update_group_profile")
|
||||||
|
get_rooms_in_group = _create_rerouter("get_rooms_in_group")
|
||||||
|
|
||||||
|
get_invited_users_in_group = _create_rerouter("get_invited_users_in_group")
|
||||||
|
|
||||||
|
add_room_to_group = _create_rerouter("add_room_to_group")
|
||||||
|
remove_room_from_group = _create_rerouter("remove_room_from_group")
|
||||||
|
|
||||||
|
update_group_summary_room = _create_rerouter("update_group_summary_room")
|
||||||
|
delete_group_summary_room = _create_rerouter("delete_group_summary_room")
|
||||||
|
|
||||||
|
update_group_category = _create_rerouter("update_group_category")
|
||||||
|
delete_group_category = _create_rerouter("delete_group_category")
|
||||||
|
get_group_category = _create_rerouter("get_group_category")
|
||||||
|
get_group_categories = _create_rerouter("get_group_categories")
|
||||||
|
|
||||||
|
update_group_summary_user = _create_rerouter("update_group_summary_user")
|
||||||
|
delete_group_summary_user = _create_rerouter("delete_group_summary_user")
|
||||||
|
|
||||||
|
update_group_role = _create_rerouter("update_group_role")
|
||||||
|
delete_group_role = _create_rerouter("delete_group_role")
|
||||||
|
get_group_role = _create_rerouter("get_group_role")
|
||||||
|
get_group_roles = _create_rerouter("get_group_roles")
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_group_summary(self, group_id, requester_user_id):
|
||||||
|
"""Get the group summary for a group.
|
||||||
|
|
||||||
|
If the group is remote we check that the users have valid attestations.
|
||||||
|
"""
|
||||||
|
if self.is_mine_id(group_id):
|
||||||
|
res = yield self.groups_server_handler.get_group_summary(
|
||||||
|
group_id, requester_user_id
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
res = yield self.transport_client.get_group_summary(
|
||||||
|
get_domain_from_id(group_id), group_id, requester_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
group_server_name = get_domain_from_id(group_id)
|
||||||
|
|
||||||
|
# Loop through the users and validate the attestations.
|
||||||
|
chunk = res["users_section"]["users"]
|
||||||
|
valid_users = []
|
||||||
|
for entry in chunk:
|
||||||
|
g_user_id = entry["user_id"]
|
||||||
|
attestation = entry.pop("attestation", {})
|
||||||
|
try:
|
||||||
|
if get_domain_from_id(g_user_id) != group_server_name:
|
||||||
|
yield self.attestations.verify_attestation(
|
||||||
|
attestation,
|
||||||
|
group_id=group_id,
|
||||||
|
user_id=g_user_id,
|
||||||
|
server_name=get_domain_from_id(g_user_id),
|
||||||
|
)
|
||||||
|
valid_users.append(entry)
|
||||||
|
except Exception as e:
|
||||||
|
logger.info("Failed to verify user is in group: %s", e)
|
||||||
|
|
||||||
|
res["users_section"]["users"] = valid_users
|
||||||
|
|
||||||
|
res["users_section"]["users"].sort(key=lambda e: e.get("order", 0))
|
||||||
|
res["rooms_section"]["rooms"].sort(key=lambda e: e.get("order", 0))
|
||||||
|
|
||||||
|
# Add `is_publicised` flag to indicate whether the user has publicised their
|
||||||
|
# membership of the group on their profile
|
||||||
|
result = yield self.store.get_publicised_groups_for_user(requester_user_id)
|
||||||
|
is_publicised = group_id in result
|
||||||
|
|
||||||
|
res.setdefault("user", {})["is_publicised"] = is_publicised
|
||||||
|
|
||||||
|
defer.returnValue(res)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def create_group(self, group_id, user_id, content):
|
||||||
|
"""Create a group
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.info("Asking to create group with ID: %r", group_id)
|
||||||
|
|
||||||
|
if self.is_mine_id(group_id):
|
||||||
|
res = yield self.groups_server_handler.create_group(
|
||||||
|
group_id, user_id, content
|
||||||
|
)
|
||||||
|
local_attestation = None
|
||||||
|
remote_attestation = None
|
||||||
|
else:
|
||||||
|
local_attestation = self.attestations.create_attestation(group_id, user_id)
|
||||||
|
content["attestation"] = local_attestation
|
||||||
|
|
||||||
|
content["user_profile"] = yield self.profile_handler.get_profile(user_id)
|
||||||
|
|
||||||
|
res = yield self.transport_client.create_group(
|
||||||
|
get_domain_from_id(group_id), group_id, user_id, content,
|
||||||
|
)
|
||||||
|
|
||||||
|
remote_attestation = res["attestation"]
|
||||||
|
yield self.attestations.verify_attestation(
|
||||||
|
remote_attestation,
|
||||||
|
group_id=group_id,
|
||||||
|
user_id=user_id,
|
||||||
|
server_name=get_domain_from_id(group_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
is_publicised = content.get("publicise", False)
|
||||||
|
token = yield self.store.register_user_group_membership(
|
||||||
|
group_id, user_id,
|
||||||
|
membership="join",
|
||||||
|
is_admin=True,
|
||||||
|
local_attestation=local_attestation,
|
||||||
|
remote_attestation=remote_attestation,
|
||||||
|
is_publicised=is_publicised,
|
||||||
|
)
|
||||||
|
self.notifier.on_new_event(
|
||||||
|
"groups_key", token, users=[user_id],
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue(res)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_users_in_group(self, group_id, requester_user_id):
|
||||||
|
"""Get users in a group
|
||||||
|
"""
|
||||||
|
if self.is_mine_id(group_id):
|
||||||
|
res = yield self.groups_server_handler.get_users_in_group(
|
||||||
|
group_id, requester_user_id
|
||||||
|
)
|
||||||
|
defer.returnValue(res)
|
||||||
|
|
||||||
|
group_server_name = get_domain_from_id(group_id)
|
||||||
|
|
||||||
|
res = yield self.transport_client.get_users_in_group(
|
||||||
|
get_domain_from_id(group_id), group_id, requester_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk = res["chunk"]
|
||||||
|
valid_entries = []
|
||||||
|
for entry in chunk:
|
||||||
|
g_user_id = entry["user_id"]
|
||||||
|
attestation = entry.pop("attestation", {})
|
||||||
|
try:
|
||||||
|
if get_domain_from_id(g_user_id) != group_server_name:
|
||||||
|
yield self.attestations.verify_attestation(
|
||||||
|
attestation,
|
||||||
|
group_id=group_id,
|
||||||
|
user_id=g_user_id,
|
||||||
|
server_name=get_domain_from_id(g_user_id),
|
||||||
|
)
|
||||||
|
valid_entries.append(entry)
|
||||||
|
except Exception as e:
|
||||||
|
logger.info("Failed to verify user is in group: %s", e)
|
||||||
|
|
||||||
|
res["chunk"] = valid_entries
|
||||||
|
|
||||||
|
defer.returnValue(res)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def join_group(self, group_id, user_id, content):
|
||||||
|
"""Request to join a group
|
||||||
|
"""
|
||||||
|
raise NotImplementedError() # TODO
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def accept_invite(self, group_id, user_id, content):
|
||||||
|
"""Accept an invite to a group
|
||||||
|
"""
|
||||||
|
if self.is_mine_id(group_id):
|
||||||
|
yield self.groups_server_handler.accept_invite(
|
||||||
|
group_id, user_id, content
|
||||||
|
)
|
||||||
|
local_attestation = None
|
||||||
|
remote_attestation = None
|
||||||
|
else:
|
||||||
|
local_attestation = self.attestations.create_attestation(group_id, user_id)
|
||||||
|
content["attestation"] = local_attestation
|
||||||
|
|
||||||
|
res = yield self.transport_client.accept_group_invite(
|
||||||
|
get_domain_from_id(group_id), group_id, user_id, content,
|
||||||
|
)
|
||||||
|
|
||||||
|
remote_attestation = res["attestation"]
|
||||||
|
|
||||||
|
yield self.attestations.verify_attestation(
|
||||||
|
remote_attestation,
|
||||||
|
group_id=group_id,
|
||||||
|
user_id=user_id,
|
||||||
|
server_name=get_domain_from_id(group_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Check that the group is public and we're being added publically
|
||||||
|
is_publicised = content.get("publicise", False)
|
||||||
|
|
||||||
|
token = yield self.store.register_user_group_membership(
|
||||||
|
group_id, user_id,
|
||||||
|
membership="join",
|
||||||
|
is_admin=False,
|
||||||
|
local_attestation=local_attestation,
|
||||||
|
remote_attestation=remote_attestation,
|
||||||
|
is_publicised=is_publicised,
|
||||||
|
)
|
||||||
|
self.notifier.on_new_event(
|
||||||
|
"groups_key", token, users=[user_id],
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue({})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def invite(self, group_id, user_id, requester_user_id, config):
|
||||||
|
"""Invite a user to a group
|
||||||
|
"""
|
||||||
|
content = {
|
||||||
|
"requester_user_id": requester_user_id,
|
||||||
|
"config": config,
|
||||||
|
}
|
||||||
|
if self.is_mine_id(group_id):
|
||||||
|
res = yield self.groups_server_handler.invite_to_group(
|
||||||
|
group_id, user_id, requester_user_id, content,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
res = yield self.transport_client.invite_to_group(
|
||||||
|
get_domain_from_id(group_id), group_id, user_id, requester_user_id,
|
||||||
|
content,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue(res)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_invite(self, group_id, user_id, content):
|
||||||
|
"""One of our users were invited to a group
|
||||||
|
"""
|
||||||
|
# TODO: Support auto join and rejection
|
||||||
|
|
||||||
|
if not self.is_mine_id(user_id):
|
||||||
|
raise SynapseError(400, "User not on this server")
|
||||||
|
|
||||||
|
local_profile = {}
|
||||||
|
if "profile" in content:
|
||||||
|
if "name" in content["profile"]:
|
||||||
|
local_profile["name"] = content["profile"]["name"]
|
||||||
|
if "avatar_url" in content["profile"]:
|
||||||
|
local_profile["avatar_url"] = content["profile"]["avatar_url"]
|
||||||
|
|
||||||
|
token = yield self.store.register_user_group_membership(
|
||||||
|
group_id, user_id,
|
||||||
|
membership="invite",
|
||||||
|
content={"profile": local_profile, "inviter": content["inviter"]},
|
||||||
|
)
|
||||||
|
self.notifier.on_new_event(
|
||||||
|
"groups_key", token, users=[user_id],
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
user_profile = yield self.profile_handler.get_profile(user_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warn("No profile for user %s: %s", user_id, e)
|
||||||
|
user_profile = {}
|
||||||
|
|
||||||
|
defer.returnValue({"state": "invite", "user_profile": user_profile})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
|
||||||
|
"""Remove a user from a group
|
||||||
|
"""
|
||||||
|
if user_id == requester_user_id:
|
||||||
|
token = yield self.store.register_user_group_membership(
|
||||||
|
group_id, user_id,
|
||||||
|
membership="leave",
|
||||||
|
)
|
||||||
|
self.notifier.on_new_event(
|
||||||
|
"groups_key", token, users=[user_id],
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Should probably remember that we tried to leave so that we can
|
||||||
|
# retry if the group server is currently down.
|
||||||
|
|
||||||
|
if self.is_mine_id(group_id):
|
||||||
|
res = yield self.groups_server_handler.remove_user_from_group(
|
||||||
|
group_id, user_id, requester_user_id, content,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
content["requester_user_id"] = requester_user_id
|
||||||
|
res = yield self.transport_client.remove_user_from_group(
|
||||||
|
get_domain_from_id(group_id), group_id, requester_user_id,
|
||||||
|
user_id, content,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue(res)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def user_removed_from_group(self, group_id, user_id, content):
|
||||||
|
"""One of our users was removed/kicked from a group
|
||||||
|
"""
|
||||||
|
# TODO: Check if user in group
|
||||||
|
token = yield self.store.register_user_group_membership(
|
||||||
|
group_id, user_id,
|
||||||
|
membership="leave",
|
||||||
|
)
|
||||||
|
self.notifier.on_new_event(
|
||||||
|
"groups_key", token, users=[user_id],
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_joined_groups(self, user_id):
|
||||||
|
group_ids = yield self.store.get_joined_groups(user_id)
|
||||||
|
defer.returnValue({"groups": group_ids})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_publicised_groups_for_user(self, user_id):
|
||||||
|
if self.hs.is_mine_id(user_id):
|
||||||
|
result = yield self.store.get_publicised_groups_for_user(user_id)
|
||||||
|
defer.returnValue({"groups": result})
|
||||||
|
else:
|
||||||
|
result = yield self.transport_client.get_publicised_groups_for_user(
|
||||||
|
get_domain_from_id(user_id), user_id
|
||||||
|
)
|
||||||
|
# TODO: Verify attestations
|
||||||
|
defer.returnValue(result)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def bulk_get_publicised_groups(self, user_ids, proxy=True):
|
||||||
|
destinations = {}
|
||||||
|
local_users = set()
|
||||||
|
|
||||||
|
for user_id in user_ids:
|
||||||
|
if self.hs.is_mine_id(user_id):
|
||||||
|
local_users.add(user_id)
|
||||||
|
else:
|
||||||
|
destinations.setdefault(
|
||||||
|
get_domain_from_id(user_id), set()
|
||||||
|
).add(user_id)
|
||||||
|
|
||||||
|
if not proxy and destinations:
|
||||||
|
raise SynapseError(400, "Some user_ids are not local")
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
failed_results = []
|
||||||
|
for destination, dest_user_ids in destinations.iteritems():
|
||||||
|
try:
|
||||||
|
r = yield self.transport_client.bulk_get_publicised_groups(
|
||||||
|
destination, list(dest_user_ids),
|
||||||
|
)
|
||||||
|
results.update(r["users"])
|
||||||
|
except Exception:
|
||||||
|
failed_results.extend(dest_user_ids)
|
||||||
|
|
||||||
|
for uid in local_users:
|
||||||
|
results[uid] = yield self.store.get_publicised_groups_for_user(
|
||||||
|
uid
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue({"users": results})
|
@ -1,5 +1,6 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2014 - 2016 OpenMarket Ltd
|
# Copyright 2014 - 2016 OpenMarket Ltd
|
||||||
|
# Copyright 2017 New Vector Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -12,7 +13,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from synapse.events import spamcheck
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
@ -26,6 +26,7 @@ from synapse.types import (
|
|||||||
from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter
|
from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter
|
||||||
from synapse.util.logcontext import preserve_fn
|
from synapse.util.logcontext import preserve_fn
|
||||||
from synapse.util.metrics import measure_func
|
from synapse.util.metrics import measure_func
|
||||||
|
from synapse.util.frozenutils import unfreeze
|
||||||
from synapse.visibility import filter_events_for_client
|
from synapse.visibility import filter_events_for_client
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
@ -47,6 +48,7 @@ class MessageHandler(BaseHandler):
|
|||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.validator = EventValidator()
|
self.validator = EventValidator()
|
||||||
|
self.profile_handler = hs.get_profile_handler()
|
||||||
|
|
||||||
self.pagination_lock = ReadWriteLock()
|
self.pagination_lock = ReadWriteLock()
|
||||||
|
|
||||||
@ -58,6 +60,8 @@ class MessageHandler(BaseHandler):
|
|||||||
|
|
||||||
self.action_generator = hs.get_action_generator()
|
self.action_generator = hs.get_action_generator()
|
||||||
|
|
||||||
|
self.spam_checker = hs.get_spam_checker()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def purge_history(self, room_id, event_id):
|
def purge_history(self, room_id, event_id):
|
||||||
event = yield self.store.get_event(event_id)
|
event = yield self.store.get_event(event_id)
|
||||||
@ -210,7 +214,7 @@ class MessageHandler(BaseHandler):
|
|||||||
|
|
||||||
if membership in {Membership.JOIN, Membership.INVITE}:
|
if membership in {Membership.JOIN, Membership.INVITE}:
|
||||||
# If event doesn't include a display name, add one.
|
# If event doesn't include a display name, add one.
|
||||||
profile = self.hs.get_handlers().profile_handler
|
profile = self.profile_handler
|
||||||
content = builder.content
|
content = builder.content
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -322,9 +326,12 @@ class MessageHandler(BaseHandler):
|
|||||||
txn_id=txn_id
|
txn_id=txn_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if spamcheck.check_event_for_spam(event):
|
spam_error = self.spam_checker.check_event_for_spam(event)
|
||||||
|
if spam_error:
|
||||||
|
if not isinstance(spam_error, basestring):
|
||||||
|
spam_error = "Spam is not permitted here"
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
403, "Spam is not permitted here", Codes.FORBIDDEN
|
403, spam_error, Codes.FORBIDDEN
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.send_nonmember_event(
|
yield self.send_nonmember_event(
|
||||||
@ -418,6 +425,51 @@ class MessageHandler(BaseHandler):
|
|||||||
[serialize_event(c, now) for c in room_state.values()]
|
[serialize_event(c, now) for c in room_state.values()]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_joined_members(self, requester, room_id):
|
||||||
|
"""Get all the joined members in the room and their profile information.
|
||||||
|
|
||||||
|
If the user has left the room return the state events from when they left.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
requester(Requester): The user requesting state events.
|
||||||
|
room_id(str): The room ID to get all state events from.
|
||||||
|
Returns:
|
||||||
|
A dict of user_id to profile info
|
||||||
|
"""
|
||||||
|
user_id = requester.user.to_string()
|
||||||
|
if not requester.app_service:
|
||||||
|
# We check AS auth after fetching the room membership, as it
|
||||||
|
# requires us to pull out all joined members anyway.
|
||||||
|
membership, _ = yield self._check_in_room_or_world_readable(
|
||||||
|
room_id, user_id
|
||||||
|
)
|
||||||
|
if membership != Membership.JOIN:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Getting joined members after leaving is not implemented"
|
||||||
|
)
|
||||||
|
|
||||||
|
users_with_profile = yield self.state.get_current_user_in_room(room_id)
|
||||||
|
|
||||||
|
# If this is an AS, double check that they are allowed to see the members.
|
||||||
|
# This can either be because the AS user is in the room or becuase there
|
||||||
|
# is a user in the room that the AS is "interested in"
|
||||||
|
if requester.app_service and user_id not in users_with_profile:
|
||||||
|
for uid in users_with_profile:
|
||||||
|
if requester.app_service.is_interested_in_user(uid):
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Loop fell through, AS has no interested users in room
|
||||||
|
raise AuthError(403, "Appservice not in room")
|
||||||
|
|
||||||
|
defer.returnValue({
|
||||||
|
user_id: {
|
||||||
|
"avatar_url": profile.avatar_url,
|
||||||
|
"display_name": profile.display_name,
|
||||||
|
}
|
||||||
|
for user_id, profile in users_with_profile.iteritems()
|
||||||
|
})
|
||||||
|
|
||||||
@measure_func("_create_new_client_event")
|
@measure_func("_create_new_client_event")
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _create_new_client_event(self, builder, requester=None, prev_event_ids=None):
|
def _create_new_client_event(self, builder, requester=None, prev_event_ids=None):
|
||||||
@ -509,7 +561,7 @@ class MessageHandler(BaseHandler):
|
|||||||
|
|
||||||
# Ensure that we can round trip before trying to persist in db
|
# Ensure that we can round trip before trying to persist in db
|
||||||
try:
|
try:
|
||||||
dump = ujson.dumps(event.content)
|
dump = ujson.dumps(unfreeze(event.content))
|
||||||
ujson.loads(dump)
|
ujson.loads(dump)
|
||||||
except:
|
except:
|
||||||
logger.exception("Failed to encode content: %r", event.content)
|
logger.exception("Failed to encode content: %r", event.content)
|
||||||
|
@ -19,14 +19,15 @@ from twisted.internet import defer
|
|||||||
|
|
||||||
import synapse.types
|
import synapse.types
|
||||||
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
|
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID, get_domain_from_id
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ProfileHandler(BaseHandler):
|
class ProfileHandler(BaseHandler):
|
||||||
|
PROFILE_UPDATE_MS = 60 * 1000
|
||||||
|
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(ProfileHandler, self).__init__(hs)
|
super(ProfileHandler, self).__init__(hs)
|
||||||
@ -36,6 +37,63 @@ class ProfileHandler(BaseHandler):
|
|||||||
"profile", self.on_profile_query
|
"profile", self.on_profile_query
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.clock.looping_call(self._update_remote_profile_cache, self.PROFILE_UPDATE_MS)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_profile(self, user_id):
|
||||||
|
target_user = UserID.from_string(user_id)
|
||||||
|
if self.hs.is_mine(target_user):
|
||||||
|
displayname = yield self.store.get_profile_displayname(
|
||||||
|
target_user.localpart
|
||||||
|
)
|
||||||
|
avatar_url = yield self.store.get_profile_avatar_url(
|
||||||
|
target_user.localpart
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue({
|
||||||
|
"displayname": displayname,
|
||||||
|
"avatar_url": avatar_url,
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
result = yield self.federation.make_query(
|
||||||
|
destination=target_user.domain,
|
||||||
|
query_type="profile",
|
||||||
|
args={
|
||||||
|
"user_id": user_id,
|
||||||
|
},
|
||||||
|
ignore_backoff=True,
|
||||||
|
)
|
||||||
|
defer.returnValue(result)
|
||||||
|
except CodeMessageException as e:
|
||||||
|
if e.code != 404:
|
||||||
|
logger.exception("Failed to get displayname")
|
||||||
|
|
||||||
|
raise
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_profile_from_cache(self, user_id):
|
||||||
|
"""Get the profile information from our local cache. If the user is
|
||||||
|
ours then the profile information will always be corect. Otherwise,
|
||||||
|
it may be out of date/missing.
|
||||||
|
"""
|
||||||
|
target_user = UserID.from_string(user_id)
|
||||||
|
if self.hs.is_mine(target_user):
|
||||||
|
displayname = yield self.store.get_profile_displayname(
|
||||||
|
target_user.localpart
|
||||||
|
)
|
||||||
|
avatar_url = yield self.store.get_profile_avatar_url(
|
||||||
|
target_user.localpart
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue({
|
||||||
|
"displayname": displayname,
|
||||||
|
"avatar_url": avatar_url,
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
profile = yield self.store.get_from_remote_profile_cache(user_id)
|
||||||
|
defer.returnValue(profile or {})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_displayname(self, target_user):
|
def get_displayname(self, target_user):
|
||||||
if self.hs.is_mine(target_user):
|
if self.hs.is_mine(target_user):
|
||||||
@ -182,3 +240,44 @@ class ProfileHandler(BaseHandler):
|
|||||||
"Failed to update join event for room %s - %s",
|
"Failed to update join event for room %s - %s",
|
||||||
room_id, str(e.message)
|
room_id, str(e.message)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _update_remote_profile_cache(self):
|
||||||
|
"""Called periodically to check profiles of remote users we haven't
|
||||||
|
checked in a while.
|
||||||
|
"""
|
||||||
|
entries = yield self.store.get_remote_profile_cache_entries_that_expire(
|
||||||
|
last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS
|
||||||
|
)
|
||||||
|
|
||||||
|
for user_id, displayname, avatar_url in entries:
|
||||||
|
is_subscribed = yield self.store.is_subscribed_remote_profile_for_user(
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
if not is_subscribed:
|
||||||
|
yield self.store.maybe_delete_remote_profile_cache(user_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
profile = yield self.federation.make_query(
|
||||||
|
destination=get_domain_from_id(user_id),
|
||||||
|
query_type="profile",
|
||||||
|
args={
|
||||||
|
"user_id": user_id,
|
||||||
|
},
|
||||||
|
ignore_backoff=True,
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
logger.exception("Failed to get avatar_url")
|
||||||
|
|
||||||
|
yield self.store.update_remote_profile_cache(
|
||||||
|
user_id, displayname, avatar_url
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_name = profile.get("displayname")
|
||||||
|
new_avatar = profile.get("avatar_url")
|
||||||
|
|
||||||
|
# We always hit update to update the last_check timestamp
|
||||||
|
yield self.store.update_remote_profile_cache(
|
||||||
|
user_id, new_name, new_avatar
|
||||||
|
)
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from synapse.util import logcontext
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
@ -59,6 +60,8 @@ class ReceiptsHandler(BaseHandler):
|
|||||||
is_new = yield self._handle_new_receipts([receipt])
|
is_new = yield self._handle_new_receipts([receipt])
|
||||||
|
|
||||||
if is_new:
|
if is_new:
|
||||||
|
# fire off a process in the background to send the receipt to
|
||||||
|
# remote servers
|
||||||
self._push_remotes([receipt])
|
self._push_remotes([receipt])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -126,6 +129,7 @@ class ReceiptsHandler(BaseHandler):
|
|||||||
|
|
||||||
defer.returnValue(True)
|
defer.returnValue(True)
|
||||||
|
|
||||||
|
@logcontext.preserve_fn # caller should not yield on this
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _push_remotes(self, receipts):
|
def _push_remotes(self, receipts):
|
||||||
"""Given a list of receipts, works out which remote servers should be
|
"""Given a list of receipts, works out which remote servers should be
|
||||||
|
@ -36,6 +36,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
super(RegistrationHandler, self).__init__(hs)
|
super(RegistrationHandler, self).__init__(hs)
|
||||||
|
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
self.profile_handler = hs.get_profile_handler()
|
||||||
self.captcha_client = CaptchaServerHttpClient(hs)
|
self.captcha_client = CaptchaServerHttpClient(hs)
|
||||||
|
|
||||||
self._next_generated_user_id = None
|
self._next_generated_user_id = None
|
||||||
@ -423,8 +424,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
|
|
||||||
if displayname is not None:
|
if displayname is not None:
|
||||||
logger.info("setting user display name: %s -> %s", user_id, displayname)
|
logger.info("setting user display name: %s -> %s", user_id, displayname)
|
||||||
profile_handler = self.hs.get_handlers().profile_handler
|
yield self.profile_handler.set_displayname(
|
||||||
yield profile_handler.set_displayname(
|
|
||||||
user, requester, displayname, by_admin=True,
|
user, requester, displayname, by_admin=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -60,6 +60,11 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(RoomCreationHandler, self).__init__(hs)
|
||||||
|
|
||||||
|
self.spam_checker = hs.get_spam_checker()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def create_room(self, requester, config, ratelimit=True):
|
def create_room(self, requester, config, ratelimit=True):
|
||||||
""" Creates a new room.
|
""" Creates a new room.
|
||||||
@ -75,6 +80,9 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
"""
|
"""
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
if not self.spam_checker.user_may_create_room(user_id):
|
||||||
|
raise SynapseError(403, "You are not permitted to create rooms")
|
||||||
|
|
||||||
if ratelimit:
|
if ratelimit:
|
||||||
yield self.ratelimit(requester)
|
yield self.ratelimit(requester)
|
||||||
|
|
||||||
|
@ -276,13 +276,14 @@ class RoomListHandler(BaseHandler):
|
|||||||
# We've already got enough, so lets just drop it.
|
# We've already got enough, so lets just drop it.
|
||||||
return
|
return
|
||||||
|
|
||||||
result = yield self._generate_room_entry(room_id, num_joined_users)
|
result = yield self.generate_room_entry(room_id, num_joined_users)
|
||||||
|
|
||||||
if result and _matches_room_entry(result, search_filter):
|
if result and _matches_room_entry(result, search_filter):
|
||||||
chunk.append(result)
|
chunk.append(result)
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=1, cache_context=True)
|
@cachedInlineCallbacks(num_args=1, cache_context=True)
|
||||||
def _generate_room_entry(self, room_id, num_joined_users, cache_context):
|
def generate_room_entry(self, room_id, num_joined_users, cache_context,
|
||||||
|
with_alias=True, allow_private=False):
|
||||||
"""Returns the entry for a room
|
"""Returns the entry for a room
|
||||||
"""
|
"""
|
||||||
result = {
|
result = {
|
||||||
@ -316,14 +317,15 @@ class RoomListHandler(BaseHandler):
|
|||||||
join_rules_event = current_state.get((EventTypes.JoinRules, ""))
|
join_rules_event = current_state.get((EventTypes.JoinRules, ""))
|
||||||
if join_rules_event:
|
if join_rules_event:
|
||||||
join_rule = join_rules_event.content.get("join_rule", None)
|
join_rule = join_rules_event.content.get("join_rule", None)
|
||||||
if join_rule and join_rule != JoinRules.PUBLIC:
|
if not allow_private and join_rule and join_rule != JoinRules.PUBLIC:
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
|
|
||||||
aliases = yield self.store.get_aliases_for_room(
|
if with_alias:
|
||||||
room_id, on_invalidate=cache_context.invalidate
|
aliases = yield self.store.get_aliases_for_room(
|
||||||
)
|
room_id, on_invalidate=cache_context.invalidate
|
||||||
if aliases:
|
)
|
||||||
result["aliases"] = aliases
|
if aliases:
|
||||||
|
result["aliases"] = aliases
|
||||||
|
|
||||||
name_event = yield current_state.get((EventTypes.Name, ""))
|
name_event = yield current_state.get((EventTypes.Name, ""))
|
||||||
if name_event:
|
if name_event:
|
||||||
|
@ -45,9 +45,12 @@ class RoomMemberHandler(BaseHandler):
|
|||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(RoomMemberHandler, self).__init__(hs)
|
super(RoomMemberHandler, self).__init__(hs)
|
||||||
|
|
||||||
|
self.profile_handler = hs.get_profile_handler()
|
||||||
|
|
||||||
self.member_linearizer = Linearizer(name="member")
|
self.member_linearizer = Linearizer(name="member")
|
||||||
|
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
self.spam_checker = hs.get_spam_checker()
|
||||||
|
|
||||||
self.distributor = hs.get_distributor()
|
self.distributor = hs.get_distributor()
|
||||||
self.distributor.declare("user_joined_room")
|
self.distributor.declare("user_joined_room")
|
||||||
@ -210,12 +213,26 @@ class RoomMemberHandler(BaseHandler):
|
|||||||
if is_blocked:
|
if is_blocked:
|
||||||
raise SynapseError(403, "This room has been blocked on this server")
|
raise SynapseError(403, "This room has been blocked on this server")
|
||||||
|
|
||||||
if (effective_membership_state == "invite" and
|
if effective_membership_state == "invite":
|
||||||
self.hs.config.block_non_admin_invites):
|
block_invite = False
|
||||||
is_requester_admin = yield self.auth.is_server_admin(
|
is_requester_admin = yield self.auth.is_server_admin(
|
||||||
requester.user,
|
requester.user,
|
||||||
)
|
)
|
||||||
if not is_requester_admin:
|
if not is_requester_admin:
|
||||||
|
if self.hs.config.block_non_admin_invites:
|
||||||
|
logger.info(
|
||||||
|
"Blocking invite: user is not admin and non-admin "
|
||||||
|
"invites disabled"
|
||||||
|
)
|
||||||
|
block_invite = True
|
||||||
|
|
||||||
|
if not self.spam_checker.user_may_invite(
|
||||||
|
requester.user.to_string(), target.to_string(), room_id,
|
||||||
|
):
|
||||||
|
logger.info("Blocking invite due to spam checker")
|
||||||
|
block_invite = True
|
||||||
|
|
||||||
|
if block_invite:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
403, "Invites have been disabled on this server",
|
403, "Invites have been disabled on this server",
|
||||||
)
|
)
|
||||||
@ -267,7 +284,7 @@ class RoomMemberHandler(BaseHandler):
|
|||||||
|
|
||||||
content["membership"] = Membership.JOIN
|
content["membership"] = Membership.JOIN
|
||||||
|
|
||||||
profile = self.hs.get_handlers().profile_handler
|
profile = self.profile_handler
|
||||||
if not content_specified:
|
if not content_specified:
|
||||||
content["displayname"] = yield profile.get_displayname(target)
|
content["displayname"] = yield profile.get_displayname(target)
|
||||||
content["avatar_url"] = yield profile.get_avatar_url(target)
|
content["avatar_url"] = yield profile.get_avatar_url(target)
|
||||||
|
@ -108,6 +108,17 @@ class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class GroupsSyncResult(collections.namedtuple("GroupsSyncResult", [
|
||||||
|
"join",
|
||||||
|
"invite",
|
||||||
|
"leave",
|
||||||
|
])):
|
||||||
|
__slots__ = []
|
||||||
|
|
||||||
|
def __nonzero__(self):
|
||||||
|
return bool(self.join or self.invite or self.leave)
|
||||||
|
|
||||||
|
|
||||||
class DeviceLists(collections.namedtuple("DeviceLists", [
|
class DeviceLists(collections.namedtuple("DeviceLists", [
|
||||||
"changed", # list of user_ids whose devices may have changed
|
"changed", # list of user_ids whose devices may have changed
|
||||||
"left", # list of user_ids whose devices we no longer track
|
"left", # list of user_ids whose devices we no longer track
|
||||||
@ -129,6 +140,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
|
|||||||
"device_lists", # List of user_ids whose devices have chanegd
|
"device_lists", # List of user_ids whose devices have chanegd
|
||||||
"device_one_time_keys_count", # Dict of algorithm to count for one time keys
|
"device_one_time_keys_count", # Dict of algorithm to count for one time keys
|
||||||
# for this device
|
# for this device
|
||||||
|
"groups",
|
||||||
])):
|
])):
|
||||||
__slots__ = []
|
__slots__ = []
|
||||||
|
|
||||||
@ -144,7 +156,8 @@ class SyncResult(collections.namedtuple("SyncResult", [
|
|||||||
self.archived or
|
self.archived or
|
||||||
self.account_data or
|
self.account_data or
|
||||||
self.to_device or
|
self.to_device or
|
||||||
self.device_lists
|
self.device_lists or
|
||||||
|
self.groups
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -595,6 +608,8 @@ class SyncHandler(object):
|
|||||||
user_id, device_id
|
user_id, device_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
yield self._generate_sync_entry_for_groups(sync_result_builder)
|
||||||
|
|
||||||
defer.returnValue(SyncResult(
|
defer.returnValue(SyncResult(
|
||||||
presence=sync_result_builder.presence,
|
presence=sync_result_builder.presence,
|
||||||
account_data=sync_result_builder.account_data,
|
account_data=sync_result_builder.account_data,
|
||||||
@ -603,10 +618,57 @@ class SyncHandler(object):
|
|||||||
archived=sync_result_builder.archived,
|
archived=sync_result_builder.archived,
|
||||||
to_device=sync_result_builder.to_device,
|
to_device=sync_result_builder.to_device,
|
||||||
device_lists=device_lists,
|
device_lists=device_lists,
|
||||||
|
groups=sync_result_builder.groups,
|
||||||
device_one_time_keys_count=one_time_key_counts,
|
device_one_time_keys_count=one_time_key_counts,
|
||||||
next_batch=sync_result_builder.now_token,
|
next_batch=sync_result_builder.now_token,
|
||||||
))
|
))
|
||||||
|
|
||||||
|
@measure_func("_generate_sync_entry_for_groups")
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _generate_sync_entry_for_groups(self, sync_result_builder):
|
||||||
|
user_id = sync_result_builder.sync_config.user.to_string()
|
||||||
|
since_token = sync_result_builder.since_token
|
||||||
|
now_token = sync_result_builder.now_token
|
||||||
|
|
||||||
|
if since_token and since_token.groups_key:
|
||||||
|
results = yield self.store.get_groups_changes_for_user(
|
||||||
|
user_id, since_token.groups_key, now_token.groups_key,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
results = yield self.store.get_all_groups_for_user(
|
||||||
|
user_id, now_token.groups_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
invited = {}
|
||||||
|
joined = {}
|
||||||
|
left = {}
|
||||||
|
for result in results:
|
||||||
|
membership = result["membership"]
|
||||||
|
group_id = result["group_id"]
|
||||||
|
gtype = result["type"]
|
||||||
|
content = result["content"]
|
||||||
|
|
||||||
|
if membership == "join":
|
||||||
|
if gtype == "membership":
|
||||||
|
# TODO: Add profile
|
||||||
|
content.pop("membership", None)
|
||||||
|
joined[group_id] = content["content"]
|
||||||
|
else:
|
||||||
|
joined.setdefault(group_id, {})[gtype] = content
|
||||||
|
elif membership == "invite":
|
||||||
|
if gtype == "membership":
|
||||||
|
content.pop("membership", None)
|
||||||
|
invited[group_id] = content["content"]
|
||||||
|
else:
|
||||||
|
if gtype == "membership":
|
||||||
|
left[group_id] = content["content"]
|
||||||
|
|
||||||
|
sync_result_builder.groups = GroupsSyncResult(
|
||||||
|
join=joined,
|
||||||
|
invite=invited,
|
||||||
|
leave=left,
|
||||||
|
)
|
||||||
|
|
||||||
@measure_func("_generate_sync_entry_for_device_list")
|
@measure_func("_generate_sync_entry_for_device_list")
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _generate_sync_entry_for_device_list(self, sync_result_builder,
|
def _generate_sync_entry_for_device_list(self, sync_result_builder,
|
||||||
@ -1368,6 +1430,7 @@ class SyncResultBuilder(object):
|
|||||||
self.invited = []
|
self.invited = []
|
||||||
self.archived = []
|
self.archived = []
|
||||||
self.device = []
|
self.device = []
|
||||||
|
self.groups = None
|
||||||
self.to_device = []
|
self.to_device = []
|
||||||
|
|
||||||
|
|
||||||
|
@ -354,16 +354,28 @@ def _get_hosts_for_srv_record(dns_client, host):
|
|||||||
|
|
||||||
return res[0]
|
return res[0]
|
||||||
|
|
||||||
def eb(res):
|
def eb(res, record_type):
|
||||||
res.trap(DNSNameError)
|
if res.check(DNSNameError):
|
||||||
return []
|
return []
|
||||||
|
logger.warn("Error looking up %s for %s: %s",
|
||||||
|
record_type, host, res, res.value)
|
||||||
|
return res
|
||||||
|
|
||||||
# no logcontexts here, so we can safely fire these off and gatherResults
|
# no logcontexts here, so we can safely fire these off and gatherResults
|
||||||
d1 = dns_client.lookupAddress(host).addCallbacks(cb, eb)
|
d1 = dns_client.lookupAddress(host).addCallbacks(cb, eb)
|
||||||
d2 = dns_client.lookupIPV6Address(host).addCallbacks(cb, eb)
|
d2 = dns_client.lookupIPV6Address(host).addCallbacks(cb, eb)
|
||||||
results = yield defer.gatherResults([d1, d2], consumeErrors=True)
|
results = yield defer.DeferredList(
|
||||||
|
[d1, d2], consumeErrors=True)
|
||||||
|
|
||||||
|
# if all of the lookups failed, raise an exception rather than blowing out
|
||||||
|
# the cache with an empty result.
|
||||||
|
if results and all(s == defer.FAILURE for (s, _) in results):
|
||||||
|
defer.returnValue(results[0][1])
|
||||||
|
|
||||||
|
for (success, result) in results:
|
||||||
|
if success == defer.FAILURE:
|
||||||
|
continue
|
||||||
|
|
||||||
for result in results:
|
|
||||||
for answer in result:
|
for answer in result:
|
||||||
if not answer.payload:
|
if not answer.payload:
|
||||||
continue
|
continue
|
||||||
|
@ -204,18 +204,15 @@ class MatrixFederationHttpClient(object):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"{%s} Sending request failed to %s: %s %s: %s - %s",
|
"{%s} Sending request failed to %s: %s %s: %s",
|
||||||
txn_id,
|
txn_id,
|
||||||
destination,
|
destination,
|
||||||
method,
|
method,
|
||||||
url_bytes,
|
url_bytes,
|
||||||
type(e).__name__,
|
|
||||||
_flatten_response_never_received(e),
|
_flatten_response_never_received(e),
|
||||||
)
|
)
|
||||||
|
|
||||||
log_result = "%s - %s" % (
|
log_result = _flatten_response_never_received(e)
|
||||||
type(e).__name__, _flatten_response_never_received(e),
|
|
||||||
)
|
|
||||||
|
|
||||||
if retries_left and not timeout:
|
if retries_left and not timeout:
|
||||||
if long_retries:
|
if long_retries:
|
||||||
@ -347,7 +344,7 @@ class MatrixFederationHttpClient(object):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def post_json(self, destination, path, data={}, long_retries=False,
|
def post_json(self, destination, path, data={}, long_retries=False,
|
||||||
timeout=None, ignore_backoff=False):
|
timeout=None, ignore_backoff=False, args={}):
|
||||||
""" Sends the specifed json data using POST
|
""" Sends the specifed json data using POST
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -383,6 +380,7 @@ class MatrixFederationHttpClient(object):
|
|||||||
destination,
|
destination,
|
||||||
"POST",
|
"POST",
|
||||||
path,
|
path,
|
||||||
|
query_bytes=encode_query_args(args),
|
||||||
body_callback=body_callback,
|
body_callback=body_callback,
|
||||||
headers_dict={"Content-Type": ["application/json"]},
|
headers_dict={"Content-Type": ["application/json"]},
|
||||||
long_retries=long_retries,
|
long_retries=long_retries,
|
||||||
@ -427,13 +425,6 @@ class MatrixFederationHttpClient(object):
|
|||||||
"""
|
"""
|
||||||
logger.debug("get_json args: %s", args)
|
logger.debug("get_json args: %s", args)
|
||||||
|
|
||||||
encoded_args = {}
|
|
||||||
for k, vs in args.items():
|
|
||||||
if isinstance(vs, basestring):
|
|
||||||
vs = [vs]
|
|
||||||
encoded_args[k] = [v.encode("UTF-8") for v in vs]
|
|
||||||
|
|
||||||
query_bytes = urllib.urlencode(encoded_args, True)
|
|
||||||
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
|
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
|
||||||
|
|
||||||
def body_callback(method, url_bytes, headers_dict):
|
def body_callback(method, url_bytes, headers_dict):
|
||||||
@ -444,7 +435,7 @@ class MatrixFederationHttpClient(object):
|
|||||||
destination,
|
destination,
|
||||||
"GET",
|
"GET",
|
||||||
path,
|
path,
|
||||||
query_bytes=query_bytes,
|
query_bytes=encode_query_args(args),
|
||||||
body_callback=body_callback,
|
body_callback=body_callback,
|
||||||
retry_on_dns_fail=retry_on_dns_fail,
|
retry_on_dns_fail=retry_on_dns_fail,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
@ -460,6 +451,52 @@ class MatrixFederationHttpClient(object):
|
|||||||
|
|
||||||
defer.returnValue(json.loads(body))
|
defer.returnValue(json.loads(body))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def delete_json(self, destination, path, long_retries=False,
|
||||||
|
timeout=None, ignore_backoff=False, args={}):
|
||||||
|
"""Send a DELETE request to the remote expecting some json response
|
||||||
|
|
||||||
|
Args:
|
||||||
|
destination (str): The remote server to send the HTTP request
|
||||||
|
to.
|
||||||
|
path (str): The HTTP path.
|
||||||
|
long_retries (bool): A boolean that indicates whether we should
|
||||||
|
retry for a short or long time.
|
||||||
|
timeout(int): How long to try (in ms) the destination for before
|
||||||
|
giving up. None indicates no timeout.
|
||||||
|
ignore_backoff (bool): true to ignore the historical backoff data and
|
||||||
|
try the request anyway.
|
||||||
|
Returns:
|
||||||
|
Deferred: Succeeds when we get a 2xx HTTP response. The result
|
||||||
|
will be the decoded JSON body.
|
||||||
|
|
||||||
|
Fails with ``HTTPRequestException`` if we get an HTTP response
|
||||||
|
code >= 300.
|
||||||
|
|
||||||
|
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||||
|
to retry this server.
|
||||||
|
"""
|
||||||
|
|
||||||
|
response = yield self._request(
|
||||||
|
destination,
|
||||||
|
"DELETE",
|
||||||
|
path,
|
||||||
|
query_bytes=encode_query_args(args),
|
||||||
|
headers_dict={"Content-Type": ["application/json"]},
|
||||||
|
long_retries=long_retries,
|
||||||
|
timeout=timeout,
|
||||||
|
ignore_backoff=ignore_backoff,
|
||||||
|
)
|
||||||
|
|
||||||
|
if 200 <= response.code < 300:
|
||||||
|
# We need to update the transactions table to say it was sent?
|
||||||
|
check_content_type_is_json(response.headers)
|
||||||
|
|
||||||
|
with logcontext.PreserveLoggingContext():
|
||||||
|
body = yield readBody(response)
|
||||||
|
|
||||||
|
defer.returnValue(json.loads(body))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_file(self, destination, path, output_stream, args={},
|
def get_file(self, destination, path, output_stream, args={},
|
||||||
retry_on_dns_fail=True, max_size=None,
|
retry_on_dns_fail=True, max_size=None,
|
||||||
@ -578,12 +615,14 @@ class _JsonProducer(object):
|
|||||||
|
|
||||||
def _flatten_response_never_received(e):
|
def _flatten_response_never_received(e):
|
||||||
if hasattr(e, "reasons"):
|
if hasattr(e, "reasons"):
|
||||||
return ", ".join(
|
reasons = ", ".join(
|
||||||
_flatten_response_never_received(f.value)
|
_flatten_response_never_received(f.value)
|
||||||
for f in e.reasons
|
for f in e.reasons
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return "%s:[%s]" % (type(e).__name__, reasons)
|
||||||
else:
|
else:
|
||||||
return "%s: %s" % (type(e).__name__, e.message,)
|
return repr(e)
|
||||||
|
|
||||||
|
|
||||||
def check_content_type_is_json(headers):
|
def check_content_type_is_json(headers):
|
||||||
@ -610,3 +649,15 @@ def check_content_type_is_json(headers):
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Content-Type not application/json: was '%s'" % c_type
|
"Content-Type not application/json: was '%s'" % c_type
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def encode_query_args(args):
|
||||||
|
encoded_args = {}
|
||||||
|
for k, vs in args.items():
|
||||||
|
if isinstance(vs, basestring):
|
||||||
|
vs = [vs]
|
||||||
|
encoded_args[k] = [v.encode("UTF-8") for v in vs]
|
||||||
|
|
||||||
|
query_bytes = urllib.urlencode(encoded_args, True)
|
||||||
|
|
||||||
|
return query_bytes
|
||||||
|
@ -145,7 +145,9 @@ def wrap_request_handler(request_handler, include_metrics=False):
|
|||||||
"error": "Internal server error",
|
"error": "Internal server error",
|
||||||
"errcode": Codes.UNKNOWN,
|
"errcode": Codes.UNKNOWN,
|
||||||
},
|
},
|
||||||
send_cors=True
|
send_cors=True,
|
||||||
|
pretty_print=_request_user_agent_is_curl(request),
|
||||||
|
version_string=self.version_string,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
try:
|
try:
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
# Copyright 2015, 2016 OpenMarket Ltd
|
# Copyright 2015, 2016 OpenMarket Ltd
|
||||||
|
# Copyright 2017 New Vector Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -238,6 +239,28 @@ BASE_APPEND_OVERRIDE_RULES = [
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
'rule_id': 'global/override/.m.rule.roomnotif',
|
||||||
|
'conditions': [
|
||||||
|
{
|
||||||
|
'kind': 'event_match',
|
||||||
|
'key': 'content.body',
|
||||||
|
'pattern': '@room',
|
||||||
|
'_id': '_roomnotif_content',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'kind': 'sender_notification_permission',
|
||||||
|
'key': 'room',
|
||||||
|
'_id': '_roomnotif_pl',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
'actions': [
|
||||||
|
'notify', {
|
||||||
|
'set_tweak': 'highlight',
|
||||||
|
'value': True,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2015 OpenMarket Ltd
|
# Copyright 2015 OpenMarket Ltd
|
||||||
|
# Copyright 2017 New Vector Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -19,11 +20,13 @@ from twisted.internet import defer
|
|||||||
|
|
||||||
from .push_rule_evaluator import PushRuleEvaluatorForEvent
|
from .push_rule_evaluator import PushRuleEvaluatorForEvent
|
||||||
|
|
||||||
|
from synapse.event_auth import get_user_power_level
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.metrics import get_metrics_for
|
from synapse.metrics import get_metrics_for
|
||||||
from synapse.util.caches import metrics as cache_metrics
|
from synapse.util.caches import metrics as cache_metrics
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async import Linearizer
|
||||||
|
from synapse.state import POWER_KEY
|
||||||
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
@ -59,6 +62,7 @@ class BulkPushRuleEvaluator(object):
|
|||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
self.room_push_rule_cache_metrics = cache_metrics.register_cache(
|
self.room_push_rule_cache_metrics = cache_metrics.register_cache(
|
||||||
"cache",
|
"cache",
|
||||||
@ -108,6 +112,29 @@ class BulkPushRuleEvaluator(object):
|
|||||||
self.room_push_rule_cache_metrics,
|
self.room_push_rule_cache_metrics,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _get_power_levels_and_sender_level(self, event, context):
|
||||||
|
pl_event_id = context.prev_state_ids.get(POWER_KEY)
|
||||||
|
if pl_event_id:
|
||||||
|
# fastpath: if there's a power level event, that's all we need, and
|
||||||
|
# not having a power level event is an extreme edge case
|
||||||
|
pl_event = yield self.store.get_event(pl_event_id)
|
||||||
|
auth_events = {POWER_KEY: pl_event}
|
||||||
|
else:
|
||||||
|
auth_events_ids = yield self.auth.compute_auth_events(
|
||||||
|
event, context.prev_state_ids, for_verification=False,
|
||||||
|
)
|
||||||
|
auth_events = yield self.store.get_events(auth_events_ids)
|
||||||
|
auth_events = {
|
||||||
|
(e.type, e.state_key): e for e in auth_events.itervalues()
|
||||||
|
}
|
||||||
|
|
||||||
|
sender_level = get_user_power_level(event.sender, auth_events)
|
||||||
|
|
||||||
|
pl_event = auth_events.get(POWER_KEY)
|
||||||
|
|
||||||
|
defer.returnValue((pl_event.content if pl_event else {}, sender_level))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def action_for_event_by_user(self, event, context):
|
def action_for_event_by_user(self, event, context):
|
||||||
"""Given an event and context, evaluate the push rules and return
|
"""Given an event and context, evaluate the push rules and return
|
||||||
@ -123,7 +150,13 @@ class BulkPushRuleEvaluator(object):
|
|||||||
event, context
|
event, context
|
||||||
)
|
)
|
||||||
|
|
||||||
evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
|
(power_levels, sender_power_level) = (
|
||||||
|
yield self._get_power_levels_and_sender_level(event, context)
|
||||||
|
)
|
||||||
|
|
||||||
|
evaluator = PushRuleEvaluatorForEvent(
|
||||||
|
event, len(room_members), sender_power_level, power_levels,
|
||||||
|
)
|
||||||
|
|
||||||
condition_cache = {}
|
condition_cache = {}
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2015, 2016 OpenMarket Ltd
|
# Copyright 2015, 2016 OpenMarket Ltd
|
||||||
|
# Copyright 2017 New Vector Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -29,6 +30,21 @@ INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
|
|||||||
|
|
||||||
|
|
||||||
def _room_member_count(ev, condition, room_member_count):
|
def _room_member_count(ev, condition, room_member_count):
|
||||||
|
return _test_ineq_condition(condition, room_member_count)
|
||||||
|
|
||||||
|
|
||||||
|
def _sender_notification_permission(ev, condition, sender_power_level, power_levels):
|
||||||
|
notif_level_key = condition.get('key')
|
||||||
|
if notif_level_key is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
notif_levels = power_levels.get('notifications', {})
|
||||||
|
room_notif_level = notif_levels.get(notif_level_key, 50)
|
||||||
|
|
||||||
|
return sender_power_level >= room_notif_level
|
||||||
|
|
||||||
|
|
||||||
|
def _test_ineq_condition(condition, number):
|
||||||
if 'is' not in condition:
|
if 'is' not in condition:
|
||||||
return False
|
return False
|
||||||
m = INEQUALITY_EXPR.match(condition['is'])
|
m = INEQUALITY_EXPR.match(condition['is'])
|
||||||
@ -41,15 +57,15 @@ def _room_member_count(ev, condition, room_member_count):
|
|||||||
rhs = int(rhs)
|
rhs = int(rhs)
|
||||||
|
|
||||||
if ineq == '' or ineq == '==':
|
if ineq == '' or ineq == '==':
|
||||||
return room_member_count == rhs
|
return number == rhs
|
||||||
elif ineq == '<':
|
elif ineq == '<':
|
||||||
return room_member_count < rhs
|
return number < rhs
|
||||||
elif ineq == '>':
|
elif ineq == '>':
|
||||||
return room_member_count > rhs
|
return number > rhs
|
||||||
elif ineq == '>=':
|
elif ineq == '>=':
|
||||||
return room_member_count >= rhs
|
return number >= rhs
|
||||||
elif ineq == '<=':
|
elif ineq == '<=':
|
||||||
return room_member_count <= rhs
|
return number <= rhs
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -65,9 +81,11 @@ def tweaks_for_actions(actions):
|
|||||||
|
|
||||||
|
|
||||||
class PushRuleEvaluatorForEvent(object):
|
class PushRuleEvaluatorForEvent(object):
|
||||||
def __init__(self, event, room_member_count):
|
def __init__(self, event, room_member_count, sender_power_level, power_levels):
|
||||||
self._event = event
|
self._event = event
|
||||||
self._room_member_count = room_member_count
|
self._room_member_count = room_member_count
|
||||||
|
self._sender_power_level = sender_power_level
|
||||||
|
self._power_levels = power_levels
|
||||||
|
|
||||||
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
|
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
|
||||||
self._value_cache = _flatten_dict(event)
|
self._value_cache = _flatten_dict(event)
|
||||||
@ -81,6 +99,10 @@ class PushRuleEvaluatorForEvent(object):
|
|||||||
return _room_member_count(
|
return _room_member_count(
|
||||||
self._event, condition, self._room_member_count
|
self._event, condition, self._room_member_count
|
||||||
)
|
)
|
||||||
|
elif condition['kind'] == 'sender_notification_permission':
|
||||||
|
return _sender_notification_permission(
|
||||||
|
self._event, condition, self._sender_power_level, self._power_levels,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -183,7 +205,7 @@ def _glob_to_re(glob, word_boundary):
|
|||||||
r,
|
r,
|
||||||
)
|
)
|
||||||
if word_boundary:
|
if word_boundary:
|
||||||
r = r"\b%s\b" % (r,)
|
r = _re_word_boundary(r)
|
||||||
|
|
||||||
return re.compile(r, flags=re.IGNORECASE)
|
return re.compile(r, flags=re.IGNORECASE)
|
||||||
else:
|
else:
|
||||||
@ -192,7 +214,7 @@ def _glob_to_re(glob, word_boundary):
|
|||||||
return re.compile(r, flags=re.IGNORECASE)
|
return re.compile(r, flags=re.IGNORECASE)
|
||||||
elif word_boundary:
|
elif word_boundary:
|
||||||
r = re.escape(glob)
|
r = re.escape(glob)
|
||||||
r = r"\b%s\b" % (r,)
|
r = _re_word_boundary(r)
|
||||||
|
|
||||||
return re.compile(r, flags=re.IGNORECASE)
|
return re.compile(r, flags=re.IGNORECASE)
|
||||||
else:
|
else:
|
||||||
@ -200,6 +222,18 @@ def _glob_to_re(glob, word_boundary):
|
|||||||
return re.compile(r, flags=re.IGNORECASE)
|
return re.compile(r, flags=re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
|
def _re_word_boundary(r):
|
||||||
|
"""
|
||||||
|
Adds word boundary characters to the start and end of an
|
||||||
|
expression to require that the match occur as a whole word,
|
||||||
|
but do so respecting the fact that strings starting or ending
|
||||||
|
with non-word characters will change word boundaries.
|
||||||
|
"""
|
||||||
|
# we can't use \b as it chokes on unicode. however \W seems to be okay
|
||||||
|
# as shorthand for [^0-9A-Za-z_].
|
||||||
|
return r"(^|\W)%s(\W|$)" % (r,)
|
||||||
|
|
||||||
|
|
||||||
def _flatten_dict(d, prefix=[], result=None):
|
def _flatten_dict(d, prefix=[], result=None):
|
||||||
if result is None:
|
if result is None:
|
||||||
result = {}
|
result = {}
|
||||||
|
54
synapse/replication/slave/storage/groups.py
Normal file
54
synapse/replication/slave/storage/groups.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ._base import BaseSlavedStore
|
||||||
|
from ._slaved_id_tracker import SlavedIdTracker
|
||||||
|
from synapse.storage import DataStore
|
||||||
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
|
|
||||||
|
|
||||||
|
class SlavedGroupServerStore(BaseSlavedStore):
|
||||||
|
def __init__(self, db_conn, hs):
|
||||||
|
super(SlavedGroupServerStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
|
self.hs = hs
|
||||||
|
|
||||||
|
self._group_updates_id_gen = SlavedIdTracker(
|
||||||
|
db_conn, "local_group_updates", "stream_id",
|
||||||
|
)
|
||||||
|
self._group_updates_stream_cache = StreamChangeCache(
|
||||||
|
"_group_updates_stream_cache", self._group_updates_id_gen.get_current_token(),
|
||||||
|
)
|
||||||
|
|
||||||
|
get_groups_changes_for_user = DataStore.get_groups_changes_for_user.__func__
|
||||||
|
get_group_stream_token = DataStore.get_group_stream_token.__func__
|
||||||
|
get_all_groups_for_user = DataStore.get_all_groups_for_user.__func__
|
||||||
|
|
||||||
|
def stream_positions(self):
|
||||||
|
result = super(SlavedGroupServerStore, self).stream_positions()
|
||||||
|
result["groups"] = self._group_updates_id_gen.get_current_token()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
|
if stream_name == "groups":
|
||||||
|
self._group_updates_id_gen.advance(token)
|
||||||
|
for row in rows:
|
||||||
|
self._group_updates_stream_cache.entity_has_changed(
|
||||||
|
row.user_id, token
|
||||||
|
)
|
||||||
|
|
||||||
|
return super(SlavedGroupServerStore, self).process_replication_rows(
|
||||||
|
stream_name, token, rows
|
||||||
|
)
|
@ -160,7 +160,11 @@ class ReplicationStreamer(object):
|
|||||||
"Getting stream: %s: %s -> %s",
|
"Getting stream: %s: %s -> %s",
|
||||||
stream.NAME, stream.last_token, stream.upto_token
|
stream.NAME, stream.last_token, stream.upto_token
|
||||||
)
|
)
|
||||||
updates, current_token = yield stream.get_updates()
|
try:
|
||||||
|
updates, current_token = yield stream.get_updates()
|
||||||
|
except:
|
||||||
|
logger.info("Failed to handle stream %s", stream.NAME)
|
||||||
|
raise
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Sending %d updates to %d connections",
|
"Sending %d updates to %d connections",
|
||||||
|
@ -118,6 +118,12 @@ CurrentStateDeltaStreamRow = namedtuple("CurrentStateDeltaStream", (
|
|||||||
"state_key", # str
|
"state_key", # str
|
||||||
"event_id", # str, optional
|
"event_id", # str, optional
|
||||||
))
|
))
|
||||||
|
GroupsStreamRow = namedtuple("GroupsStreamRow", (
|
||||||
|
"group_id", # str
|
||||||
|
"user_id", # str
|
||||||
|
"type", # str
|
||||||
|
"content", # dict
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
class Stream(object):
|
class Stream(object):
|
||||||
@ -464,6 +470,19 @@ class CurrentStateDeltaStream(Stream):
|
|||||||
super(CurrentStateDeltaStream, self).__init__(hs)
|
super(CurrentStateDeltaStream, self).__init__(hs)
|
||||||
|
|
||||||
|
|
||||||
|
class GroupServerStream(Stream):
|
||||||
|
NAME = "groups"
|
||||||
|
ROW_TYPE = GroupsStreamRow
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
store = hs.get_datastore()
|
||||||
|
|
||||||
|
self.current_token = store.get_group_stream_token
|
||||||
|
self.update_function = store.get_all_groups_changes
|
||||||
|
|
||||||
|
super(GroupServerStream, self).__init__(hs)
|
||||||
|
|
||||||
|
|
||||||
STREAMS_MAP = {
|
STREAMS_MAP = {
|
||||||
stream.NAME: stream
|
stream.NAME: stream
|
||||||
for stream in (
|
for stream in (
|
||||||
@ -482,5 +501,6 @@ STREAMS_MAP = {
|
|||||||
TagAccountDataStream,
|
TagAccountDataStream,
|
||||||
AccountDataStream,
|
AccountDataStream,
|
||||||
CurrentStateDeltaStream,
|
CurrentStateDeltaStream,
|
||||||
|
GroupServerStream,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -52,6 +52,7 @@ from synapse.rest.client.v2_alpha import (
|
|||||||
thirdparty,
|
thirdparty,
|
||||||
sendtodevice,
|
sendtodevice,
|
||||||
user_directory,
|
user_directory,
|
||||||
|
groups,
|
||||||
)
|
)
|
||||||
|
|
||||||
from synapse.http.server import JsonResource
|
from synapse.http.server import JsonResource
|
||||||
@ -102,3 +103,4 @@ class ClientRestResource(JsonResource):
|
|||||||
thirdparty.register_servlets(hs, client_resource)
|
thirdparty.register_servlets(hs, client_resource)
|
||||||
sendtodevice.register_servlets(hs, client_resource)
|
sendtodevice.register_servlets(hs, client_resource)
|
||||||
user_directory.register_servlets(hs, client_resource)
|
user_directory.register_servlets(hs, client_resource)
|
||||||
|
groups.register_servlets(hs, client_resource)
|
||||||
|
@ -26,13 +26,13 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
|
|||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(ProfileDisplaynameRestServlet, self).__init__(hs)
|
super(ProfileDisplaynameRestServlet, self).__init__(hs)
|
||||||
self.handlers = hs.get_handlers()
|
self.profile_handler = hs.get_profile_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, user_id):
|
def on_GET(self, request, user_id):
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
|
|
||||||
displayname = yield self.handlers.profile_handler.get_displayname(
|
displayname = yield self.profile_handler.get_displayname(
|
||||||
user,
|
user,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -55,7 +55,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
|
|||||||
except:
|
except:
|
||||||
defer.returnValue((400, "Unable to parse name"))
|
defer.returnValue((400, "Unable to parse name"))
|
||||||
|
|
||||||
yield self.handlers.profile_handler.set_displayname(
|
yield self.profile_handler.set_displayname(
|
||||||
user, requester, new_name, is_admin)
|
user, requester, new_name, is_admin)
|
||||||
|
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
@ -69,13 +69,13 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
|
|||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(ProfileAvatarURLRestServlet, self).__init__(hs)
|
super(ProfileAvatarURLRestServlet, self).__init__(hs)
|
||||||
self.handlers = hs.get_handlers()
|
self.profile_handler = hs.get_profile_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, user_id):
|
def on_GET(self, request, user_id):
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
|
|
||||||
avatar_url = yield self.handlers.profile_handler.get_avatar_url(
|
avatar_url = yield self.profile_handler.get_avatar_url(
|
||||||
user,
|
user,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -97,7 +97,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
|
|||||||
except:
|
except:
|
||||||
defer.returnValue((400, "Unable to parse name"))
|
defer.returnValue((400, "Unable to parse name"))
|
||||||
|
|
||||||
yield self.handlers.profile_handler.set_avatar_url(
|
yield self.profile_handler.set_avatar_url(
|
||||||
user, requester, new_name, is_admin)
|
user, requester, new_name, is_admin)
|
||||||
|
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
@ -111,16 +111,16 @@ class ProfileRestServlet(ClientV1RestServlet):
|
|||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(ProfileRestServlet, self).__init__(hs)
|
super(ProfileRestServlet, self).__init__(hs)
|
||||||
self.handlers = hs.get_handlers()
|
self.profile_handler = hs.get_profile_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, user_id):
|
def on_GET(self, request, user_id):
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
|
|
||||||
displayname = yield self.handlers.profile_handler.get_displayname(
|
displayname = yield self.profile_handler.get_displayname(
|
||||||
user,
|
user,
|
||||||
)
|
)
|
||||||
avatar_url = yield self.handlers.profile_handler.get_avatar_url(
|
avatar_url = yield self.profile_handler.get_avatar_url(
|
||||||
user,
|
user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -398,22 +398,18 @@ class JoinedRoomMemberListRestServlet(ClientV1RestServlet):
|
|||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(JoinedRoomMemberListRestServlet, self).__init__(hs)
|
super(JoinedRoomMemberListRestServlet, self).__init__(hs)
|
||||||
self.state = hs.get_state_handler()
|
self.message_handler = hs.get_handlers().message_handler
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, room_id):
|
def on_GET(self, request, room_id):
|
||||||
yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
users_with_profile = yield self.state.get_current_user_in_room(room_id)
|
users_with_profile = yield self.message_handler.get_joined_members(
|
||||||
|
requester, room_id,
|
||||||
|
)
|
||||||
|
|
||||||
defer.returnValue((200, {
|
defer.returnValue((200, {
|
||||||
"joined": {
|
"joined": users_with_profile,
|
||||||
user_id: {
|
|
||||||
"avatar_url": profile.avatar_url,
|
|
||||||
"display_name": profile.display_name,
|
|
||||||
}
|
|
||||||
for user_id, profile in users_with_profile.iteritems()
|
|
||||||
}
|
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
|
717
synapse/rest/client/v2_alpha/groups.py
Normal file
717
synapse/rest/client/v2_alpha/groups.py
Normal file
@ -0,0 +1,717 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2017 Vector Creations Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||||
|
from synapse.types import GroupID
|
||||||
|
|
||||||
|
from ._base import client_v2_patterns
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class GroupServlet(RestServlet):
|
||||||
|
"""Get the group profile
|
||||||
|
"""
|
||||||
|
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/profile$")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(GroupServlet, self).__init__()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, request, group_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
group_description = yield self.groups_handler.get_group_profile(group_id, user_id)
|
||||||
|
|
||||||
|
defer.returnValue((200, group_description))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request, group_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
yield self.groups_handler.update_group_profile(
|
||||||
|
group_id, user_id, content,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
|
||||||
|
class GroupSummaryServlet(RestServlet):
|
||||||
|
"""Get the full group summary
|
||||||
|
"""
|
||||||
|
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/summary$")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(GroupSummaryServlet, self).__init__()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, request, group_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
get_group_summary = yield self.groups_handler.get_group_summary(group_id, user_id)
|
||||||
|
|
||||||
|
defer.returnValue((200, get_group_summary))
|
||||||
|
|
||||||
|
|
||||||
|
class GroupSummaryRoomsCatServlet(RestServlet):
|
||||||
|
"""Update/delete a rooms entry in the summary.
|
||||||
|
|
||||||
|
Matches both:
|
||||||
|
- /groups/:group/summary/rooms/:room_id
|
||||||
|
- /groups/:group/summary/categories/:category/rooms/:room_id
|
||||||
|
"""
|
||||||
|
PATTERNS = client_v2_patterns(
|
||||||
|
"/groups/(?P<group_id>[^/]*)/summary"
|
||||||
|
"(/categories/(?P<category_id>[^/]+))?"
|
||||||
|
"/rooms/(?P<room_id>[^/]*)$"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(GroupSummaryRoomsCatServlet, self).__init__()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_PUT(self, request, group_id, category_id, room_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
resp = yield self.groups_handler.update_group_summary_room(
|
||||||
|
group_id, user_id,
|
||||||
|
room_id=room_id,
|
||||||
|
category_id=category_id,
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, resp))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_DELETE(self, request, group_id, category_id, room_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
resp = yield self.groups_handler.delete_group_summary_room(
|
||||||
|
group_id, user_id,
|
||||||
|
room_id=room_id,
|
||||||
|
category_id=category_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, resp))
|
||||||
|
|
||||||
|
|
||||||
|
class GroupCategoryServlet(RestServlet):
|
||||||
|
"""Get/add/update/delete a group category
|
||||||
|
"""
|
||||||
|
PATTERNS = client_v2_patterns(
|
||||||
|
"/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(GroupCategoryServlet, self).__init__()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, request, group_id, category_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
category = yield self.groups_handler.get_group_category(
|
||||||
|
group_id, user_id,
|
||||||
|
category_id=category_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, category))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_PUT(self, request, group_id, category_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
resp = yield self.groups_handler.update_group_category(
|
||||||
|
group_id, user_id,
|
||||||
|
category_id=category_id,
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, resp))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_DELETE(self, request, group_id, category_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
resp = yield self.groups_handler.delete_group_category(
|
||||||
|
group_id, user_id,
|
||||||
|
category_id=category_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, resp))
|
||||||
|
|
||||||
|
|
||||||
|
class GroupCategoriesServlet(RestServlet):
|
||||||
|
"""Get all group categories
|
||||||
|
"""
|
||||||
|
PATTERNS = client_v2_patterns(
|
||||||
|
"/groups/(?P<group_id>[^/]*)/categories/$"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(GroupCategoriesServlet, self).__init__()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, request, group_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
category = yield self.groups_handler.get_group_categories(
|
||||||
|
group_id, user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, category))
|
||||||
|
|
||||||
|
|
||||||
|
class GroupRoleServlet(RestServlet):
|
||||||
|
"""Get/add/update/delete a group role
|
||||||
|
"""
|
||||||
|
PATTERNS = client_v2_patterns(
|
||||||
|
"/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(GroupRoleServlet, self).__init__()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, request, group_id, role_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
category = yield self.groups_handler.get_group_role(
|
||||||
|
group_id, user_id,
|
||||||
|
role_id=role_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, category))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_PUT(self, request, group_id, role_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
resp = yield self.groups_handler.update_group_role(
|
||||||
|
group_id, user_id,
|
||||||
|
role_id=role_id,
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, resp))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_DELETE(self, request, group_id, role_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
resp = yield self.groups_handler.delete_group_role(
|
||||||
|
group_id, user_id,
|
||||||
|
role_id=role_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, resp))
|
||||||
|
|
||||||
|
|
||||||
|
class GroupRolesServlet(RestServlet):
|
||||||
|
"""Get all group roles
|
||||||
|
"""
|
||||||
|
PATTERNS = client_v2_patterns(
|
||||||
|
"/groups/(?P<group_id>[^/]*)/roles/$"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(GroupRolesServlet, self).__init__()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, request, group_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
category = yield self.groups_handler.get_group_roles(
|
||||||
|
group_id, user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, category))
|
||||||
|
|
||||||
|
|
||||||
|
class GroupSummaryUsersRoleServlet(RestServlet):
|
||||||
|
"""Update/delete a user's entry in the summary.
|
||||||
|
|
||||||
|
Matches both:
|
||||||
|
- /groups/:group/summary/users/:room_id
|
||||||
|
- /groups/:group/summary/roles/:role/users/:user_id
|
||||||
|
"""
|
||||||
|
PATTERNS = client_v2_patterns(
|
||||||
|
"/groups/(?P<group_id>[^/]*)/summary"
|
||||||
|
"(/roles/(?P<role_id>[^/]+))?"
|
||||||
|
"/users/(?P<user_id>[^/]*)$"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(GroupSummaryUsersRoleServlet, self).__init__()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_PUT(self, request, group_id, role_id, user_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
resp = yield self.groups_handler.update_group_summary_user(
|
||||||
|
group_id, requester_user_id,
|
||||||
|
user_id=user_id,
|
||||||
|
role_id=role_id,
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, resp))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_DELETE(self, request, group_id, role_id, user_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
resp = yield self.groups_handler.delete_group_summary_user(
|
||||||
|
group_id, requester_user_id,
|
||||||
|
user_id=user_id,
|
||||||
|
role_id=role_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, resp))
|
||||||
|
|
||||||
|
|
||||||
|
class GroupRoomServlet(RestServlet):
|
||||||
|
"""Get all rooms in a group
|
||||||
|
"""
|
||||||
|
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(GroupRoomServlet, self).__init__()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, request, group_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
result = yield self.groups_handler.get_rooms_in_group(group_id, user_id)
|
||||||
|
|
||||||
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
|
|
||||||
|
class GroupUsersServlet(RestServlet):
|
||||||
|
"""Get all users in a group
|
||||||
|
"""
|
||||||
|
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/users$")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(GroupUsersServlet, self).__init__()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, request, group_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
result = yield self.groups_handler.get_users_in_group(group_id, user_id)
|
||||||
|
|
||||||
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
|
|
||||||
|
class GroupInvitedUsersServlet(RestServlet):
|
||||||
|
"""Get users invited to a group
|
||||||
|
"""
|
||||||
|
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(GroupInvitedUsersServlet, self).__init__()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, request, group_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
result = yield self.groups_handler.get_invited_users_in_group(group_id, user_id)
|
||||||
|
|
||||||
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
|
|
||||||
|
class GroupCreateServlet(RestServlet):
|
||||||
|
"""Create a group
|
||||||
|
"""
|
||||||
|
PATTERNS = client_v2_patterns("/create_group$")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(GroupCreateServlet, self).__init__()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
self.server_name = hs.hostname
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
# TODO: Create group on remote server
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
localpart = content.pop("localpart")
|
||||||
|
group_id = GroupID.create(localpart, self.server_name).to_string()
|
||||||
|
|
||||||
|
result = yield self.groups_handler.create_group(group_id, user_id, content)
|
||||||
|
|
||||||
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
|
|
||||||
|
class GroupAdminRoomsServlet(RestServlet):
|
||||||
|
"""Add a room to the group
|
||||||
|
"""
|
||||||
|
PATTERNS = client_v2_patterns(
|
||||||
|
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(GroupAdminRoomsServlet, self).__init__()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_PUT(self, request, group_id, room_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
result = yield self.groups_handler.add_room_to_group(
|
||||||
|
group_id, user_id, room_id, content,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_DELETE(self, request, group_id, room_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
result = yield self.groups_handler.remove_room_from_group(
|
||||||
|
group_id, user_id, room_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
|
|
||||||
|
class GroupAdminUsersInviteServlet(RestServlet):
|
||||||
|
"""Invite a user to the group
|
||||||
|
"""
|
||||||
|
PATTERNS = client_v2_patterns(
|
||||||
|
"/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(GroupAdminUsersInviteServlet, self).__init__()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.is_mine_id = hs.is_mine_id
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_PUT(self, request, group_id, user_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
config = content.get("config", {})
|
||||||
|
result = yield self.groups_handler.invite(
|
||||||
|
group_id, user_id, requester_user_id, config,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
|
|
||||||
|
class GroupAdminUsersKickServlet(RestServlet):
|
||||||
|
"""Kick a user from the group
|
||||||
|
"""
|
||||||
|
PATTERNS = client_v2_patterns(
|
||||||
|
"/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(GroupAdminUsersKickServlet, self).__init__()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_PUT(self, request, group_id, user_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
result = yield self.groups_handler.remove_user_from_group(
|
||||||
|
group_id, user_id, requester_user_id, content,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
|
|
||||||
|
class GroupSelfLeaveServlet(RestServlet):
|
||||||
|
"""Leave a joined group
|
||||||
|
"""
|
||||||
|
PATTERNS = client_v2_patterns(
|
||||||
|
"/groups/(?P<group_id>[^/]*)/self/leave$"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(GroupSelfLeaveServlet, self).__init__()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_PUT(self, request, group_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
result = yield self.groups_handler.remove_user_from_group(
|
||||||
|
group_id, requester_user_id, requester_user_id, content,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
|
|
||||||
|
class GroupSelfJoinServlet(RestServlet):
|
||||||
|
"""Attempt to join a group, or knock
|
||||||
|
"""
|
||||||
|
PATTERNS = client_v2_patterns(
|
||||||
|
"/groups/(?P<group_id>[^/]*)/self/join$"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(GroupSelfJoinServlet, self).__init__()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_PUT(self, request, group_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
result = yield self.groups_handler.join_group(
|
||||||
|
group_id, requester_user_id, content,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
|
|
||||||
|
class GroupSelfAcceptInviteServlet(RestServlet):
|
||||||
|
"""Accept a group invite
|
||||||
|
"""
|
||||||
|
PATTERNS = client_v2_patterns(
|
||||||
|
"/groups/(?P<group_id>[^/]*)/self/accept_invite$"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(GroupSelfAcceptInviteServlet, self).__init__()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_PUT(self, request, group_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
result = yield self.groups_handler.accept_invite(
|
||||||
|
group_id, requester_user_id, content,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
|
|
||||||
|
class GroupSelfUpdatePublicityServlet(RestServlet):
|
||||||
|
"""Update whether we publicise a users membership of a group
|
||||||
|
"""
|
||||||
|
PATTERNS = client_v2_patterns(
|
||||||
|
"/groups/(?P<group_id>[^/]*)/self/update_publicity$"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(GroupSelfUpdatePublicityServlet, self).__init__()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_PUT(self, request, group_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
publicise = content["publicise"]
|
||||||
|
yield self.store.update_group_publicity(
|
||||||
|
group_id, requester_user_id, publicise,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
|
||||||
|
class PublicisedGroupsForUserServlet(RestServlet):
|
||||||
|
"""Get the list of groups a user is advertising
|
||||||
|
"""
|
||||||
|
PATTERNS = client_v2_patterns(
|
||||||
|
"/publicised_groups/(?P<user_id>[^/]*)$"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(PublicisedGroupsForUserServlet, self).__init__()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, request, user_id):
|
||||||
|
yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
|
result = yield self.groups_handler.get_publicised_groups_for_user(
|
||||||
|
user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
|
|
||||||
|
class PublicisedGroupsForUsersServlet(RestServlet):
|
||||||
|
"""Get the list of groups a user is advertising
|
||||||
|
"""
|
||||||
|
PATTERNS = client_v2_patterns(
|
||||||
|
"/publicised_groups$"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(PublicisedGroupsForUsersServlet, self).__init__()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request):
|
||||||
|
yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
user_ids = content["user_ids"]
|
||||||
|
|
||||||
|
result = yield self.groups_handler.bulk_get_publicised_groups(
|
||||||
|
user_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
|
|
||||||
|
class GroupsForUserServlet(RestServlet):
|
||||||
|
"""Get all groups the logged in user is joined to
|
||||||
|
"""
|
||||||
|
PATTERNS = client_v2_patterns(
|
||||||
|
"/joined_groups$"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(GroupsForUserServlet, self).__init__()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, request):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
result = yield self.groups_handler.get_joined_groups(user_id)
|
||||||
|
|
||||||
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
|
|
||||||
|
def register_servlets(hs, http_server):
|
||||||
|
GroupServlet(hs).register(http_server)
|
||||||
|
GroupSummaryServlet(hs).register(http_server)
|
||||||
|
GroupInvitedUsersServlet(hs).register(http_server)
|
||||||
|
GroupUsersServlet(hs).register(http_server)
|
||||||
|
GroupRoomServlet(hs).register(http_server)
|
||||||
|
GroupCreateServlet(hs).register(http_server)
|
||||||
|
GroupAdminRoomsServlet(hs).register(http_server)
|
||||||
|
GroupAdminUsersInviteServlet(hs).register(http_server)
|
||||||
|
GroupAdminUsersKickServlet(hs).register(http_server)
|
||||||
|
GroupSelfLeaveServlet(hs).register(http_server)
|
||||||
|
GroupSelfJoinServlet(hs).register(http_server)
|
||||||
|
GroupSelfAcceptInviteServlet(hs).register(http_server)
|
||||||
|
GroupsForUserServlet(hs).register(http_server)
|
||||||
|
GroupCategoryServlet(hs).register(http_server)
|
||||||
|
GroupCategoriesServlet(hs).register(http_server)
|
||||||
|
GroupSummaryRoomsCatServlet(hs).register(http_server)
|
||||||
|
GroupRoleServlet(hs).register(http_server)
|
||||||
|
GroupRolesServlet(hs).register(http_server)
|
||||||
|
GroupSelfUpdatePublicityServlet(hs).register(http_server)
|
||||||
|
GroupSummaryUsersRoleServlet(hs).register(http_server)
|
||||||
|
PublicisedGroupsForUserServlet(hs).register(http_server)
|
||||||
|
PublicisedGroupsForUsersServlet(hs).register(http_server)
|
@ -17,8 +17,10 @@
|
|||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
import synapse
|
import synapse
|
||||||
|
import synapse.types
|
||||||
from synapse.api.auth import get_access_token_from_request, has_access_token
|
from synapse.api.auth import get_access_token_from_request, has_access_token
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
|
from synapse.types import RoomID, RoomAlias
|
||||||
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
|
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string
|
RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string
|
||||||
@ -170,6 +172,7 @@ class RegisterRestServlet(RestServlet):
|
|||||||
self.auth_handler = hs.get_auth_handler()
|
self.auth_handler = hs.get_auth_handler()
|
||||||
self.registration_handler = hs.get_handlers().registration_handler
|
self.registration_handler = hs.get_handlers().registration_handler
|
||||||
self.identity_handler = hs.get_handlers().identity_handler
|
self.identity_handler = hs.get_handlers().identity_handler
|
||||||
|
self.room_member_handler = hs.get_handlers().room_member_handler
|
||||||
self.device_handler = hs.get_device_handler()
|
self.device_handler = hs.get_device_handler()
|
||||||
self.macaroon_gen = hs.get_macaroon_generator()
|
self.macaroon_gen = hs.get_macaroon_generator()
|
||||||
|
|
||||||
@ -340,6 +343,14 @@ class RegisterRestServlet(RestServlet):
|
|||||||
generate_token=False,
|
generate_token=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# auto-join the user to any rooms we're supposed to dump them into
|
||||||
|
fake_requester = synapse.types.create_requester(registered_user_id)
|
||||||
|
for r in self.hs.config.auto_join_rooms:
|
||||||
|
try:
|
||||||
|
yield self._join_user_to_room(fake_requester, r)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to join new user to %r: %r", r, e)
|
||||||
|
|
||||||
# remember that we've now registered that user account, and with
|
# remember that we've now registered that user account, and with
|
||||||
# what user ID (since the user may not have specified)
|
# what user ID (since the user may not have specified)
|
||||||
self.auth_handler.set_session_data(
|
self.auth_handler.set_session_data(
|
||||||
@ -372,6 +383,29 @@ class RegisterRestServlet(RestServlet):
|
|||||||
def on_OPTIONS(self, _):
|
def on_OPTIONS(self, _):
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _join_user_to_room(self, requester, room_identifier):
|
||||||
|
room_id = None
|
||||||
|
if RoomID.is_valid(room_identifier):
|
||||||
|
room_id = room_identifier
|
||||||
|
elif RoomAlias.is_valid(room_identifier):
|
||||||
|
room_alias = RoomAlias.from_string(room_identifier)
|
||||||
|
room_id, remote_room_hosts = (
|
||||||
|
yield self.room_member_handler.lookup_room_alias(room_alias)
|
||||||
|
)
|
||||||
|
room_id = room_id.to_string()
|
||||||
|
else:
|
||||||
|
raise SynapseError(400, "%s was not legal room ID or room alias" % (
|
||||||
|
room_identifier,
|
||||||
|
))
|
||||||
|
|
||||||
|
yield self.room_member_handler.update_membership(
|
||||||
|
requester=requester,
|
||||||
|
target=requester.user,
|
||||||
|
room_id=room_id,
|
||||||
|
action="join",
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _do_appservice_registration(self, username, as_token, body):
|
def _do_appservice_registration(self, username, as_token, body):
|
||||||
user_id = yield self.registration_handler.appservice_register(
|
user_id = yield self.registration_handler.appservice_register(
|
||||||
|
@ -200,6 +200,11 @@ class SyncRestServlet(RestServlet):
|
|||||||
"invite": invited,
|
"invite": invited,
|
||||||
"leave": archived,
|
"leave": archived,
|
||||||
},
|
},
|
||||||
|
"groups": {
|
||||||
|
"join": sync_result.groups.join,
|
||||||
|
"invite": sync_result.groups.invite,
|
||||||
|
"leave": sync_result.groups.leave,
|
||||||
|
},
|
||||||
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
|
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
|
||||||
"next_batch": sync_result.next_batch.to_string(),
|
"next_batch": sync_result.next_batch.to_string(),
|
||||||
}
|
}
|
||||||
|
@ -14,78 +14,200 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
import functools
|
||||||
|
|
||||||
|
NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
|
||||||
|
|
||||||
|
|
||||||
|
def _wrap_in_base_path(func):
|
||||||
|
"""Takes a function that returns a relative path and turns it into an
|
||||||
|
absolute path based on the location of the primary media store
|
||||||
|
"""
|
||||||
|
@functools.wraps(func)
|
||||||
|
def _wrapped(self, *args, **kwargs):
|
||||||
|
path = func(self, *args, **kwargs)
|
||||||
|
return os.path.join(self.base_path, path)
|
||||||
|
|
||||||
|
return _wrapped
|
||||||
|
|
||||||
|
|
||||||
class MediaFilePaths(object):
|
class MediaFilePaths(object):
|
||||||
|
"""Describes where files are stored on disk.
|
||||||
|
|
||||||
def __init__(self, base_path):
|
Most of the functions have a `*_rel` variant which returns a file path that
|
||||||
self.base_path = base_path
|
is relative to the base media store path. This is mainly used when we want
|
||||||
|
to write to the backup media store (when one is configured)
|
||||||
|
"""
|
||||||
|
|
||||||
def default_thumbnail(self, default_top_level, default_sub_type, width,
|
def __init__(self, primary_base_path):
|
||||||
height, content_type, method):
|
self.base_path = primary_base_path
|
||||||
|
|
||||||
|
def default_thumbnail_rel(self, default_top_level, default_sub_type, width,
|
||||||
|
height, content_type, method):
|
||||||
top_level_type, sub_type = content_type.split("/")
|
top_level_type, sub_type = content_type.split("/")
|
||||||
file_name = "%i-%i-%s-%s-%s" % (
|
file_name = "%i-%i-%s-%s-%s" % (
|
||||||
width, height, top_level_type, sub_type, method
|
width, height, top_level_type, sub_type, method
|
||||||
)
|
)
|
||||||
return os.path.join(
|
return os.path.join(
|
||||||
self.base_path, "default_thumbnails", default_top_level,
|
"default_thumbnails", default_top_level,
|
||||||
default_sub_type, file_name
|
default_sub_type, file_name
|
||||||
)
|
)
|
||||||
|
|
||||||
def local_media_filepath(self, media_id):
|
default_thumbnail = _wrap_in_base_path(default_thumbnail_rel)
|
||||||
|
|
||||||
|
def local_media_filepath_rel(self, media_id):
|
||||||
return os.path.join(
|
return os.path.join(
|
||||||
self.base_path, "local_content",
|
"local_content",
|
||||||
media_id[0:2], media_id[2:4], media_id[4:]
|
media_id[0:2], media_id[2:4], media_id[4:]
|
||||||
)
|
)
|
||||||
|
|
||||||
def local_media_thumbnail(self, media_id, width, height, content_type,
|
local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
|
||||||
method):
|
|
||||||
|
def local_media_thumbnail_rel(self, media_id, width, height, content_type,
|
||||||
|
method):
|
||||||
top_level_type, sub_type = content_type.split("/")
|
top_level_type, sub_type = content_type.split("/")
|
||||||
file_name = "%i-%i-%s-%s-%s" % (
|
file_name = "%i-%i-%s-%s-%s" % (
|
||||||
width, height, top_level_type, sub_type, method
|
width, height, top_level_type, sub_type, method
|
||||||
)
|
)
|
||||||
return os.path.join(
|
return os.path.join(
|
||||||
self.base_path, "local_thumbnails",
|
"local_thumbnails",
|
||||||
media_id[0:2], media_id[2:4], media_id[4:],
|
media_id[0:2], media_id[2:4], media_id[4:],
|
||||||
file_name
|
file_name
|
||||||
)
|
)
|
||||||
|
|
||||||
def remote_media_filepath(self, server_name, file_id):
|
local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel)
|
||||||
|
|
||||||
|
def remote_media_filepath_rel(self, server_name, file_id):
|
||||||
return os.path.join(
|
return os.path.join(
|
||||||
self.base_path, "remote_content", server_name,
|
"remote_content", server_name,
|
||||||
file_id[0:2], file_id[2:4], file_id[4:]
|
file_id[0:2], file_id[2:4], file_id[4:]
|
||||||
)
|
)
|
||||||
|
|
||||||
def remote_media_thumbnail(self, server_name, file_id, width, height,
|
remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
|
||||||
content_type, method):
|
|
||||||
|
def remote_media_thumbnail_rel(self, server_name, file_id, width, height,
|
||||||
|
content_type, method):
|
||||||
top_level_type, sub_type = content_type.split("/")
|
top_level_type, sub_type = content_type.split("/")
|
||||||
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
|
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
|
||||||
return os.path.join(
|
return os.path.join(
|
||||||
self.base_path, "remote_thumbnail", server_name,
|
"remote_thumbnail", server_name,
|
||||||
file_id[0:2], file_id[2:4], file_id[4:],
|
file_id[0:2], file_id[2:4], file_id[4:],
|
||||||
file_name
|
file_name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel)
|
||||||
|
|
||||||
def remote_media_thumbnail_dir(self, server_name, file_id):
|
def remote_media_thumbnail_dir(self, server_name, file_id):
|
||||||
return os.path.join(
|
return os.path.join(
|
||||||
self.base_path, "remote_thumbnail", server_name,
|
self.base_path, "remote_thumbnail", server_name,
|
||||||
file_id[0:2], file_id[2:4], file_id[4:],
|
file_id[0:2], file_id[2:4], file_id[4:],
|
||||||
)
|
)
|
||||||
|
|
||||||
def url_cache_filepath(self, media_id):
|
def url_cache_filepath_rel(self, media_id):
|
||||||
return os.path.join(
|
if NEW_FORMAT_ID_RE.match(media_id):
|
||||||
self.base_path, "url_cache",
|
# Media id is of the form <DATE><RANDOM_STRING>
|
||||||
media_id[0:2], media_id[2:4], media_id[4:]
|
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
||||||
)
|
return os.path.join(
|
||||||
|
"url_cache",
|
||||||
|
media_id[:10], media_id[11:]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return os.path.join(
|
||||||
|
"url_cache",
|
||||||
|
media_id[0:2], media_id[2:4], media_id[4:],
|
||||||
|
)
|
||||||
|
|
||||||
|
url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
|
||||||
|
|
||||||
|
def url_cache_filepath_dirs_to_delete(self, media_id):
|
||||||
|
"The dirs to try and remove if we delete the media_id file"
|
||||||
|
if NEW_FORMAT_ID_RE.match(media_id):
|
||||||
|
return [
|
||||||
|
os.path.join(
|
||||||
|
self.base_path, "url_cache",
|
||||||
|
media_id[:10],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return [
|
||||||
|
os.path.join(
|
||||||
|
self.base_path, "url_cache",
|
||||||
|
media_id[0:2], media_id[2:4],
|
||||||
|
),
|
||||||
|
os.path.join(
|
||||||
|
self.base_path, "url_cache",
|
||||||
|
media_id[0:2],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def url_cache_thumbnail_rel(self, media_id, width, height, content_type,
|
||||||
|
method):
|
||||||
|
# Media id is of the form <DATE><RANDOM_STRING>
|
||||||
|
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
||||||
|
|
||||||
def url_cache_thumbnail(self, media_id, width, height, content_type,
|
|
||||||
method):
|
|
||||||
top_level_type, sub_type = content_type.split("/")
|
top_level_type, sub_type = content_type.split("/")
|
||||||
file_name = "%i-%i-%s-%s-%s" % (
|
file_name = "%i-%i-%s-%s-%s" % (
|
||||||
width, height, top_level_type, sub_type, method
|
width, height, top_level_type, sub_type, method
|
||||||
)
|
)
|
||||||
return os.path.join(
|
|
||||||
self.base_path, "url_cache_thumbnails",
|
if NEW_FORMAT_ID_RE.match(media_id):
|
||||||
media_id[0:2], media_id[2:4], media_id[4:],
|
return os.path.join(
|
||||||
file_name
|
"url_cache_thumbnails",
|
||||||
)
|
media_id[:10], media_id[11:],
|
||||||
|
file_name
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return os.path.join(
|
||||||
|
"url_cache_thumbnails",
|
||||||
|
media_id[0:2], media_id[2:4], media_id[4:],
|
||||||
|
file_name
|
||||||
|
)
|
||||||
|
|
||||||
|
url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
|
||||||
|
|
||||||
|
def url_cache_thumbnail_directory(self, media_id):
|
||||||
|
# Media id is of the form <DATE><RANDOM_STRING>
|
||||||
|
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
||||||
|
|
||||||
|
if NEW_FORMAT_ID_RE.match(media_id):
|
||||||
|
return os.path.join(
|
||||||
|
self.base_path, "url_cache_thumbnails",
|
||||||
|
media_id[:10], media_id[11:],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return os.path.join(
|
||||||
|
self.base_path, "url_cache_thumbnails",
|
||||||
|
media_id[0:2], media_id[2:4], media_id[4:],
|
||||||
|
)
|
||||||
|
|
||||||
|
def url_cache_thumbnail_dirs_to_delete(self, media_id):
|
||||||
|
"The dirs to try and remove if we delete the media_id thumbnails"
|
||||||
|
# Media id is of the form <DATE><RANDOM_STRING>
|
||||||
|
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
||||||
|
if NEW_FORMAT_ID_RE.match(media_id):
|
||||||
|
return [
|
||||||
|
os.path.join(
|
||||||
|
self.base_path, "url_cache_thumbnails",
|
||||||
|
media_id[:10], media_id[11:],
|
||||||
|
),
|
||||||
|
os.path.join(
|
||||||
|
self.base_path, "url_cache_thumbnails",
|
||||||
|
media_id[:10],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return [
|
||||||
|
os.path.join(
|
||||||
|
self.base_path, "url_cache_thumbnails",
|
||||||
|
media_id[0:2], media_id[2:4], media_id[4:],
|
||||||
|
),
|
||||||
|
os.path.join(
|
||||||
|
self.base_path, "url_cache_thumbnails",
|
||||||
|
media_id[0:2], media_id[2:4],
|
||||||
|
),
|
||||||
|
os.path.join(
|
||||||
|
self.base_path, "url_cache_thumbnails",
|
||||||
|
media_id[0:2],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
@ -33,7 +33,7 @@ from synapse.api.errors import SynapseError, HttpResponseException, \
|
|||||||
|
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async import Linearizer
|
||||||
from synapse.util.stringutils import is_ascii
|
from synapse.util.stringutils import is_ascii
|
||||||
from synapse.util.logcontext import preserve_context_over_fn
|
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@ -59,7 +59,14 @@ class MediaRepository(object):
|
|||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.max_upload_size = hs.config.max_upload_size
|
self.max_upload_size = hs.config.max_upload_size
|
||||||
self.max_image_pixels = hs.config.max_image_pixels
|
self.max_image_pixels = hs.config.max_image_pixels
|
||||||
self.filepaths = MediaFilePaths(hs.config.media_store_path)
|
|
||||||
|
self.primary_base_path = hs.config.media_store_path
|
||||||
|
self.filepaths = MediaFilePaths(self.primary_base_path)
|
||||||
|
|
||||||
|
self.backup_base_path = hs.config.backup_media_store_path
|
||||||
|
|
||||||
|
self.synchronous_backup_media_store = hs.config.synchronous_backup_media_store
|
||||||
|
|
||||||
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
||||||
self.thumbnail_requirements = hs.config.thumbnail_requirements
|
self.thumbnail_requirements = hs.config.thumbnail_requirements
|
||||||
|
|
||||||
@ -87,18 +94,86 @@ class MediaRepository(object):
|
|||||||
if not os.path.exists(dirname):
|
if not os.path.exists(dirname):
|
||||||
os.makedirs(dirname)
|
os.makedirs(dirname)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _write_file_synchronously(source, fname):
|
||||||
|
"""Write `source` to the path `fname` synchronously. Should be called
|
||||||
|
from a thread.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: A file like object to be written
|
||||||
|
fname (str): Path to write to
|
||||||
|
"""
|
||||||
|
MediaRepository._makedirs(fname)
|
||||||
|
source.seek(0) # Ensure we read from the start of the file
|
||||||
|
with open(fname, "wb") as f:
|
||||||
|
shutil.copyfileobj(source, f)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def write_to_file_and_backup(self, source, path):
|
||||||
|
"""Write `source` to the on disk media store, and also the backup store
|
||||||
|
if configured.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: A file like object that should be written
|
||||||
|
path (str): Relative path to write file to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[str]: the file path written to in the primary media store
|
||||||
|
"""
|
||||||
|
fname = os.path.join(self.primary_base_path, path)
|
||||||
|
|
||||||
|
# Write to the main repository
|
||||||
|
yield make_deferred_yieldable(threads.deferToThread(
|
||||||
|
self._write_file_synchronously, source, fname,
|
||||||
|
))
|
||||||
|
|
||||||
|
# Write to backup repository
|
||||||
|
yield self.copy_to_backup(path)
|
||||||
|
|
||||||
|
defer.returnValue(fname)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def copy_to_backup(self, path):
|
||||||
|
"""Copy a file from the primary to backup media store, if configured.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path(str): Relative path to write file to
|
||||||
|
"""
|
||||||
|
if self.backup_base_path:
|
||||||
|
primary_fname = os.path.join(self.primary_base_path, path)
|
||||||
|
backup_fname = os.path.join(self.backup_base_path, path)
|
||||||
|
|
||||||
|
# We can either wait for successful writing to the backup repository
|
||||||
|
# or write in the background and immediately return
|
||||||
|
if self.synchronous_backup_media_store:
|
||||||
|
yield make_deferred_yieldable(threads.deferToThread(
|
||||||
|
shutil.copyfile, primary_fname, backup_fname,
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
preserve_fn(threads.deferToThread)(
|
||||||
|
shutil.copyfile, primary_fname, backup_fname,
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def create_content(self, media_type, upload_name, content, content_length,
|
def create_content(self, media_type, upload_name, content, content_length,
|
||||||
auth_user):
|
auth_user):
|
||||||
|
"""Store uploaded content for a local user and return the mxc URL
|
||||||
|
|
||||||
|
Args:
|
||||||
|
media_type(str): The content type of the file
|
||||||
|
upload_name(str): The name of the file
|
||||||
|
content: A file like object that is the content to store
|
||||||
|
content_length(int): The length of the content
|
||||||
|
auth_user(str): The user_id of the uploader
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[str]: The mxc url of the stored content
|
||||||
|
"""
|
||||||
media_id = random_string(24)
|
media_id = random_string(24)
|
||||||
|
|
||||||
fname = self.filepaths.local_media_filepath(media_id)
|
fname = yield self.write_to_file_and_backup(
|
||||||
self._makedirs(fname)
|
content, self.filepaths.local_media_filepath_rel(media_id)
|
||||||
|
)
|
||||||
# This shouldn't block for very long because the content will have
|
|
||||||
# already been uploaded at this point.
|
|
||||||
with open(fname, "wb") as f:
|
|
||||||
f.write(content)
|
|
||||||
|
|
||||||
logger.info("Stored local media in file %r", fname)
|
logger.info("Stored local media in file %r", fname)
|
||||||
|
|
||||||
@ -115,7 +190,7 @@ class MediaRepository(object):
|
|||||||
"media_length": content_length,
|
"media_length": content_length,
|
||||||
}
|
}
|
||||||
|
|
||||||
yield self._generate_local_thumbnails(media_id, media_info)
|
yield self._generate_thumbnails(None, media_id, media_info)
|
||||||
|
|
||||||
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
|
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
|
||||||
|
|
||||||
@ -148,9 +223,10 @@ class MediaRepository(object):
|
|||||||
def _download_remote_file(self, server_name, media_id):
|
def _download_remote_file(self, server_name, media_id):
|
||||||
file_id = random_string(24)
|
file_id = random_string(24)
|
||||||
|
|
||||||
fname = self.filepaths.remote_media_filepath(
|
fpath = self.filepaths.remote_media_filepath_rel(
|
||||||
server_name, file_id
|
server_name, file_id
|
||||||
)
|
)
|
||||||
|
fname = os.path.join(self.primary_base_path, fpath)
|
||||||
self._makedirs(fname)
|
self._makedirs(fname)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -192,6 +268,8 @@ class MediaRepository(object):
|
|||||||
server_name, media_id)
|
server_name, media_id)
|
||||||
raise SynapseError(502, "Failed to fetch remote media")
|
raise SynapseError(502, "Failed to fetch remote media")
|
||||||
|
|
||||||
|
yield self.copy_to_backup(fpath)
|
||||||
|
|
||||||
media_type = headers["Content-Type"][0]
|
media_type = headers["Content-Type"][0]
|
||||||
time_now_ms = self.clock.time_msec()
|
time_now_ms = self.clock.time_msec()
|
||||||
|
|
||||||
@ -244,7 +322,7 @@ class MediaRepository(object):
|
|||||||
"filesystem_id": file_id,
|
"filesystem_id": file_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
yield self._generate_remote_thumbnails(
|
yield self._generate_thumbnails(
|
||||||
server_name, media_id, media_info
|
server_name, media_id, media_info
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -253,9 +331,8 @@ class MediaRepository(object):
|
|||||||
def _get_thumbnail_requirements(self, media_type):
|
def _get_thumbnail_requirements(self, media_type):
|
||||||
return self.thumbnail_requirements.get(media_type, ())
|
return self.thumbnail_requirements.get(media_type, ())
|
||||||
|
|
||||||
def _generate_thumbnail(self, input_path, t_path, t_width, t_height,
|
def _generate_thumbnail(self, thumbnailer, t_width, t_height,
|
||||||
t_method, t_type):
|
t_method, t_type):
|
||||||
thumbnailer = Thumbnailer(input_path)
|
|
||||||
m_width = thumbnailer.width
|
m_width = thumbnailer.width
|
||||||
m_height = thumbnailer.height
|
m_height = thumbnailer.height
|
||||||
|
|
||||||
@ -267,72 +344,105 @@ class MediaRepository(object):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if t_method == "crop":
|
if t_method == "crop":
|
||||||
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
|
t_byte_source = thumbnailer.crop(t_width, t_height, t_type)
|
||||||
elif t_method == "scale":
|
elif t_method == "scale":
|
||||||
t_width, t_height = thumbnailer.aspect(t_width, t_height)
|
t_width, t_height = thumbnailer.aspect(t_width, t_height)
|
||||||
t_width = min(m_width, t_width)
|
t_width = min(m_width, t_width)
|
||||||
t_height = min(m_height, t_height)
|
t_height = min(m_height, t_height)
|
||||||
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
|
t_byte_source = thumbnailer.scale(t_width, t_height, t_type)
|
||||||
else:
|
else:
|
||||||
t_len = None
|
t_byte_source = None
|
||||||
|
|
||||||
return t_len
|
return t_byte_source
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def generate_local_exact_thumbnail(self, media_id, t_width, t_height,
|
def generate_local_exact_thumbnail(self, media_id, t_width, t_height,
|
||||||
t_method, t_type):
|
t_method, t_type):
|
||||||
input_path = self.filepaths.local_media_filepath(media_id)
|
input_path = self.filepaths.local_media_filepath(media_id)
|
||||||
|
|
||||||
t_path = self.filepaths.local_media_thumbnail(
|
thumbnailer = Thumbnailer(input_path)
|
||||||
media_id, t_width, t_height, t_type, t_method
|
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
|
||||||
)
|
|
||||||
self._makedirs(t_path)
|
|
||||||
|
|
||||||
t_len = yield preserve_context_over_fn(
|
|
||||||
threads.deferToThread,
|
|
||||||
self._generate_thumbnail,
|
self._generate_thumbnail,
|
||||||
input_path, t_path, t_width, t_height, t_method, t_type
|
thumbnailer, t_width, t_height, t_method, t_type
|
||||||
)
|
))
|
||||||
|
|
||||||
|
if t_byte_source:
|
||||||
|
try:
|
||||||
|
output_path = yield self.write_to_file_and_backup(
|
||||||
|
t_byte_source,
|
||||||
|
self.filepaths.local_media_thumbnail_rel(
|
||||||
|
media_id, t_width, t_height, t_type, t_method
|
||||||
|
)
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
t_byte_source.close()
|
||||||
|
|
||||||
|
logger.info("Stored thumbnail in file %r", output_path)
|
||||||
|
|
||||||
|
t_len = os.path.getsize(output_path)
|
||||||
|
|
||||||
if t_len:
|
|
||||||
yield self.store.store_local_thumbnail(
|
yield self.store.store_local_thumbnail(
|
||||||
media_id, t_width, t_height, t_type, t_method, t_len
|
media_id, t_width, t_height, t_type, t_method, t_len
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(t_path)
|
defer.returnValue(output_path)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
|
def generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
|
||||||
t_width, t_height, t_method, t_type):
|
t_width, t_height, t_method, t_type):
|
||||||
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
|
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
|
||||||
|
|
||||||
t_path = self.filepaths.remote_media_thumbnail(
|
thumbnailer = Thumbnailer(input_path)
|
||||||
server_name, file_id, t_width, t_height, t_type, t_method
|
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
|
||||||
)
|
|
||||||
self._makedirs(t_path)
|
|
||||||
|
|
||||||
t_len = yield preserve_context_over_fn(
|
|
||||||
threads.deferToThread,
|
|
||||||
self._generate_thumbnail,
|
self._generate_thumbnail,
|
||||||
input_path, t_path, t_width, t_height, t_method, t_type
|
thumbnailer, t_width, t_height, t_method, t_type
|
||||||
)
|
))
|
||||||
|
|
||||||
|
if t_byte_source:
|
||||||
|
try:
|
||||||
|
output_path = yield self.write_to_file_and_backup(
|
||||||
|
t_byte_source,
|
||||||
|
self.filepaths.remote_media_thumbnail_rel(
|
||||||
|
server_name, file_id, t_width, t_height, t_type, t_method
|
||||||
|
)
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
t_byte_source.close()
|
||||||
|
|
||||||
|
logger.info("Stored thumbnail in file %r", output_path)
|
||||||
|
|
||||||
|
t_len = os.path.getsize(output_path)
|
||||||
|
|
||||||
if t_len:
|
|
||||||
yield self.store.store_remote_media_thumbnail(
|
yield self.store.store_remote_media_thumbnail(
|
||||||
server_name, media_id, file_id,
|
server_name, media_id, file_id,
|
||||||
t_width, t_height, t_type, t_method, t_len
|
t_width, t_height, t_type, t_method, t_len
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(t_path)
|
defer.returnValue(output_path)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _generate_local_thumbnails(self, media_id, media_info, url_cache=False):
|
def _generate_thumbnails(self, server_name, media_id, media_info, url_cache=False):
|
||||||
|
"""Generate and store thumbnails for an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_name(str|None): The server name if remote media, else None if local
|
||||||
|
media_id(str)
|
||||||
|
media_info(dict)
|
||||||
|
url_cache(bool): If we are thumbnailing images downloaded for the URL cache,
|
||||||
|
used exclusively by the url previewer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[dict]: Dict with "width" and "height" keys of original image
|
||||||
|
"""
|
||||||
media_type = media_info["media_type"]
|
media_type = media_info["media_type"]
|
||||||
|
file_id = media_info.get("filesystem_id")
|
||||||
requirements = self._get_thumbnail_requirements(media_type)
|
requirements = self._get_thumbnail_requirements(media_type)
|
||||||
if not requirements:
|
if not requirements:
|
||||||
return
|
return
|
||||||
|
|
||||||
if url_cache:
|
if server_name:
|
||||||
|
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
|
||||||
|
elif url_cache:
|
||||||
input_path = self.filepaths.url_cache_filepath(media_id)
|
input_path = self.filepaths.url_cache_filepath(media_id)
|
||||||
else:
|
else:
|
||||||
input_path = self.filepaths.local_media_filepath(media_id)
|
input_path = self.filepaths.local_media_filepath(media_id)
|
||||||
@ -348,135 +458,72 @@ class MediaRepository(object):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
local_thumbnails = []
|
# We deduplicate the thumbnail sizes by ignoring the cropped versions if
|
||||||
|
# they have the same dimensions of a scaled one.
|
||||||
|
thumbnails = {}
|
||||||
|
for r_width, r_height, r_method, r_type in requirements:
|
||||||
|
if r_method == "crop":
|
||||||
|
thumbnails.setdefault((r_width, r_height, r_type), r_method)
|
||||||
|
elif r_method == "scale":
|
||||||
|
t_width, t_height = thumbnailer.aspect(r_width, r_height)
|
||||||
|
t_width = min(m_width, t_width)
|
||||||
|
t_height = min(m_height, t_height)
|
||||||
|
thumbnails[(t_width, t_height, r_type)] = r_method
|
||||||
|
|
||||||
def generate_thumbnails():
|
# Now we generate the thumbnails for each dimension, store it
|
||||||
scales = set()
|
for (t_width, t_height, t_type), t_method in thumbnails.iteritems():
|
||||||
crops = set()
|
# Work out the correct file name for thumbnail
|
||||||
for r_width, r_height, r_method, r_type in requirements:
|
if server_name:
|
||||||
if r_method == "scale":
|
file_path = self.filepaths.remote_media_thumbnail_rel(
|
||||||
t_width, t_height = thumbnailer.aspect(r_width, r_height)
|
|
||||||
scales.add((
|
|
||||||
min(m_width, t_width), min(m_height, t_height), r_type,
|
|
||||||
))
|
|
||||||
elif r_method == "crop":
|
|
||||||
crops.add((r_width, r_height, r_type))
|
|
||||||
|
|
||||||
for t_width, t_height, t_type in scales:
|
|
||||||
t_method = "scale"
|
|
||||||
if url_cache:
|
|
||||||
t_path = self.filepaths.url_cache_thumbnail(
|
|
||||||
media_id, t_width, t_height, t_type, t_method
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
t_path = self.filepaths.local_media_thumbnail(
|
|
||||||
media_id, t_width, t_height, t_type, t_method
|
|
||||||
)
|
|
||||||
self._makedirs(t_path)
|
|
||||||
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
|
|
||||||
|
|
||||||
local_thumbnails.append((
|
|
||||||
media_id, t_width, t_height, t_type, t_method, t_len
|
|
||||||
))
|
|
||||||
|
|
||||||
for t_width, t_height, t_type in crops:
|
|
||||||
if (t_width, t_height, t_type) in scales:
|
|
||||||
# If the aspect ratio of the cropped thumbnail matches a purely
|
|
||||||
# scaled one then there is no point in calculating a separate
|
|
||||||
# thumbnail.
|
|
||||||
continue
|
|
||||||
t_method = "crop"
|
|
||||||
if url_cache:
|
|
||||||
t_path = self.filepaths.url_cache_thumbnail(
|
|
||||||
media_id, t_width, t_height, t_type, t_method
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
t_path = self.filepaths.local_media_thumbnail(
|
|
||||||
media_id, t_width, t_height, t_type, t_method
|
|
||||||
)
|
|
||||||
self._makedirs(t_path)
|
|
||||||
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
|
|
||||||
local_thumbnails.append((
|
|
||||||
media_id, t_width, t_height, t_type, t_method, t_len
|
|
||||||
))
|
|
||||||
|
|
||||||
yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
|
|
||||||
|
|
||||||
for l in local_thumbnails:
|
|
||||||
yield self.store.store_local_thumbnail(*l)
|
|
||||||
|
|
||||||
defer.returnValue({
|
|
||||||
"width": m_width,
|
|
||||||
"height": m_height,
|
|
||||||
})
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _generate_remote_thumbnails(self, server_name, media_id, media_info):
|
|
||||||
media_type = media_info["media_type"]
|
|
||||||
file_id = media_info["filesystem_id"]
|
|
||||||
requirements = self._get_thumbnail_requirements(media_type)
|
|
||||||
if not requirements:
|
|
||||||
return
|
|
||||||
|
|
||||||
remote_thumbnails = []
|
|
||||||
|
|
||||||
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
|
|
||||||
thumbnailer = Thumbnailer(input_path)
|
|
||||||
m_width = thumbnailer.width
|
|
||||||
m_height = thumbnailer.height
|
|
||||||
|
|
||||||
def generate_thumbnails():
|
|
||||||
if m_width * m_height >= self.max_image_pixels:
|
|
||||||
logger.info(
|
|
||||||
"Image too large to thumbnail %r x %r > %r",
|
|
||||||
m_width, m_height, self.max_image_pixels
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
scales = set()
|
|
||||||
crops = set()
|
|
||||||
for r_width, r_height, r_method, r_type in requirements:
|
|
||||||
if r_method == "scale":
|
|
||||||
t_width, t_height = thumbnailer.aspect(r_width, r_height)
|
|
||||||
scales.add((
|
|
||||||
min(m_width, t_width), min(m_height, t_height), r_type,
|
|
||||||
))
|
|
||||||
elif r_method == "crop":
|
|
||||||
crops.add((r_width, r_height, r_type))
|
|
||||||
|
|
||||||
for t_width, t_height, t_type in scales:
|
|
||||||
t_method = "scale"
|
|
||||||
t_path = self.filepaths.remote_media_thumbnail(
|
|
||||||
server_name, file_id, t_width, t_height, t_type, t_method
|
server_name, file_id, t_width, t_height, t_type, t_method
|
||||||
)
|
)
|
||||||
self._makedirs(t_path)
|
elif url_cache:
|
||||||
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
|
file_path = self.filepaths.url_cache_thumbnail_rel(
|
||||||
remote_thumbnails.append([
|
media_id, t_width, t_height, t_type, t_method
|
||||||
server_name, media_id, file_id,
|
|
||||||
t_width, t_height, t_type, t_method, t_len
|
|
||||||
])
|
|
||||||
|
|
||||||
for t_width, t_height, t_type in crops:
|
|
||||||
if (t_width, t_height, t_type) in scales:
|
|
||||||
# If the aspect ratio of the cropped thumbnail matches a purely
|
|
||||||
# scaled one then there is no point in calculating a separate
|
|
||||||
# thumbnail.
|
|
||||||
continue
|
|
||||||
t_method = "crop"
|
|
||||||
t_path = self.filepaths.remote_media_thumbnail(
|
|
||||||
server_name, file_id, t_width, t_height, t_type, t_method
|
|
||||||
)
|
)
|
||||||
self._makedirs(t_path)
|
else:
|
||||||
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
|
file_path = self.filepaths.local_media_thumbnail_rel(
|
||||||
remote_thumbnails.append([
|
media_id, t_width, t_height, t_type, t_method
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate the thumbnail
|
||||||
|
if t_method == "crop":
|
||||||
|
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
|
||||||
|
thumbnailer.crop,
|
||||||
|
t_width, t_height, t_type,
|
||||||
|
))
|
||||||
|
elif t_method == "scale":
|
||||||
|
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
|
||||||
|
thumbnailer.scale,
|
||||||
|
t_width, t_height, t_type,
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
logger.error("Unrecognized method: %r", t_method)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not t_byte_source:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Write to disk
|
||||||
|
output_path = yield self.write_to_file_and_backup(
|
||||||
|
t_byte_source, file_path,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
t_byte_source.close()
|
||||||
|
|
||||||
|
t_len = os.path.getsize(output_path)
|
||||||
|
|
||||||
|
# Write to database
|
||||||
|
if server_name:
|
||||||
|
yield self.store.store_remote_media_thumbnail(
|
||||||
server_name, media_id, file_id,
|
server_name, media_id, file_id,
|
||||||
t_width, t_height, t_type, t_method, t_len
|
t_width, t_height, t_type, t_method, t_len
|
||||||
])
|
)
|
||||||
|
else:
|
||||||
yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
|
yield self.store.store_local_thumbnail(
|
||||||
|
media_id, t_width, t_height, t_type, t_method, t_len
|
||||||
for r in remote_thumbnails:
|
)
|
||||||
yield self.store.store_remote_media_thumbnail(*r)
|
|
||||||
|
|
||||||
defer.returnValue({
|
defer.returnValue({
|
||||||
"width": m_width,
|
"width": m_width,
|
||||||
@ -497,6 +544,8 @@ class MediaRepository(object):
|
|||||||
|
|
||||||
logger.info("Deleting: %r", key)
|
logger.info("Deleting: %r", key)
|
||||||
|
|
||||||
|
# TODO: Should we delete from the backup store
|
||||||
|
|
||||||
with (yield self.remote_media_linearizer.queue(key)):
|
with (yield self.remote_media_linearizer.queue(key)):
|
||||||
full_path = self.filepaths.remote_media_filepath(origin, file_id)
|
full_path = self.filepaths.remote_media_filepath(origin, file_id)
|
||||||
try:
|
try:
|
||||||
|
@ -36,6 +36,9 @@ import cgi
|
|||||||
import ujson as json
|
import ujson as json
|
||||||
import urlparse
|
import urlparse
|
||||||
import itertools
|
import itertools
|
||||||
|
import datetime
|
||||||
|
import errno
|
||||||
|
import shutil
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -56,6 +59,7 @@ class PreviewUrlResource(Resource):
|
|||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.client = SpiderHttpClient(hs)
|
self.client = SpiderHttpClient(hs)
|
||||||
self.media_repo = media_repo
|
self.media_repo = media_repo
|
||||||
|
self.primary_base_path = media_repo.primary_base_path
|
||||||
|
|
||||||
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
|
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
|
||||||
|
|
||||||
@ -70,6 +74,10 @@ class PreviewUrlResource(Resource):
|
|||||||
|
|
||||||
self.downloads = {}
|
self.downloads = {}
|
||||||
|
|
||||||
|
self._cleaner_loop = self.clock.looping_call(
|
||||||
|
self._expire_url_cache_data, 10 * 1000
|
||||||
|
)
|
||||||
|
|
||||||
def render_GET(self, request):
|
def render_GET(self, request):
|
||||||
self._async_render_GET(request)
|
self._async_render_GET(request)
|
||||||
return NOT_DONE_YET
|
return NOT_DONE_YET
|
||||||
@ -130,7 +138,7 @@ class PreviewUrlResource(Resource):
|
|||||||
cache_result = yield self.store.get_url_cache(url, ts)
|
cache_result = yield self.store.get_url_cache(url, ts)
|
||||||
if (
|
if (
|
||||||
cache_result and
|
cache_result and
|
||||||
cache_result["download_ts"] + cache_result["expires"] > ts and
|
cache_result["expires_ts"] > ts and
|
||||||
cache_result["response_code"] / 100 == 2
|
cache_result["response_code"] / 100 == 2
|
||||||
):
|
):
|
||||||
respond_with_json_bytes(
|
respond_with_json_bytes(
|
||||||
@ -163,8 +171,8 @@ class PreviewUrlResource(Resource):
|
|||||||
logger.debug("got media_info of '%s'" % media_info)
|
logger.debug("got media_info of '%s'" % media_info)
|
||||||
|
|
||||||
if _is_media(media_info['media_type']):
|
if _is_media(media_info['media_type']):
|
||||||
dims = yield self.media_repo._generate_local_thumbnails(
|
dims = yield self.media_repo._generate_thumbnails(
|
||||||
media_info['filesystem_id'], media_info, url_cache=True,
|
None, media_info['filesystem_id'], media_info, url_cache=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
og = {
|
og = {
|
||||||
@ -209,8 +217,8 @@ class PreviewUrlResource(Resource):
|
|||||||
|
|
||||||
if _is_media(image_info['media_type']):
|
if _is_media(image_info['media_type']):
|
||||||
# TODO: make sure we don't choke on white-on-transparent images
|
# TODO: make sure we don't choke on white-on-transparent images
|
||||||
dims = yield self.media_repo._generate_local_thumbnails(
|
dims = yield self.media_repo._generate_thumbnails(
|
||||||
image_info['filesystem_id'], image_info, url_cache=True,
|
None, image_info['filesystem_id'], image_info, url_cache=True,
|
||||||
)
|
)
|
||||||
if dims:
|
if dims:
|
||||||
og["og:image:width"] = dims['width']
|
og["og:image:width"] = dims['width']
|
||||||
@ -239,7 +247,7 @@ class PreviewUrlResource(Resource):
|
|||||||
url,
|
url,
|
||||||
media_info["response_code"],
|
media_info["response_code"],
|
||||||
media_info["etag"],
|
media_info["etag"],
|
||||||
media_info["expires"],
|
media_info["expires"] + media_info["created_ts"],
|
||||||
json.dumps(og),
|
json.dumps(og),
|
||||||
media_info["filesystem_id"],
|
media_info["filesystem_id"],
|
||||||
media_info["created_ts"],
|
media_info["created_ts"],
|
||||||
@ -253,10 +261,10 @@ class PreviewUrlResource(Resource):
|
|||||||
# we're most likely being explicitly triggered by a human rather than a
|
# we're most likely being explicitly triggered by a human rather than a
|
||||||
# bot, so are we really a robot?
|
# bot, so are we really a robot?
|
||||||
|
|
||||||
# XXX: horrible duplication with base_resource's _download_remote_file()
|
file_id = datetime.date.today().isoformat() + '_' + random_string(16)
|
||||||
file_id = random_string(24)
|
|
||||||
|
|
||||||
fname = self.filepaths.url_cache_filepath(file_id)
|
fpath = self.filepaths.url_cache_filepath_rel(file_id)
|
||||||
|
fname = os.path.join(self.primary_base_path, fpath)
|
||||||
self.media_repo._makedirs(fname)
|
self.media_repo._makedirs(fname)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -267,6 +275,8 @@ class PreviewUrlResource(Resource):
|
|||||||
)
|
)
|
||||||
# FIXME: pass through 404s and other error messages nicely
|
# FIXME: pass through 404s and other error messages nicely
|
||||||
|
|
||||||
|
yield self.media_repo.copy_to_backup(fpath)
|
||||||
|
|
||||||
media_type = headers["Content-Type"][0]
|
media_type = headers["Content-Type"][0]
|
||||||
time_now_ms = self.clock.time_msec()
|
time_now_ms = self.clock.time_msec()
|
||||||
|
|
||||||
@ -328,6 +338,91 @@ class PreviewUrlResource(Resource):
|
|||||||
"etag": headers["ETag"][0] if "ETag" in headers else None,
|
"etag": headers["ETag"][0] if "ETag" in headers else None,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _expire_url_cache_data(self):
|
||||||
|
"""Clean up expired url cache content, media and thumbnails.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO: Delete from backup media store
|
||||||
|
|
||||||
|
now = self.clock.time_msec()
|
||||||
|
|
||||||
|
# First we delete expired url cache entries
|
||||||
|
media_ids = yield self.store.get_expired_url_cache(now)
|
||||||
|
|
||||||
|
removed_media = []
|
||||||
|
for media_id in media_ids:
|
||||||
|
fname = self.filepaths.url_cache_filepath(media_id)
|
||||||
|
try:
|
||||||
|
os.remove(fname)
|
||||||
|
except OSError as e:
|
||||||
|
# If the path doesn't exist, meh
|
||||||
|
if e.errno != errno.ENOENT:
|
||||||
|
logger.warn("Failed to remove media: %r: %s", media_id, e)
|
||||||
|
continue
|
||||||
|
|
||||||
|
removed_media.append(media_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
|
||||||
|
for dir in dirs:
|
||||||
|
os.rmdir(dir)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
yield self.store.delete_url_cache(removed_media)
|
||||||
|
|
||||||
|
if removed_media:
|
||||||
|
logger.info("Deleted %d entries from url cache", len(removed_media))
|
||||||
|
|
||||||
|
# Now we delete old images associated with the url cache.
|
||||||
|
# These may be cached for a bit on the client (i.e., they
|
||||||
|
# may have a room open with a preview url thing open).
|
||||||
|
# So we wait a couple of days before deleting, just in case.
|
||||||
|
expire_before = now - 2 * 24 * 60 * 60 * 1000
|
||||||
|
media_ids = yield self.store.get_url_cache_media_before(expire_before)
|
||||||
|
|
||||||
|
removed_media = []
|
||||||
|
for media_id in media_ids:
|
||||||
|
fname = self.filepaths.url_cache_filepath(media_id)
|
||||||
|
try:
|
||||||
|
os.remove(fname)
|
||||||
|
except OSError as e:
|
||||||
|
# If the path doesn't exist, meh
|
||||||
|
if e.errno != errno.ENOENT:
|
||||||
|
logger.warn("Failed to remove media: %r: %s", media_id, e)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
|
||||||
|
for dir in dirs:
|
||||||
|
os.rmdir(dir)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
thumbnail_dir = self.filepaths.url_cache_thumbnail_directory(media_id)
|
||||||
|
try:
|
||||||
|
shutil.rmtree(thumbnail_dir)
|
||||||
|
except OSError as e:
|
||||||
|
# If the path doesn't exist, meh
|
||||||
|
if e.errno != errno.ENOENT:
|
||||||
|
logger.warn("Failed to remove media: %r: %s", media_id, e)
|
||||||
|
continue
|
||||||
|
|
||||||
|
removed_media.append(media_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
dirs = self.filepaths.url_cache_thumbnail_dirs_to_delete(media_id)
|
||||||
|
for dir in dirs:
|
||||||
|
os.rmdir(dir)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
yield self.store.delete_url_cache_media(removed_media)
|
||||||
|
|
||||||
|
if removed_media:
|
||||||
|
logger.info("Deleted %d media from url cache", len(removed_media))
|
||||||
|
|
||||||
|
|
||||||
def decode_and_calc_og(body, media_uri, request_encoding=None):
|
def decode_and_calc_og(body, media_uri, request_encoding=None):
|
||||||
from lxml import etree
|
from lxml import etree
|
||||||
|
@ -50,12 +50,16 @@ class Thumbnailer(object):
|
|||||||
else:
|
else:
|
||||||
return ((max_height * self.width) // self.height, max_height)
|
return ((max_height * self.width) // self.height, max_height)
|
||||||
|
|
||||||
def scale(self, output_path, width, height, output_type):
|
def scale(self, width, height, output_type):
|
||||||
"""Rescales the image to the given dimensions"""
|
"""Rescales the image to the given dimensions.
|
||||||
scaled = self.image.resize((width, height), Image.ANTIALIAS)
|
|
||||||
return self.save_image(scaled, output_type, output_path)
|
|
||||||
|
|
||||||
def crop(self, output_path, width, height, output_type):
|
Returns:
|
||||||
|
BytesIO: the bytes of the encoded image ready to be written to disk
|
||||||
|
"""
|
||||||
|
scaled = self.image.resize((width, height), Image.ANTIALIAS)
|
||||||
|
return self._encode_image(scaled, output_type)
|
||||||
|
|
||||||
|
def crop(self, width, height, output_type):
|
||||||
"""Rescales and crops the image to the given dimensions preserving
|
"""Rescales and crops the image to the given dimensions preserving
|
||||||
aspect::
|
aspect::
|
||||||
(w_in / h_in) = (w_scaled / h_scaled)
|
(w_in / h_in) = (w_scaled / h_scaled)
|
||||||
@ -65,6 +69,9 @@ class Thumbnailer(object):
|
|||||||
Args:
|
Args:
|
||||||
max_width: The largest possible width.
|
max_width: The largest possible width.
|
||||||
max_height: The larget possible height.
|
max_height: The larget possible height.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BytesIO: the bytes of the encoded image ready to be written to disk
|
||||||
"""
|
"""
|
||||||
if width * self.height > height * self.width:
|
if width * self.height > height * self.width:
|
||||||
scaled_height = (width * self.height) // self.width
|
scaled_height = (width * self.height) // self.width
|
||||||
@ -82,13 +89,9 @@ class Thumbnailer(object):
|
|||||||
crop_left = (scaled_width - width) // 2
|
crop_left = (scaled_width - width) // 2
|
||||||
crop_right = width + crop_left
|
crop_right = width + crop_left
|
||||||
cropped = scaled_image.crop((crop_left, 0, crop_right, height))
|
cropped = scaled_image.crop((crop_left, 0, crop_right, height))
|
||||||
return self.save_image(cropped, output_type, output_path)
|
return self._encode_image(cropped, output_type)
|
||||||
|
|
||||||
def save_image(self, output_image, output_type, output_path):
|
def _encode_image(self, output_image, output_type):
|
||||||
output_bytes_io = BytesIO()
|
output_bytes_io = BytesIO()
|
||||||
output_image.save(output_bytes_io, self.FORMATS[output_type], quality=80)
|
output_image.save(output_bytes_io, self.FORMATS[output_type], quality=80)
|
||||||
output_bytes = output_bytes_io.getvalue()
|
return output_bytes_io
|
||||||
with open(output_path, "wb") as output_file:
|
|
||||||
output_file.write(output_bytes)
|
|
||||||
logger.info("Stored thumbnail in file %r", output_path)
|
|
||||||
return len(output_bytes)
|
|
||||||
|
@ -93,7 +93,7 @@ class UploadResource(Resource):
|
|||||||
# TODO(markjh): parse content-dispostion
|
# TODO(markjh): parse content-dispostion
|
||||||
|
|
||||||
content_uri = yield self.media_repo.create_content(
|
content_uri = yield self.media_repo.create_content(
|
||||||
media_type, upload_name, request.content.read(),
|
media_type, upload_name, request.content,
|
||||||
content_length, requester.user
|
content_length, requester.user
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -31,6 +31,7 @@ from synapse.appservice.api import ApplicationServiceApi
|
|||||||
from synapse.appservice.scheduler import ApplicationServiceScheduler
|
from synapse.appservice.scheduler import ApplicationServiceScheduler
|
||||||
from synapse.crypto.keyring import Keyring
|
from synapse.crypto.keyring import Keyring
|
||||||
from synapse.events.builder import EventBuilderFactory
|
from synapse.events.builder import EventBuilderFactory
|
||||||
|
from synapse.events.spamcheck import SpamChecker
|
||||||
from synapse.federation import initialize_http_replication
|
from synapse.federation import initialize_http_replication
|
||||||
from synapse.federation.send_queue import FederationRemoteSendQueue
|
from synapse.federation.send_queue import FederationRemoteSendQueue
|
||||||
from synapse.federation.transport.client import TransportLayerClient
|
from synapse.federation.transport.client import TransportLayerClient
|
||||||
@ -50,6 +51,10 @@ from synapse.handlers.initial_sync import InitialSyncHandler
|
|||||||
from synapse.handlers.receipts import ReceiptsHandler
|
from synapse.handlers.receipts import ReceiptsHandler
|
||||||
from synapse.handlers.read_marker import ReadMarkerHandler
|
from synapse.handlers.read_marker import ReadMarkerHandler
|
||||||
from synapse.handlers.user_directory import UserDirectoyHandler
|
from synapse.handlers.user_directory import UserDirectoyHandler
|
||||||
|
from synapse.handlers.groups_local import GroupsLocalHandler
|
||||||
|
from synapse.handlers.profile import ProfileHandler
|
||||||
|
from synapse.groups.groups_server import GroupsServerHandler
|
||||||
|
from synapse.groups.attestations import GroupAttestionRenewer, GroupAttestationSigning
|
||||||
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
|
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
|
||||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||||
from synapse.notifier import Notifier
|
from synapse.notifier import Notifier
|
||||||
@ -111,6 +116,7 @@ class HomeServer(object):
|
|||||||
'application_service_scheduler',
|
'application_service_scheduler',
|
||||||
'application_service_handler',
|
'application_service_handler',
|
||||||
'device_message_handler',
|
'device_message_handler',
|
||||||
|
'profile_handler',
|
||||||
'notifier',
|
'notifier',
|
||||||
'distributor',
|
'distributor',
|
||||||
'client_resource',
|
'client_resource',
|
||||||
@ -139,6 +145,11 @@ class HomeServer(object):
|
|||||||
'read_marker_handler',
|
'read_marker_handler',
|
||||||
'action_generator',
|
'action_generator',
|
||||||
'user_directory_handler',
|
'user_directory_handler',
|
||||||
|
'groups_local_handler',
|
||||||
|
'groups_server_handler',
|
||||||
|
'groups_attestation_signing',
|
||||||
|
'groups_attestation_renewer',
|
||||||
|
'spam_checker',
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, hostname, **kwargs):
|
def __init__(self, hostname, **kwargs):
|
||||||
@ -251,6 +262,9 @@ class HomeServer(object):
|
|||||||
def build_initial_sync_handler(self):
|
def build_initial_sync_handler(self):
|
||||||
return InitialSyncHandler(self)
|
return InitialSyncHandler(self)
|
||||||
|
|
||||||
|
def build_profile_handler(self):
|
||||||
|
return ProfileHandler(self)
|
||||||
|
|
||||||
def build_event_sources(self):
|
def build_event_sources(self):
|
||||||
return EventSources(self)
|
return EventSources(self)
|
||||||
|
|
||||||
@ -309,6 +323,21 @@ class HomeServer(object):
|
|||||||
def build_user_directory_handler(self):
|
def build_user_directory_handler(self):
|
||||||
return UserDirectoyHandler(self)
|
return UserDirectoyHandler(self)
|
||||||
|
|
||||||
|
def build_groups_local_handler(self):
|
||||||
|
return GroupsLocalHandler(self)
|
||||||
|
|
||||||
|
def build_groups_server_handler(self):
|
||||||
|
return GroupsServerHandler(self)
|
||||||
|
|
||||||
|
def build_groups_attestation_signing(self):
|
||||||
|
return GroupAttestationSigning(self)
|
||||||
|
|
||||||
|
def build_groups_attestation_renewer(self):
|
||||||
|
return GroupAttestionRenewer(self)
|
||||||
|
|
||||||
|
def build_spam_checker(self):
|
||||||
|
return SpamChecker(self)
|
||||||
|
|
||||||
def remove_pusher(self, app_id, push_key, user_id):
|
def remove_pusher(self, app_id, push_key, user_id):
|
||||||
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
|
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
|
||||||
|
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
import synapse.api.auth
|
import synapse.api.auth
|
||||||
|
import synapse.federation.transaction_queue
|
||||||
|
import synapse.federation.transport.client
|
||||||
import synapse.handlers
|
import synapse.handlers
|
||||||
import synapse.handlers.auth
|
import synapse.handlers.auth
|
||||||
import synapse.handlers.device
|
import synapse.handlers.device
|
||||||
@ -27,3 +29,9 @@ class HomeServer(object):
|
|||||||
|
|
||||||
def get_state_handler(self) -> synapse.state.StateHandler:
|
def get_state_handler(self) -> synapse.state.StateHandler:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_federation_sender(self) -> synapse.federation.transaction_queue.TransactionQueue:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_federation_transport_client(self) -> synapse.federation.transport.client.TransportLayerClient:
|
||||||
|
pass
|
||||||
|
@ -288,6 +288,9 @@ class StateHandler(object):
|
|||||||
"""
|
"""
|
||||||
logger.debug("resolve_state_groups event_ids %s", event_ids)
|
logger.debug("resolve_state_groups event_ids %s", event_ids)
|
||||||
|
|
||||||
|
# map from state group id to the state in that state group (where
|
||||||
|
# 'state' is a map from state key to event id)
|
||||||
|
# dict[int, dict[(str, str), str]]
|
||||||
state_groups_ids = yield self.store.get_state_groups_ids(
|
state_groups_ids = yield self.store.get_state_groups_ids(
|
||||||
room_id, event_ids
|
room_id, event_ids
|
||||||
)
|
)
|
||||||
@ -320,11 +323,15 @@ class StateHandler(object):
|
|||||||
"Resolving state for %s with %d groups", room_id, len(state_groups_ids)
|
"Resolving state for %s with %d groups", room_id, len(state_groups_ids)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# build a map from state key to the event_ids which set that state.
|
||||||
|
# dict[(str, str), set[str])
|
||||||
state = {}
|
state = {}
|
||||||
for st in state_groups_ids.values():
|
for st in state_groups_ids.values():
|
||||||
for key, e_id in st.items():
|
for key, e_id in st.items():
|
||||||
state.setdefault(key, set()).add(e_id)
|
state.setdefault(key, set()).add(e_id)
|
||||||
|
|
||||||
|
# build a map from state key to the event_ids which set that state,
|
||||||
|
# including only those where there are state keys in conflict.
|
||||||
conflicted_state = {
|
conflicted_state = {
|
||||||
k: list(v)
|
k: list(v)
|
||||||
for k, v in state.items()
|
for k, v in state.items()
|
||||||
@ -494,8 +501,14 @@ def _resolve_with_state_fac(unconflicted_state, conflicted_state,
|
|||||||
|
|
||||||
logger.info("Asking for %d conflicted events", len(needed_events))
|
logger.info("Asking for %d conflicted events", len(needed_events))
|
||||||
|
|
||||||
|
# dict[str, FrozenEvent]: a map from state event id to event. Only includes
|
||||||
|
# the state events which are in conflict.
|
||||||
state_map = yield state_map_factory(needed_events)
|
state_map = yield state_map_factory(needed_events)
|
||||||
|
|
||||||
|
# get the ids of the auth events which allow us to authenticate the
|
||||||
|
# conflicted state, picking only from the unconflicting state.
|
||||||
|
#
|
||||||
|
# dict[(str, str), str]: a map from state key to event id
|
||||||
auth_events = _create_auth_events_from_maps(
|
auth_events = _create_auth_events_from_maps(
|
||||||
unconflicted_state, conflicted_state, state_map
|
unconflicted_state, conflicted_state, state_map
|
||||||
)
|
)
|
||||||
|
@ -37,7 +37,7 @@ from .media_repository import MediaRepositoryStore
|
|||||||
from .rejections import RejectionsStore
|
from .rejections import RejectionsStore
|
||||||
from .event_push_actions import EventPushActionsStore
|
from .event_push_actions import EventPushActionsStore
|
||||||
from .deviceinbox import DeviceInboxStore
|
from .deviceinbox import DeviceInboxStore
|
||||||
|
from .group_server import GroupServerStore
|
||||||
from .state import StateStore
|
from .state import StateStore
|
||||||
from .signatures import SignatureStore
|
from .signatures import SignatureStore
|
||||||
from .filtering import FilteringStore
|
from .filtering import FilteringStore
|
||||||
@ -88,6 +88,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||||||
DeviceStore,
|
DeviceStore,
|
||||||
DeviceInboxStore,
|
DeviceInboxStore,
|
||||||
UserDirectoryStore,
|
UserDirectoryStore,
|
||||||
|
GroupServerStore,
|
||||||
):
|
):
|
||||||
|
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
@ -135,6 +136,9 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||||||
db_conn, "pushers", "id",
|
db_conn, "pushers", "id",
|
||||||
extra_tables=[("deleted_pushers", "stream_id")],
|
extra_tables=[("deleted_pushers", "stream_id")],
|
||||||
)
|
)
|
||||||
|
self._group_updates_id_gen = StreamIdGenerator(
|
||||||
|
db_conn, "local_group_updates", "stream_id",
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(self.database_engine, PostgresEngine):
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
self._cache_id_gen = StreamIdGenerator(
|
self._cache_id_gen = StreamIdGenerator(
|
||||||
@ -235,6 +239,18 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||||||
prefilled_cache=curr_state_delta_prefill,
|
prefilled_cache=curr_state_delta_prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_group_updates_prefill, min_group_updates_id = self._get_cache_dict(
|
||||||
|
db_conn, "local_group_updates",
|
||||||
|
entity_column="user_id",
|
||||||
|
stream_column="stream_id",
|
||||||
|
max_value=self._group_updates_id_gen.get_current_token(),
|
||||||
|
limit=1000,
|
||||||
|
)
|
||||||
|
self._group_updates_stream_cache = StreamChangeCache(
|
||||||
|
"_group_updates_stream_cache", min_group_updates_id,
|
||||||
|
prefilled_cache=_group_updates_prefill,
|
||||||
|
)
|
||||||
|
|
||||||
cur = LoggingTransaction(
|
cur = LoggingTransaction(
|
||||||
db_conn.cursor(),
|
db_conn.cursor(),
|
||||||
name="_find_stream_orderings_for_times_txn",
|
name="_find_stream_orderings_for_times_txn",
|
||||||
|
@ -743,6 +743,33 @@ class SQLBaseStore(object):
|
|||||||
txn.execute(sql, values)
|
txn.execute(sql, values)
|
||||||
return cls.cursor_to_dict(txn)
|
return cls.cursor_to_dict(txn)
|
||||||
|
|
||||||
|
def _simple_update(self, table, keyvalues, updatevalues, desc):
|
||||||
|
return self.runInteraction(
|
||||||
|
desc,
|
||||||
|
self._simple_update_txn,
|
||||||
|
table, keyvalues, updatevalues,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _simple_update_txn(txn, table, keyvalues, updatevalues):
|
||||||
|
if keyvalues:
|
||||||
|
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
|
||||||
|
else:
|
||||||
|
where = ""
|
||||||
|
|
||||||
|
update_sql = "UPDATE %s SET %s %s" % (
|
||||||
|
table,
|
||||||
|
", ".join("%s = ?" % (k,) for k in updatevalues),
|
||||||
|
where,
|
||||||
|
)
|
||||||
|
|
||||||
|
txn.execute(
|
||||||
|
update_sql,
|
||||||
|
updatevalues.values() + keyvalues.values()
|
||||||
|
)
|
||||||
|
|
||||||
|
return txn.rowcount
|
||||||
|
|
||||||
def _simple_update_one(self, table, keyvalues, updatevalues,
|
def _simple_update_one(self, table, keyvalues, updatevalues,
|
||||||
desc="_simple_update_one"):
|
desc="_simple_update_one"):
|
||||||
"""Executes an UPDATE query on the named table, setting new values for
|
"""Executes an UPDATE query on the named table, setting new values for
|
||||||
@ -768,27 +795,13 @@ class SQLBaseStore(object):
|
|||||||
table, keyvalues, updatevalues,
|
table, keyvalues, updatevalues,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
|
def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
|
||||||
if keyvalues:
|
rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues)
|
||||||
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
|
|
||||||
else:
|
|
||||||
where = ""
|
|
||||||
|
|
||||||
update_sql = "UPDATE %s SET %s %s" % (
|
if rowcount == 0:
|
||||||
table,
|
|
||||||
", ".join("%s = ?" % (k,) for k in updatevalues),
|
|
||||||
where,
|
|
||||||
)
|
|
||||||
|
|
||||||
txn.execute(
|
|
||||||
update_sql,
|
|
||||||
updatevalues.values() + keyvalues.values()
|
|
||||||
)
|
|
||||||
|
|
||||||
if txn.rowcount == 0:
|
|
||||||
raise StoreError(404, "No row found")
|
raise StoreError(404, "No row found")
|
||||||
if txn.rowcount > 1:
|
if rowcount > 1:
|
||||||
raise StoreError(500, "More than one row matched")
|
raise StoreError(500, "More than one row matched")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -21,7 +21,7 @@ from synapse.events.utils import prune_event
|
|||||||
|
|
||||||
from synapse.util.async import ObservableDeferred
|
from synapse.util.async import ObservableDeferred
|
||||||
from synapse.util.logcontext import (
|
from synapse.util.logcontext import (
|
||||||
preserve_fn, PreserveLoggingContext, preserve_context_over_deferred
|
preserve_fn, PreserveLoggingContext, make_deferred_yieldable
|
||||||
)
|
)
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
@ -88,13 +88,23 @@ class _EventPeristenceQueue(object):
|
|||||||
def add_to_queue(self, room_id, events_and_contexts, backfilled):
|
def add_to_queue(self, room_id, events_and_contexts, backfilled):
|
||||||
"""Add events to the queue, with the given persist_event options.
|
"""Add events to the queue, with the given persist_event options.
|
||||||
|
|
||||||
|
NB: due to the normal usage pattern of this method, it does *not*
|
||||||
|
follow the synapse logcontext rules, and leaves the logcontext in
|
||||||
|
place whether or not the returned deferred is ready.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_id (str):
|
room_id (str):
|
||||||
events_and_contexts (list[(EventBase, EventContext)]):
|
events_and_contexts (list[(EventBase, EventContext)]):
|
||||||
backfilled (bool):
|
backfilled (bool):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
defer.Deferred: a deferred which will resolve once the events are
|
||||||
|
persisted. Runs its callbacks *without* a logcontext.
|
||||||
"""
|
"""
|
||||||
queue = self._event_persist_queues.setdefault(room_id, deque())
|
queue = self._event_persist_queues.setdefault(room_id, deque())
|
||||||
if queue:
|
if queue:
|
||||||
|
# if the last item in the queue has the same `backfilled` setting,
|
||||||
|
# we can just add these new events to that item.
|
||||||
end_item = queue[-1]
|
end_item = queue[-1]
|
||||||
if end_item.backfilled == backfilled:
|
if end_item.backfilled == backfilled:
|
||||||
end_item.events_and_contexts.extend(events_and_contexts)
|
end_item.events_and_contexts.extend(events_and_contexts)
|
||||||
@ -113,11 +123,11 @@ class _EventPeristenceQueue(object):
|
|||||||
def handle_queue(self, room_id, per_item_callback):
|
def handle_queue(self, room_id, per_item_callback):
|
||||||
"""Attempts to handle the queue for a room if not already being handled.
|
"""Attempts to handle the queue for a room if not already being handled.
|
||||||
|
|
||||||
The given callback will be invoked with for each item in the queue,1
|
The given callback will be invoked with for each item in the queue,
|
||||||
of type _EventPersistQueueItem. The per_item_callback will continuously
|
of type _EventPersistQueueItem. The per_item_callback will continuously
|
||||||
be called with new items, unless the queue becomnes empty. The return
|
be called with new items, unless the queue becomnes empty. The return
|
||||||
value of the function will be given to the deferreds waiting on the item,
|
value of the function will be given to the deferreds waiting on the item,
|
||||||
exceptions will be passed to the deferres as well.
|
exceptions will be passed to the deferreds as well.
|
||||||
|
|
||||||
This function should therefore be called whenever anything is added
|
This function should therefore be called whenever anything is added
|
||||||
to the queue.
|
to the queue.
|
||||||
@ -233,7 +243,7 @@ class EventsStore(SQLBaseStore):
|
|||||||
|
|
||||||
deferreds = []
|
deferreds = []
|
||||||
for room_id, evs_ctxs in partitioned.iteritems():
|
for room_id, evs_ctxs in partitioned.iteritems():
|
||||||
d = preserve_fn(self._event_persist_queue.add_to_queue)(
|
d = self._event_persist_queue.add_to_queue(
|
||||||
room_id, evs_ctxs,
|
room_id, evs_ctxs,
|
||||||
backfilled=backfilled,
|
backfilled=backfilled,
|
||||||
)
|
)
|
||||||
@ -242,7 +252,7 @@ class EventsStore(SQLBaseStore):
|
|||||||
for room_id in partitioned:
|
for room_id in partitioned:
|
||||||
self._maybe_start_persisting(room_id)
|
self._maybe_start_persisting(room_id)
|
||||||
|
|
||||||
return preserve_context_over_deferred(
|
return make_deferred_yieldable(
|
||||||
defer.gatherResults(deferreds, consumeErrors=True)
|
defer.gatherResults(deferreds, consumeErrors=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -267,7 +277,7 @@ class EventsStore(SQLBaseStore):
|
|||||||
|
|
||||||
self._maybe_start_persisting(event.room_id)
|
self._maybe_start_persisting(event.room_id)
|
||||||
|
|
||||||
yield preserve_context_over_deferred(deferred)
|
yield make_deferred_yieldable(deferred)
|
||||||
|
|
||||||
max_persisted_id = yield self._stream_id_gen.get_current_token()
|
max_persisted_id = yield self._stream_id_gen.get_current_token()
|
||||||
defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id))
|
defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id))
|
||||||
@ -784,6 +794,9 @@ class EventsStore(SQLBaseStore):
|
|||||||
self._invalidate_cache_and_stream(
|
self._invalidate_cache_and_stream(
|
||||||
txn, self.is_host_joined, (room_id, host)
|
txn, self.is_host_joined, (room_id, host)
|
||||||
)
|
)
|
||||||
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.was_host_joined, (room_id, host)
|
||||||
|
)
|
||||||
|
|
||||||
self._invalidate_cache_and_stream(
|
self._invalidate_cache_and_stream(
|
||||||
txn, self.get_users_in_room, (room_id,)
|
txn, self.get_users_in_room, (room_id,)
|
||||||
@ -1523,7 +1536,7 @@ class EventsStore(SQLBaseStore):
|
|||||||
if not allow_rejected:
|
if not allow_rejected:
|
||||||
rows[:] = [r for r in rows if not r["rejects"]]
|
rows[:] = [r for r in rows if not r["rejects"]]
|
||||||
|
|
||||||
res = yield preserve_context_over_deferred(defer.gatherResults(
|
res = yield make_deferred_yieldable(defer.gatherResults(
|
||||||
[
|
[
|
||||||
preserve_fn(self._get_event_from_row)(
|
preserve_fn(self._get_event_from_row)(
|
||||||
row["internal_metadata"], row["json"], row["redacts"],
|
row["internal_metadata"], row["json"], row["redacts"],
|
||||||
|
1199
synapse/storage/group_server.py
Normal file
1199
synapse/storage/group_server.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -62,7 +62,7 @@ class MediaRepositoryStore(SQLBaseStore):
|
|||||||
def get_url_cache_txn(txn):
|
def get_url_cache_txn(txn):
|
||||||
# get the most recently cached result (relative to the given ts)
|
# get the most recently cached result (relative to the given ts)
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT response_code, etag, expires, og, media_id, download_ts"
|
"SELECT response_code, etag, expires_ts, og, media_id, download_ts"
|
||||||
" FROM local_media_repository_url_cache"
|
" FROM local_media_repository_url_cache"
|
||||||
" WHERE url = ? AND download_ts <= ?"
|
" WHERE url = ? AND download_ts <= ?"
|
||||||
" ORDER BY download_ts DESC LIMIT 1"
|
" ORDER BY download_ts DESC LIMIT 1"
|
||||||
@ -74,7 +74,7 @@ class MediaRepositoryStore(SQLBaseStore):
|
|||||||
# ...or if we've requested a timestamp older than the oldest
|
# ...or if we've requested a timestamp older than the oldest
|
||||||
# copy in the cache, return the oldest copy (if any)
|
# copy in the cache, return the oldest copy (if any)
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT response_code, etag, expires, og, media_id, download_ts"
|
"SELECT response_code, etag, expires_ts, og, media_id, download_ts"
|
||||||
" FROM local_media_repository_url_cache"
|
" FROM local_media_repository_url_cache"
|
||||||
" WHERE url = ? AND download_ts > ?"
|
" WHERE url = ? AND download_ts > ?"
|
||||||
" ORDER BY download_ts ASC LIMIT 1"
|
" ORDER BY download_ts ASC LIMIT 1"
|
||||||
@ -86,14 +86,14 @@ class MediaRepositoryStore(SQLBaseStore):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
return dict(zip((
|
return dict(zip((
|
||||||
'response_code', 'etag', 'expires', 'og', 'media_id', 'download_ts'
|
'response_code', 'etag', 'expires_ts', 'og', 'media_id', 'download_ts'
|
||||||
), row))
|
), row))
|
||||||
|
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"get_url_cache", get_url_cache_txn
|
"get_url_cache", get_url_cache_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def store_url_cache(self, url, response_code, etag, expires, og, media_id,
|
def store_url_cache(self, url, response_code, etag, expires_ts, og, media_id,
|
||||||
download_ts):
|
download_ts):
|
||||||
return self._simple_insert(
|
return self._simple_insert(
|
||||||
"local_media_repository_url_cache",
|
"local_media_repository_url_cache",
|
||||||
@ -101,7 +101,7 @@ class MediaRepositoryStore(SQLBaseStore):
|
|||||||
"url": url,
|
"url": url,
|
||||||
"response_code": response_code,
|
"response_code": response_code,
|
||||||
"etag": etag,
|
"etag": etag,
|
||||||
"expires": expires,
|
"expires_ts": expires_ts,
|
||||||
"og": og,
|
"og": og,
|
||||||
"media_id": media_id,
|
"media_id": media_id,
|
||||||
"download_ts": download_ts,
|
"download_ts": download_ts,
|
||||||
@ -238,3 +238,64 @@ class MediaRepositoryStore(SQLBaseStore):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
return self.runInteraction("delete_remote_media", delete_remote_media_txn)
|
return self.runInteraction("delete_remote_media", delete_remote_media_txn)
|
||||||
|
|
||||||
|
def get_expired_url_cache(self, now_ts):
|
||||||
|
sql = (
|
||||||
|
"SELECT media_id FROM local_media_repository_url_cache"
|
||||||
|
" WHERE expires_ts < ?"
|
||||||
|
" ORDER BY expires_ts ASC"
|
||||||
|
" LIMIT 500"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_expired_url_cache_txn(txn):
|
||||||
|
txn.execute(sql, (now_ts,))
|
||||||
|
return [row[0] for row in txn]
|
||||||
|
|
||||||
|
return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn)
|
||||||
|
|
||||||
|
def delete_url_cache(self, media_ids):
|
||||||
|
sql = (
|
||||||
|
"DELETE FROM local_media_repository_url_cache"
|
||||||
|
" WHERE media_id = ?"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _delete_url_cache_txn(txn):
|
||||||
|
txn.executemany(sql, [(media_id,) for media_id in media_ids])
|
||||||
|
|
||||||
|
return self.runInteraction("delete_url_cache", _delete_url_cache_txn)
|
||||||
|
|
||||||
|
def get_url_cache_media_before(self, before_ts):
|
||||||
|
sql = (
|
||||||
|
"SELECT media_id FROM local_media_repository"
|
||||||
|
" WHERE created_ts < ? AND url_cache IS NOT NULL"
|
||||||
|
" ORDER BY created_ts ASC"
|
||||||
|
" LIMIT 500"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_url_cache_media_before_txn(txn):
|
||||||
|
txn.execute(sql, (before_ts,))
|
||||||
|
return [row[0] for row in txn]
|
||||||
|
|
||||||
|
return self.runInteraction(
|
||||||
|
"get_url_cache_media_before", _get_url_cache_media_before_txn,
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_url_cache_media(self, media_ids):
|
||||||
|
def _delete_url_cache_media_txn(txn):
|
||||||
|
sql = (
|
||||||
|
"DELETE FROM local_media_repository"
|
||||||
|
" WHERE media_id = ?"
|
||||||
|
)
|
||||||
|
|
||||||
|
txn.executemany(sql, [(media_id,) for media_id in media_ids])
|
||||||
|
|
||||||
|
sql = (
|
||||||
|
"DELETE FROM local_media_repository_thumbnails"
|
||||||
|
" WHERE media_id = ?"
|
||||||
|
)
|
||||||
|
|
||||||
|
txn.executemany(sql, [(media_id,) for media_id in media_ids])
|
||||||
|
|
||||||
|
return self.runInteraction(
|
||||||
|
"delete_url_cache_media", _delete_url_cache_media_txn,
|
||||||
|
)
|
||||||
|
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# Remember to update this number every time a change is made to database
|
# Remember to update this number every time a change is made to database
|
||||||
# schema files, so the users will be informed on server restarts.
|
# schema files, so the users will be informed on server restarts.
|
||||||
SCHEMA_VERSION = 43
|
SCHEMA_VERSION = 45
|
||||||
|
|
||||||
dir_path = os.path.abspath(os.path.dirname(__file__))
|
dir_path = os.path.abspath(os.path.dirname(__file__))
|
||||||
|
|
||||||
|
@ -13,6 +13,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
|
|
||||||
|
|
||||||
@ -55,3 +57,99 @@ class ProfileStore(SQLBaseStore):
|
|||||||
updatevalues={"avatar_url": new_avatar_url},
|
updatevalues={"avatar_url": new_avatar_url},
|
||||||
desc="set_profile_avatar_url",
|
desc="set_profile_avatar_url",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_from_remote_profile_cache(self, user_id):
|
||||||
|
return self._simple_select_one(
|
||||||
|
table="remote_profile_cache",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
retcols=("displayname", "avatar_url",),
|
||||||
|
allow_none=True,
|
||||||
|
desc="get_from_remote_profile_cache",
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_remote_profile_cache(self, user_id, displayname, avatar_url):
|
||||||
|
"""Ensure we are caching the remote user's profiles.
|
||||||
|
|
||||||
|
This should only be called when `is_subscribed_remote_profile_for_user`
|
||||||
|
would return true for the user.
|
||||||
|
"""
|
||||||
|
return self._simple_upsert(
|
||||||
|
table="remote_profile_cache",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
values={
|
||||||
|
"displayname": displayname,
|
||||||
|
"avatar_url": avatar_url,
|
||||||
|
"last_check": self._clock.time_msec(),
|
||||||
|
},
|
||||||
|
desc="add_remote_profile_cache",
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_remote_profile_cache(self, user_id, displayname, avatar_url):
|
||||||
|
return self._simple_update(
|
||||||
|
table="remote_profile_cache",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
values={
|
||||||
|
"displayname": displayname,
|
||||||
|
"avatar_url": avatar_url,
|
||||||
|
"last_check": self._clock.time_msec(),
|
||||||
|
},
|
||||||
|
desc="update_remote_profile_cache",
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def maybe_delete_remote_profile_cache(self, user_id):
|
||||||
|
"""Check if we still care about the remote user's profile, and if we
|
||||||
|
don't then remove their profile from the cache
|
||||||
|
"""
|
||||||
|
subscribed = yield self.is_subscribed_remote_profile_for_user(user_id)
|
||||||
|
if not subscribed:
|
||||||
|
yield self._simple_delete(
|
||||||
|
table="remote_profile_cache",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
desc="delete_remote_profile_cache",
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_remote_profile_cache_entries_that_expire(self, last_checked):
|
||||||
|
"""Get all users who haven't been checked since `last_checked`
|
||||||
|
"""
|
||||||
|
def _get_remote_profile_cache_entries_that_expire_txn(txn):
|
||||||
|
sql = """
|
||||||
|
SELECT user_id, displayname, avatar_url
|
||||||
|
FROM remote_profile_cache
|
||||||
|
WHERE last_check < ?
|
||||||
|
"""
|
||||||
|
|
||||||
|
txn.execute(sql, (last_checked,))
|
||||||
|
|
||||||
|
return self.cursor_to_dict(txn)
|
||||||
|
|
||||||
|
return self.runInteraction(
|
||||||
|
"get_remote_profile_cache_entries_that_expire",
|
||||||
|
_get_remote_profile_cache_entries_that_expire_txn,
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def is_subscribed_remote_profile_for_user(self, user_id):
|
||||||
|
"""Check whether we are interested in a remote user's profile.
|
||||||
|
"""
|
||||||
|
res = yield self._simple_select_one_onecol(
|
||||||
|
table="group_users",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
retcol="user_id",
|
||||||
|
allow_none=True,
|
||||||
|
desc="should_update_remote_profile_cache_for_user",
|
||||||
|
)
|
||||||
|
|
||||||
|
if res:
|
||||||
|
defer.returnValue(True)
|
||||||
|
|
||||||
|
res = yield self._simple_select_one_onecol(
|
||||||
|
table="group_invites",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
retcol="user_id",
|
||||||
|
allow_none=True,
|
||||||
|
desc="should_update_remote_profile_cache_for_user",
|
||||||
|
)
|
||||||
|
|
||||||
|
if res:
|
||||||
|
defer.returnValue(True)
|
||||||
|
@ -533,6 +533,46 @@ class RoomMemberStore(SQLBaseStore):
|
|||||||
|
|
||||||
defer.returnValue(True)
|
defer.returnValue(True)
|
||||||
|
|
||||||
|
@cachedInlineCallbacks()
|
||||||
|
def was_host_joined(self, room_id, host):
|
||||||
|
"""Check whether the server is or ever was in the room.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id (str)
|
||||||
|
host (str)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred: Resolves to True if the host is/was in the room, otherwise
|
||||||
|
False.
|
||||||
|
"""
|
||||||
|
if '%' in host or '_' in host:
|
||||||
|
raise Exception("Invalid host name")
|
||||||
|
|
||||||
|
sql = """
|
||||||
|
SELECT user_id FROM room_memberships
|
||||||
|
WHERE room_id = ?
|
||||||
|
AND user_id LIKE ?
|
||||||
|
AND membership = 'join'
|
||||||
|
LIMIT 1
|
||||||
|
"""
|
||||||
|
|
||||||
|
# We do need to be careful to ensure that host doesn't have any wild cards
|
||||||
|
# in it, but we checked above for known ones and we'll check below that
|
||||||
|
# the returned user actually has the correct domain.
|
||||||
|
like_clause = "%:" + host
|
||||||
|
|
||||||
|
rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause)
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
defer.returnValue(False)
|
||||||
|
|
||||||
|
user_id = rows[0][0]
|
||||||
|
if get_domain_from_id(user_id) != host:
|
||||||
|
# This can only happen if the host name has something funky in it
|
||||||
|
raise Exception("Invalid host name")
|
||||||
|
|
||||||
|
defer.returnValue(True)
|
||||||
|
|
||||||
def get_joined_hosts(self, room_id, state_entry):
|
def get_joined_hosts(self, room_id, state_entry):
|
||||||
state_group = state_entry.state_group
|
state_group = state_entry.state_group
|
||||||
if not state_group:
|
if not state_group:
|
||||||
|
38
synapse/storage/schema/delta/44/expire_url_cache.sql
Normal file
38
synapse/storage/schema/delta/44/expire_url_cache.sql
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
/* Copyright 2017 New Vector Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL;
|
||||||
|
|
||||||
|
-- we need to change `expires` to `expires_ts` so that we can index on it. SQLite doesn't support
|
||||||
|
-- indices on expressions until 3.9.
|
||||||
|
CREATE TABLE local_media_repository_url_cache_new(
|
||||||
|
url TEXT,
|
||||||
|
response_code INTEGER,
|
||||||
|
etag TEXT,
|
||||||
|
expires_ts BIGINT,
|
||||||
|
og TEXT,
|
||||||
|
media_id TEXT,
|
||||||
|
download_ts BIGINT
|
||||||
|
);
|
||||||
|
|
||||||
|
INSERT INTO local_media_repository_url_cache_new
|
||||||
|
SELECT url, response_code, etag, expires + download_ts, og, media_id, download_ts FROM local_media_repository_url_cache;
|
||||||
|
|
||||||
|
DROP TABLE local_media_repository_url_cache;
|
||||||
|
ALTER TABLE local_media_repository_url_cache_new RENAME TO local_media_repository_url_cache;
|
||||||
|
|
||||||
|
CREATE INDEX local_media_repository_url_cache_expires_idx ON local_media_repository_url_cache(expires_ts);
|
||||||
|
CREATE INDEX local_media_repository_url_cache_by_url_download_ts ON local_media_repository_url_cache(url, download_ts);
|
||||||
|
CREATE INDEX local_media_repository_url_cache_media_idx ON local_media_repository_url_cache(media_id);
|
167
synapse/storage/schema/delta/45/group_server.sql
Normal file
167
synapse/storage/schema/delta/45/group_server.sql
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
/* Copyright 2017 Vector Creations Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
CREATE TABLE groups (
|
||||||
|
group_id TEXT NOT NULL,
|
||||||
|
name TEXT, -- the display name of the room
|
||||||
|
avatar_url TEXT,
|
||||||
|
short_description TEXT,
|
||||||
|
long_description TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX groups_idx ON groups(group_id);
|
||||||
|
|
||||||
|
|
||||||
|
-- list of users the group server thinks are joined
|
||||||
|
CREATE TABLE group_users (
|
||||||
|
group_id TEXT NOT NULL,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
is_admin BOOLEAN NOT NULL,
|
||||||
|
is_public BOOLEAN NOT NULL -- whether the users membership can be seen by everyone
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
|
CREATE INDEX groups_users_g_idx ON group_users(group_id, user_id);
|
||||||
|
CREATE INDEX groups_users_u_idx ON group_users(user_id);
|
||||||
|
|
||||||
|
-- list of users the group server thinks are invited
|
||||||
|
CREATE TABLE group_invites (
|
||||||
|
group_id TEXT NOT NULL,
|
||||||
|
user_id TEXT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX groups_invites_g_idx ON group_invites(group_id, user_id);
|
||||||
|
CREATE INDEX groups_invites_u_idx ON group_invites(user_id);
|
||||||
|
|
||||||
|
|
||||||
|
CREATE TABLE group_rooms (
|
||||||
|
group_id TEXT NOT NULL,
|
||||||
|
room_id TEXT NOT NULL,
|
||||||
|
is_public BOOLEAN NOT NULL -- whether the room can be seen by everyone
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX groups_rooms_g_idx ON group_rooms(group_id, room_id);
|
||||||
|
CREATE INDEX groups_rooms_r_idx ON group_rooms(room_id);
|
||||||
|
|
||||||
|
|
||||||
|
-- Rooms to include in the summary
|
||||||
|
CREATE TABLE group_summary_rooms (
|
||||||
|
group_id TEXT NOT NULL,
|
||||||
|
room_id TEXT NOT NULL,
|
||||||
|
category_id TEXT NOT NULL,
|
||||||
|
room_order BIGINT NOT NULL,
|
||||||
|
is_public BOOLEAN NOT NULL, -- whether the room should be show to everyone
|
||||||
|
UNIQUE (group_id, category_id, room_id, room_order),
|
||||||
|
CHECK (room_order > 0)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX group_summary_rooms_g_idx ON group_summary_rooms(group_id, room_id, category_id);
|
||||||
|
|
||||||
|
|
||||||
|
-- Categories to include in the summary
|
||||||
|
CREATE TABLE group_summary_room_categories (
|
||||||
|
group_id TEXT NOT NULL,
|
||||||
|
category_id TEXT NOT NULL,
|
||||||
|
cat_order BIGINT NOT NULL,
|
||||||
|
UNIQUE (group_id, category_id, cat_order),
|
||||||
|
CHECK (cat_order > 0)
|
||||||
|
);
|
||||||
|
|
||||||
|
-- The categories in the group
|
||||||
|
CREATE TABLE group_room_categories (
|
||||||
|
group_id TEXT NOT NULL,
|
||||||
|
category_id TEXT NOT NULL,
|
||||||
|
profile TEXT NOT NULL,
|
||||||
|
is_public BOOLEAN NOT NULL, -- whether the category should be show to everyone
|
||||||
|
UNIQUE (group_id, category_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
-- The users to include in the group summary
|
||||||
|
CREATE TABLE group_summary_users (
|
||||||
|
group_id TEXT NOT NULL,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
role_id TEXT NOT NULL,
|
||||||
|
user_order BIGINT NOT NULL,
|
||||||
|
is_public BOOLEAN NOT NULL -- whether the user should be show to everyone
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX group_summary_users_g_idx ON group_summary_users(group_id);
|
||||||
|
|
||||||
|
-- The roles to include in the group summary
|
||||||
|
CREATE TABLE group_summary_roles (
|
||||||
|
group_id TEXT NOT NULL,
|
||||||
|
role_id TEXT NOT NULL,
|
||||||
|
role_order BIGINT NOT NULL,
|
||||||
|
UNIQUE (group_id, role_id, role_order),
|
||||||
|
CHECK (role_order > 0)
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
|
-- The roles in a groups
|
||||||
|
CREATE TABLE group_roles (
|
||||||
|
group_id TEXT NOT NULL,
|
||||||
|
role_id TEXT NOT NULL,
|
||||||
|
profile TEXT NOT NULL,
|
||||||
|
is_public BOOLEAN NOT NULL, -- whether the role should be show to everyone
|
||||||
|
UNIQUE (group_id, role_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
|
-- List of attestations we've given out and need to renew
|
||||||
|
CREATE TABLE group_attestations_renewals (
|
||||||
|
group_id TEXT NOT NULL,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
valid_until_ms BIGINT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX group_attestations_renewals_g_idx ON group_attestations_renewals(group_id, user_id);
|
||||||
|
CREATE INDEX group_attestations_renewals_u_idx ON group_attestations_renewals(user_id);
|
||||||
|
CREATE INDEX group_attestations_renewals_v_idx ON group_attestations_renewals(valid_until_ms);
|
||||||
|
|
||||||
|
|
||||||
|
-- List of attestations we've received from remotes and are interested in.
|
||||||
|
CREATE TABLE group_attestations_remote (
|
||||||
|
group_id TEXT NOT NULL,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
valid_until_ms BIGINT NOT NULL,
|
||||||
|
attestation_json TEXT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX group_attestations_remote_g_idx ON group_attestations_remote(group_id, user_id);
|
||||||
|
CREATE INDEX group_attestations_remote_u_idx ON group_attestations_remote(user_id);
|
||||||
|
CREATE INDEX group_attestations_remote_v_idx ON group_attestations_remote(valid_until_ms);
|
||||||
|
|
||||||
|
|
||||||
|
-- The group membership for the HS's users
|
||||||
|
CREATE TABLE local_group_membership (
|
||||||
|
group_id TEXT NOT NULL,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
is_admin BOOLEAN NOT NULL,
|
||||||
|
membership TEXT NOT NULL,
|
||||||
|
is_publicised BOOLEAN NOT NULL, -- if the user is publicising their membership
|
||||||
|
content TEXT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX local_group_membership_u_idx ON local_group_membership(user_id, group_id);
|
||||||
|
CREATE INDEX local_group_membership_g_idx ON local_group_membership(group_id);
|
||||||
|
|
||||||
|
|
||||||
|
CREATE TABLE local_group_updates (
|
||||||
|
stream_id BIGINT NOT NULL,
|
||||||
|
group_id TEXT NOT NULL,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
type TEXT NOT NULL,
|
||||||
|
content TEXT NOT NULL
|
||||||
|
);
|
28
synapse/storage/schema/delta/45/profile_cache.sql
Normal file
28
synapse/storage/schema/delta/45/profile_cache.sql
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
/* Copyright 2017 New Vector Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
-- A subset of remote users whose profiles we have cached.
|
||||||
|
-- Whether a user is in this table or not is defined by the storage function
|
||||||
|
-- `is_subscribed_remote_profile_for_user`
|
||||||
|
CREATE TABLE remote_profile_cache (
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
displayname TEXT,
|
||||||
|
avatar_url TEXT,
|
||||||
|
last_check BIGINT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX remote_profile_cache_user_id ON remote_profile_cache(user_id);
|
||||||
|
CREATE INDEX remote_profile_cache_time ON remote_profile_cache(last_check);
|
@ -45,6 +45,7 @@ class EventSources(object):
|
|||||||
push_rules_key, _ = self.store.get_push_rules_stream_token()
|
push_rules_key, _ = self.store.get_push_rules_stream_token()
|
||||||
to_device_key = self.store.get_to_device_stream_token()
|
to_device_key = self.store.get_to_device_stream_token()
|
||||||
device_list_key = self.store.get_device_stream_token()
|
device_list_key = self.store.get_device_stream_token()
|
||||||
|
groups_key = self.store.get_group_stream_token()
|
||||||
|
|
||||||
token = StreamToken(
|
token = StreamToken(
|
||||||
room_key=(
|
room_key=(
|
||||||
@ -65,6 +66,7 @@ class EventSources(object):
|
|||||||
push_rules_key=push_rules_key,
|
push_rules_key=push_rules_key,
|
||||||
to_device_key=to_device_key,
|
to_device_key=to_device_key,
|
||||||
device_list_key=device_list_key,
|
device_list_key=device_list_key,
|
||||||
|
groups_key=groups_key,
|
||||||
)
|
)
|
||||||
defer.returnValue(token)
|
defer.returnValue(token)
|
||||||
|
|
||||||
@ -73,6 +75,7 @@ class EventSources(object):
|
|||||||
push_rules_key, _ = self.store.get_push_rules_stream_token()
|
push_rules_key, _ = self.store.get_push_rules_stream_token()
|
||||||
to_device_key = self.store.get_to_device_stream_token()
|
to_device_key = self.store.get_to_device_stream_token()
|
||||||
device_list_key = self.store.get_device_stream_token()
|
device_list_key = self.store.get_device_stream_token()
|
||||||
|
groups_key = self.store.get_group_stream_token()
|
||||||
|
|
||||||
token = StreamToken(
|
token = StreamToken(
|
||||||
room_key=(
|
room_key=(
|
||||||
@ -93,5 +96,6 @@ class EventSources(object):
|
|||||||
push_rules_key=push_rules_key,
|
push_rules_key=push_rules_key,
|
||||||
to_device_key=to_device_key,
|
to_device_key=to_device_key,
|
||||||
device_list_key=device_list_key,
|
device_list_key=device_list_key,
|
||||||
|
groups_key=groups_key,
|
||||||
)
|
)
|
||||||
defer.returnValue(token)
|
defer.returnValue(token)
|
||||||
|
@ -156,6 +156,11 @@ class EventID(DomainSpecificString):
|
|||||||
SIGIL = "$"
|
SIGIL = "$"
|
||||||
|
|
||||||
|
|
||||||
|
class GroupID(DomainSpecificString):
|
||||||
|
"""Structure representing a group ID."""
|
||||||
|
SIGIL = "+"
|
||||||
|
|
||||||
|
|
||||||
class StreamToken(
|
class StreamToken(
|
||||||
namedtuple("Token", (
|
namedtuple("Token", (
|
||||||
"room_key",
|
"room_key",
|
||||||
@ -166,6 +171,7 @@ class StreamToken(
|
|||||||
"push_rules_key",
|
"push_rules_key",
|
||||||
"to_device_key",
|
"to_device_key",
|
||||||
"device_list_key",
|
"device_list_key",
|
||||||
|
"groups_key",
|
||||||
))
|
))
|
||||||
):
|
):
|
||||||
_SEPARATOR = "_"
|
_SEPARATOR = "_"
|
||||||
@ -204,6 +210,7 @@ class StreamToken(
|
|||||||
or (int(other.push_rules_key) < int(self.push_rules_key))
|
or (int(other.push_rules_key) < int(self.push_rules_key))
|
||||||
or (int(other.to_device_key) < int(self.to_device_key))
|
or (int(other.to_device_key) < int(self.to_device_key))
|
||||||
or (int(other.device_list_key) < int(self.device_list_key))
|
or (int(other.device_list_key) < int(self.device_list_key))
|
||||||
|
or (int(other.groups_key) < int(self.groups_key))
|
||||||
)
|
)
|
||||||
|
|
||||||
def copy_and_advance(self, key, new_value):
|
def copy_and_advance(self, key, new_value):
|
||||||
|
@ -19,7 +19,7 @@ from twisted.internet import defer, reactor
|
|||||||
from .logcontext import (
|
from .logcontext import (
|
||||||
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
|
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
|
||||||
)
|
)
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import logcontext, unwrapFirstError
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
@ -53,6 +53,11 @@ class ObservableDeferred(object):
|
|||||||
|
|
||||||
Cancelling or otherwise resolving an observer will not affect the original
|
Cancelling or otherwise resolving an observer will not affect the original
|
||||||
ObservableDeferred.
|
ObservableDeferred.
|
||||||
|
|
||||||
|
NB that it does not attempt to do anything with logcontexts; in general
|
||||||
|
you should probably make_deferred_yieldable the deferreds
|
||||||
|
returned by `observe`, and ensure that the original deferred runs its
|
||||||
|
callbacks in the sentinel logcontext.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = ["_deferred", "_observers", "_result"]
|
__slots__ = ["_deferred", "_observers", "_result"]
|
||||||
@ -155,7 +160,7 @@ def concurrently_execute(func, args, limit):
|
|||||||
except StopIteration:
|
except StopIteration:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return preserve_context_over_deferred(defer.gatherResults([
|
return logcontext.make_deferred_yieldable(defer.gatherResults([
|
||||||
preserve_fn(_concurrently_execute_inner)()
|
preserve_fn(_concurrently_execute_inner)()
|
||||||
for _ in xrange(limit)
|
for _ in xrange(limit)
|
||||||
], consumeErrors=True)).addErrback(unwrapFirstError)
|
], consumeErrors=True)).addErrback(unwrapFirstError)
|
||||||
@ -203,7 +208,26 @@ class Linearizer(object):
|
|||||||
except:
|
except:
|
||||||
logger.exception("Unexpected exception in Linearizer")
|
logger.exception("Unexpected exception in Linearizer")
|
||||||
|
|
||||||
logger.info("Acquired linearizer lock %r for key %r", self.name, key)
|
logger.info("Acquired linearizer lock %r for key %r", self.name,
|
||||||
|
key)
|
||||||
|
|
||||||
|
# if the code holding the lock completes synchronously, then it
|
||||||
|
# will recursively run the next claimant on the list. That can
|
||||||
|
# relatively rapidly lead to stack exhaustion. This is essentially
|
||||||
|
# the same problem as http://twistedmatrix.com/trac/ticket/9304.
|
||||||
|
#
|
||||||
|
# In order to break the cycle, we add a cheeky sleep(0) here to
|
||||||
|
# ensure that we fall back to the reactor between each iteration.
|
||||||
|
#
|
||||||
|
# (There's no particular need for it to happen before we return
|
||||||
|
# the context manager, but it needs to happen while we hold the
|
||||||
|
# lock, and the context manager's exit code must be synchronous,
|
||||||
|
# so actually this is the only sensible place.
|
||||||
|
yield run_on_reactor()
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.info("Acquired uncontended linearizer lock %r for key %r",
|
||||||
|
self.name, key)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _ctx_manager():
|
def _ctx_manager():
|
||||||
@ -211,7 +235,8 @@ class Linearizer(object):
|
|||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
logger.info("Releasing linearizer lock %r for key %r", self.name, key)
|
logger.info("Releasing linearizer lock %r for key %r", self.name, key)
|
||||||
new_defer.callback(None)
|
with PreserveLoggingContext():
|
||||||
|
new_defer.callback(None)
|
||||||
current_d = self.key_to_defer.get(key)
|
current_d = self.key_to_defer.get(key)
|
||||||
if current_d is new_defer:
|
if current_d is new_defer:
|
||||||
self.key_to_defer.pop(key, None)
|
self.key_to_defer.pop(key, None)
|
||||||
|
51
synapse/util/logformatter.py
Normal file
51
synapse/util/logformatter.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2017 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import StringIO
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
|
class LogFormatter(logging.Formatter):
|
||||||
|
"""Log formatter which gives more detail for exceptions
|
||||||
|
|
||||||
|
This is the same as the standard log formatter, except that when logging
|
||||||
|
exceptions [typically via log.foo("msg", exc_info=1)], it prints the
|
||||||
|
sequence that led up to the point at which the exception was caught.
|
||||||
|
(Normally only stack frames between the point the exception was raised and
|
||||||
|
where it was caught are logged).
|
||||||
|
"""
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(LogFormatter, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def formatException(self, ei):
|
||||||
|
sio = StringIO.StringIO()
|
||||||
|
(typ, val, tb) = ei
|
||||||
|
|
||||||
|
# log the stack above the exception capture point if possible, but
|
||||||
|
# check that we actually have an f_back attribute to work around
|
||||||
|
# https://twistedmatrix.com/trac/ticket/9305
|
||||||
|
|
||||||
|
if tb and hasattr(tb.tb_frame, 'f_back'):
|
||||||
|
sio.write("Capture point (most recent call last):\n")
|
||||||
|
traceback.print_stack(tb.tb_frame.f_back, None, sio)
|
||||||
|
|
||||||
|
traceback.print_exception(typ, val, tb, None, sio)
|
||||||
|
s = sio.getvalue()
|
||||||
|
sio.close()
|
||||||
|
if s[-1:] == "\n":
|
||||||
|
s = s[:-1]
|
||||||
|
return s
|
42
synapse/util/module_loader.py
Normal file
42
synapse/util/module_loader.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2017 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from synapse.config._base import ConfigError
|
||||||
|
|
||||||
|
|
||||||
|
def load_module(provider):
|
||||||
|
""" Loads a module with its config
|
||||||
|
Take a dict with keys 'module' (the module name) and 'config'
|
||||||
|
(the config dict).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
Tuple of (provider class, parsed config object)
|
||||||
|
"""
|
||||||
|
# We need to import the module, and then pick the class out of
|
||||||
|
# that, so we split based on the last dot.
|
||||||
|
module, clz = provider['module'].rsplit(".", 1)
|
||||||
|
module = importlib.import_module(module)
|
||||||
|
provider_class = getattr(module, clz)
|
||||||
|
|
||||||
|
try:
|
||||||
|
provider_config = provider_class.parse_config(provider["config"])
|
||||||
|
except Exception as e:
|
||||||
|
raise ConfigError(
|
||||||
|
"Failed to parse config for %r: %r" % (provider['module'], e)
|
||||||
|
)
|
||||||
|
|
||||||
|
return provider_class, provider_config
|
@ -62,8 +62,6 @@ class ProfileTestCase(unittest.TestCase):
|
|||||||
self.ratelimiter = hs.get_ratelimiter()
|
self.ratelimiter = hs.get_ratelimiter()
|
||||||
self.ratelimiter.send_message.return_value = (True, 0)
|
self.ratelimiter.send_message.return_value = (True, 0)
|
||||||
|
|
||||||
hs.handlers = ProfileHandlers(hs)
|
|
||||||
|
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
self.frank = UserID.from_string("@1234ABCD:test")
|
self.frank = UserID.from_string("@1234ABCD:test")
|
||||||
@ -72,7 +70,7 @@ class ProfileTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
yield self.store.create_profile(self.frank.localpart)
|
yield self.store.create_profile(self.frank.localpart)
|
||||||
|
|
||||||
self.handler = hs.get_handlers().profile_handler
|
self.handler = hs.get_profile_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_my_name(self):
|
def test_get_my_name(self):
|
||||||
|
@ -40,13 +40,14 @@ class RegistrationTestCase(unittest.TestCase):
|
|||||||
self.hs = yield setup_test_homeserver(
|
self.hs = yield setup_test_homeserver(
|
||||||
handlers=None,
|
handlers=None,
|
||||||
http_client=None,
|
http_client=None,
|
||||||
expire_access_token=True)
|
expire_access_token=True,
|
||||||
|
profile_handler=Mock(),
|
||||||
|
)
|
||||||
self.macaroon_generator = Mock(
|
self.macaroon_generator = Mock(
|
||||||
generate_access_token=Mock(return_value='secret'))
|
generate_access_token=Mock(return_value='secret'))
|
||||||
self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
|
self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
|
||||||
self.hs.handlers = RegistrationHandlers(self.hs)
|
self.hs.handlers = RegistrationHandlers(self.hs)
|
||||||
self.handler = self.hs.get_handlers().registration_handler
|
self.handler = self.hs.get_handlers().registration_handler
|
||||||
self.hs.get_handlers().profile_handler = Mock()
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
|
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
|
||||||
|
@ -46,6 +46,7 @@ class ProfileTestCase(unittest.TestCase):
|
|||||||
resource_for_client=self.mock_resource,
|
resource_for_client=self.mock_resource,
|
||||||
federation=Mock(),
|
federation=Mock(),
|
||||||
replication_layer=Mock(),
|
replication_layer=Mock(),
|
||||||
|
profile_handler=self.mock_handler
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_user_by_req(request=None, allow_guest=False):
|
def _get_user_by_req(request=None, allow_guest=False):
|
||||||
@ -53,8 +54,6 @@ class ProfileTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
hs.get_v1auth().get_user_by_req = _get_user_by_req
|
hs.get_v1auth().get_user_by_req = _get_user_by_req
|
||||||
|
|
||||||
hs.get_handlers().profile_handler = self.mock_handler
|
|
||||||
|
|
||||||
profile.register_servlets(hs, self.mock_resource)
|
profile.register_servlets(hs, self.mock_resource)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -1032,7 +1032,7 @@ class RoomMessageListTestCase(RestTestCase):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_topo_token_is_accepted(self):
|
def test_topo_token_is_accepted(self):
|
||||||
token = "t1-0_0_0_0_0_0_0_0"
|
token = "t1-0_0_0_0_0_0_0_0_0"
|
||||||
(code, response) = yield self.mock_resource.trigger_get(
|
(code, response) = yield self.mock_resource.trigger_get(
|
||||||
"/rooms/%s/messages?access_token=x&from=%s" %
|
"/rooms/%s/messages?access_token=x&from=%s" %
|
||||||
(self.room_id, token))
|
(self.room_id, token))
|
||||||
@ -1044,7 +1044,7 @@ class RoomMessageListTestCase(RestTestCase):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_stream_token_is_accepted_for_fwd_pagianation(self):
|
def test_stream_token_is_accepted_for_fwd_pagianation(self):
|
||||||
token = "s0_0_0_0_0_0_0_0"
|
token = "s0_0_0_0_0_0_0_0_0"
|
||||||
(code, response) = yield self.mock_resource.trigger_get(
|
(code, response) = yield self.mock_resource.trigger_get(
|
||||||
"/rooms/%s/messages?access_token=x&from=%s" %
|
"/rooms/%s/messages?access_token=x&from=%s" %
|
||||||
(self.room_id, token))
|
(self.room_id, token))
|
||||||
|
@ -47,6 +47,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||||||
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
|
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
|
||||||
self.hs.get_device_handler = Mock(return_value=self.device_handler)
|
self.hs.get_device_handler = Mock(return_value=self.device_handler)
|
||||||
self.hs.config.enable_registration = True
|
self.hs.config.enable_registration = True
|
||||||
|
self.hs.config.auto_join_rooms = []
|
||||||
|
|
||||||
# init the thing we're testing
|
# init the thing we're testing
|
||||||
self.servlet = RegisterRestServlet(self.hs)
|
self.servlet = RegisterRestServlet(self.hs)
|
||||||
|
@ -1,76 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
# Copyright 2015, 2016 OpenMarket Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
|
||||||
|
|
||||||
|
|
||||||
class EventInjector:
|
|
||||||
def __init__(self, hs):
|
|
||||||
self.hs = hs
|
|
||||||
self.store = hs.get_datastore()
|
|
||||||
self.message_handler = hs.get_handlers().message_handler
|
|
||||||
self.event_builder_factory = hs.get_event_builder_factory()
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def create_room(self, room, user):
|
|
||||||
builder = self.event_builder_factory.new({
|
|
||||||
"type": EventTypes.Create,
|
|
||||||
"sender": user.to_string(),
|
|
||||||
"room_id": room.to_string(),
|
|
||||||
"content": {},
|
|
||||||
})
|
|
||||||
|
|
||||||
event, context = yield self.message_handler._create_new_client_event(
|
|
||||||
builder
|
|
||||||
)
|
|
||||||
|
|
||||||
yield self.store.persist_event(event, context)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def inject_room_member(self, room, user, membership):
|
|
||||||
builder = self.event_builder_factory.new({
|
|
||||||
"type": EventTypes.Member,
|
|
||||||
"sender": user.to_string(),
|
|
||||||
"state_key": user.to_string(),
|
|
||||||
"room_id": room.to_string(),
|
|
||||||
"content": {"membership": membership},
|
|
||||||
})
|
|
||||||
|
|
||||||
event, context = yield self.message_handler._create_new_client_event(
|
|
||||||
builder
|
|
||||||
)
|
|
||||||
|
|
||||||
yield self.store.persist_event(event, context)
|
|
||||||
|
|
||||||
defer.returnValue(event)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def inject_message(self, room, user, body):
|
|
||||||
builder = self.event_builder_factory.new({
|
|
||||||
"type": EventTypes.Message,
|
|
||||||
"sender": user.to_string(),
|
|
||||||
"state_key": user.to_string(),
|
|
||||||
"room_id": room.to_string(),
|
|
||||||
"content": {"body": body, "msgtype": u"message"},
|
|
||||||
})
|
|
||||||
|
|
||||||
event, context = yield self.message_handler._create_new_client_event(
|
|
||||||
builder
|
|
||||||
)
|
|
||||||
|
|
||||||
yield self.store.persist_event(event, context)
|
|
@ -12,8 +12,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from synapse.util import async, logcontext
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
@ -38,7 +37,28 @@ class LinearizerTestCase(unittest.TestCase):
|
|||||||
with cm1:
|
with cm1:
|
||||||
self.assertFalse(d2.called)
|
self.assertFalse(d2.called)
|
||||||
|
|
||||||
self.assertTrue(d2.called)
|
|
||||||
|
|
||||||
with (yield d2):
|
with (yield d2):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_lots_of_queued_things(self):
|
||||||
|
# we have one slow thing, and lots of fast things queued up behind it.
|
||||||
|
# it should *not* explode the stack.
|
||||||
|
linearizer = Linearizer()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def func(i, sleep=False):
|
||||||
|
with logcontext.LoggingContext("func(%s)" % i) as lc:
|
||||||
|
with (yield linearizer.queue("")):
|
||||||
|
self.assertEqual(
|
||||||
|
logcontext.LoggingContext.current_context(), lc)
|
||||||
|
if sleep:
|
||||||
|
yield async.sleep(0)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
logcontext.LoggingContext.current_context(), lc)
|
||||||
|
|
||||||
|
func(0, sleep=True)
|
||||||
|
for i in xrange(1, 100):
|
||||||
|
func(i)
|
||||||
|
|
||||||
|
return func(1000)
|
||||||
|
@ -94,3 +94,41 @@ class LoggingContextTestCase(unittest.TestCase):
|
|||||||
yield defer.succeed(None)
|
yield defer.succeed(None)
|
||||||
|
|
||||||
return self._test_preserve_fn(nonblocking_function)
|
return self._test_preserve_fn(nonblocking_function)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_make_deferred_yieldable(self):
|
||||||
|
# a function which retuns an incomplete deferred, but doesn't follow
|
||||||
|
# the synapse rules.
|
||||||
|
def blocking_function():
|
||||||
|
d = defer.Deferred()
|
||||||
|
reactor.callLater(0, d.callback, None)
|
||||||
|
return d
|
||||||
|
|
||||||
|
sentinel_context = LoggingContext.current_context()
|
||||||
|
|
||||||
|
with LoggingContext() as context_one:
|
||||||
|
context_one.test_key = "one"
|
||||||
|
|
||||||
|
d1 = logcontext.make_deferred_yieldable(blocking_function())
|
||||||
|
# make sure that the context was reset by make_deferred_yieldable
|
||||||
|
self.assertIs(LoggingContext.current_context(), sentinel_context)
|
||||||
|
|
||||||
|
yield d1
|
||||||
|
|
||||||
|
# now it should be restored
|
||||||
|
self._check_test_key("one")
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_make_deferred_yieldable_on_non_deferred(self):
|
||||||
|
"""Check that make_deferred_yieldable does the right thing when its
|
||||||
|
argument isn't actually a deferred"""
|
||||||
|
|
||||||
|
with LoggingContext() as context_one:
|
||||||
|
context_one.test_key = "one"
|
||||||
|
|
||||||
|
d1 = logcontext.make_deferred_yieldable("bum")
|
||||||
|
self._check_test_key("one")
|
||||||
|
|
||||||
|
r = yield d1
|
||||||
|
self.assertEqual(r, "bum")
|
||||||
|
self._check_test_key("one")
|
Loading…
Reference in New Issue
Block a user