mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-08-03 15:34:10 -04:00
Merge branch 'matthew/gin_work_mem' into matthew/hit_the_gin
This commit is contained in:
commit
ddb6a79b68
116 changed files with 4226 additions and 1834 deletions
|
@ -16,4 +16,4 @@
|
|||
""" This is a reference implementation of a Matrix home server.
|
||||
"""
|
||||
|
||||
__version__ = "0.25.1"
|
||||
__version__ = "0.26.0"
|
||||
|
|
|
@ -46,6 +46,7 @@ class Codes(object):
|
|||
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
|
||||
THREEPID_IN_USE = "M_THREEPID_IN_USE"
|
||||
THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND"
|
||||
THREEPID_DENIED = "M_THREEPID_DENIED"
|
||||
INVALID_USERNAME = "M_INVALID_USERNAME"
|
||||
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
|
||||
|
||||
|
@ -140,6 +141,32 @@ class RegistrationError(SynapseError):
|
|||
pass
|
||||
|
||||
|
||||
class FederationDeniedError(SynapseError):
|
||||
"""An error raised when the server tries to federate with a server which
|
||||
is not on its federation whitelist.
|
||||
|
||||
Attributes:
|
||||
destination (str): The destination which has been denied
|
||||
"""
|
||||
|
||||
def __init__(self, destination):
|
||||
"""Raised by federation client or server to indicate that we are
|
||||
are deliberately not attempting to contact a given server because it is
|
||||
not on our federation whitelist.
|
||||
|
||||
Args:
|
||||
destination (str): the domain in question
|
||||
"""
|
||||
|
||||
self.destination = destination
|
||||
|
||||
super(FederationDeniedError, self).__init__(
|
||||
code=403,
|
||||
msg="Federation denied with %s." % (self.destination,),
|
||||
errcode=Codes.FORBIDDEN,
|
||||
)
|
||||
|
||||
|
||||
class InteractiveAuthIncompleteError(Exception):
|
||||
"""An error raised when UI auth is not yet complete
|
||||
|
||||
|
|
|
@ -49,19 +49,6 @@ class AppserviceSlaveStore(
|
|||
|
||||
|
||||
class AppserviceServer(HomeServer):
|
||||
def get_db_conn(self, run_new_connection=True):
|
||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||
# not be passed to the database engine.
|
||||
db_params = {
|
||||
k: v for k, v in self.db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
db_conn = self.database_engine.module.connect(**db_params)
|
||||
|
||||
if run_new_connection:
|
||||
self.database_engine.on_new_connection(db_conn)
|
||||
return db_conn
|
||||
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
|
||||
|
|
|
@ -64,19 +64,6 @@ class ClientReaderSlavedStore(
|
|||
|
||||
|
||||
class ClientReaderServer(HomeServer):
|
||||
def get_db_conn(self, run_new_connection=True):
|
||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||
# not be passed to the database engine.
|
||||
db_params = {
|
||||
k: v for k, v in self.db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
db_conn = self.database_engine.module.connect(**db_params)
|
||||
|
||||
if run_new_connection:
|
||||
self.database_engine.on_new_connection(db_conn)
|
||||
return db_conn
|
||||
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self)
|
||||
|
|
|
@ -58,19 +58,6 @@ class FederationReaderSlavedStore(
|
|||
|
||||
|
||||
class FederationReaderServer(HomeServer):
|
||||
def get_db_conn(self, run_new_connection=True):
|
||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||
# not be passed to the database engine.
|
||||
db_params = {
|
||||
k: v for k, v in self.db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
db_conn = self.database_engine.module.connect(**db_params)
|
||||
|
||||
if run_new_connection:
|
||||
self.database_engine.on_new_connection(db_conn)
|
||||
return db_conn
|
||||
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)
|
||||
|
|
|
@ -76,19 +76,6 @@ class FederationSenderSlaveStore(
|
|||
|
||||
|
||||
class FederationSenderServer(HomeServer):
|
||||
def get_db_conn(self, run_new_connection=True):
|
||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||
# not be passed to the database engine.
|
||||
db_params = {
|
||||
k: v for k, v in self.db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
db_conn = self.database_engine.module.connect(**db_params)
|
||||
|
||||
if run_new_connection:
|
||||
self.database_engine.on_new_connection(db_conn)
|
||||
return db_conn
|
||||
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self)
|
||||
|
|
|
@ -118,19 +118,6 @@ class FrontendProxySlavedStore(
|
|||
|
||||
|
||||
class FrontendProxyServer(HomeServer):
|
||||
def get_db_conn(self, run_new_connection=True):
|
||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||
# not be passed to the database engine.
|
||||
db_params = {
|
||||
k: v for k, v in self.db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
db_conn = self.database_engine.module.connect(**db_params)
|
||||
|
||||
if run_new_connection:
|
||||
self.database_engine.on_new_connection(db_conn)
|
||||
return db_conn
|
||||
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = FrontendProxySlavedStore(self.get_db_conn(), self)
|
||||
|
|
|
@ -266,19 +266,6 @@ class SynapseHomeServer(HomeServer):
|
|||
except IncorrectDatabaseSetup as e:
|
||||
quit_with_error(e.message)
|
||||
|
||||
def get_db_conn(self, run_new_connection=True):
|
||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||
# not be passed to the database engine.
|
||||
db_params = {
|
||||
k: v for k, v in self.db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
db_conn = self.database_engine.module.connect(**db_params)
|
||||
|
||||
if run_new_connection:
|
||||
self.database_engine.on_new_connection(db_conn)
|
||||
return db_conn
|
||||
|
||||
|
||||
def setup(config_options):
|
||||
"""
|
||||
|
|
|
@ -60,19 +60,6 @@ class MediaRepositorySlavedStore(
|
|||
|
||||
|
||||
class MediaRepositoryServer(HomeServer):
|
||||
def get_db_conn(self, run_new_connection=True):
|
||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||
# not be passed to the database engine.
|
||||
db_params = {
|
||||
k: v for k, v in self.db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
db_conn = self.database_engine.module.connect(**db_params)
|
||||
|
||||
if run_new_connection:
|
||||
self.database_engine.on_new_connection(db_conn)
|
||||
return db_conn
|
||||
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
|
||||
|
|
|
@ -81,19 +81,6 @@ class PusherSlaveStore(
|
|||
|
||||
|
||||
class PusherServer(HomeServer):
|
||||
def get_db_conn(self, run_new_connection=True):
|
||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||
# not be passed to the database engine.
|
||||
db_params = {
|
||||
k: v for k, v in self.db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
db_conn = self.database_engine.module.connect(**db_params)
|
||||
|
||||
if run_new_connection:
|
||||
self.database_engine.on_new_connection(db_conn)
|
||||
return db_conn
|
||||
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = PusherSlaveStore(self.get_db_conn(), self)
|
||||
|
|
|
@ -246,19 +246,6 @@ class SynchrotronApplicationService(object):
|
|||
|
||||
|
||||
class SynchrotronServer(HomeServer):
|
||||
def get_db_conn(self, run_new_connection=True):
|
||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||
# not be passed to the database engine.
|
||||
db_params = {
|
||||
k: v for k, v in self.db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
db_conn = self.database_engine.module.connect(**db_params)
|
||||
|
||||
if run_new_connection:
|
||||
self.database_engine.on_new_connection(db_conn)
|
||||
return db_conn
|
||||
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self)
|
||||
|
|
|
@ -184,6 +184,9 @@ def main():
|
|||
worker_configfiles.append(worker_configfile)
|
||||
|
||||
if options.all_processes:
|
||||
# To start the main synapse with -a you need to add a worker file
|
||||
# with worker_app == "synapse.app.homeserver"
|
||||
start_stop_synapse = False
|
||||
worker_configdir = options.all_processes
|
||||
if not os.path.isdir(worker_configdir):
|
||||
write(
|
||||
|
@ -200,11 +203,29 @@ def main():
|
|||
with open(worker_configfile) as stream:
|
||||
worker_config = yaml.load(stream)
|
||||
worker_app = worker_config["worker_app"]
|
||||
worker_pidfile = worker_config["worker_pid_file"]
|
||||
worker_daemonize = worker_config["worker_daemonize"]
|
||||
assert worker_daemonize, "In config %r: expected '%s' to be True" % (
|
||||
worker_configfile, "worker_daemonize")
|
||||
worker_cache_factor = worker_config.get("synctl_cache_factor")
|
||||
if worker_app == "synapse.app.homeserver":
|
||||
# We need to special case all of this to pick up options that may
|
||||
# be set in the main config file or in this worker config file.
|
||||
worker_pidfile = (
|
||||
worker_config.get("pid_file")
|
||||
or pidfile
|
||||
)
|
||||
worker_cache_factor = worker_config.get("synctl_cache_factor") or cache_factor
|
||||
daemonize = worker_config.get("daemonize") or config.get("daemonize")
|
||||
assert daemonize, "Main process must have daemonize set to true"
|
||||
|
||||
# The master process doesn't support using worker_* config.
|
||||
for key in worker_config:
|
||||
if key == "worker_app": # But we allow worker_app
|
||||
continue
|
||||
assert not key.startswith("worker_"), \
|
||||
"Main process cannot use worker_* config"
|
||||
else:
|
||||
worker_pidfile = worker_config["worker_pid_file"]
|
||||
worker_daemonize = worker_config["worker_daemonize"]
|
||||
assert worker_daemonize, "In config %r: expected '%s' to be True" % (
|
||||
worker_configfile, "worker_daemonize")
|
||||
worker_cache_factor = worker_config.get("synctl_cache_factor")
|
||||
workers.append(Worker(
|
||||
worker_app, worker_configfile, worker_pidfile, worker_cache_factor,
|
||||
))
|
||||
|
|
|
@ -92,19 +92,6 @@ class UserDirectorySlaveStore(
|
|||
|
||||
|
||||
class UserDirectoryServer(HomeServer):
|
||||
def get_db_conn(self, run_new_connection=True):
|
||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||
# not be passed to the database engine.
|
||||
db_params = {
|
||||
k: v for k, v in self.db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
db_conn = self.database_engine.module.connect(**db_params)
|
||||
|
||||
if run_new_connection:
|
||||
self.database_engine.on_new_connection(db_conn)
|
||||
return db_conn
|
||||
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self)
|
||||
|
|
|
@ -28,27 +28,27 @@ DEFAULT_LOG_CONFIG = Template("""
|
|||
version: 1
|
||||
|
||||
formatters:
|
||||
precise:
|
||||
format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s\
|
||||
- %(message)s'
|
||||
precise:
|
||||
format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - \
|
||||
%(request)s - %(message)s'
|
||||
|
||||
filters:
|
||||
context:
|
||||
(): synapse.util.logcontext.LoggingContextFilter
|
||||
request: ""
|
||||
context:
|
||||
(): synapse.util.logcontext.LoggingContextFilter
|
||||
request: ""
|
||||
|
||||
handlers:
|
||||
file:
|
||||
class: logging.handlers.RotatingFileHandler
|
||||
formatter: precise
|
||||
filename: ${log_file}
|
||||
maxBytes: 104857600
|
||||
backupCount: 10
|
||||
filters: [context]
|
||||
console:
|
||||
class: logging.StreamHandler
|
||||
formatter: precise
|
||||
filters: [context]
|
||||
file:
|
||||
class: logging.handlers.RotatingFileHandler
|
||||
formatter: precise
|
||||
filename: ${log_file}
|
||||
maxBytes: 104857600
|
||||
backupCount: 10
|
||||
filters: [context]
|
||||
console:
|
||||
class: logging.StreamHandler
|
||||
formatter: precise
|
||||
filters: [context]
|
||||
|
||||
loggers:
|
||||
synapse:
|
||||
|
@ -74,17 +74,10 @@ class LoggingConfig(Config):
|
|||
self.log_file = self.abspath(config.get("log_file"))
|
||||
|
||||
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||
log_file = self.abspath("homeserver.log")
|
||||
log_config = self.abspath(
|
||||
os.path.join(config_dir_path, server_name + ".log.config")
|
||||
)
|
||||
return """
|
||||
# Logging verbosity level. Ignored if log_config is specified.
|
||||
verbose: 0
|
||||
|
||||
# File to write logging to. Ignored if log_config is specified.
|
||||
log_file: "%(log_file)s"
|
||||
|
||||
# A yaml python logging config file
|
||||
log_config: "%(log_config)s"
|
||||
""" % locals()
|
||||
|
@ -123,9 +116,10 @@ class LoggingConfig(Config):
|
|||
def generate_files(self, config):
|
||||
log_config = config.get("log_config")
|
||||
if log_config and not os.path.exists(log_config):
|
||||
log_file = self.abspath("homeserver.log")
|
||||
with open(log_config, "wb") as log_config_file:
|
||||
log_config_file.write(
|
||||
DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"])
|
||||
DEFAULT_LOG_CONFIG.substitute(log_file=log_file)
|
||||
)
|
||||
|
||||
|
||||
|
@ -150,6 +144,9 @@ def setup_logging(config, use_worker_options=False):
|
|||
)
|
||||
|
||||
if log_config is None:
|
||||
# We don't have a logfile, so fall back to the 'verbosity' param from
|
||||
# the config or cmdline. (Note that we generate a log config for new
|
||||
# installs, so this will be an unusual case)
|
||||
level = logging.INFO
|
||||
level_for_storage = logging.INFO
|
||||
if config.verbosity:
|
||||
|
@ -157,11 +154,10 @@ def setup_logging(config, use_worker_options=False):
|
|||
if config.verbosity > 1:
|
||||
level_for_storage = logging.DEBUG
|
||||
|
||||
# FIXME: we need a logging.WARN for a -q quiet option
|
||||
logger = logging.getLogger('')
|
||||
logger.setLevel(level)
|
||||
|
||||
logging.getLogger('synapse.storage').setLevel(level_for_storage)
|
||||
logging.getLogger('synapse.storage.SQL').setLevel(level_for_storage)
|
||||
|
||||
formatter = logging.Formatter(log_format)
|
||||
if log_file:
|
||||
|
|
|
@ -31,6 +31,8 @@ class RegistrationConfig(Config):
|
|||
strtobool(str(config["disable_registration"]))
|
||||
)
|
||||
|
||||
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
|
||||
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
|
||||
self.registration_shared_secret = config.get("registration_shared_secret")
|
||||
|
||||
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
||||
|
@ -52,6 +54,23 @@ class RegistrationConfig(Config):
|
|||
# Enable registration for new users.
|
||||
enable_registration: False
|
||||
|
||||
# The user must provide all of the below types of 3PID when registering.
|
||||
#
|
||||
# registrations_require_3pid:
|
||||
# - email
|
||||
# - msisdn
|
||||
|
||||
# Mandate that users are only allowed to associate certain formats of
|
||||
# 3PIDs with accounts on this server.
|
||||
#
|
||||
# allowed_local_3pids:
|
||||
# - medium: email
|
||||
# pattern: ".*@matrix\\.org"
|
||||
# - medium: email
|
||||
# pattern: ".*@vector\\.im"
|
||||
# - medium: msisdn
|
||||
# pattern: "\\+44"
|
||||
|
||||
# If set, allows registration by anyone who also has the shared
|
||||
# secret, even if registration is otherwise disabled.
|
||||
registration_shared_secret: "%(registration_shared_secret)s"
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
from ._base import Config, ConfigError
|
||||
from collections import namedtuple
|
||||
|
||||
from synapse.util.module_loader import load_module
|
||||
|
||||
|
||||
MISSING_NETADDR = (
|
||||
"Missing netaddr library. This is required for URL preview API."
|
||||
|
@ -36,6 +38,14 @@ ThumbnailRequirement = namedtuple(
|
|||
"ThumbnailRequirement", ["width", "height", "method", "media_type"]
|
||||
)
|
||||
|
||||
MediaStorageProviderConfig = namedtuple(
|
||||
"MediaStorageProviderConfig", (
|
||||
"store_local", # Whether to store newly uploaded local files
|
||||
"store_remote", # Whether to store newly downloaded remote files
|
||||
"store_synchronous", # Whether to wait for successful storage for local uploads
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def parse_thumbnail_requirements(thumbnail_sizes):
|
||||
""" Takes a list of dictionaries with "width", "height", and "method" keys
|
||||
|
@ -73,16 +83,61 @@ class ContentRepositoryConfig(Config):
|
|||
|
||||
self.media_store_path = self.ensure_directory(config["media_store_path"])
|
||||
|
||||
self.backup_media_store_path = config.get("backup_media_store_path")
|
||||
if self.backup_media_store_path:
|
||||
self.backup_media_store_path = self.ensure_directory(
|
||||
self.backup_media_store_path
|
||||
)
|
||||
backup_media_store_path = config.get("backup_media_store_path")
|
||||
|
||||
self.synchronous_backup_media_store = config.get(
|
||||
synchronous_backup_media_store = config.get(
|
||||
"synchronous_backup_media_store", False
|
||||
)
|
||||
|
||||
storage_providers = config.get("media_storage_providers", [])
|
||||
|
||||
if backup_media_store_path:
|
||||
if storage_providers:
|
||||
raise ConfigError(
|
||||
"Cannot use both 'backup_media_store_path' and 'storage_providers'"
|
||||
)
|
||||
|
||||
storage_providers = [{
|
||||
"module": "file_system",
|
||||
"store_local": True,
|
||||
"store_synchronous": synchronous_backup_media_store,
|
||||
"store_remote": True,
|
||||
"config": {
|
||||
"directory": backup_media_store_path,
|
||||
}
|
||||
}]
|
||||
|
||||
# This is a list of config that can be used to create the storage
|
||||
# providers. The entries are tuples of (Class, class_config,
|
||||
# MediaStorageProviderConfig), where Class is the class of the provider,
|
||||
# the class_config the config to pass to it, and
|
||||
# MediaStorageProviderConfig are options for StorageProviderWrapper.
|
||||
#
|
||||
# We don't create the storage providers here as not all workers need
|
||||
# them to be started.
|
||||
self.media_storage_providers = []
|
||||
|
||||
for provider_config in storage_providers:
|
||||
# We special case the module "file_system" so as not to need to
|
||||
# expose FileStorageProviderBackend
|
||||
if provider_config["module"] == "file_system":
|
||||
provider_config["module"] = (
|
||||
"synapse.rest.media.v1.storage_provider"
|
||||
".FileStorageProviderBackend"
|
||||
)
|
||||
|
||||
provider_class, parsed_config = load_module(provider_config)
|
||||
|
||||
wrapper_config = MediaStorageProviderConfig(
|
||||
provider_config.get("store_local", False),
|
||||
provider_config.get("store_remote", False),
|
||||
provider_config.get("store_synchronous", False),
|
||||
)
|
||||
|
||||
self.media_storage_providers.append(
|
||||
(provider_class, parsed_config, wrapper_config,)
|
||||
)
|
||||
|
||||
self.uploads_path = self.ensure_directory(config["uploads_path"])
|
||||
self.dynamic_thumbnails = config["dynamic_thumbnails"]
|
||||
self.thumbnail_requirements = parse_thumbnail_requirements(
|
||||
|
@ -127,13 +182,19 @@ class ContentRepositoryConfig(Config):
|
|||
# Directory where uploaded images and attachments are stored.
|
||||
media_store_path: "%(media_store)s"
|
||||
|
||||
# A secondary directory where uploaded images and attachments are
|
||||
# stored as a backup.
|
||||
# backup_media_store_path: "%(media_store)s"
|
||||
|
||||
# Whether to wait for successful write to backup media store before
|
||||
# returning successfully.
|
||||
# synchronous_backup_media_store: false
|
||||
# Media storage providers allow media to be stored in different
|
||||
# locations.
|
||||
# media_storage_providers:
|
||||
# - module: file_system
|
||||
# # Whether to write new local files.
|
||||
# store_local: false
|
||||
# # Whether to write new remote media
|
||||
# store_remote: false
|
||||
# # Whether to block upload requests waiting for write to this
|
||||
# # provider to complete
|
||||
# store_synchronous: false
|
||||
# config:
|
||||
# directory: /mnt/some/other/directory
|
||||
|
||||
# Directory where in-progress uploads are stored.
|
||||
uploads_path: "%(uploads_path)s"
|
||||
|
|
|
@ -55,6 +55,17 @@ class ServerConfig(Config):
|
|||
"block_non_admin_invites", False,
|
||||
)
|
||||
|
||||
# FIXME: federation_domain_whitelist needs sytests
|
||||
self.federation_domain_whitelist = None
|
||||
federation_domain_whitelist = config.get(
|
||||
"federation_domain_whitelist", None
|
||||
)
|
||||
# turn the whitelist into a hash for speed of lookup
|
||||
if federation_domain_whitelist is not None:
|
||||
self.federation_domain_whitelist = {}
|
||||
for domain in federation_domain_whitelist:
|
||||
self.federation_domain_whitelist[domain] = True
|
||||
|
||||
if self.public_baseurl is not None:
|
||||
if self.public_baseurl[-1] != '/':
|
||||
self.public_baseurl += '/'
|
||||
|
@ -210,6 +221,17 @@ class ServerConfig(Config):
|
|||
# (except those sent by local server admins). The default is False.
|
||||
# block_non_admin_invites: True
|
||||
|
||||
# Restrict federation to the following whitelist of domains.
|
||||
# N.B. we recommend also firewalling your federation listener to limit
|
||||
# inbound federation traffic as early as possible, rather than relying
|
||||
# purely on this application-layer restriction. If not specified, the
|
||||
# default is to whitelist everything.
|
||||
#
|
||||
# federation_domain_whitelist:
|
||||
# - lon.example.com
|
||||
# - nyc.example.com
|
||||
# - syd.example.com
|
||||
|
||||
# List of ports that Synapse should listen on, their purpose and their
|
||||
# configuration.
|
||||
listeners:
|
||||
|
|
|
@ -96,7 +96,7 @@ class TlsConfig(Config):
|
|||
# certificates returned by this server match one of the fingerprints.
|
||||
#
|
||||
# Synapse automatically adds the fingerprint of its own certificate
|
||||
# to the list. So if federation traffic is handle directly by synapse
|
||||
# to the list. So if federation traffic is handled directly by synapse
|
||||
# then no modification to the list is required.
|
||||
#
|
||||
# If synapse is run behind a load balancer that handles the TLS then it
|
||||
|
|
|
@ -23,6 +23,11 @@ class WorkerConfig(Config):
|
|||
|
||||
def read_config(self, config):
|
||||
self.worker_app = config.get("worker_app")
|
||||
|
||||
# Canonicalise worker_app so that master always has None
|
||||
if self.worker_app == "synapse.app.homeserver":
|
||||
self.worker_app = None
|
||||
|
||||
self.worker_listeners = config.get("worker_listeners")
|
||||
self.worker_daemonize = config.get("worker_daemonize")
|
||||
self.worker_pid_file = config.get("worker_pid_file")
|
||||
|
|
|
@ -319,7 +319,7 @@ def _is_membership_change_allowed(event, auth_events):
|
|||
# TODO (erikj): Implement kicks.
|
||||
if target_banned and user_level < ban_level:
|
||||
raise AuthError(
|
||||
403, "You cannot unban user &s." % (target_user_id,)
|
||||
403, "You cannot unban user %s." % (target_user_id,)
|
||||
)
|
||||
elif target_user_id != event.user_id:
|
||||
kick_level = _get_named_level(auth_events, "kick", 50)
|
||||
|
|
|
@ -25,7 +25,9 @@ class EventContext(object):
|
|||
The current state map excluding the current event.
|
||||
(type, state_key) -> event_id
|
||||
|
||||
state_group (int): state group id
|
||||
state_group (int|None): state group id, if the state has been stored
|
||||
as a state group. This is usually only None if e.g. the event is
|
||||
an outlier.
|
||||
rejected (bool|str): A rejection reason if the event was rejected, else
|
||||
False
|
||||
|
||||
|
|
|
@ -16,7 +16,9 @@ import logging
|
|||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.crypto.event_signing import check_event_content_hash
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.events.utils import prune_event
|
||||
from synapse.http.servlet import assert_params_in_request
|
||||
from synapse.util import unwrapFirstError, logcontext
|
||||
from twisted.internet import defer
|
||||
|
||||
|
@ -169,3 +171,28 @@ class FederationBase(object):
|
|||
)
|
||||
|
||||
return deferreds
|
||||
|
||||
|
||||
def event_from_pdu_json(pdu_json, outlier=False):
|
||||
"""Construct a FrozenEvent from an event json received over federation
|
||||
|
||||
Args:
|
||||
pdu_json (object): pdu as received over federation
|
||||
outlier (bool): True to mark this event as an outlier
|
||||
|
||||
Returns:
|
||||
FrozenEvent
|
||||
|
||||
Raises:
|
||||
SynapseError: if the pdu is missing required fields
|
||||
"""
|
||||
# we could probably enforce a bunch of other fields here (room_id, sender,
|
||||
# origin, etc etc)
|
||||
assert_params_in_request(pdu_json, ('event_id', 'type'))
|
||||
event = FrozenEvent(
|
||||
pdu_json
|
||||
)
|
||||
|
||||
event.internal_metadata.outlier = outlier
|
||||
|
||||
return event
|
||||
|
|
|
@ -14,28 +14,28 @@
|
|||
# limitations under the License.
|
||||
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from .federation_base import FederationBase
|
||||
from synapse.api.constants import Membership
|
||||
|
||||
from synapse.api.errors import (
|
||||
CodeMessageException, HttpResponseException, SynapseError,
|
||||
)
|
||||
from synapse.util import unwrapFirstError, logcontext
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||
from synapse.events import FrozenEvent, builder
|
||||
import synapse.metrics
|
||||
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
import logging
|
||||
import random
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.errors import (
|
||||
CodeMessageException, HttpResponseException, SynapseError, FederationDeniedError
|
||||
)
|
||||
from synapse.events import builder
|
||||
from synapse.federation.federation_base import (
|
||||
FederationBase,
|
||||
event_from_pdu_json,
|
||||
)
|
||||
import synapse.metrics
|
||||
from synapse.util import logcontext, unwrapFirstError
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -184,7 +184,7 @@ class FederationClient(FederationBase):
|
|||
logger.debug("backfill transaction_data=%s", repr(transaction_data))
|
||||
|
||||
pdus = [
|
||||
self.event_from_pdu_json(p, outlier=False)
|
||||
event_from_pdu_json(p, outlier=False)
|
||||
for p in transaction_data["pdus"]
|
||||
]
|
||||
|
||||
|
@ -244,7 +244,7 @@ class FederationClient(FederationBase):
|
|||
logger.debug("transaction_data %r", transaction_data)
|
||||
|
||||
pdu_list = [
|
||||
self.event_from_pdu_json(p, outlier=outlier)
|
||||
event_from_pdu_json(p, outlier=outlier)
|
||||
for p in transaction_data["pdus"]
|
||||
]
|
||||
|
||||
|
@ -266,6 +266,9 @@ class FederationClient(FederationBase):
|
|||
except NotRetryingDestination as e:
|
||||
logger.info(e.message)
|
||||
continue
|
||||
except FederationDeniedError as e:
|
||||
logger.info(e.message)
|
||||
continue
|
||||
except Exception as e:
|
||||
pdu_attempts[destination] = now
|
||||
|
||||
|
@ -336,11 +339,11 @@ class FederationClient(FederationBase):
|
|||
)
|
||||
|
||||
pdus = [
|
||||
self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
|
||||
event_from_pdu_json(p, outlier=True) for p in result["pdus"]
|
||||
]
|
||||
|
||||
auth_chain = [
|
||||
self.event_from_pdu_json(p, outlier=True)
|
||||
event_from_pdu_json(p, outlier=True)
|
||||
for p in result.get("auth_chain", [])
|
||||
]
|
||||
|
||||
|
@ -441,7 +444,7 @@ class FederationClient(FederationBase):
|
|||
)
|
||||
|
||||
auth_chain = [
|
||||
self.event_from_pdu_json(p, outlier=True)
|
||||
event_from_pdu_json(p, outlier=True)
|
||||
for p in res["auth_chain"]
|
||||
]
|
||||
|
||||
|
@ -570,12 +573,12 @@ class FederationClient(FederationBase):
|
|||
logger.debug("Got content: %s", content)
|
||||
|
||||
state = [
|
||||
self.event_from_pdu_json(p, outlier=True)
|
||||
event_from_pdu_json(p, outlier=True)
|
||||
for p in content.get("state", [])
|
||||
]
|
||||
|
||||
auth_chain = [
|
||||
self.event_from_pdu_json(p, outlier=True)
|
||||
event_from_pdu_json(p, outlier=True)
|
||||
for p in content.get("auth_chain", [])
|
||||
]
|
||||
|
||||
|
@ -650,7 +653,7 @@ class FederationClient(FederationBase):
|
|||
|
||||
logger.debug("Got response to send_invite: %s", pdu_dict)
|
||||
|
||||
pdu = self.event_from_pdu_json(pdu_dict)
|
||||
pdu = event_from_pdu_json(pdu_dict)
|
||||
|
||||
# Check signatures are correct.
|
||||
pdu = yield self._check_sigs_and_hash(pdu)
|
||||
|
@ -740,7 +743,7 @@ class FederationClient(FederationBase):
|
|||
)
|
||||
|
||||
auth_chain = [
|
||||
self.event_from_pdu_json(e)
|
||||
event_from_pdu_json(e)
|
||||
for e in content["auth_chain"]
|
||||
]
|
||||
|
||||
|
@ -788,7 +791,7 @@ class FederationClient(FederationBase):
|
|||
)
|
||||
|
||||
events = [
|
||||
self.event_from_pdu_json(e)
|
||||
event_from_pdu_json(e)
|
||||
for e in content.get("events", [])
|
||||
]
|
||||
|
||||
|
@ -805,15 +808,6 @@ class FederationClient(FederationBase):
|
|||
|
||||
defer.returnValue(signed_events)
|
||||
|
||||
def event_from_pdu_json(self, pdu_json, outlier=False):
|
||||
event = FrozenEvent(
|
||||
pdu_json
|
||||
)
|
||||
|
||||
event.internal_metadata.outlier = outlier
|
||||
|
||||
return event
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def forward_third_party_invite(self, destinations, room_id, event_dict):
|
||||
for destination in destinations:
|
||||
|
|
|
@ -12,25 +12,24 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from twisted.internet import defer
|
||||
|
||||
from .federation_base import FederationBase
|
||||
from .units import Transaction, Edu
|
||||
|
||||
from synapse.util import async
|
||||
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.types import get_domain_from_id
|
||||
import synapse.metrics
|
||||
|
||||
from synapse.api.errors import AuthError, FederationError, SynapseError
|
||||
|
||||
from synapse.crypto.event_signing import compute_event_signature
|
||||
import logging
|
||||
|
||||
import simplejson as json
|
||||
import logging
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import AuthError, FederationError, SynapseError
|
||||
from synapse.crypto.event_signing import compute_event_signature
|
||||
from synapse.federation.federation_base import (
|
||||
FederationBase,
|
||||
event_from_pdu_json,
|
||||
)
|
||||
from synapse.federation.units import Edu, Transaction
|
||||
import synapse.metrics
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.util import async
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||
from synapse.util.logutils import log_function
|
||||
|
||||
# when processing incoming transactions, we try to handle multiple rooms in
|
||||
# parallel, up to this limit.
|
||||
|
@ -172,7 +171,7 @@ class FederationServer(FederationBase):
|
|||
p["age_ts"] = request_time - int(p["age"])
|
||||
del p["age"]
|
||||
|
||||
event = self.event_from_pdu_json(p)
|
||||
event = event_from_pdu_json(p)
|
||||
room_id = event.room_id
|
||||
pdus_by_room.setdefault(room_id, []).append(event)
|
||||
|
||||
|
@ -346,7 +345,7 @@ class FederationServer(FederationBase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_invite_request(self, origin, content):
|
||||
pdu = self.event_from_pdu_json(content)
|
||||
pdu = event_from_pdu_json(content)
|
||||
ret_pdu = yield self.handler.on_invite_request(origin, pdu)
|
||||
time_now = self._clock.time_msec()
|
||||
defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
|
||||
|
@ -354,7 +353,7 @@ class FederationServer(FederationBase):
|
|||
@defer.inlineCallbacks
|
||||
def on_send_join_request(self, origin, content):
|
||||
logger.debug("on_send_join_request: content: %s", content)
|
||||
pdu = self.event_from_pdu_json(content)
|
||||
pdu = event_from_pdu_json(content)
|
||||
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
|
||||
res_pdus = yield self.handler.on_send_join_request(origin, pdu)
|
||||
time_now = self._clock.time_msec()
|
||||
|
@ -374,7 +373,7 @@ class FederationServer(FederationBase):
|
|||
@defer.inlineCallbacks
|
||||
def on_send_leave_request(self, origin, content):
|
||||
logger.debug("on_send_leave_request: content: %s", content)
|
||||
pdu = self.event_from_pdu_json(content)
|
||||
pdu = event_from_pdu_json(content)
|
||||
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
|
||||
yield self.handler.on_send_leave_request(origin, pdu)
|
||||
defer.returnValue((200, {}))
|
||||
|
@ -411,7 +410,7 @@ class FederationServer(FederationBase):
|
|||
"""
|
||||
with (yield self._server_linearizer.queue((origin, room_id))):
|
||||
auth_chain = [
|
||||
self.event_from_pdu_json(e)
|
||||
event_from_pdu_json(e)
|
||||
for e in content["auth_chain"]
|
||||
]
|
||||
|
||||
|
@ -586,15 +585,6 @@ class FederationServer(FederationBase):
|
|||
def __str__(self):
|
||||
return "<ReplicationLayer(%s)>" % self.server_name
|
||||
|
||||
def event_from_pdu_json(self, pdu_json, outlier=False):
|
||||
event = FrozenEvent(
|
||||
pdu_json
|
||||
)
|
||||
|
||||
event.internal_metadata.outlier = outlier
|
||||
|
||||
return event
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def exchange_third_party_invite(
|
||||
self,
|
||||
|
|
|
@ -19,7 +19,7 @@ from twisted.internet import defer
|
|||
from .persistence import TransactionActions
|
||||
from .units import Transaction, Edu
|
||||
|
||||
from synapse.api.errors import HttpResponseException
|
||||
from synapse.api.errors import HttpResponseException, FederationDeniedError
|
||||
from synapse.util import logcontext, PreserveLoggingContext
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
|
||||
|
@ -42,6 +42,8 @@ sent_edus_counter = client_metrics.register_counter("sent_edus")
|
|||
|
||||
sent_transactions_counter = client_metrics.register_counter("sent_transactions")
|
||||
|
||||
events_processed_counter = client_metrics.register_counter("events_processed")
|
||||
|
||||
|
||||
class TransactionQueue(object):
|
||||
"""This class makes sure we only have one transaction in flight at
|
||||
|
@ -205,6 +207,8 @@ class TransactionQueue(object):
|
|||
|
||||
self._send_pdu(event, destinations)
|
||||
|
||||
events_processed_counter.inc_by(len(events))
|
||||
|
||||
yield self.store.update_federation_out_pos(
|
||||
"events", next_token
|
||||
)
|
||||
|
@ -486,6 +490,8 @@ class TransactionQueue(object):
|
|||
(e.retry_last_ts + e.retry_interval) / 1000.0
|
||||
),
|
||||
)
|
||||
except FederationDeniedError as e:
|
||||
logger.info(e)
|
||||
except Exception as e:
|
||||
logger.warn(
|
||||
"TX [%s] Failed to send transaction: %s",
|
||||
|
|
|
@ -212,6 +212,9 @@ class TransportLayerClient(object):
|
|||
|
||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||
to retry this server.
|
||||
|
||||
Fails with ``FederationDeniedError`` if the remote destination
|
||||
is not in our federation whitelist
|
||||
"""
|
||||
valid_memberships = {Membership.JOIN, Membership.LEAVE}
|
||||
if membership not in valid_memberships:
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.api.errors import Codes, SynapseError, FederationDeniedError
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.http.servlet import (
|
||||
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
|
||||
|
@ -81,6 +81,7 @@ class Authenticator(object):
|
|||
self.keyring = hs.get_keyring()
|
||||
self.server_name = hs.hostname
|
||||
self.store = hs.get_datastore()
|
||||
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
|
||||
|
||||
# A method just so we can pass 'self' as the authenticator to the Servlets
|
||||
@defer.inlineCallbacks
|
||||
|
@ -92,6 +93,12 @@ class Authenticator(object):
|
|||
"signatures": {},
|
||||
}
|
||||
|
||||
if (
|
||||
self.federation_domain_whitelist is not None and
|
||||
self.server_name not in self.federation_domain_whitelist
|
||||
):
|
||||
raise FederationDeniedError(self.server_name)
|
||||
|
||||
if content is not None:
|
||||
json_request["content"] = content
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.util.metrics import Measure
|
||||
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||
|
@ -23,6 +24,10 @@ import logging
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
metrics = synapse.metrics.get_metrics_for(__name__)
|
||||
|
||||
events_processed_counter = metrics.register_counter("events_processed")
|
||||
|
||||
|
||||
def log_failure(failure):
|
||||
logger.error(
|
||||
|
@ -103,6 +108,8 @@ class ApplicationServicesHandler(object):
|
|||
service, event
|
||||
)
|
||||
|
||||
events_processed_counter.inc_by(len(events))
|
||||
|
||||
yield self.store.set_appservice_last_pos(upper_bound)
|
||||
finally:
|
||||
self.is_processing = False
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from twisted.internet import defer
|
||||
from twisted.internet import defer, threads
|
||||
|
||||
from ._base import BaseHandler
|
||||
from synapse.api.constants import LoginType
|
||||
|
@ -25,6 +25,7 @@ from synapse.module_api import ModuleApi
|
|||
from synapse.types import UserID
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.logcontext import make_deferred_yieldable
|
||||
|
||||
from twisted.web.client import PartialDownloadError
|
||||
|
||||
|
@ -714,7 +715,7 @@ class AuthHandler(BaseHandler):
|
|||
if not lookupres:
|
||||
defer.returnValue(None)
|
||||
(user_id, password_hash) = lookupres
|
||||
result = self.validate_hash(password, password_hash)
|
||||
result = yield self.validate_hash(password, password_hash)
|
||||
if not result:
|
||||
logger.warn("Failed password login for user %s", user_id)
|
||||
defer.returnValue(None)
|
||||
|
@ -842,10 +843,13 @@ class AuthHandler(BaseHandler):
|
|||
password (str): Password to hash.
|
||||
|
||||
Returns:
|
||||
Hashed password (str).
|
||||
Deferred(str): Hashed password.
|
||||
"""
|
||||
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
|
||||
bcrypt.gensalt(self.bcrypt_rounds))
|
||||
def _do_hash():
|
||||
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
|
||||
bcrypt.gensalt(self.bcrypt_rounds))
|
||||
|
||||
return make_deferred_yieldable(threads.deferToThread(_do_hash))
|
||||
|
||||
def validate_hash(self, password, stored_hash):
|
||||
"""Validates that self.hash(password) == stored_hash.
|
||||
|
@ -855,13 +859,17 @@ class AuthHandler(BaseHandler):
|
|||
stored_hash (str): Expected hash value.
|
||||
|
||||
Returns:
|
||||
Whether self.hash(password) == stored_hash (bool).
|
||||
Deferred(bool): Whether self.hash(password) == stored_hash.
|
||||
"""
|
||||
if stored_hash:
|
||||
|
||||
def _do_validate_hash():
|
||||
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
|
||||
stored_hash.encode('utf8')) == stored_hash
|
||||
|
||||
if stored_hash:
|
||||
return make_deferred_yieldable(threads.deferToThread(_do_validate_hash))
|
||||
else:
|
||||
return False
|
||||
return defer.succeed(False)
|
||||
|
||||
|
||||
class MacaroonGeneartor(object):
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# limitations under the License.
|
||||
from synapse.api import errors
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import FederationDeniedError
|
||||
from synapse.util import stringutils
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
|
@ -513,6 +514,9 @@ class DeviceListEduUpdater(object):
|
|||
# This makes it more likely that the device lists will
|
||||
# eventually become consistent.
|
||||
return
|
||||
except FederationDeniedError as e:
|
||||
logger.info(e)
|
||||
return
|
||||
except Exception:
|
||||
# TODO: Remember that we are now out of sync and try again
|
||||
# later
|
||||
|
|
|
@ -17,7 +17,8 @@ import logging
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.types import get_domain_from_id, UserID
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
|
||||
|
@ -33,7 +34,7 @@ class DeviceMessageHandler(object):
|
|||
"""
|
||||
self.store = hs.get_datastore()
|
||||
self.notifier = hs.get_notifier()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.is_mine = hs.is_mine
|
||||
self.federation = hs.get_federation_sender()
|
||||
|
||||
hs.get_replication_layer().register_edu_handler(
|
||||
|
@ -52,6 +53,12 @@ class DeviceMessageHandler(object):
|
|||
message_type = content["type"]
|
||||
message_id = content["message_id"]
|
||||
for user_id, by_device in content["messages"].items():
|
||||
# we use UserID.from_string to catch invalid user ids
|
||||
if not self.is_mine(UserID.from_string(user_id)):
|
||||
logger.warning("Request for keys for non-local user %s",
|
||||
user_id)
|
||||
raise SynapseError(400, "Not a user here")
|
||||
|
||||
messages_by_device = {
|
||||
device_id: {
|
||||
"content": message_content,
|
||||
|
@ -77,7 +84,8 @@ class DeviceMessageHandler(object):
|
|||
local_messages = {}
|
||||
remote_messages = {}
|
||||
for user_id, by_device in messages.items():
|
||||
if self.is_mine_id(user_id):
|
||||
# we use UserID.from_string to catch invalid user ids
|
||||
if self.is_mine(UserID.from_string(user_id)):
|
||||
messages_by_device = {
|
||||
device_id: {
|
||||
"content": message_content,
|
||||
|
|
|
@ -34,6 +34,7 @@ class DirectoryHandler(BaseHandler):
|
|||
|
||||
self.state = hs.get_state_handler()
|
||||
self.appservice_handler = hs.get_application_service_handler()
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
|
||||
self.federation = hs.get_replication_layer()
|
||||
self.federation.register_query_handler(
|
||||
|
@ -249,8 +250,7 @@ class DirectoryHandler(BaseHandler):
|
|||
def send_room_alias_update_event(self, requester, user_id, room_id):
|
||||
aliases = yield self.store.get_aliases_for_room(room_id)
|
||||
|
||||
msg_handler = self.hs.get_handlers().message_handler
|
||||
yield msg_handler.create_and_send_nonmember_event(
|
||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
{
|
||||
"type": EventTypes.Aliases,
|
||||
|
@ -272,8 +272,7 @@ class DirectoryHandler(BaseHandler):
|
|||
if not alias_event or alias_event.content.get("alias", "") != alias_str:
|
||||
return
|
||||
|
||||
msg_handler = self.hs.get_handlers().message_handler
|
||||
yield msg_handler.create_and_send_nonmember_event(
|
||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
{
|
||||
"type": EventTypes.CanonicalAlias,
|
||||
|
|
|
@ -19,8 +19,10 @@ import logging
|
|||
from canonicaljson import encode_canonical_json
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import SynapseError, CodeMessageException
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.api.errors import (
|
||||
SynapseError, CodeMessageException, FederationDeniedError,
|
||||
)
|
||||
from synapse.types import get_domain_from_id, UserID
|
||||
from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
||||
|
@ -32,7 +34,7 @@ class E2eKeysHandler(object):
|
|||
self.store = hs.get_datastore()
|
||||
self.federation = hs.get_replication_layer()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.is_mine = hs.is_mine
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
# doesn't really work as part of the generic query API, because the
|
||||
|
@ -70,7 +72,8 @@ class E2eKeysHandler(object):
|
|||
remote_queries = {}
|
||||
|
||||
for user_id, device_ids in device_keys_query.items():
|
||||
if self.is_mine_id(user_id):
|
||||
# we use UserID.from_string to catch invalid user ids
|
||||
if self.is_mine(UserID.from_string(user_id)):
|
||||
local_query[user_id] = device_ids
|
||||
else:
|
||||
remote_queries[user_id] = device_ids
|
||||
|
@ -139,6 +142,10 @@ class E2eKeysHandler(object):
|
|||
failures[destination] = {
|
||||
"status": 503, "message": "Not ready for retry",
|
||||
}
|
||||
except FederationDeniedError as e:
|
||||
failures[destination] = {
|
||||
"status": 403, "message": "Federation Denied",
|
||||
}
|
||||
except Exception as e:
|
||||
# include ConnectionRefused and other errors
|
||||
failures[destination] = {
|
||||
|
@ -170,7 +177,8 @@ class E2eKeysHandler(object):
|
|||
|
||||
result_dict = {}
|
||||
for user_id, device_ids in query.items():
|
||||
if not self.is_mine_id(user_id):
|
||||
# we use UserID.from_string to catch invalid user ids
|
||||
if not self.is_mine(UserID.from_string(user_id)):
|
||||
logger.warning("Request for keys for non-local user %s",
|
||||
user_id)
|
||||
raise SynapseError(400, "Not a user here")
|
||||
|
@ -213,7 +221,8 @@ class E2eKeysHandler(object):
|
|||
remote_queries = {}
|
||||
|
||||
for user_id, device_keys in query.get("one_time_keys", {}).items():
|
||||
if self.is_mine_id(user_id):
|
||||
# we use UserID.from_string to catch invalid user ids
|
||||
if self.is_mine(UserID.from_string(user_id)):
|
||||
for device_id, algorithm in device_keys.items():
|
||||
local_query.append((user_id, device_id, algorithm))
|
||||
else:
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -22,6 +23,7 @@ from ._base import BaseHandler
|
|||
|
||||
from synapse.api.errors import (
|
||||
AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
|
||||
FederationDeniedError,
|
||||
)
|
||||
from synapse.api.constants import EventTypes, Membership, RejectedReason
|
||||
from synapse.events.validator import EventValidator
|
||||
|
@ -74,6 +76,7 @@ class FederationHandler(BaseHandler):
|
|||
self.is_mine_id = hs.is_mine_id
|
||||
self.pusher_pool = hs.get_pusherpool()
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
|
||||
self.replication_layer.set_handler(self)
|
||||
|
||||
|
@ -782,6 +785,9 @@ class FederationHandler(BaseHandler):
|
|||
except NotRetryingDestination as e:
|
||||
logger.info(e.message)
|
||||
continue
|
||||
except FederationDeniedError as e:
|
||||
logger.info(e)
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to backfill from %s because %s",
|
||||
|
@ -804,13 +810,12 @@ class FederationHandler(BaseHandler):
|
|||
event_ids = list(extremities.keys())
|
||||
|
||||
logger.debug("calling resolve_state_groups in _maybe_backfill")
|
||||
resolve = logcontext.preserve_fn(
|
||||
self.state_handler.resolve_state_groups_for_events
|
||||
)
|
||||
states = yield logcontext.make_deferred_yieldable(defer.gatherResults(
|
||||
[
|
||||
logcontext.preserve_fn(self.state_handler.resolve_state_groups)(
|
||||
room_id, [e]
|
||||
)
|
||||
for e in event_ids
|
||||
], consumeErrors=True,
|
||||
[resolve(room_id, [e]) for e in event_ids],
|
||||
consumeErrors=True,
|
||||
))
|
||||
states = dict(zip(event_ids, [s.state for s in states]))
|
||||
|
||||
|
@ -1004,8 +1009,7 @@ class FederationHandler(BaseHandler):
|
|||
})
|
||||
|
||||
try:
|
||||
message_handler = self.hs.get_handlers().message_handler
|
||||
event, context = yield message_handler._create_new_client_event(
|
||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
||||
builder=builder,
|
||||
)
|
||||
except AuthError as e:
|
||||
|
@ -1245,8 +1249,7 @@ class FederationHandler(BaseHandler):
|
|||
"state_key": user_id,
|
||||
})
|
||||
|
||||
message_handler = self.hs.get_handlers().message_handler
|
||||
event, context = yield message_handler._create_new_client_event(
|
||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
||||
builder=builder,
|
||||
)
|
||||
|
||||
|
@ -1828,8 +1831,8 @@ class FederationHandler(BaseHandler):
|
|||
current_state = set(e.event_id for e in auth_events.values())
|
||||
different_auth = event_auth_events - current_state
|
||||
|
||||
self._update_context_for_auth_events(
|
||||
context, auth_events, event_key,
|
||||
yield self._update_context_for_auth_events(
|
||||
event, context, auth_events, event_key,
|
||||
)
|
||||
|
||||
if different_auth and not event.internal_metadata.is_outlier():
|
||||
|
@ -1910,8 +1913,8 @@ class FederationHandler(BaseHandler):
|
|||
# 4. Look at rejects and their proofs.
|
||||
# TODO.
|
||||
|
||||
self._update_context_for_auth_events(
|
||||
context, auth_events, event_key,
|
||||
yield self._update_context_for_auth_events(
|
||||
event, context, auth_events, event_key,
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -1920,11 +1923,15 @@ class FederationHandler(BaseHandler):
|
|||
logger.warn("Failed auth resolution for %r because %s", event, e)
|
||||
raise e
|
||||
|
||||
def _update_context_for_auth_events(self, context, auth_events,
|
||||
@defer.inlineCallbacks
|
||||
def _update_context_for_auth_events(self, event, context, auth_events,
|
||||
event_key):
|
||||
"""Update the state_ids in an event context after auth event resolution
|
||||
"""Update the state_ids in an event context after auth event resolution,
|
||||
storing the changes as a new state group.
|
||||
|
||||
Args:
|
||||
event (Event): The event we're handling the context for
|
||||
|
||||
context (synapse.events.snapshot.EventContext): event context
|
||||
to be updated
|
||||
|
||||
|
@ -1947,7 +1954,13 @@ class FederationHandler(BaseHandler):
|
|||
context.prev_state_ids.update({
|
||||
k: a.event_id for k, a in auth_events.iteritems()
|
||||
})
|
||||
context.state_group = self.store.get_next_state_group()
|
||||
context.state_group = yield self.store.store_state_group(
|
||||
event.event_id,
|
||||
event.room_id,
|
||||
prev_group=context.prev_group,
|
||||
delta_ids=context.delta_ids,
|
||||
current_state_ids=context.current_state_ids,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def construct_auth_difference(self, local_auth, remote_auth):
|
||||
|
@ -2117,8 +2130,7 @@ class FederationHandler(BaseHandler):
|
|||
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
|
||||
builder = self.event_builder_factory.new(event_dict)
|
||||
EventValidator().validate_new(builder)
|
||||
message_handler = self.hs.get_handlers().message_handler
|
||||
event, context = yield message_handler._create_new_client_event(
|
||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
||||
builder=builder
|
||||
)
|
||||
|
||||
|
@ -2156,8 +2168,7 @@ class FederationHandler(BaseHandler):
|
|||
"""
|
||||
builder = self.event_builder_factory.new(event_dict)
|
||||
|
||||
message_handler = self.hs.get_handlers().message_handler
|
||||
event, context = yield message_handler._create_new_client_event(
|
||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
||||
builder=builder,
|
||||
)
|
||||
|
||||
|
@ -2207,8 +2218,9 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
builder = self.event_builder_factory.new(event_dict)
|
||||
EventValidator().validate_new(builder)
|
||||
message_handler = self.hs.get_handlers().message_handler
|
||||
event, context = yield message_handler._create_new_client_event(builder=builder)
|
||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
||||
builder=builder,
|
||||
)
|
||||
defer.returnValue((event, context))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -383,11 +383,12 @@ class GroupsLocalHandler(object):
|
|||
|
||||
defer.returnValue({"groups": result})
|
||||
else:
|
||||
result = yield self.transport_client.get_publicised_groups_for_user(
|
||||
get_domain_from_id(user_id), user_id
|
||||
bulk_result = yield self.transport_client.bulk_get_publicised_groups(
|
||||
get_domain_from_id(user_id), [user_id],
|
||||
)
|
||||
result = bulk_result.get("users", {}).get(user_id)
|
||||
# TODO: Verify attestations
|
||||
defer.returnValue(result)
|
||||
defer.returnValue({"groups": result})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def bulk_get_publicised_groups(self, user_ids, proxy=True):
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 - 2016 OpenMarket Ltd
|
||||
# Copyright 2017 New Vector Ltd
|
||||
# Copyright 2017 - 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -47,23 +47,11 @@ class MessageHandler(BaseHandler):
|
|||
self.hs = hs
|
||||
self.state = hs.get_state_handler()
|
||||
self.clock = hs.get_clock()
|
||||
self.validator = EventValidator()
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
|
||||
self.pagination_lock = ReadWriteLock()
|
||||
|
||||
self.pusher_pool = hs.get_pusherpool()
|
||||
|
||||
# We arbitrarily limit concurrent event creation for a room to 5.
|
||||
# This is to stop us from diverging history *too* much.
|
||||
self.limiter = Limiter(max_count=5)
|
||||
|
||||
self.action_generator = hs.get_action_generator()
|
||||
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def purge_history(self, room_id, event_id):
|
||||
def purge_history(self, room_id, event_id, delete_local_events=False):
|
||||
event = yield self.store.get_event(event_id)
|
||||
|
||||
if event.room_id != room_id:
|
||||
|
@ -72,7 +60,7 @@ class MessageHandler(BaseHandler):
|
|||
depth = event.depth
|
||||
|
||||
with (yield self.pagination_lock.write(room_id)):
|
||||
yield self.store.delete_old_state(room_id, depth)
|
||||
yield self.store.purge_history(room_id, depth, delete_local_events)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_messages(self, requester, room_id=None, pagin_config=None,
|
||||
|
@ -182,166 +170,6 @@ class MessageHandler(BaseHandler):
|
|||
|
||||
defer.returnValue(chunk)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_event(self, requester, event_dict, token_id=None, txn_id=None,
|
||||
prev_event_ids=None):
|
||||
"""
|
||||
Given a dict from a client, create a new event.
|
||||
|
||||
Creates an FrozenEvent object, filling out auth_events, prev_events,
|
||||
etc.
|
||||
|
||||
Adds display names to Join membership events.
|
||||
|
||||
Args:
|
||||
requester
|
||||
event_dict (dict): An entire event
|
||||
token_id (str)
|
||||
txn_id (str)
|
||||
prev_event_ids (list): The prev event ids to use when creating the event
|
||||
|
||||
Returns:
|
||||
Tuple of created event (FrozenEvent), Context
|
||||
"""
|
||||
builder = self.event_builder_factory.new(event_dict)
|
||||
|
||||
with (yield self.limiter.queue(builder.room_id)):
|
||||
self.validator.validate_new(builder)
|
||||
|
||||
if builder.type == EventTypes.Member:
|
||||
membership = builder.content.get("membership", None)
|
||||
target = UserID.from_string(builder.state_key)
|
||||
|
||||
if membership in {Membership.JOIN, Membership.INVITE}:
|
||||
# If event doesn't include a display name, add one.
|
||||
profile = self.profile_handler
|
||||
content = builder.content
|
||||
|
||||
try:
|
||||
if "displayname" not in content:
|
||||
content["displayname"] = yield profile.get_displayname(target)
|
||||
if "avatar_url" not in content:
|
||||
content["avatar_url"] = yield profile.get_avatar_url(target)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"Failed to get profile information for %r: %s",
|
||||
target, e
|
||||
)
|
||||
|
||||
if token_id is not None:
|
||||
builder.internal_metadata.token_id = token_id
|
||||
|
||||
if txn_id is not None:
|
||||
builder.internal_metadata.txn_id = txn_id
|
||||
|
||||
event, context = yield self._create_new_client_event(
|
||||
builder=builder,
|
||||
requester=requester,
|
||||
prev_event_ids=prev_event_ids,
|
||||
)
|
||||
|
||||
defer.returnValue((event, context))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_nonmember_event(self, requester, event, context, ratelimit=True):
|
||||
"""
|
||||
Persists and notifies local clients and federation of an event.
|
||||
|
||||
Args:
|
||||
event (FrozenEvent) the event to send.
|
||||
context (Context) the context of the event.
|
||||
ratelimit (bool): Whether to rate limit this send.
|
||||
is_guest (bool): Whether the sender is a guest.
|
||||
"""
|
||||
if event.type == EventTypes.Member:
|
||||
raise SynapseError(
|
||||
500,
|
||||
"Tried to send member event through non-member codepath"
|
||||
)
|
||||
|
||||
# We check here if we are currently being rate limited, so that we
|
||||
# don't do unnecessary work. We check again just before we actually
|
||||
# send the event.
|
||||
yield self.ratelimit(requester, update=False)
|
||||
|
||||
user = UserID.from_string(event.sender)
|
||||
|
||||
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
|
||||
|
||||
if event.is_state():
|
||||
prev_state = yield self.deduplicate_state_event(event, context)
|
||||
if prev_state is not None:
|
||||
defer.returnValue(prev_state)
|
||||
|
||||
yield self.handle_new_client_event(
|
||||
requester=requester,
|
||||
event=event,
|
||||
context=context,
|
||||
ratelimit=ratelimit,
|
||||
)
|
||||
|
||||
if event.type == EventTypes.Message:
|
||||
presence = self.hs.get_presence_handler()
|
||||
# We don't want to block sending messages on any presence code. This
|
||||
# matters as sometimes presence code can take a while.
|
||||
preserve_fn(presence.bump_presence_active_time)(user)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def deduplicate_state_event(self, event, context):
|
||||
"""
|
||||
Checks whether event is in the latest resolved state in context.
|
||||
|
||||
If so, returns the version of the event in context.
|
||||
Otherwise, returns None.
|
||||
"""
|
||||
prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
|
||||
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
|
||||
if not prev_event:
|
||||
return
|
||||
|
||||
if prev_event and event.user_id == prev_event.user_id:
|
||||
prev_content = encode_canonical_json(prev_event.content)
|
||||
next_content = encode_canonical_json(event.content)
|
||||
if prev_content == next_content:
|
||||
defer.returnValue(prev_event)
|
||||
return
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_and_send_nonmember_event(
|
||||
self,
|
||||
requester,
|
||||
event_dict,
|
||||
ratelimit=True,
|
||||
txn_id=None
|
||||
):
|
||||
"""
|
||||
Creates an event, then sends it.
|
||||
|
||||
See self.create_event and self.send_nonmember_event.
|
||||
"""
|
||||
event, context = yield self.create_event(
|
||||
requester,
|
||||
event_dict,
|
||||
token_id=requester.access_token_id,
|
||||
txn_id=txn_id
|
||||
)
|
||||
|
||||
spam_error = self.spam_checker.check_event_for_spam(event)
|
||||
if spam_error:
|
||||
if not isinstance(spam_error, basestring):
|
||||
spam_error = "Spam is not permitted here"
|
||||
raise SynapseError(
|
||||
403, spam_error, Codes.FORBIDDEN
|
||||
)
|
||||
|
||||
yield self.send_nonmember_event(
|
||||
requester,
|
||||
event,
|
||||
context,
|
||||
ratelimit=ratelimit,
|
||||
)
|
||||
defer.returnValue(event)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_data(self, user_id=None, room_id=None,
|
||||
event_type=None, state_key="", is_guest=False):
|
||||
|
@ -470,9 +298,192 @@ class MessageHandler(BaseHandler):
|
|||
for user_id, profile in users_with_profile.iteritems()
|
||||
})
|
||||
|
||||
@measure_func("_create_new_client_event")
|
||||
|
||||
class EventCreationHandler(object):
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
self.state = hs.get_state_handler()
|
||||
self.clock = hs.get_clock()
|
||||
self.validator = EventValidator()
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
self.event_builder_factory = hs.get_event_builder_factory()
|
||||
self.server_name = hs.hostname
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
self.notifier = hs.get_notifier()
|
||||
|
||||
# This is only used to get at ratelimit function, and maybe_kick_guest_users
|
||||
self.base_handler = BaseHandler(hs)
|
||||
|
||||
self.pusher_pool = hs.get_pusherpool()
|
||||
|
||||
# We arbitrarily limit concurrent event creation for a room to 5.
|
||||
# This is to stop us from diverging history *too* much.
|
||||
self.limiter = Limiter(max_count=5)
|
||||
|
||||
self.action_generator = hs.get_action_generator()
|
||||
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _create_new_client_event(self, builder, requester=None, prev_event_ids=None):
|
||||
def create_event(self, requester, event_dict, token_id=None, txn_id=None,
|
||||
prev_event_ids=None):
|
||||
"""
|
||||
Given a dict from a client, create a new event.
|
||||
|
||||
Creates an FrozenEvent object, filling out auth_events, prev_events,
|
||||
etc.
|
||||
|
||||
Adds display names to Join membership events.
|
||||
|
||||
Args:
|
||||
requester
|
||||
event_dict (dict): An entire event
|
||||
token_id (str)
|
||||
txn_id (str)
|
||||
prev_event_ids (list): The prev event ids to use when creating the event
|
||||
|
||||
Returns:
|
||||
Tuple of created event (FrozenEvent), Context
|
||||
"""
|
||||
builder = self.event_builder_factory.new(event_dict)
|
||||
|
||||
with (yield self.limiter.queue(builder.room_id)):
|
||||
self.validator.validate_new(builder)
|
||||
|
||||
if builder.type == EventTypes.Member:
|
||||
membership = builder.content.get("membership", None)
|
||||
target = UserID.from_string(builder.state_key)
|
||||
|
||||
if membership in {Membership.JOIN, Membership.INVITE}:
|
||||
# If event doesn't include a display name, add one.
|
||||
profile = self.profile_handler
|
||||
content = builder.content
|
||||
|
||||
try:
|
||||
if "displayname" not in content:
|
||||
content["displayname"] = yield profile.get_displayname(target)
|
||||
if "avatar_url" not in content:
|
||||
content["avatar_url"] = yield profile.get_avatar_url(target)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"Failed to get profile information for %r: %s",
|
||||
target, e
|
||||
)
|
||||
|
||||
if token_id is not None:
|
||||
builder.internal_metadata.token_id = token_id
|
||||
|
||||
if txn_id is not None:
|
||||
builder.internal_metadata.txn_id = txn_id
|
||||
|
||||
event, context = yield self.create_new_client_event(
|
||||
builder=builder,
|
||||
requester=requester,
|
||||
prev_event_ids=prev_event_ids,
|
||||
)
|
||||
|
||||
defer.returnValue((event, context))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_nonmember_event(self, requester, event, context, ratelimit=True):
|
||||
"""
|
||||
Persists and notifies local clients and federation of an event.
|
||||
|
||||
Args:
|
||||
event (FrozenEvent) the event to send.
|
||||
context (Context) the context of the event.
|
||||
ratelimit (bool): Whether to rate limit this send.
|
||||
is_guest (bool): Whether the sender is a guest.
|
||||
"""
|
||||
if event.type == EventTypes.Member:
|
||||
raise SynapseError(
|
||||
500,
|
||||
"Tried to send member event through non-member codepath"
|
||||
)
|
||||
|
||||
user = UserID.from_string(event.sender)
|
||||
|
||||
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
|
||||
|
||||
if event.is_state():
|
||||
prev_state = yield self.deduplicate_state_event(event, context)
|
||||
if prev_state is not None:
|
||||
defer.returnValue(prev_state)
|
||||
|
||||
yield self.handle_new_client_event(
|
||||
requester=requester,
|
||||
event=event,
|
||||
context=context,
|
||||
ratelimit=ratelimit,
|
||||
)
|
||||
|
||||
if event.type == EventTypes.Message:
|
||||
presence = self.hs.get_presence_handler()
|
||||
# We don't want to block sending messages on any presence code. This
|
||||
# matters as sometimes presence code can take a while.
|
||||
preserve_fn(presence.bump_presence_active_time)(user)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def deduplicate_state_event(self, event, context):
|
||||
"""
|
||||
Checks whether event is in the latest resolved state in context.
|
||||
|
||||
If so, returns the version of the event in context.
|
||||
Otherwise, returns None.
|
||||
"""
|
||||
prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
|
||||
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
|
||||
if not prev_event:
|
||||
return
|
||||
|
||||
if prev_event and event.user_id == prev_event.user_id:
|
||||
prev_content = encode_canonical_json(prev_event.content)
|
||||
next_content = encode_canonical_json(event.content)
|
||||
if prev_content == next_content:
|
||||
defer.returnValue(prev_event)
|
||||
return
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_and_send_nonmember_event(
|
||||
self,
|
||||
requester,
|
||||
event_dict,
|
||||
ratelimit=True,
|
||||
txn_id=None
|
||||
):
|
||||
"""
|
||||
Creates an event, then sends it.
|
||||
|
||||
See self.create_event and self.send_nonmember_event.
|
||||
"""
|
||||
event, context = yield self.create_event(
|
||||
requester,
|
||||
event_dict,
|
||||
token_id=requester.access_token_id,
|
||||
txn_id=txn_id
|
||||
)
|
||||
|
||||
spam_error = self.spam_checker.check_event_for_spam(event)
|
||||
if spam_error:
|
||||
if not isinstance(spam_error, basestring):
|
||||
spam_error = "Spam is not permitted here"
|
||||
raise SynapseError(
|
||||
403, spam_error, Codes.FORBIDDEN
|
||||
)
|
||||
|
||||
yield self.send_nonmember_event(
|
||||
requester,
|
||||
event,
|
||||
context,
|
||||
ratelimit=ratelimit,
|
||||
)
|
||||
defer.returnValue(event)
|
||||
|
||||
@measure_func("create_new_client_event")
|
||||
@defer.inlineCallbacks
|
||||
def create_new_client_event(self, builder, requester=None, prev_event_ids=None):
|
||||
if prev_event_ids:
|
||||
prev_events = yield self.store.add_event_hashes(prev_event_ids)
|
||||
prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids)
|
||||
|
@ -509,9 +520,7 @@ class MessageHandler(BaseHandler):
|
|||
builder.prev_events = prev_events
|
||||
builder.depth = depth
|
||||
|
||||
state_handler = self.state_handler
|
||||
|
||||
context = yield state_handler.compute_event_context(builder)
|
||||
context = yield self.state.compute_event_context(builder)
|
||||
if requester:
|
||||
context.app_service = requester.app_service
|
||||
|
||||
|
@ -551,7 +560,7 @@ class MessageHandler(BaseHandler):
|
|||
# We now need to go and hit out to wherever we need to hit out to.
|
||||
|
||||
if ratelimit:
|
||||
yield self.ratelimit(requester)
|
||||
yield self.base_handler.ratelimit(requester)
|
||||
|
||||
try:
|
||||
yield self.auth.check_from_context(event, context)
|
||||
|
@ -567,7 +576,7 @@ class MessageHandler(BaseHandler):
|
|||
logger.exception("Failed to encode content: %r", event.content)
|
||||
raise
|
||||
|
||||
yield self.maybe_kick_guest_users(event, context)
|
||||
yield self.base_handler.maybe_kick_guest_users(event, context)
|
||||
|
||||
if event.type == EventTypes.CanonicalAlias:
|
||||
# Check the alias is acually valid (at this time at least)
|
||||
|
|
|
@ -25,6 +25,7 @@ from synapse.http.client import CaptchaServerHttpClient
|
|||
from synapse import types
|
||||
from synapse.types import UserID
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.util.threepids import check_3pid_allowed
|
||||
from ._base import BaseHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -131,7 +132,7 @@ class RegistrationHandler(BaseHandler):
|
|||
yield run_on_reactor()
|
||||
password_hash = None
|
||||
if password:
|
||||
password_hash = self.auth_handler().hash(password)
|
||||
password_hash = yield self.auth_handler().hash(password)
|
||||
|
||||
if localpart:
|
||||
yield self.check_username(localpart, guest_access_token=guest_access_token)
|
||||
|
@ -293,7 +294,7 @@ class RegistrationHandler(BaseHandler):
|
|||
"""
|
||||
|
||||
for c in threepidCreds:
|
||||
logger.info("validating theeepidcred sid %s on id server %s",
|
||||
logger.info("validating threepidcred sid %s on id server %s",
|
||||
c['sid'], c['idServer'])
|
||||
try:
|
||||
identity_handler = self.hs.get_handlers().identity_handler
|
||||
|
@ -307,6 +308,11 @@ class RegistrationHandler(BaseHandler):
|
|||
logger.info("got threepid with medium '%s' and address '%s'",
|
||||
threepid['medium'], threepid['address'])
|
||||
|
||||
if not check_3pid_allowed(self.hs, threepid['medium'], threepid['address']):
|
||||
raise RegistrationError(
|
||||
403, "Third party identifier is not allowed"
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def bind_emails(self, user_id, threepidCreds):
|
||||
"""Links emails with a user ID and informs an identity server.
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 - 2016 OpenMarket Ltd
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -64,6 +65,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
super(RoomCreationHandler, self).__init__(hs)
|
||||
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_room(self, requester, config, ratelimit=True):
|
||||
|
@ -163,13 +165,11 @@ class RoomCreationHandler(BaseHandler):
|
|||
|
||||
creation_content = config.get("creation_content", {})
|
||||
|
||||
msg_handler = self.hs.get_handlers().message_handler
|
||||
room_member_handler = self.hs.get_handlers().room_member_handler
|
||||
|
||||
yield self._send_events_for_new_room(
|
||||
requester,
|
||||
room_id,
|
||||
msg_handler,
|
||||
room_member_handler,
|
||||
preset_config=preset_config,
|
||||
invite_list=invite_list,
|
||||
|
@ -181,7 +181,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
|
||||
if "name" in config:
|
||||
name = config["name"]
|
||||
yield msg_handler.create_and_send_nonmember_event(
|
||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
{
|
||||
"type": EventTypes.Name,
|
||||
|
@ -194,7 +194,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
|
||||
if "topic" in config:
|
||||
topic = config["topic"]
|
||||
yield msg_handler.create_and_send_nonmember_event(
|
||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
{
|
||||
"type": EventTypes.Topic,
|
||||
|
@ -249,7 +249,6 @@ class RoomCreationHandler(BaseHandler):
|
|||
self,
|
||||
creator, # A Requester object.
|
||||
room_id,
|
||||
msg_handler,
|
||||
room_member_handler,
|
||||
preset_config,
|
||||
invite_list,
|
||||
|
@ -272,7 +271,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
@defer.inlineCallbacks
|
||||
def send(etype, content, **kwargs):
|
||||
event = create(etype, content, **kwargs)
|
||||
yield msg_handler.create_and_send_nonmember_event(
|
||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
creator,
|
||||
event,
|
||||
ratelimit=False
|
||||
|
|
|
@ -203,7 +203,8 @@ class RoomListHandler(BaseHandler):
|
|||
if limit:
|
||||
step = limit + 1
|
||||
else:
|
||||
step = len(rooms_to_scan)
|
||||
# step cannot be zero
|
||||
step = len(rooms_to_scan) if len(rooms_to_scan) != 0 else 1
|
||||
|
||||
chunk = []
|
||||
for i in xrange(0, len(rooms_to_scan), step):
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -46,6 +47,7 @@ class RoomMemberHandler(BaseHandler):
|
|||
super(RoomMemberHandler, self).__init__(hs)
|
||||
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
self.event_creation_hander = hs.get_event_creation_handler()
|
||||
|
||||
self.member_linearizer = Linearizer(name="member")
|
||||
|
||||
|
@ -66,13 +68,12 @@ class RoomMemberHandler(BaseHandler):
|
|||
):
|
||||
if content is None:
|
||||
content = {}
|
||||
msg_handler = self.hs.get_handlers().message_handler
|
||||
|
||||
content["membership"] = membership
|
||||
if requester.is_guest:
|
||||
content["kind"] = "guest"
|
||||
|
||||
event, context = yield msg_handler.create_event(
|
||||
event, context = yield self.event_creation_hander.create_event(
|
||||
requester,
|
||||
{
|
||||
"type": EventTypes.Member,
|
||||
|
@ -90,12 +91,14 @@ class RoomMemberHandler(BaseHandler):
|
|||
)
|
||||
|
||||
# Check if this event matches the previous membership event for the user.
|
||||
duplicate = yield msg_handler.deduplicate_state_event(event, context)
|
||||
duplicate = yield self.event_creation_hander.deduplicate_state_event(
|
||||
event, context,
|
||||
)
|
||||
if duplicate is not None:
|
||||
# Discard the new event since this membership change is a no-op.
|
||||
defer.returnValue(duplicate)
|
||||
|
||||
yield msg_handler.handle_new_client_event(
|
||||
yield self.event_creation_hander.handle_new_client_event(
|
||||
requester,
|
||||
event,
|
||||
context,
|
||||
|
@ -394,8 +397,9 @@ class RoomMemberHandler(BaseHandler):
|
|||
else:
|
||||
requester = synapse.types.create_requester(target_user)
|
||||
|
||||
message_handler = self.hs.get_handlers().message_handler
|
||||
prev_event = yield message_handler.deduplicate_state_event(event, context)
|
||||
prev_event = yield self.event_creation_hander.deduplicate_state_event(
|
||||
event, context,
|
||||
)
|
||||
if prev_event is not None:
|
||||
return
|
||||
|
||||
|
@ -412,7 +416,7 @@ class RoomMemberHandler(BaseHandler):
|
|||
if is_blocked:
|
||||
raise SynapseError(403, "This room has been blocked on this server")
|
||||
|
||||
yield message_handler.handle_new_client_event(
|
||||
yield self.event_creation_hander.handle_new_client_event(
|
||||
requester,
|
||||
event,
|
||||
context,
|
||||
|
@ -644,8 +648,7 @@ class RoomMemberHandler(BaseHandler):
|
|||
)
|
||||
)
|
||||
|
||||
msg_handler = self.hs.get_handlers().message_handler
|
||||
yield msg_handler.create_and_send_nonmember_event(
|
||||
yield self.event_creation_hander.create_and_send_nonmember_event(
|
||||
requester,
|
||||
{
|
||||
"type": EventTypes.ThirdPartyInvite,
|
||||
|
|
|
@ -31,7 +31,7 @@ class SetPasswordHandler(BaseHandler):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def set_password(self, user_id, newpassword, requester=None):
|
||||
password_hash = self._auth_handler.hash(newpassword)
|
||||
password_hash = yield self._auth_handler.hash(newpassword)
|
||||
|
||||
except_device_id = requester.device_id if requester else None
|
||||
except_access_token_id = requester.access_token_id if requester else None
|
||||
|
|
|
@ -18,6 +18,7 @@ from OpenSSL.SSL import VERIFY_NONE
|
|||
from synapse.api.errors import (
|
||||
CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
|
||||
)
|
||||
from synapse.util.caches import CACHE_SIZE_FACTOR
|
||||
from synapse.util.logcontext import make_deferred_yieldable
|
||||
from synapse.util import logcontext
|
||||
import synapse.metrics
|
||||
|
@ -30,6 +31,7 @@ from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
|
|||
from twisted.web.client import (
|
||||
BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent,
|
||||
readBody, PartialDownloadError,
|
||||
HTTPConnectionPool,
|
||||
)
|
||||
from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer
|
||||
from twisted.web.http import PotentialDataLoss
|
||||
|
@ -64,13 +66,23 @@ class SimpleHttpClient(object):
|
|||
"""
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
|
||||
pool = HTTPConnectionPool(reactor)
|
||||
|
||||
# the pusher makes lots of concurrent SSL connections to sygnal, and
|
||||
# tends to do so in batches, so we need to allow the pool to keep lots
|
||||
# of idle connections around.
|
||||
pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5))
|
||||
pool.cachedConnectionTimeout = 2 * 60
|
||||
|
||||
# The default context factory in Twisted 14.0.0 (which we require) is
|
||||
# BrowserLikePolicyForHTTPS which will do regular cert validation
|
||||
# 'like a browser'
|
||||
self.agent = Agent(
|
||||
reactor,
|
||||
connectTimeout=15,
|
||||
contextFactory=hs.get_http_client_context_factory()
|
||||
contextFactory=hs.get_http_client_context_factory(),
|
||||
pool=pool,
|
||||
)
|
||||
self.user_agent = hs.version_string
|
||||
self.clock = hs.get_clock()
|
||||
|
|
|
@ -357,8 +357,7 @@ def _get_hosts_for_srv_record(dns_client, host):
|
|||
def eb(res, record_type):
|
||||
if res.check(DNSNameError):
|
||||
return []
|
||||
logger.warn("Error looking up %s for %s: %s",
|
||||
record_type, host, res, res.value)
|
||||
logger.warn("Error looking up %s for %s: %s", record_type, host, res)
|
||||
return res
|
||||
|
||||
# no logcontexts here, so we can safely fire these off and gatherResults
|
||||
|
|
|
@ -27,7 +27,7 @@ import synapse.metrics
|
|||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from synapse.api.errors import (
|
||||
SynapseError, Codes, HttpResponseException,
|
||||
SynapseError, Codes, HttpResponseException, FederationDeniedError,
|
||||
)
|
||||
|
||||
from signedjson.sign import sign_json
|
||||
|
@ -123,11 +123,22 @@ class MatrixFederationHttpClient(object):
|
|||
|
||||
Fails with ``HTTPRequestException``: if we get an HTTP response
|
||||
code >= 300.
|
||||
|
||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||
to retry this server.
|
||||
|
||||
Fails with ``FederationDeniedError`` if this destination
|
||||
is not on our federation whitelist
|
||||
|
||||
(May also fail with plenty of other Exceptions for things like DNS
|
||||
failures, connection failures, SSL failures.)
|
||||
"""
|
||||
if (
|
||||
self.hs.config.federation_domain_whitelist and
|
||||
destination not in self.hs.config.federation_domain_whitelist
|
||||
):
|
||||
raise FederationDeniedError(destination)
|
||||
|
||||
limiter = yield synapse.util.retryutils.get_retry_limiter(
|
||||
destination,
|
||||
self.clock,
|
||||
|
@ -308,6 +319,9 @@ class MatrixFederationHttpClient(object):
|
|||
|
||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||
to retry this server.
|
||||
|
||||
Fails with ``FederationDeniedError`` if this destination
|
||||
is not on our federation whitelist
|
||||
"""
|
||||
|
||||
if not json_data_callback:
|
||||
|
@ -368,6 +382,9 @@ class MatrixFederationHttpClient(object):
|
|||
|
||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||
to retry this server.
|
||||
|
||||
Fails with ``FederationDeniedError`` if this destination
|
||||
is not on our federation whitelist
|
||||
"""
|
||||
|
||||
def body_callback(method, url_bytes, headers_dict):
|
||||
|
@ -422,6 +439,9 @@ class MatrixFederationHttpClient(object):
|
|||
|
||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||
to retry this server.
|
||||
|
||||
Fails with ``FederationDeniedError`` if this destination
|
||||
is not on our federation whitelist
|
||||
"""
|
||||
logger.debug("get_json args: %s", args)
|
||||
|
||||
|
@ -475,6 +495,9 @@ class MatrixFederationHttpClient(object):
|
|||
|
||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||
to retry this server.
|
||||
|
||||
Fails with ``FederationDeniedError`` if this destination
|
||||
is not on our federation whitelist
|
||||
"""
|
||||
|
||||
response = yield self._request(
|
||||
|
@ -518,6 +541,9 @@ class MatrixFederationHttpClient(object):
|
|||
|
||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||
to retry this server.
|
||||
|
||||
Fails with ``FederationDeniedError`` if this destination
|
||||
is not on our federation whitelist
|
||||
"""
|
||||
|
||||
encoded_args = {}
|
||||
|
|
|
@ -42,36 +42,70 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
metrics = synapse.metrics.get_metrics_for(__name__)
|
||||
|
||||
incoming_requests_counter = metrics.register_counter(
|
||||
"requests",
|
||||
# total number of responses served, split by method/servlet/tag
|
||||
response_count = metrics.register_counter(
|
||||
"response_count",
|
||||
labels=["method", "servlet", "tag"],
|
||||
alternative_names=(
|
||||
# the following are all deprecated aliases for the same metric
|
||||
metrics.name_prefix + x for x in (
|
||||
"_requests",
|
||||
"_response_time:count",
|
||||
"_response_ru_utime:count",
|
||||
"_response_ru_stime:count",
|
||||
"_response_db_txn_count:count",
|
||||
"_response_db_txn_duration:count",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
outgoing_responses_counter = metrics.register_counter(
|
||||
"responses",
|
||||
labels=["method", "code"],
|
||||
)
|
||||
|
||||
response_timer = metrics.register_distribution(
|
||||
"response_time",
|
||||
labels=["method", "servlet", "tag"]
|
||||
response_timer = metrics.register_counter(
|
||||
"response_time_seconds",
|
||||
labels=["method", "servlet", "tag"],
|
||||
alternative_names=(
|
||||
metrics.name_prefix + "_response_time:total",
|
||||
),
|
||||
)
|
||||
|
||||
response_ru_utime = metrics.register_distribution(
|
||||
"response_ru_utime", labels=["method", "servlet", "tag"]
|
||||
response_ru_utime = metrics.register_counter(
|
||||
"response_ru_utime_seconds", labels=["method", "servlet", "tag"],
|
||||
alternative_names=(
|
||||
metrics.name_prefix + "_response_ru_utime:total",
|
||||
),
|
||||
)
|
||||
|
||||
response_ru_stime = metrics.register_distribution(
|
||||
"response_ru_stime", labels=["method", "servlet", "tag"]
|
||||
response_ru_stime = metrics.register_counter(
|
||||
"response_ru_stime_seconds", labels=["method", "servlet", "tag"],
|
||||
alternative_names=(
|
||||
metrics.name_prefix + "_response_ru_stime:total",
|
||||
),
|
||||
)
|
||||
|
||||
response_db_txn_count = metrics.register_distribution(
|
||||
"response_db_txn_count", labels=["method", "servlet", "tag"]
|
||||
response_db_txn_count = metrics.register_counter(
|
||||
"response_db_txn_count", labels=["method", "servlet", "tag"],
|
||||
alternative_names=(
|
||||
metrics.name_prefix + "_response_db_txn_count:total",
|
||||
),
|
||||
)
|
||||
|
||||
response_db_txn_duration = metrics.register_distribution(
|
||||
"response_db_txn_duration", labels=["method", "servlet", "tag"]
|
||||
# seconds spent waiting for db txns, excluding scheduling time, when processing
|
||||
# this request
|
||||
response_db_txn_duration = metrics.register_counter(
|
||||
"response_db_txn_duration_seconds", labels=["method", "servlet", "tag"],
|
||||
alternative_names=(
|
||||
metrics.name_prefix + "_response_db_txn_duration:total",
|
||||
),
|
||||
)
|
||||
|
||||
# seconds spent waiting for a db connection, when processing this request
|
||||
response_db_sched_duration = metrics.register_counter(
|
||||
"response_db_sched_duration_seconds", labels=["method", "servlet", "tag"]
|
||||
)
|
||||
|
||||
_next_request_id = 0
|
||||
|
||||
|
@ -107,6 +141,10 @@ def wrap_request_handler(request_handler, include_metrics=False):
|
|||
with LoggingContext(request_id) as request_context:
|
||||
with Measure(self.clock, "wrapped_request_handler"):
|
||||
request_metrics = RequestMetrics()
|
||||
# we start the request metrics timer here with an initial stab
|
||||
# at the servlet name. For most requests that name will be
|
||||
# JsonResource (or a subclass), and JsonResource._async_render
|
||||
# will update it once it picks a servlet.
|
||||
request_metrics.start(self.clock, name=self.__class__.__name__)
|
||||
|
||||
request_context.request = request_id
|
||||
|
@ -249,12 +287,23 @@ class JsonResource(HttpServer, resource.Resource):
|
|||
if not m:
|
||||
continue
|
||||
|
||||
# We found a match! Trigger callback and then return the
|
||||
# returned response. We pass both the request and any
|
||||
# matched groups from the regex to the callback.
|
||||
# We found a match! First update the metrics object to indicate
|
||||
# which servlet is handling the request.
|
||||
|
||||
callback = path_entry.callback
|
||||
|
||||
servlet_instance = getattr(callback, "__self__", None)
|
||||
if servlet_instance is not None:
|
||||
servlet_classname = servlet_instance.__class__.__name__
|
||||
else:
|
||||
servlet_classname = "%r" % callback
|
||||
|
||||
request_metrics.name = servlet_classname
|
||||
|
||||
# Now trigger the callback. If it returns a response, we send it
|
||||
# here. If it throws an exception, that is handled by the wrapper
|
||||
# installed by @request_handler.
|
||||
|
||||
kwargs = intern_dict({
|
||||
name: urllib.unquote(value).decode("UTF-8") if value else value
|
||||
for name, value in m.groupdict().items()
|
||||
|
@ -265,30 +314,14 @@ class JsonResource(HttpServer, resource.Resource):
|
|||
code, response = callback_return
|
||||
self._send_response(request, code, response)
|
||||
|
||||
servlet_instance = getattr(callback, "__self__", None)
|
||||
if servlet_instance is not None:
|
||||
servlet_classname = servlet_instance.__class__.__name__
|
||||
else:
|
||||
servlet_classname = "%r" % callback
|
||||
|
||||
request_metrics.name = servlet_classname
|
||||
|
||||
return
|
||||
|
||||
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
|
||||
request_metrics.name = self.__class__.__name__ + ".UnrecognizedRequest"
|
||||
raise UnrecognizedRequestError()
|
||||
|
||||
def _send_response(self, request, code, response_json_object,
|
||||
response_code_message=None):
|
||||
# could alternatively use request.notifyFinish() and flip a flag when
|
||||
# the Deferred fires, but since the flag is RIGHT THERE it seems like
|
||||
# a waste.
|
||||
if request._disconnected:
|
||||
logger.warn(
|
||||
"Not sending response to request %s, already disconnected.",
|
||||
request)
|
||||
return
|
||||
|
||||
outgoing_responses_counter.inc(request.method, str(code))
|
||||
|
||||
# TODO: Only enable CORS for the requests that need it.
|
||||
|
@ -322,7 +355,7 @@ class RequestMetrics(object):
|
|||
)
|
||||
return
|
||||
|
||||
incoming_requests_counter.inc(request.method, self.name, tag)
|
||||
response_count.inc(request.method, self.name, tag)
|
||||
|
||||
response_timer.inc_by(
|
||||
clock.time_msec() - self.start, request.method,
|
||||
|
@ -341,7 +374,10 @@ class RequestMetrics(object):
|
|||
context.db_txn_count, request.method, self.name, tag
|
||||
)
|
||||
response_db_txn_duration.inc_by(
|
||||
context.db_txn_duration, request.method, self.name, tag
|
||||
context.db_txn_duration_ms / 1000., request.method, self.name, tag
|
||||
)
|
||||
response_db_sched_duration.inc_by(
|
||||
context.db_sched_duration_ms / 1000., request.method, self.name, tag
|
||||
)
|
||||
|
||||
|
||||
|
@ -364,6 +400,15 @@ class RootRedirect(resource.Resource):
|
|||
def respond_with_json(request, code, json_object, send_cors=False,
|
||||
response_code_message=None, pretty_print=False,
|
||||
version_string="", canonical_json=True):
|
||||
# could alternatively use request.notifyFinish() and flip a flag when
|
||||
# the Deferred fires, but since the flag is RIGHT THERE it seems like
|
||||
# a waste.
|
||||
if request._disconnected:
|
||||
logger.warn(
|
||||
"Not sending response to request %s, already disconnected.",
|
||||
request)
|
||||
return
|
||||
|
||||
if pretty_print:
|
||||
json_bytes = encode_pretty_printed_json(json_object) + "\n"
|
||||
else:
|
||||
|
|
|
@ -148,11 +148,13 @@ def parse_string_from_args(args, name, default=None, required=False,
|
|||
return default
|
||||
|
||||
|
||||
def parse_json_value_from_request(request):
|
||||
def parse_json_value_from_request(request, allow_empty_body=False):
|
||||
"""Parse a JSON value from the body of a twisted HTTP request.
|
||||
|
||||
Args:
|
||||
request: the twisted HTTP request.
|
||||
allow_empty_body (bool): if True, an empty body will be accepted and
|
||||
turned into None
|
||||
|
||||
Returns:
|
||||
The JSON value.
|
||||
|
@ -165,6 +167,9 @@ def parse_json_value_from_request(request):
|
|||
except Exception:
|
||||
raise SynapseError(400, "Error reading JSON content.")
|
||||
|
||||
if not content_bytes and allow_empty_body:
|
||||
return None
|
||||
|
||||
try:
|
||||
content = simplejson.loads(content_bytes)
|
||||
except Exception as e:
|
||||
|
@ -174,17 +179,24 @@ def parse_json_value_from_request(request):
|
|||
return content
|
||||
|
||||
|
||||
def parse_json_object_from_request(request):
|
||||
def parse_json_object_from_request(request, allow_empty_body=False):
|
||||
"""Parse a JSON object from the body of a twisted HTTP request.
|
||||
|
||||
Args:
|
||||
request: the twisted HTTP request.
|
||||
allow_empty_body (bool): if True, an empty body will be accepted and
|
||||
turned into an empty dict.
|
||||
|
||||
Raises:
|
||||
SynapseError if the request body couldn't be decoded as JSON or
|
||||
if it wasn't a JSON object.
|
||||
"""
|
||||
content = parse_json_value_from_request(request)
|
||||
content = parse_json_value_from_request(
|
||||
request, allow_empty_body=allow_empty_body,
|
||||
)
|
||||
|
||||
if allow_empty_body and content is None:
|
||||
return {}
|
||||
|
||||
if type(content) != dict:
|
||||
message = "Content must be a JSON object."
|
||||
|
|
|
@ -66,14 +66,15 @@ class SynapseRequest(Request):
|
|||
context = LoggingContext.current_context()
|
||||
ru_utime, ru_stime = context.get_resource_usage()
|
||||
db_txn_count = context.db_txn_count
|
||||
db_txn_duration = context.db_txn_duration
|
||||
db_txn_duration_ms = context.db_txn_duration_ms
|
||||
db_sched_duration_ms = context.db_sched_duration_ms
|
||||
except Exception:
|
||||
ru_utime, ru_stime = (0, 0)
|
||||
db_txn_count, db_txn_duration = (0, 0)
|
||||
db_txn_count, db_txn_duration_ms = (0, 0)
|
||||
|
||||
self.site.access_logger.info(
|
||||
"%s - %s - {%s}"
|
||||
" Processed request: %dms (%dms, %dms) (%dms/%d)"
|
||||
" Processed request: %dms (%dms, %dms) (%dms/%dms/%d)"
|
||||
" %sB %s \"%s %s %s\" \"%s\"",
|
||||
self.getClientIP(),
|
||||
self.site.site_tag,
|
||||
|
@ -81,7 +82,8 @@ class SynapseRequest(Request):
|
|||
int(time.time() * 1000) - self.start_time,
|
||||
int(ru_utime * 1000),
|
||||
int(ru_stime * 1000),
|
||||
int(db_txn_duration * 1000),
|
||||
db_sched_duration_ms,
|
||||
db_txn_duration_ms,
|
||||
int(db_txn_count),
|
||||
self.sentLength,
|
||||
self.code,
|
||||
|
|
|
@ -146,10 +146,15 @@ def runUntilCurrentTimer(func):
|
|||
num_pending += 1
|
||||
|
||||
num_pending += len(reactor.threadCallQueue)
|
||||
|
||||
start = time.time() * 1000
|
||||
ret = func(*args, **kwargs)
|
||||
end = time.time() * 1000
|
||||
|
||||
# record the amount of wallclock time spent running pending calls.
|
||||
# This is a proxy for the actual amount of time between reactor polls,
|
||||
# since about 25% of time is actually spent running things triggered by
|
||||
# I/O events, but that is harder to capture without rewriting half the
|
||||
# reactor.
|
||||
tick_time.inc_by(end - start)
|
||||
pending_calls_metric.inc_by(num_pending)
|
||||
|
||||
|
|
|
@ -15,18 +15,38 @@
|
|||
|
||||
|
||||
from itertools import chain
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# TODO(paul): I can't believe Python doesn't have one of these
|
||||
def map_concat(func, items):
|
||||
# flatten a list-of-lists
|
||||
return list(chain.from_iterable(map(func, items)))
|
||||
def flatten(items):
|
||||
"""Flatten a list of lists
|
||||
|
||||
Args:
|
||||
items: iterable[iterable[X]]
|
||||
|
||||
Returns:
|
||||
list[X]: flattened list
|
||||
"""
|
||||
return list(chain.from_iterable(items))
|
||||
|
||||
|
||||
class BaseMetric(object):
|
||||
"""Base class for metrics which report a single value per label set
|
||||
"""
|
||||
|
||||
def __init__(self, name, labels=[]):
|
||||
self.name = name
|
||||
def __init__(self, name, labels=[], alternative_names=[]):
|
||||
"""
|
||||
Args:
|
||||
name (str): principal name for this metric
|
||||
labels (list(str)): names of the labels which will be reported
|
||||
for this metric
|
||||
alternative_names (iterable(str)): list of alternative names for
|
||||
this metric. This can be useful to provide a migration path
|
||||
when renaming metrics.
|
||||
"""
|
||||
self._names = [name] + list(alternative_names)
|
||||
self.labels = labels # OK not to clone as we never write it
|
||||
|
||||
def dimension(self):
|
||||
|
@ -36,7 +56,7 @@ class BaseMetric(object):
|
|||
return not len(self.labels)
|
||||
|
||||
def _render_labelvalue(self, value):
|
||||
# TODO: some kind of value escape
|
||||
# TODO: escape backslashes, quotes and newlines
|
||||
return '"%s"' % (value)
|
||||
|
||||
def _render_key(self, values):
|
||||
|
@ -47,19 +67,60 @@ class BaseMetric(object):
|
|||
for k, v in zip(self.labels, values)])
|
||||
)
|
||||
|
||||
def _render_for_labels(self, label_values, value):
|
||||
"""Render this metric for a single set of labels
|
||||
|
||||
Args:
|
||||
label_values (list[str]): values for each of the labels
|
||||
value: value of the metric at with these labels
|
||||
|
||||
Returns:
|
||||
iterable[str]: rendered metric
|
||||
"""
|
||||
rendered_labels = self._render_key(label_values)
|
||||
return (
|
||||
"%s%s %.12g" % (name, rendered_labels, value)
|
||||
for name in self._names
|
||||
)
|
||||
|
||||
def render(self):
|
||||
"""Render this metric
|
||||
|
||||
Each metric is rendered as:
|
||||
|
||||
name{label1="val1",label2="val2"} value
|
||||
|
||||
https://prometheus.io/docs/instrumenting/exposition_formats/#text-format-details
|
||||
|
||||
Returns:
|
||||
iterable[str]: rendered metrics
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class CounterMetric(BaseMetric):
|
||||
"""The simplest kind of metric; one that stores a monotonically-increasing
|
||||
integer that counts events."""
|
||||
value that counts events or running totals.
|
||||
|
||||
Example use cases for Counters:
|
||||
- Number of requests processed
|
||||
- Number of items that were inserted into a queue
|
||||
- Total amount of data that a system has processed
|
||||
Counters can only go up (and be reset when the process restarts).
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(CounterMetric, self).__init__(*args, **kwargs)
|
||||
|
||||
# dict[list[str]]: value for each set of label values. the keys are the
|
||||
# label values, in the same order as the labels in self.labels.
|
||||
#
|
||||
# (if the metric is a scalar, the (single) key is the empty list).
|
||||
self.counts = {}
|
||||
|
||||
# Scalar metrics are never empty
|
||||
if self.is_scalar():
|
||||
self.counts[()] = 0
|
||||
self.counts[()] = 0.
|
||||
|
||||
def inc_by(self, incr, *values):
|
||||
if len(values) != self.dimension():
|
||||
|
@ -77,11 +138,11 @@ class CounterMetric(BaseMetric):
|
|||
def inc(self, *values):
|
||||
self.inc_by(1, *values)
|
||||
|
||||
def render_item(self, k):
|
||||
return ["%s%s %d" % (self.name, self._render_key(k), self.counts[k])]
|
||||
|
||||
def render(self):
|
||||
return map_concat(self.render_item, sorted(self.counts.keys()))
|
||||
return flatten(
|
||||
self._render_for_labels(k, self.counts[k])
|
||||
for k in sorted(self.counts.keys())
|
||||
)
|
||||
|
||||
|
||||
class CallbackMetric(BaseMetric):
|
||||
|
@ -95,13 +156,19 @@ class CallbackMetric(BaseMetric):
|
|||
self.callback = callback
|
||||
|
||||
def render(self):
|
||||
value = self.callback()
|
||||
try:
|
||||
value = self.callback()
|
||||
except Exception:
|
||||
logger.exception("Failed to render %s", self.name)
|
||||
return ["# FAILED to render " + self.name]
|
||||
|
||||
if self.is_scalar():
|
||||
return ["%s %.12g" % (self.name, value)]
|
||||
return list(self._render_for_labels([], value))
|
||||
|
||||
return ["%s%s %.12g" % (self.name, self._render_key(k), value[k])
|
||||
for k in sorted(value.keys())]
|
||||
return flatten(
|
||||
self._render_for_labels(k, value[k])
|
||||
for k in sorted(value.keys())
|
||||
)
|
||||
|
||||
|
||||
class DistributionMetric(object):
|
||||
|
@ -126,7 +193,9 @@ class DistributionMetric(object):
|
|||
|
||||
|
||||
class CacheMetric(object):
|
||||
__slots__ = ("name", "cache_name", "hits", "misses", "size_callback")
|
||||
__slots__ = (
|
||||
"name", "cache_name", "hits", "misses", "evicted_size", "size_callback",
|
||||
)
|
||||
|
||||
def __init__(self, name, size_callback, cache_name):
|
||||
self.name = name
|
||||
|
@ -134,6 +203,7 @@ class CacheMetric(object):
|
|||
|
||||
self.hits = 0
|
||||
self.misses = 0
|
||||
self.evicted_size = 0
|
||||
|
||||
self.size_callback = size_callback
|
||||
|
||||
|
@ -143,6 +213,9 @@ class CacheMetric(object):
|
|||
def inc_misses(self):
|
||||
self.misses += 1
|
||||
|
||||
def inc_evictions(self, size=1):
|
||||
self.evicted_size += size
|
||||
|
||||
def render(self):
|
||||
size = self.size_callback()
|
||||
hits = self.hits
|
||||
|
@ -152,6 +225,9 @@ class CacheMetric(object):
|
|||
"""%s:hits{name="%s"} %d""" % (self.name, self.cache_name, hits),
|
||||
"""%s:total{name="%s"} %d""" % (self.name, self.cache_name, total),
|
||||
"""%s:size{name="%s"} %d""" % (self.name, self.cache_name, size),
|
||||
"""%s:evicted_size{name="%s"} %d""" % (
|
||||
self.name, self.cache_name, self.evicted_size
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -13,21 +13,30 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.push import PusherConfigException
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
|
||||
|
||||
import logging
|
||||
import push_rule_evaluator
|
||||
import push_tools
|
||||
|
||||
import synapse
|
||||
from synapse.push import PusherConfigException
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
metrics = synapse.metrics.get_metrics_for(__name__)
|
||||
|
||||
http_push_processed_counter = metrics.register_counter(
|
||||
"http_pushes_processed",
|
||||
)
|
||||
|
||||
http_push_failed_counter = metrics.register_counter(
|
||||
"http_pushes_failed",
|
||||
)
|
||||
|
||||
|
||||
class HttpPusher(object):
|
||||
INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes
|
||||
|
@ -152,9 +161,16 @@ class HttpPusher(object):
|
|||
self.user_id, self.last_stream_ordering, self.max_stream_ordering
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Processing %i unprocessed push actions for %s starting at "
|
||||
"stream_ordering %s",
|
||||
len(unprocessed), self.name, self.last_stream_ordering,
|
||||
)
|
||||
|
||||
for push_action in unprocessed:
|
||||
processed = yield self._process_one(push_action)
|
||||
if processed:
|
||||
http_push_processed_counter.inc()
|
||||
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
|
||||
self.last_stream_ordering = push_action['stream_ordering']
|
||||
yield self.store.update_pusher_last_stream_ordering_and_success(
|
||||
|
@ -169,6 +185,7 @@ class HttpPusher(object):
|
|||
self.failing_since
|
||||
)
|
||||
else:
|
||||
http_push_failed_counter.inc()
|
||||
if not self.failing_since:
|
||||
self.failing_since = self.clock.time_msec()
|
||||
yield self.store.update_pusher_failing_since(
|
||||
|
@ -316,7 +333,10 @@ class HttpPusher(object):
|
|||
try:
|
||||
resp = yield self.http_client.post_json_get_json(self.url, notification_dict)
|
||||
except Exception:
|
||||
logger.warn("Failed to push %s ", self.url)
|
||||
logger.warn(
|
||||
"Failed to push event %s to %s",
|
||||
event.event_id, self.name, exc_info=True,
|
||||
)
|
||||
defer.returnValue(False)
|
||||
rejected = []
|
||||
if 'rejected' in resp:
|
||||
|
@ -325,7 +345,7 @@ class HttpPusher(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _send_badge(self, badge):
|
||||
logger.info("Sending updated badge count %d to %r", badge, self.user_id)
|
||||
logger.info("Sending updated badge count %d to %s", badge, self.name)
|
||||
d = {
|
||||
'notification': {
|
||||
'id': '',
|
||||
|
@ -347,7 +367,10 @@ class HttpPusher(object):
|
|||
try:
|
||||
resp = yield self.http_client.post_json_get_json(self.url, d)
|
||||
except Exception:
|
||||
logger.exception("Failed to push %s ", self.url)
|
||||
logger.warn(
|
||||
"Failed to send badge count to %s",
|
||||
self.name, exc_info=True,
|
||||
)
|
||||
defer.returnValue(False)
|
||||
rejected = []
|
||||
if 'rejected' in resp:
|
||||
|
|
|
@ -19,7 +19,7 @@ from synapse.storage import DataStore
|
|||
from synapse.storage.event_federation import EventFederationStore
|
||||
from synapse.storage.event_push_actions import EventPushActionsStore
|
||||
from synapse.storage.roommember import RoomMemberStore
|
||||
from synapse.storage.state import StateGroupReadStore
|
||||
from synapse.storage.state import StateGroupWorkerStore
|
||||
from synapse.storage.stream import StreamStore
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
from ._base import BaseSlavedStore
|
||||
|
@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
|
|||
# the method descriptor on the DataStore and chuck them into our class.
|
||||
|
||||
|
||||
class SlavedEventStore(StateGroupReadStore, BaseSlavedStore):
|
||||
class SlavedEventStore(StateGroupWorkerStore, BaseSlavedStore):
|
||||
|
||||
def __init__(self, db_conn, hs):
|
||||
super(SlavedEventStore, self).__init__(db_conn, hs)
|
||||
|
|
|
@ -517,25 +517,28 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||
self.send_error("Wrong remote")
|
||||
|
||||
def on_RDATA(self, cmd):
|
||||
stream_name = cmd.stream_name
|
||||
inbound_rdata_count.inc(stream_name)
|
||||
|
||||
try:
|
||||
row = STREAMS_MAP[cmd.stream_name].ROW_TYPE(*cmd.row)
|
||||
row = STREAMS_MAP[stream_name].ROW_TYPE(*cmd.row)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"[%s] Failed to parse RDATA: %r %r",
|
||||
self.id(), cmd.stream_name, cmd.row
|
||||
self.id(), stream_name, cmd.row
|
||||
)
|
||||
raise
|
||||
|
||||
if cmd.token is None:
|
||||
# I.e. this is part of a batch of updates for this stream. Batch
|
||||
# until we get an update for the stream with a non None token
|
||||
self.pending_batches.setdefault(cmd.stream_name, []).append(row)
|
||||
self.pending_batches.setdefault(stream_name, []).append(row)
|
||||
else:
|
||||
# Check if this is the last of a batch of updates
|
||||
rows = self.pending_batches.pop(cmd.stream_name, [])
|
||||
rows = self.pending_batches.pop(stream_name, [])
|
||||
rows.append(row)
|
||||
|
||||
self.handler.on_rdata(cmd.stream_name, cmd.token, rows)
|
||||
self.handler.on_rdata(stream_name, cmd.token, rows)
|
||||
|
||||
def on_POSITION(self, cmd):
|
||||
self.handler.on_position(cmd.stream_name, cmd.token)
|
||||
|
@ -644,3 +647,9 @@ metrics.register_callback(
|
|||
},
|
||||
labels=["command", "name", "conn_id"],
|
||||
)
|
||||
|
||||
# number of updates received for each RDATA stream
|
||||
inbound_rdata_count = metrics.register_counter(
|
||||
"inbound_rdata_count",
|
||||
labels=["stream_name"],
|
||||
)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -128,7 +129,16 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
|
|||
if not is_admin:
|
||||
raise AuthError(403, "You are not a server admin")
|
||||
|
||||
yield self.handlers.message_handler.purge_history(room_id, event_id)
|
||||
body = parse_json_object_from_request(request, allow_empty_body=True)
|
||||
|
||||
delete_local_events = bool(
|
||||
body.get("delete_local_history", False)
|
||||
)
|
||||
|
||||
yield self.handlers.message_handler.purge_history(
|
||||
room_id, event_id,
|
||||
delete_local_events=delete_local_events,
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
@ -171,6 +181,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
|
|||
self.store = hs.get_datastore()
|
||||
self.handlers = hs.get_handlers()
|
||||
self.state = hs.get_state_handler()
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_id):
|
||||
|
@ -203,8 +214,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
|
|||
)
|
||||
new_room_id = info["room_id"]
|
||||
|
||||
msg_handler = self.handlers.message_handler
|
||||
yield msg_handler.create_and_send_nonmember_event(
|
||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
room_creator_requester,
|
||||
{
|
||||
"type": "m.room.message",
|
||||
|
@ -289,6 +299,27 @@ class QuarantineMediaInRoom(ClientV1RestServlet):
|
|||
defer.returnValue((200, {"num_quarantined": num_quarantined}))
|
||||
|
||||
|
||||
class ListMediaInRoom(ClientV1RestServlet):
|
||||
"""Lists all of the media in a given room.
|
||||
"""
|
||||
PATTERNS = client_path_patterns("/admin/room/(?P<room_id>[^/]+)/media")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ListMediaInRoom, self).__init__(hs)
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
is_admin = yield self.auth.is_server_admin(requester.user)
|
||||
if not is_admin:
|
||||
raise AuthError(403, "You are not a server admin")
|
||||
|
||||
local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id)
|
||||
|
||||
defer.returnValue((200, {"local": local_mxcs, "remote": remote_mxcs}))
|
||||
|
||||
|
||||
class ResetPasswordRestServlet(ClientV1RestServlet):
|
||||
"""Post request to allow an administrator reset password for a user.
|
||||
This needs user to have administrator access in Synapse.
|
||||
|
@ -487,3 +518,4 @@ def register_servlets(hs, http_server):
|
|||
SearchUsersRestServlet(hs).register(http_server)
|
||||
ShutdownRoomRestServlet(hs).register(http_server)
|
||||
QuarantineMediaInRoom(hs).register(http_server)
|
||||
ListMediaInRoom(hs).register(http_server)
|
||||
|
|
|
@ -191,19 +191,25 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
|
||||
# convert threepid identifiers to user IDs
|
||||
if identifier["type"] == "m.id.thirdparty":
|
||||
if 'medium' not in identifier or 'address' not in identifier:
|
||||
address = identifier.get('address')
|
||||
medium = identifier.get('medium')
|
||||
|
||||
if medium is None or address is None:
|
||||
raise SynapseError(400, "Invalid thirdparty identifier")
|
||||
|
||||
address = identifier['address']
|
||||
if identifier['medium'] == 'email':
|
||||
if medium == 'email':
|
||||
# For emails, transform the address to lowercase.
|
||||
# We store all email addreses as lowercase in the DB.
|
||||
# (See add_threepid in synapse/handlers/auth.py)
|
||||
address = address.lower()
|
||||
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
identifier['medium'], address
|
||||
medium, address,
|
||||
)
|
||||
if not user_id:
|
||||
logger.warn(
|
||||
"unknown 3pid identifier medium %s, address %r",
|
||||
medium, address,
|
||||
)
|
||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||
|
||||
identifier = {
|
||||
|
|
|
@ -70,10 +70,15 @@ class RegisterRestServlet(ClientV1RestServlet):
|
|||
self.handlers = hs.get_handlers()
|
||||
|
||||
def on_GET(self, request):
|
||||
|
||||
require_email = 'email' in self.hs.config.registrations_require_3pid
|
||||
require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
|
||||
|
||||
flows = []
|
||||
if self.hs.config.enable_registration_captcha:
|
||||
return (
|
||||
200,
|
||||
{"flows": [
|
||||
# only support the email-only flow if we don't require MSISDN 3PIDs
|
||||
if not require_msisdn:
|
||||
flows.extend([
|
||||
{
|
||||
"type": LoginType.RECAPTCHA,
|
||||
"stages": [
|
||||
|
@ -82,27 +87,34 @@ class RegisterRestServlet(ClientV1RestServlet):
|
|||
LoginType.PASSWORD
|
||||
]
|
||||
},
|
||||
])
|
||||
# only support 3PIDless registration if no 3PIDs are required
|
||||
if not require_email and not require_msisdn:
|
||||
flows.extend([
|
||||
{
|
||||
"type": LoginType.RECAPTCHA,
|
||||
"stages": [LoginType.RECAPTCHA, LoginType.PASSWORD]
|
||||
}
|
||||
]}
|
||||
)
|
||||
])
|
||||
else:
|
||||
return (
|
||||
200,
|
||||
{"flows": [
|
||||
# only support the email-only flow if we don't require MSISDN 3PIDs
|
||||
if require_email or not require_msisdn:
|
||||
flows.extend([
|
||||
{
|
||||
"type": LoginType.EMAIL_IDENTITY,
|
||||
"stages": [
|
||||
LoginType.EMAIL_IDENTITY, LoginType.PASSWORD
|
||||
]
|
||||
},
|
||||
}
|
||||
])
|
||||
# only support 3PIDless registration if no 3PIDs are required
|
||||
if not require_email and not require_msisdn:
|
||||
flows.extend([
|
||||
{
|
||||
"type": LoginType.PASSWORD
|
||||
}
|
||||
]}
|
||||
)
|
||||
])
|
||||
return (200, {"flows": flows})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -82,6 +83,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
|
|||
def __init__(self, hs):
|
||||
super(RoomStateEventRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
self.event_creation_hander = hs.get_event_creation_handler()
|
||||
|
||||
def register(self, http_server):
|
||||
# /room/$roomid/state/$eventtype
|
||||
|
@ -162,15 +164,16 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
|
|||
content=content,
|
||||
)
|
||||
else:
|
||||
msg_handler = self.handlers.message_handler
|
||||
event, context = yield msg_handler.create_event(
|
||||
event, context = yield self.event_creation_hander.create_event(
|
||||
requester,
|
||||
event_dict,
|
||||
token_id=requester.access_token_id,
|
||||
txn_id=txn_id,
|
||||
)
|
||||
|
||||
yield msg_handler.send_nonmember_event(requester, event, context)
|
||||
yield self.event_creation_hander.send_nonmember_event(
|
||||
requester, event, context,
|
||||
)
|
||||
|
||||
ret = {}
|
||||
if event:
|
||||
|
@ -184,6 +187,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
|
|||
def __init__(self, hs):
|
||||
super(RoomSendEventRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
self.event_creation_hander = hs.get_event_creation_handler()
|
||||
|
||||
def register(self, http_server):
|
||||
# /rooms/$roomid/send/$event_type[/$txn_id]
|
||||
|
@ -195,15 +199,19 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
|
|||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
msg_handler = self.handlers.message_handler
|
||||
event = yield msg_handler.create_and_send_nonmember_event(
|
||||
event_dict = {
|
||||
"type": event_type,
|
||||
"content": content,
|
||||
"room_id": room_id,
|
||||
"sender": requester.user.to_string(),
|
||||
}
|
||||
|
||||
if 'ts' in request.args and requester.app_service:
|
||||
event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
|
||||
|
||||
event = yield self.event_creation_hander.create_and_send_nonmember_event(
|
||||
requester,
|
||||
{
|
||||
"type": event_type,
|
||||
"content": content,
|
||||
"room_id": room_id,
|
||||
"sender": requester.user.to_string(),
|
||||
},
|
||||
event_dict,
|
||||
txn_id=txn_id,
|
||||
)
|
||||
|
||||
|
@ -487,13 +495,35 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
|
|||
defer.returnValue((200, content))
|
||||
|
||||
|
||||
class RoomEventContext(ClientV1RestServlet):
|
||||
class RoomEventServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns(
|
||||
"/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RoomEventServlet, self).__init__(hs)
|
||||
self.clock = hs.get_clock()
|
||||
self.event_handler = hs.get_event_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id, event_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
event = yield self.event_handler.get_event(requester.user, event_id)
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
if event:
|
||||
defer.returnValue((200, serialize_event(event, time_now)))
|
||||
else:
|
||||
defer.returnValue((404, "Event not found."))
|
||||
|
||||
|
||||
class RoomEventContextServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns(
|
||||
"/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RoomEventContext, self).__init__(hs)
|
||||
super(RoomEventContextServlet, self).__init__(hs)
|
||||
self.clock = hs.get_clock()
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
|
@ -643,6 +673,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
|
|||
def __init__(self, hs):
|
||||
super(RoomRedactEventRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
|
||||
def register(self, http_server):
|
||||
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
|
||||
|
@ -653,8 +684,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
|
|||
requester = yield self.auth.get_user_by_req(request)
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
msg_handler = self.handlers.message_handler
|
||||
event = yield msg_handler.create_and_send_nonmember_event(
|
||||
event = yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
{
|
||||
"type": EventTypes.Redaction,
|
||||
|
@ -803,4 +833,5 @@ def register_servlets(hs, http_server):
|
|||
RoomTypingRestServlet(hs).register(http_server)
|
||||
SearchRestServlet(hs).register(http_server)
|
||||
JoinedRoomsRestServlet(hs).register(http_server)
|
||||
RoomEventContext(hs).register(http_server)
|
||||
RoomEventServlet(hs).register(http_server)
|
||||
RoomEventContextServlet(hs).register(http_server)
|
||||
|
|
|
@ -26,6 +26,7 @@ from synapse.http.servlet import (
|
|||
)
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
from synapse.util.threepids import check_3pid_allowed
|
||||
from ._base import client_v2_patterns, interactive_auth_handler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -47,6 +48,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
|
|||
'id_server', 'client_secret', 'email', 'send_attempt'
|
||||
])
|
||||
|
||||
if not check_3pid_allowed(self.hs, "email", body['email']):
|
||||
raise SynapseError(
|
||||
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||
)
|
||||
|
||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
'email', body['email']
|
||||
)
|
||||
|
@ -78,6 +84,11 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
|
|||
|
||||
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
||||
|
||||
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
|
||||
raise SynapseError(
|
||||
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||
)
|
||||
|
||||
existingUid = yield self.datastore.get_user_id_by_threepid(
|
||||
'msisdn', msisdn
|
||||
)
|
||||
|
@ -217,6 +228,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
|
|||
if absent:
|
||||
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
||||
|
||||
if not check_3pid_allowed(self.hs, "email", body['email']):
|
||||
raise SynapseError(
|
||||
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||
)
|
||||
|
||||
existingUid = yield self.datastore.get_user_id_by_threepid(
|
||||
'email', body['email']
|
||||
)
|
||||
|
@ -255,6 +271,11 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
|
|||
|
||||
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
||||
|
||||
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
|
||||
raise SynapseError(
|
||||
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||
)
|
||||
|
||||
existingUid = yield self.datastore.get_user_id_by_threepid(
|
||||
'msisdn', msisdn
|
||||
)
|
||||
|
|
|
@ -26,6 +26,7 @@ from synapse.http.servlet import (
|
|||
RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string
|
||||
)
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
from synapse.util.threepids import check_3pid_allowed
|
||||
|
||||
from ._base import client_v2_patterns, interactive_auth_handler
|
||||
|
||||
|
@ -70,6 +71,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
|
|||
'id_server', 'client_secret', 'email', 'send_attempt'
|
||||
])
|
||||
|
||||
if not check_3pid_allowed(self.hs, "email", body['email']):
|
||||
raise SynapseError(
|
||||
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||
)
|
||||
|
||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
'email', body['email']
|
||||
)
|
||||
|
@ -105,6 +111,11 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
|
|||
|
||||
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
||||
|
||||
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
|
||||
raise SynapseError(
|
||||
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||
)
|
||||
|
||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
'msisdn', msisdn
|
||||
)
|
||||
|
@ -305,31 +316,67 @@ class RegisterRestServlet(RestServlet):
|
|||
if 'x_show_msisdn' in body and body['x_show_msisdn']:
|
||||
show_msisdn = True
|
||||
|
||||
# FIXME: need a better error than "no auth flow found" for scenarios
|
||||
# where we required 3PID for registration but the user didn't give one
|
||||
require_email = 'email' in self.hs.config.registrations_require_3pid
|
||||
require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
|
||||
|
||||
flows = []
|
||||
if self.hs.config.enable_registration_captcha:
|
||||
flows = [
|
||||
[LoginType.RECAPTCHA],
|
||||
[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
|
||||
]
|
||||
# only support 3PIDless registration if no 3PIDs are required
|
||||
if not require_email and not require_msisdn:
|
||||
flows.extend([[LoginType.RECAPTCHA]])
|
||||
# only support the email-only flow if we don't require MSISDN 3PIDs
|
||||
if not require_msisdn:
|
||||
flows.extend([[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]])
|
||||
|
||||
if show_msisdn:
|
||||
# only support the MSISDN-only flow if we don't require email 3PIDs
|
||||
if not require_email:
|
||||
flows.extend([[LoginType.MSISDN, LoginType.RECAPTCHA]])
|
||||
# always let users provide both MSISDN & email
|
||||
flows.extend([
|
||||
[LoginType.MSISDN, LoginType.RECAPTCHA],
|
||||
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
|
||||
])
|
||||
else:
|
||||
flows = [
|
||||
[LoginType.DUMMY],
|
||||
[LoginType.EMAIL_IDENTITY],
|
||||
]
|
||||
# only support 3PIDless registration if no 3PIDs are required
|
||||
if not require_email and not require_msisdn:
|
||||
flows.extend([[LoginType.DUMMY]])
|
||||
# only support the email-only flow if we don't require MSISDN 3PIDs
|
||||
if not require_msisdn:
|
||||
flows.extend([[LoginType.EMAIL_IDENTITY]])
|
||||
|
||||
if show_msisdn:
|
||||
# only support the MSISDN-only flow if we don't require email 3PIDs
|
||||
if not require_email or require_msisdn:
|
||||
flows.extend([[LoginType.MSISDN]])
|
||||
# always let users provide both MSISDN & email
|
||||
flows.extend([
|
||||
[LoginType.MSISDN],
|
||||
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
|
||||
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY]
|
||||
])
|
||||
|
||||
auth_result, params, session_id = yield self.auth_handler.check_auth(
|
||||
flows, body, self.hs.get_ip_from_request(request)
|
||||
)
|
||||
|
||||
# Check that we're not trying to register a denied 3pid.
|
||||
#
|
||||
# the user-facing checks will probably already have happened in
|
||||
# /register/email/requestToken when we requested a 3pid, but that's not
|
||||
# guaranteed.
|
||||
|
||||
if auth_result:
|
||||
for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
|
||||
if login_type in auth_result:
|
||||
medium = auth_result[login_type]['medium']
|
||||
address = auth_result[login_type]['address']
|
||||
|
||||
if not check_3pid_allowed(self.hs, medium, address):
|
||||
raise SynapseError(
|
||||
403, "Third party identifier is not allowed",
|
||||
Codes.THREEPID_DENIED,
|
||||
)
|
||||
|
||||
if registered_user_id is not None:
|
||||
logger.info(
|
||||
"Already registered user ID %r for this session",
|
||||
|
|
|
@ -93,6 +93,7 @@ class RemoteKey(Resource):
|
|||
self.store = hs.get_datastore()
|
||||
self.version_string = hs.version_string
|
||||
self.clock = hs.get_clock()
|
||||
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
|
||||
|
||||
def render_GET(self, request):
|
||||
self.async_render_GET(request)
|
||||
|
@ -137,6 +138,13 @@ class RemoteKey(Resource):
|
|||
logger.info("Handling query for keys %r", query)
|
||||
store_queries = []
|
||||
for server_name, key_ids in query.items():
|
||||
if (
|
||||
self.federation_domain_whitelist is not None and
|
||||
server_name not in self.federation_domain_whitelist
|
||||
):
|
||||
logger.debug("Federation denied with %s", server_name)
|
||||
continue
|
||||
|
||||
if not key_ids:
|
||||
key_ids = (None,)
|
||||
for key_id in key_ids:
|
||||
|
|
|
@ -70,38 +70,11 @@ def respond_with_file(request, media_type, file_path,
|
|||
logger.debug("Responding with %r", file_path)
|
||||
|
||||
if os.path.isfile(file_path):
|
||||
request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
|
||||
if upload_name:
|
||||
if is_ascii(upload_name):
|
||||
request.setHeader(
|
||||
b"Content-Disposition",
|
||||
b"inline; filename=%s" % (
|
||||
urllib.quote(upload_name.encode("utf-8")),
|
||||
),
|
||||
)
|
||||
else:
|
||||
request.setHeader(
|
||||
b"Content-Disposition",
|
||||
b"inline; filename*=utf-8''%s" % (
|
||||
urllib.quote(upload_name.encode("utf-8")),
|
||||
),
|
||||
)
|
||||
|
||||
# cache for at least a day.
|
||||
# XXX: we might want to turn this off for data we don't want to
|
||||
# recommend caching as it's sensitive or private - or at least
|
||||
# select private. don't bother setting Expires as all our
|
||||
# clients are smart enough to be happy with Cache-Control
|
||||
request.setHeader(
|
||||
b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
|
||||
)
|
||||
if file_size is None:
|
||||
stat = os.stat(file_path)
|
||||
file_size = stat.st_size
|
||||
|
||||
request.setHeader(
|
||||
b"Content-Length", b"%d" % (file_size,)
|
||||
)
|
||||
add_file_headers(request, media_type, file_size, upload_name)
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
yield logcontext.make_deferred_yieldable(
|
||||
|
@ -111,3 +84,118 @@ def respond_with_file(request, media_type, file_path,
|
|||
finish_request(request)
|
||||
else:
|
||||
respond_404(request)
|
||||
|
||||
|
||||
def add_file_headers(request, media_type, file_size, upload_name):
|
||||
"""Adds the correct response headers in preparation for responding with the
|
||||
media.
|
||||
|
||||
Args:
|
||||
request (twisted.web.http.Request)
|
||||
media_type (str): The media/content type.
|
||||
file_size (int): Size in bytes of the media, if known.
|
||||
upload_name (str): The name of the requested file, if any.
|
||||
"""
|
||||
request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
|
||||
if upload_name:
|
||||
if is_ascii(upload_name):
|
||||
request.setHeader(
|
||||
b"Content-Disposition",
|
||||
b"inline; filename=%s" % (
|
||||
urllib.quote(upload_name.encode("utf-8")),
|
||||
),
|
||||
)
|
||||
else:
|
||||
request.setHeader(
|
||||
b"Content-Disposition",
|
||||
b"inline; filename*=utf-8''%s" % (
|
||||
urllib.quote(upload_name.encode("utf-8")),
|
||||
),
|
||||
)
|
||||
|
||||
# cache for at least a day.
|
||||
# XXX: we might want to turn this off for data we don't want to
|
||||
# recommend caching as it's sensitive or private - or at least
|
||||
# select private. don't bother setting Expires as all our
|
||||
# clients are smart enough to be happy with Cache-Control
|
||||
request.setHeader(
|
||||
b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
|
||||
)
|
||||
|
||||
request.setHeader(
|
||||
b"Content-Length", b"%d" % (file_size,)
|
||||
)
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def respond_with_responder(request, responder, media_type, file_size, upload_name=None):
|
||||
"""Responds to the request with given responder. If responder is None then
|
||||
returns 404.
|
||||
|
||||
Args:
|
||||
request (twisted.web.http.Request)
|
||||
responder (Responder|None)
|
||||
media_type (str): The media/content type.
|
||||
file_size (int|None): Size in bytes of the media. If not known it should be None
|
||||
upload_name (str|None): The name of the requested file, if any.
|
||||
"""
|
||||
if not responder:
|
||||
respond_404(request)
|
||||
return
|
||||
|
||||
add_file_headers(request, media_type, file_size, upload_name)
|
||||
with responder:
|
||||
yield responder.write_to_consumer(request)
|
||||
finish_request(request)
|
||||
|
||||
|
||||
class Responder(object):
|
||||
"""Represents a response that can be streamed to the requester.
|
||||
|
||||
Responder is a context manager which *must* be used, so that any resources
|
||||
held can be cleaned up.
|
||||
"""
|
||||
def write_to_consumer(self, consumer):
|
||||
"""Stream response into consumer
|
||||
|
||||
Args:
|
||||
consumer (IConsumer)
|
||||
|
||||
Returns:
|
||||
Deferred: Resolves once the response has finished being written
|
||||
"""
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
|
||||
class FileInfo(object):
|
||||
"""Details about a requested/uploaded file.
|
||||
|
||||
Attributes:
|
||||
server_name (str): The server name where the media originated from,
|
||||
or None if local.
|
||||
file_id (str): The local ID of the file. For local files this is the
|
||||
same as the media_id
|
||||
url_cache (bool): If the file is for the url preview cache
|
||||
thumbnail (bool): Whether the file is a thumbnail or not.
|
||||
thumbnail_width (int)
|
||||
thumbnail_height (int)
|
||||
thumbnail_method (str)
|
||||
thumbnail_type (str): Content type of thumbnail, e.g. image/png
|
||||
"""
|
||||
def __init__(self, server_name, file_id, url_cache=False,
|
||||
thumbnail=False, thumbnail_width=None, thumbnail_height=None,
|
||||
thumbnail_method=None, thumbnail_type=None):
|
||||
self.server_name = server_name
|
||||
self.file_id = file_id
|
||||
self.url_cache = url_cache
|
||||
self.thumbnail = thumbnail
|
||||
self.thumbnail_width = thumbnail_width
|
||||
self.thumbnail_height = thumbnail_height
|
||||
self.thumbnail_method = thumbnail_method
|
||||
self.thumbnail_type = thumbnail_type
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
import synapse.http.servlet
|
||||
|
||||
from ._base import parse_media_id, respond_with_file, respond_404
|
||||
from ._base import parse_media_id, respond_404
|
||||
from twisted.web.resource import Resource
|
||||
from synapse.http.server import request_handler, set_cors_headers
|
||||
|
||||
|
@ -32,12 +32,12 @@ class DownloadResource(Resource):
|
|||
def __init__(self, hs, media_repo):
|
||||
Resource.__init__(self)
|
||||
|
||||
self.filepaths = media_repo.filepaths
|
||||
self.media_repo = media_repo
|
||||
self.server_name = hs.hostname
|
||||
self.store = hs.get_datastore()
|
||||
self.version_string = hs.version_string
|
||||
|
||||
# Both of these are expected by @request_handler()
|
||||
self.clock = hs.get_clock()
|
||||
self.version_string = hs.version_string
|
||||
|
||||
def render_GET(self, request):
|
||||
self._async_render_GET(request)
|
||||
|
@ -57,59 +57,16 @@ class DownloadResource(Resource):
|
|||
)
|
||||
server_name, media_id, name = parse_media_id(request)
|
||||
if server_name == self.server_name:
|
||||
yield self._respond_local_file(request, media_id, name)
|
||||
yield self.media_repo.get_local_media(request, media_id, name)
|
||||
else:
|
||||
yield self._respond_remote_file(
|
||||
request, server_name, media_id, name
|
||||
)
|
||||
allow_remote = synapse.http.servlet.parse_boolean(
|
||||
request, "allow_remote", default=True)
|
||||
if not allow_remote:
|
||||
logger.info(
|
||||
"Rejecting request for remote media %s/%s due to allow_remote",
|
||||
server_name, media_id,
|
||||
)
|
||||
respond_404(request)
|
||||
return
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _respond_local_file(self, request, media_id, name):
|
||||
media_info = yield self.store.get_local_media(media_id)
|
||||
if not media_info or media_info["quarantined_by"]:
|
||||
respond_404(request)
|
||||
return
|
||||
|
||||
media_type = media_info["media_type"]
|
||||
media_length = media_info["media_length"]
|
||||
upload_name = name if name else media_info["upload_name"]
|
||||
if media_info["url_cache"]:
|
||||
# TODO: Check the file still exists, if it doesn't we can redownload
|
||||
# it from the url `media_info["url_cache"]`
|
||||
file_path = self.filepaths.url_cache_filepath(media_id)
|
||||
else:
|
||||
file_path = self.filepaths.local_media_filepath(media_id)
|
||||
|
||||
yield respond_with_file(
|
||||
request, media_type, file_path, media_length,
|
||||
upload_name=upload_name,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _respond_remote_file(self, request, server_name, media_id, name):
|
||||
# don't forward requests for remote media if allow_remote is false
|
||||
allow_remote = synapse.http.servlet.parse_boolean(
|
||||
request, "allow_remote", default=True)
|
||||
if not allow_remote:
|
||||
logger.info(
|
||||
"Rejecting request for remote media %s/%s due to allow_remote",
|
||||
server_name, media_id,
|
||||
)
|
||||
respond_404(request)
|
||||
return
|
||||
|
||||
media_info = yield self.media_repo.get_remote_media(server_name, media_id)
|
||||
|
||||
media_type = media_info["media_type"]
|
||||
media_length = media_info["media_length"]
|
||||
filesystem_id = media_info["filesystem_id"]
|
||||
upload_name = name if name else media_info["upload_name"]
|
||||
|
||||
file_path = self.filepaths.remote_media_filepath(
|
||||
server_name, filesystem_id
|
||||
)
|
||||
|
||||
yield respond_with_file(
|
||||
request, media_type, file_path, media_length,
|
||||
upload_name=upload_name,
|
||||
)
|
||||
yield self.media_repo.get_remote_media(request, server_name, media_id, name)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -18,6 +19,7 @@ import twisted.internet.error
|
|||
import twisted.web.http
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from ._base import respond_404, FileInfo, respond_with_responder
|
||||
from .upload_resource import UploadResource
|
||||
from .download_resource import DownloadResource
|
||||
from .thumbnail_resource import ThumbnailResource
|
||||
|
@ -25,15 +27,18 @@ from .identicon_resource import IdenticonResource
|
|||
from .preview_url_resource import PreviewUrlResource
|
||||
from .filepath import MediaFilePaths
|
||||
from .thumbnailer import Thumbnailer
|
||||
from .storage_provider import StorageProviderWrapper
|
||||
from .media_storage import MediaStorage
|
||||
|
||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.api.errors import SynapseError, HttpResponseException, \
|
||||
NotFoundError
|
||||
from synapse.api.errors import (
|
||||
SynapseError, HttpResponseException, NotFoundError, FederationDeniedError,
|
||||
)
|
||||
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util.stringutils import is_ascii
|
||||
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||
from synapse.util.logcontext import make_deferred_yieldable
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
||||
import os
|
||||
|
@ -47,7 +52,7 @@ import urlparse
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000
|
||||
UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
|
||||
|
||||
|
||||
class MediaRepository(object):
|
||||
|
@ -63,96 +68,62 @@ class MediaRepository(object):
|
|||
self.primary_base_path = hs.config.media_store_path
|
||||
self.filepaths = MediaFilePaths(self.primary_base_path)
|
||||
|
||||
self.backup_base_path = hs.config.backup_media_store_path
|
||||
|
||||
self.synchronous_backup_media_store = hs.config.synchronous_backup_media_store
|
||||
|
||||
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
||||
self.thumbnail_requirements = hs.config.thumbnail_requirements
|
||||
|
||||
self.remote_media_linearizer = Linearizer(name="media_remote")
|
||||
|
||||
self.recently_accessed_remotes = set()
|
||||
self.recently_accessed_locals = set()
|
||||
|
||||
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
|
||||
|
||||
# List of StorageProviders where we should search for media and
|
||||
# potentially upload to.
|
||||
storage_providers = []
|
||||
|
||||
for clz, provider_config, wrapper_config in hs.config.media_storage_providers:
|
||||
backend = clz(hs, provider_config)
|
||||
provider = StorageProviderWrapper(
|
||||
backend,
|
||||
store_local=wrapper_config.store_local,
|
||||
store_remote=wrapper_config.store_remote,
|
||||
store_synchronous=wrapper_config.store_synchronous,
|
||||
)
|
||||
storage_providers.append(provider)
|
||||
|
||||
self.media_storage = MediaStorage(
|
||||
self.primary_base_path, self.filepaths, storage_providers,
|
||||
)
|
||||
|
||||
self.clock.looping_call(
|
||||
self._update_recently_accessed_remotes,
|
||||
UPDATE_RECENTLY_ACCESSED_REMOTES_TS
|
||||
self._update_recently_accessed,
|
||||
UPDATE_RECENTLY_ACCESSED_TS,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _update_recently_accessed_remotes(self):
|
||||
media = self.recently_accessed_remotes
|
||||
def _update_recently_accessed(self):
|
||||
remote_media = self.recently_accessed_remotes
|
||||
self.recently_accessed_remotes = set()
|
||||
|
||||
local_media = self.recently_accessed_locals
|
||||
self.recently_accessed_locals = set()
|
||||
|
||||
yield self.store.update_cached_last_access_time(
|
||||
media, self.clock.time_msec()
|
||||
local_media, remote_media, self.clock.time_msec()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _makedirs(filepath):
|
||||
dirname = os.path.dirname(filepath)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
@staticmethod
|
||||
def _write_file_synchronously(source, fname):
|
||||
"""Write `source` to the path `fname` synchronously. Should be called
|
||||
from a thread.
|
||||
def mark_recently_accessed(self, server_name, media_id):
|
||||
"""Mark the given media as recently accessed.
|
||||
|
||||
Args:
|
||||
source: A file like object to be written
|
||||
fname (str): Path to write to
|
||||
server_name (str|None): Origin server of media, or None if local
|
||||
media_id (str): The media ID of the content
|
||||
"""
|
||||
MediaRepository._makedirs(fname)
|
||||
source.seek(0) # Ensure we read from the start of the file
|
||||
with open(fname, "wb") as f:
|
||||
shutil.copyfileobj(source, f)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def write_to_file_and_backup(self, source, path):
|
||||
"""Write `source` to the on disk media store, and also the backup store
|
||||
if configured.
|
||||
|
||||
Args:
|
||||
source: A file like object that should be written
|
||||
path (str): Relative path to write file to
|
||||
|
||||
Returns:
|
||||
Deferred[str]: the file path written to in the primary media store
|
||||
"""
|
||||
fname = os.path.join(self.primary_base_path, path)
|
||||
|
||||
# Write to the main repository
|
||||
yield make_deferred_yieldable(threads.deferToThread(
|
||||
self._write_file_synchronously, source, fname,
|
||||
))
|
||||
|
||||
# Write to backup repository
|
||||
yield self.copy_to_backup(path)
|
||||
|
||||
defer.returnValue(fname)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def copy_to_backup(self, path):
|
||||
"""Copy a file from the primary to backup media store, if configured.
|
||||
|
||||
Args:
|
||||
path(str): Relative path to write file to
|
||||
"""
|
||||
if self.backup_base_path:
|
||||
primary_fname = os.path.join(self.primary_base_path, path)
|
||||
backup_fname = os.path.join(self.backup_base_path, path)
|
||||
|
||||
# We can either wait for successful writing to the backup repository
|
||||
# or write in the background and immediately return
|
||||
if self.synchronous_backup_media_store:
|
||||
yield make_deferred_yieldable(threads.deferToThread(
|
||||
shutil.copyfile, primary_fname, backup_fname,
|
||||
))
|
||||
else:
|
||||
preserve_fn(threads.deferToThread)(
|
||||
shutil.copyfile, primary_fname, backup_fname,
|
||||
)
|
||||
if server_name:
|
||||
self.recently_accessed_remotes.add((server_name, media_id))
|
||||
else:
|
||||
self.recently_accessed_locals.add(media_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_content(self, media_type, upload_name, content, content_length,
|
||||
|
@ -171,10 +142,13 @@ class MediaRepository(object):
|
|||
"""
|
||||
media_id = random_string(24)
|
||||
|
||||
fname = yield self.write_to_file_and_backup(
|
||||
content, self.filepaths.local_media_filepath_rel(media_id)
|
||||
file_info = FileInfo(
|
||||
server_name=None,
|
||||
file_id=media_id,
|
||||
)
|
||||
|
||||
fname = yield self.media_storage.store_file(content, file_info)
|
||||
|
||||
logger.info("Stored local media in file %r", fname)
|
||||
|
||||
yield self.store.store_local_media(
|
||||
|
@ -185,134 +159,275 @@ class MediaRepository(object):
|
|||
media_length=content_length,
|
||||
user_id=auth_user,
|
||||
)
|
||||
media_info = {
|
||||
"media_type": media_type,
|
||||
"media_length": content_length,
|
||||
}
|
||||
|
||||
yield self._generate_thumbnails(None, media_id, media_info)
|
||||
yield self._generate_thumbnails(
|
||||
None, media_id, media_id, media_type,
|
||||
)
|
||||
|
||||
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_remote_media(self, server_name, media_id):
|
||||
def get_local_media(self, request, media_id, name):
|
||||
"""Responds to reqests for local media, if exists, or returns 404.
|
||||
|
||||
Args:
|
||||
request(twisted.web.http.Request)
|
||||
media_id (str): The media ID of the content. (This is the same as
|
||||
the file_id for local content.)
|
||||
name (str|None): Optional name that, if specified, will be used as
|
||||
the filename in the Content-Disposition header of the response.
|
||||
|
||||
Returns:
|
||||
Deferred: Resolves once a response has successfully been written
|
||||
to request
|
||||
"""
|
||||
media_info = yield self.store.get_local_media(media_id)
|
||||
if not media_info or media_info["quarantined_by"]:
|
||||
respond_404(request)
|
||||
return
|
||||
|
||||
self.mark_recently_accessed(None, media_id)
|
||||
|
||||
media_type = media_info["media_type"]
|
||||
media_length = media_info["media_length"]
|
||||
upload_name = name if name else media_info["upload_name"]
|
||||
url_cache = media_info["url_cache"]
|
||||
|
||||
file_info = FileInfo(
|
||||
None, media_id,
|
||||
url_cache=url_cache,
|
||||
)
|
||||
|
||||
responder = yield self.media_storage.fetch_media(file_info)
|
||||
yield respond_with_responder(
|
||||
request, responder, media_type, media_length, upload_name,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_remote_media(self, request, server_name, media_id, name):
|
||||
"""Respond to requests for remote media.
|
||||
|
||||
Args:
|
||||
request(twisted.web.http.Request)
|
||||
server_name (str): Remote server_name where the media originated.
|
||||
media_id (str): The media ID of the content (as defined by the
|
||||
remote server).
|
||||
name (str|None): Optional name that, if specified, will be used as
|
||||
the filename in the Content-Disposition header of the response.
|
||||
|
||||
Returns:
|
||||
Deferred: Resolves once a response has successfully been written
|
||||
to request
|
||||
"""
|
||||
if (
|
||||
self.federation_domain_whitelist is not None and
|
||||
server_name not in self.federation_domain_whitelist
|
||||
):
|
||||
raise FederationDeniedError(server_name)
|
||||
|
||||
self.mark_recently_accessed(server_name, media_id)
|
||||
|
||||
# We linearize here to ensure that we don't try and download remote
|
||||
# media multiple times concurrently
|
||||
key = (server_name, media_id)
|
||||
with (yield self.remote_media_linearizer.queue(key)):
|
||||
media_info = yield self._get_remote_media_impl(server_name, media_id)
|
||||
responder, media_info = yield self._get_remote_media_impl(
|
||||
server_name, media_id,
|
||||
)
|
||||
|
||||
# We deliberately stream the file outside the lock
|
||||
if responder:
|
||||
media_type = media_info["media_type"]
|
||||
media_length = media_info["media_length"]
|
||||
upload_name = name if name else media_info["upload_name"]
|
||||
yield respond_with_responder(
|
||||
request, responder, media_type, media_length, upload_name,
|
||||
)
|
||||
else:
|
||||
respond_404(request)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_remote_media_info(self, server_name, media_id):
|
||||
"""Gets the media info associated with the remote file, downloading
|
||||
if necessary.
|
||||
|
||||
Args:
|
||||
server_name (str): Remote server_name where the media originated.
|
||||
media_id (str): The media ID of the content (as defined by the
|
||||
remote server).
|
||||
|
||||
Returns:
|
||||
Deferred[dict]: The media_info of the file
|
||||
"""
|
||||
if (
|
||||
self.federation_domain_whitelist is not None and
|
||||
server_name not in self.federation_domain_whitelist
|
||||
):
|
||||
raise FederationDeniedError(server_name)
|
||||
|
||||
# We linearize here to ensure that we don't try and download remote
|
||||
# media multiple times concurrently
|
||||
key = (server_name, media_id)
|
||||
with (yield self.remote_media_linearizer.queue(key)):
|
||||
responder, media_info = yield self._get_remote_media_impl(
|
||||
server_name, media_id,
|
||||
)
|
||||
|
||||
# Ensure we actually use the responder so that it releases resources
|
||||
if responder:
|
||||
with responder:
|
||||
pass
|
||||
|
||||
defer.returnValue(media_info)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_remote_media_impl(self, server_name, media_id):
|
||||
"""Looks for media in local cache, if not there then attempt to
|
||||
download from remote server.
|
||||
|
||||
Args:
|
||||
server_name (str): Remote server_name where the media originated.
|
||||
media_id (str): The media ID of the content (as defined by the
|
||||
remote server).
|
||||
|
||||
Returns:
|
||||
Deferred[(Responder, media_info)]
|
||||
"""
|
||||
media_info = yield self.store.get_cached_remote_media(
|
||||
server_name, media_id
|
||||
)
|
||||
if not media_info:
|
||||
media_info = yield self._download_remote_file(
|
||||
server_name, media_id
|
||||
)
|
||||
elif media_info["quarantined_by"]:
|
||||
raise NotFoundError()
|
||||
|
||||
# file_id is the ID we use to track the file locally. If we've already
|
||||
# seen the file then reuse the existing ID, otherwise genereate a new
|
||||
# one.
|
||||
if media_info:
|
||||
file_id = media_info["filesystem_id"]
|
||||
else:
|
||||
self.recently_accessed_remotes.add((server_name, media_id))
|
||||
yield self.store.update_cached_last_access_time(
|
||||
[(server_name, media_id)], self.clock.time_msec()
|
||||
)
|
||||
defer.returnValue(media_info)
|
||||
file_id = random_string(24)
|
||||
|
||||
file_info = FileInfo(server_name, file_id)
|
||||
|
||||
# If we have an entry in the DB, try and look for it
|
||||
if media_info:
|
||||
if media_info["quarantined_by"]:
|
||||
logger.info("Media is quarantined")
|
||||
raise NotFoundError()
|
||||
|
||||
responder = yield self.media_storage.fetch_media(file_info)
|
||||
if responder:
|
||||
defer.returnValue((responder, media_info))
|
||||
|
||||
# Failed to find the file anywhere, lets download it.
|
||||
|
||||
media_info = yield self._download_remote_file(
|
||||
server_name, media_id, file_id
|
||||
)
|
||||
|
||||
responder = yield self.media_storage.fetch_media(file_info)
|
||||
defer.returnValue((responder, media_info))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _download_remote_file(self, server_name, media_id):
|
||||
file_id = random_string(24)
|
||||
def _download_remote_file(self, server_name, media_id, file_id):
|
||||
"""Attempt to download the remote file from the given server name,
|
||||
using the given file_id as the local id.
|
||||
|
||||
fpath = self.filepaths.remote_media_filepath_rel(
|
||||
server_name, file_id
|
||||
Args:
|
||||
server_name (str): Originating server
|
||||
media_id (str): The media ID of the content (as defined by the
|
||||
remote server). This is different than the file_id, which is
|
||||
locally generated.
|
||||
file_id (str): Local file ID
|
||||
|
||||
Returns:
|
||||
Deferred[MediaInfo]
|
||||
"""
|
||||
|
||||
file_info = FileInfo(
|
||||
server_name=server_name,
|
||||
file_id=file_id,
|
||||
)
|
||||
fname = os.path.join(self.primary_base_path, fpath)
|
||||
self._makedirs(fname)
|
||||
|
||||
try:
|
||||
with open(fname, "wb") as f:
|
||||
request_path = "/".join((
|
||||
"/_matrix/media/v1/download", server_name, media_id,
|
||||
))
|
||||
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
|
||||
request_path = "/".join((
|
||||
"/_matrix/media/v1/download", server_name, media_id,
|
||||
))
|
||||
try:
|
||||
length, headers = yield self.client.get_file(
|
||||
server_name, request_path, output_stream=f,
|
||||
max_size=self.max_upload_size, args={
|
||||
# tell the remote server to 404 if it doesn't
|
||||
# recognise the server_name, to make sure we don't
|
||||
# end up with a routing loop.
|
||||
"allow_remote": "false",
|
||||
}
|
||||
)
|
||||
except twisted.internet.error.DNSLookupError as e:
|
||||
logger.warn("HTTP error fetching remote media %s/%s: %r",
|
||||
server_name, media_id, e)
|
||||
raise NotFoundError()
|
||||
|
||||
except HttpResponseException as e:
|
||||
logger.warn("HTTP error fetching remote media %s/%s: %s",
|
||||
server_name, media_id, e.response)
|
||||
if e.code == twisted.web.http.NOT_FOUND:
|
||||
raise SynapseError.from_http_response_exception(e)
|
||||
raise SynapseError(502, "Failed to fetch remote media")
|
||||
|
||||
except SynapseError:
|
||||
logger.exception("Failed to fetch remote media %s/%s",
|
||||
server_name, media_id)
|
||||
raise
|
||||
except NotRetryingDestination:
|
||||
logger.warn("Not retrying destination %r", server_name)
|
||||
raise SynapseError(502, "Failed to fetch remote media")
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch remote media %s/%s",
|
||||
server_name, media_id)
|
||||
raise SynapseError(502, "Failed to fetch remote media")
|
||||
|
||||
yield finish()
|
||||
|
||||
media_type = headers["Content-Type"][0]
|
||||
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
content_disposition = headers.get("Content-Disposition", None)
|
||||
if content_disposition:
|
||||
_, params = cgi.parse_header(content_disposition[0],)
|
||||
upload_name = None
|
||||
|
||||
# First check if there is a valid UTF-8 filename
|
||||
upload_name_utf8 = params.get("filename*", None)
|
||||
if upload_name_utf8:
|
||||
if upload_name_utf8.lower().startswith("utf-8''"):
|
||||
upload_name = upload_name_utf8[7:]
|
||||
|
||||
# If there isn't check for an ascii name.
|
||||
if not upload_name:
|
||||
upload_name_ascii = params.get("filename", None)
|
||||
if upload_name_ascii and is_ascii(upload_name_ascii):
|
||||
upload_name = upload_name_ascii
|
||||
|
||||
if upload_name:
|
||||
upload_name = urlparse.unquote(upload_name)
|
||||
try:
|
||||
length, headers = yield self.client.get_file(
|
||||
server_name, request_path, output_stream=f,
|
||||
max_size=self.max_upload_size, args={
|
||||
# tell the remote server to 404 if it doesn't
|
||||
# recognise the server_name, to make sure we don't
|
||||
# end up with a routing loop.
|
||||
"allow_remote": "false",
|
||||
}
|
||||
)
|
||||
except twisted.internet.error.DNSLookupError as e:
|
||||
logger.warn("HTTP error fetching remote media %s/%s: %r",
|
||||
server_name, media_id, e)
|
||||
raise NotFoundError()
|
||||
upload_name = upload_name.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
upload_name = None
|
||||
else:
|
||||
upload_name = None
|
||||
|
||||
except HttpResponseException as e:
|
||||
logger.warn("HTTP error fetching remote media %s/%s: %s",
|
||||
server_name, media_id, e.response)
|
||||
if e.code == twisted.web.http.NOT_FOUND:
|
||||
raise SynapseError.from_http_response_exception(e)
|
||||
raise SynapseError(502, "Failed to fetch remote media")
|
||||
logger.info("Stored remote media in file %r", fname)
|
||||
|
||||
except SynapseError:
|
||||
logger.exception("Failed to fetch remote media %s/%s",
|
||||
server_name, media_id)
|
||||
raise
|
||||
except NotRetryingDestination:
|
||||
logger.warn("Not retrying destination %r", server_name)
|
||||
raise SynapseError(502, "Failed to fetch remote media")
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch remote media %s/%s",
|
||||
server_name, media_id)
|
||||
raise SynapseError(502, "Failed to fetch remote media")
|
||||
|
||||
yield self.copy_to_backup(fpath)
|
||||
|
||||
media_type = headers["Content-Type"][0]
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
content_disposition = headers.get("Content-Disposition", None)
|
||||
if content_disposition:
|
||||
_, params = cgi.parse_header(content_disposition[0],)
|
||||
upload_name = None
|
||||
|
||||
# First check if there is a valid UTF-8 filename
|
||||
upload_name_utf8 = params.get("filename*", None)
|
||||
if upload_name_utf8:
|
||||
if upload_name_utf8.lower().startswith("utf-8''"):
|
||||
upload_name = upload_name_utf8[7:]
|
||||
|
||||
# If there isn't check for an ascii name.
|
||||
if not upload_name:
|
||||
upload_name_ascii = params.get("filename", None)
|
||||
if upload_name_ascii and is_ascii(upload_name_ascii):
|
||||
upload_name = upload_name_ascii
|
||||
|
||||
if upload_name:
|
||||
upload_name = urlparse.unquote(upload_name)
|
||||
try:
|
||||
upload_name = upload_name.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
upload_name = None
|
||||
else:
|
||||
upload_name = None
|
||||
|
||||
logger.info("Stored remote media in file %r", fname)
|
||||
|
||||
yield self.store.store_cached_remote_media(
|
||||
origin=server_name,
|
||||
media_id=media_id,
|
||||
media_type=media_type,
|
||||
time_now_ms=self.clock.time_msec(),
|
||||
upload_name=upload_name,
|
||||
media_length=length,
|
||||
filesystem_id=file_id,
|
||||
)
|
||||
except Exception:
|
||||
os.remove(fname)
|
||||
raise
|
||||
yield self.store.store_cached_remote_media(
|
||||
origin=server_name,
|
||||
media_id=media_id,
|
||||
media_type=media_type,
|
||||
time_now_ms=self.clock.time_msec(),
|
||||
upload_name=upload_name,
|
||||
media_length=length,
|
||||
filesystem_id=file_id,
|
||||
)
|
||||
|
||||
media_info = {
|
||||
"media_type": media_type,
|
||||
|
@ -323,7 +438,7 @@ class MediaRepository(object):
|
|||
}
|
||||
|
||||
yield self._generate_thumbnails(
|
||||
server_name, media_id, media_info
|
||||
server_name, media_id, file_id, media_type,
|
||||
)
|
||||
|
||||
defer.returnValue(media_info)
|
||||
|
@ -357,8 +472,10 @@ class MediaRepository(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def generate_local_exact_thumbnail(self, media_id, t_width, t_height,
|
||||
t_method, t_type):
|
||||
input_path = self.filepaths.local_media_filepath(media_id)
|
||||
t_method, t_type, url_cache):
|
||||
input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
|
||||
None, media_id, url_cache=url_cache,
|
||||
))
|
||||
|
||||
thumbnailer = Thumbnailer(input_path)
|
||||
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
|
||||
|
@ -368,11 +485,19 @@ class MediaRepository(object):
|
|||
|
||||
if t_byte_source:
|
||||
try:
|
||||
output_path = yield self.write_to_file_and_backup(
|
||||
t_byte_source,
|
||||
self.filepaths.local_media_thumbnail_rel(
|
||||
media_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
file_info = FileInfo(
|
||||
server_name=None,
|
||||
file_id=media_id,
|
||||
url_cache=url_cache,
|
||||
thumbnail=True,
|
||||
thumbnail_width=t_width,
|
||||
thumbnail_height=t_height,
|
||||
thumbnail_method=t_method,
|
||||
thumbnail_type=t_type,
|
||||
)
|
||||
|
||||
output_path = yield self.media_storage.store_file(
|
||||
t_byte_source, file_info,
|
||||
)
|
||||
finally:
|
||||
t_byte_source.close()
|
||||
|
@ -390,7 +515,9 @@ class MediaRepository(object):
|
|||
@defer.inlineCallbacks
|
||||
def generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
|
||||
t_width, t_height, t_method, t_type):
|
||||
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
|
||||
input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
|
||||
server_name, file_id, url_cache=False,
|
||||
))
|
||||
|
||||
thumbnailer = Thumbnailer(input_path)
|
||||
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
|
||||
|
@ -400,11 +527,18 @@ class MediaRepository(object):
|
|||
|
||||
if t_byte_source:
|
||||
try:
|
||||
output_path = yield self.write_to_file_and_backup(
|
||||
t_byte_source,
|
||||
self.filepaths.remote_media_thumbnail_rel(
|
||||
server_name, file_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
file_info = FileInfo(
|
||||
server_name=server_name,
|
||||
file_id=media_id,
|
||||
thumbnail=True,
|
||||
thumbnail_width=t_width,
|
||||
thumbnail_height=t_height,
|
||||
thumbnail_method=t_method,
|
||||
thumbnail_type=t_type,
|
||||
)
|
||||
|
||||
output_path = yield self.media_storage.store_file(
|
||||
t_byte_source, file_info,
|
||||
)
|
||||
finally:
|
||||
t_byte_source.close()
|
||||
|
@ -421,31 +555,29 @@ class MediaRepository(object):
|
|||
defer.returnValue(output_path)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _generate_thumbnails(self, server_name, media_id, media_info, url_cache=False):
|
||||
def _generate_thumbnails(self, server_name, media_id, file_id, media_type,
|
||||
url_cache=False):
|
||||
"""Generate and store thumbnails for an image.
|
||||
|
||||
Args:
|
||||
server_name(str|None): The server name if remote media, else None if local
|
||||
media_id(str)
|
||||
media_info(dict)
|
||||
url_cache(bool): If we are thumbnailing images downloaded for the URL cache,
|
||||
server_name (str|None): The server name if remote media, else None if local
|
||||
media_id (str): The media ID of the content. (This is the same as
|
||||
the file_id for local content)
|
||||
file_id (str): Local file ID
|
||||
media_type (str): The content type of the file
|
||||
url_cache (bool): If we are thumbnailing images downloaded for the URL cache,
|
||||
used exclusively by the url previewer
|
||||
|
||||
Returns:
|
||||
Deferred[dict]: Dict with "width" and "height" keys of original image
|
||||
"""
|
||||
media_type = media_info["media_type"]
|
||||
file_id = media_info.get("filesystem_id")
|
||||
requirements = self._get_thumbnail_requirements(media_type)
|
||||
if not requirements:
|
||||
return
|
||||
|
||||
if server_name:
|
||||
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
|
||||
elif url_cache:
|
||||
input_path = self.filepaths.url_cache_filepath(media_id)
|
||||
else:
|
||||
input_path = self.filepaths.local_media_filepath(media_id)
|
||||
input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
|
||||
server_name, file_id, url_cache=url_cache,
|
||||
))
|
||||
|
||||
thumbnailer = Thumbnailer(input_path)
|
||||
m_width = thumbnailer.width
|
||||
|
@ -472,20 +604,6 @@ class MediaRepository(object):
|
|||
|
||||
# Now we generate the thumbnails for each dimension, store it
|
||||
for (t_width, t_height, t_type), t_method in thumbnails.iteritems():
|
||||
# Work out the correct file name for thumbnail
|
||||
if server_name:
|
||||
file_path = self.filepaths.remote_media_thumbnail_rel(
|
||||
server_name, file_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
elif url_cache:
|
||||
file_path = self.filepaths.url_cache_thumbnail_rel(
|
||||
media_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
else:
|
||||
file_path = self.filepaths.local_media_thumbnail_rel(
|
||||
media_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
|
||||
# Generate the thumbnail
|
||||
if t_method == "crop":
|
||||
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
|
||||
|
@ -505,9 +623,19 @@ class MediaRepository(object):
|
|||
continue
|
||||
|
||||
try:
|
||||
# Write to disk
|
||||
output_path = yield self.write_to_file_and_backup(
|
||||
t_byte_source, file_path,
|
||||
file_info = FileInfo(
|
||||
server_name=server_name,
|
||||
file_id=file_id,
|
||||
thumbnail=True,
|
||||
thumbnail_width=t_width,
|
||||
thumbnail_height=t_height,
|
||||
thumbnail_method=t_method,
|
||||
thumbnail_type=t_type,
|
||||
url_cache=url_cache,
|
||||
)
|
||||
|
||||
output_path = yield self.media_storage.store_file(
|
||||
t_byte_source, file_info,
|
||||
)
|
||||
finally:
|
||||
t_byte_source.close()
|
||||
|
@ -620,7 +748,11 @@ class MediaRepositoryResource(Resource):
|
|||
|
||||
self.putChild("upload", UploadResource(hs, media_repo))
|
||||
self.putChild("download", DownloadResource(hs, media_repo))
|
||||
self.putChild("thumbnail", ThumbnailResource(hs, media_repo))
|
||||
self.putChild("thumbnail", ThumbnailResource(
|
||||
hs, media_repo, media_repo.media_storage,
|
||||
))
|
||||
self.putChild("identicon", IdenticonResource())
|
||||
if hs.config.url_preview_enabled:
|
||||
self.putChild("preview_url", PreviewUrlResource(hs, media_repo))
|
||||
self.putChild("preview_url", PreviewUrlResource(
|
||||
hs, media_repo, media_repo.media_storage,
|
||||
))
|
||||
|
|
274
synapse/rest/media/v1/media_storage.py
Normal file
274
synapse/rest/media/v1/media_storage.py
Normal file
|
@ -0,0 +1,274 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vecotr Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer, threads
|
||||
from twisted.protocols.basic import FileSender
|
||||
|
||||
from ._base import Responder
|
||||
|
||||
from synapse.util.file_consumer import BackgroundFileConsumer
|
||||
from synapse.util.logcontext import make_deferred_yieldable
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import logging
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MediaStorage(object):
|
||||
"""Responsible for storing/fetching files from local sources.
|
||||
|
||||
Args:
|
||||
local_media_directory (str): Base path where we store media on disk
|
||||
filepaths (MediaFilePaths)
|
||||
storage_providers ([StorageProvider]): List of StorageProvider that are
|
||||
used to fetch and store files.
|
||||
"""
|
||||
|
||||
def __init__(self, local_media_directory, filepaths, storage_providers):
|
||||
self.local_media_directory = local_media_directory
|
||||
self.filepaths = filepaths
|
||||
self.storage_providers = storage_providers
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def store_file(self, source, file_info):
|
||||
"""Write `source` to the on disk media store, and also any other
|
||||
configured storage providers
|
||||
|
||||
Args:
|
||||
source: A file like object that should be written
|
||||
file_info (FileInfo): Info about the file to store
|
||||
|
||||
Returns:
|
||||
Deferred[str]: the file path written to in the primary media store
|
||||
"""
|
||||
path = self._file_info_to_path(file_info)
|
||||
fname = os.path.join(self.local_media_directory, path)
|
||||
|
||||
dirname = os.path.dirname(fname)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
# Write to the main repository
|
||||
yield make_deferred_yieldable(threads.deferToThread(
|
||||
_write_file_synchronously, source, fname,
|
||||
))
|
||||
|
||||
# Tell the storage providers about the new file. They'll decide
|
||||
# if they should upload it and whether to do so synchronously
|
||||
# or not.
|
||||
for provider in self.storage_providers:
|
||||
yield provider.store_file(path, file_info)
|
||||
|
||||
defer.returnValue(fname)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def store_into_file(self, file_info):
|
||||
"""Context manager used to get a file like object to write into, as
|
||||
described by file_info.
|
||||
|
||||
Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
|
||||
like object that can be written to, fname is the absolute path of file
|
||||
on disk, and finish_cb is a function that returns a Deferred.
|
||||
|
||||
fname can be used to read the contents from after upload, e.g. to
|
||||
generate thumbnails.
|
||||
|
||||
finish_cb must be called and waited on after the file has been
|
||||
successfully been written to. Should not be called if there was an
|
||||
error.
|
||||
|
||||
Args:
|
||||
file_info (FileInfo): Info about the file to store
|
||||
|
||||
Example:
|
||||
|
||||
with media_storage.store_into_file(info) as (f, fname, finish_cb):
|
||||
# .. write into f ...
|
||||
yield finish_cb()
|
||||
"""
|
||||
|
||||
path = self._file_info_to_path(file_info)
|
||||
fname = os.path.join(self.local_media_directory, path)
|
||||
|
||||
dirname = os.path.dirname(fname)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
finished_called = [False]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def finish():
|
||||
for provider in self.storage_providers:
|
||||
yield provider.store_file(path, file_info)
|
||||
|
||||
finished_called[0] = True
|
||||
|
||||
try:
|
||||
with open(fname, "wb") as f:
|
||||
yield f, fname, finish
|
||||
except Exception:
|
||||
t, v, tb = sys.exc_info()
|
||||
try:
|
||||
os.remove(fname)
|
||||
except Exception:
|
||||
pass
|
||||
raise t, v, tb
|
||||
|
||||
if not finished_called:
|
||||
raise Exception("Finished callback not called")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def fetch_media(self, file_info):
|
||||
"""Attempts to fetch media described by file_info from the local cache
|
||||
and configured storage providers.
|
||||
|
||||
Args:
|
||||
file_info (FileInfo)
|
||||
|
||||
Returns:
|
||||
Deferred[Responder|None]: Returns a Responder if the file was found,
|
||||
otherwise None.
|
||||
"""
|
||||
|
||||
path = self._file_info_to_path(file_info)
|
||||
local_path = os.path.join(self.local_media_directory, path)
|
||||
if os.path.exists(local_path):
|
||||
defer.returnValue(FileResponder(open(local_path, "rb")))
|
||||
|
||||
for provider in self.storage_providers:
|
||||
res = yield provider.fetch(path, file_info)
|
||||
if res:
|
||||
defer.returnValue(res)
|
||||
|
||||
defer.returnValue(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def ensure_media_is_in_local_cache(self, file_info):
|
||||
"""Ensures that the given file is in the local cache. Attempts to
|
||||
download it from storage providers if it isn't.
|
||||
|
||||
Args:
|
||||
file_info (FileInfo)
|
||||
|
||||
Returns:
|
||||
Deferred[str]: Full path to local file
|
||||
"""
|
||||
path = self._file_info_to_path(file_info)
|
||||
local_path = os.path.join(self.local_media_directory, path)
|
||||
if os.path.exists(local_path):
|
||||
defer.returnValue(local_path)
|
||||
|
||||
dirname = os.path.dirname(local_path)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
for provider in self.storage_providers:
|
||||
res = yield provider.fetch(path, file_info)
|
||||
if res:
|
||||
with res:
|
||||
consumer = BackgroundFileConsumer(open(local_path, "w"))
|
||||
yield res.write_to_consumer(consumer)
|
||||
yield consumer.wait()
|
||||
defer.returnValue(local_path)
|
||||
|
||||
raise Exception("file could not be found")
|
||||
|
||||
def _file_info_to_path(self, file_info):
|
||||
"""Converts file_info into a relative path.
|
||||
|
||||
The path is suitable for storing files under a directory, e.g. used to
|
||||
store files on local FS under the base media repository directory.
|
||||
|
||||
Args:
|
||||
file_info (FileInfo)
|
||||
|
||||
Returns:
|
||||
str
|
||||
"""
|
||||
if file_info.url_cache:
|
||||
if file_info.thumbnail:
|
||||
return self.filepaths.url_cache_thumbnail_rel(
|
||||
media_id=file_info.file_id,
|
||||
width=file_info.thumbnail_width,
|
||||
height=file_info.thumbnail_height,
|
||||
content_type=file_info.thumbnail_type,
|
||||
method=file_info.thumbnail_method,
|
||||
)
|
||||
return self.filepaths.url_cache_filepath_rel(file_info.file_id)
|
||||
|
||||
if file_info.server_name:
|
||||
if file_info.thumbnail:
|
||||
return self.filepaths.remote_media_thumbnail_rel(
|
||||
server_name=file_info.server_name,
|
||||
file_id=file_info.file_id,
|
||||
width=file_info.thumbnail_width,
|
||||
height=file_info.thumbnail_height,
|
||||
content_type=file_info.thumbnail_type,
|
||||
method=file_info.thumbnail_method
|
||||
)
|
||||
return self.filepaths.remote_media_filepath_rel(
|
||||
file_info.server_name, file_info.file_id,
|
||||
)
|
||||
|
||||
if file_info.thumbnail:
|
||||
return self.filepaths.local_media_thumbnail_rel(
|
||||
media_id=file_info.file_id,
|
||||
width=file_info.thumbnail_width,
|
||||
height=file_info.thumbnail_height,
|
||||
content_type=file_info.thumbnail_type,
|
||||
method=file_info.thumbnail_method
|
||||
)
|
||||
return self.filepaths.local_media_filepath_rel(
|
||||
file_info.file_id,
|
||||
)
|
||||
|
||||
|
||||
def _write_file_synchronously(source, fname):
|
||||
"""Write `source` to the path `fname` synchronously. Should be called
|
||||
from a thread.
|
||||
|
||||
Args:
|
||||
source: A file like object to be written
|
||||
fname (str): Path to write to
|
||||
"""
|
||||
dirname = os.path.dirname(fname)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
source.seek(0) # Ensure we read from the start of the file
|
||||
with open(fname, "wb") as f:
|
||||
shutil.copyfileobj(source, f)
|
||||
|
||||
|
||||
class FileResponder(Responder):
|
||||
"""Wraps an open file that can be sent to a request.
|
||||
|
||||
Args:
|
||||
open_file (file): A file like object to be streamed ot the client,
|
||||
is closed when finished streaming.
|
||||
"""
|
||||
def __init__(self, open_file):
|
||||
self.open_file = open_file
|
||||
|
||||
def write_to_consumer(self, consumer):
|
||||
return FileSender().beginFileTransfer(self.open_file, consumer)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.open_file.close()
|
|
@ -12,11 +12,26 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import cgi
|
||||
import datetime
|
||||
import errno
|
||||
import fnmatch
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import traceback
|
||||
import ujson as json
|
||||
import urlparse
|
||||
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
from twisted.internet import defer
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from ._base import FileInfo
|
||||
|
||||
from synapse.api.errors import (
|
||||
SynapseError, Codes,
|
||||
)
|
||||
|
@ -31,25 +46,13 @@ from synapse.http.server import (
|
|||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.stringutils import is_ascii
|
||||
|
||||
import os
|
||||
import re
|
||||
import fnmatch
|
||||
import cgi
|
||||
import ujson as json
|
||||
import urlparse
|
||||
import itertools
|
||||
import datetime
|
||||
import errno
|
||||
import shutil
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PreviewUrlResource(Resource):
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs, media_repo):
|
||||
def __init__(self, hs, media_repo, media_storage):
|
||||
Resource.__init__(self)
|
||||
|
||||
self.auth = hs.get_auth()
|
||||
|
@ -62,6 +65,7 @@ class PreviewUrlResource(Resource):
|
|||
self.client = SpiderHttpClient(hs)
|
||||
self.media_repo = media_repo
|
||||
self.primary_base_path = media_repo.primary_base_path
|
||||
self.media_storage = media_storage
|
||||
|
||||
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
|
||||
|
||||
|
@ -182,8 +186,10 @@ class PreviewUrlResource(Resource):
|
|||
logger.debug("got media_info of '%s'" % media_info)
|
||||
|
||||
if _is_media(media_info['media_type']):
|
||||
file_id = media_info['filesystem_id']
|
||||
dims = yield self.media_repo._generate_thumbnails(
|
||||
None, media_info['filesystem_id'], media_info, url_cache=True,
|
||||
None, file_id, file_id, media_info["media_type"],
|
||||
url_cache=True,
|
||||
)
|
||||
|
||||
og = {
|
||||
|
@ -228,8 +234,10 @@ class PreviewUrlResource(Resource):
|
|||
|
||||
if _is_media(image_info['media_type']):
|
||||
# TODO: make sure we don't choke on white-on-transparent images
|
||||
file_id = image_info['filesystem_id']
|
||||
dims = yield self.media_repo._generate_thumbnails(
|
||||
None, image_info['filesystem_id'], image_info, url_cache=True,
|
||||
None, file_id, file_id, image_info["media_type"],
|
||||
url_cache=True,
|
||||
)
|
||||
if dims:
|
||||
og["og:image:width"] = dims['width']
|
||||
|
@ -273,21 +281,34 @@ class PreviewUrlResource(Resource):
|
|||
|
||||
file_id = datetime.date.today().isoformat() + '_' + random_string(16)
|
||||
|
||||
fpath = self.filepaths.url_cache_filepath_rel(file_id)
|
||||
fname = os.path.join(self.primary_base_path, fpath)
|
||||
self.media_repo._makedirs(fname)
|
||||
file_info = FileInfo(
|
||||
server_name=None,
|
||||
file_id=file_id,
|
||||
url_cache=True,
|
||||
)
|
||||
|
||||
try:
|
||||
with open(fname, "wb") as f:
|
||||
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
|
||||
try:
|
||||
logger.debug("Trying to get url '%s'" % url)
|
||||
length, headers, uri, code = yield self.client.get_file(
|
||||
url, output_stream=f, max_size=self.max_spider_size,
|
||||
)
|
||||
except Exception as e:
|
||||
# FIXME: pass through 404s and other error messages nicely
|
||||
logger.warn("Error downloading %s: %r", url, e)
|
||||
raise SynapseError(
|
||||
500, "Failed to download content: %s" % (
|
||||
traceback.format_exception_only(sys.exc_type, e),
|
||||
),
|
||||
Codes.UNKNOWN,
|
||||
)
|
||||
yield finish()
|
||||
|
||||
yield self.media_repo.copy_to_backup(fpath)
|
||||
|
||||
media_type = headers["Content-Type"][0]
|
||||
try:
|
||||
if "Content-Type" in headers:
|
||||
media_type = headers["Content-Type"][0]
|
||||
else:
|
||||
media_type = "application/octet-stream"
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
content_disposition = headers.get("Content-Disposition", None)
|
||||
|
@ -327,11 +348,11 @@ class PreviewUrlResource(Resource):
|
|||
)
|
||||
|
||||
except Exception as e:
|
||||
os.remove(fname)
|
||||
raise SynapseError(
|
||||
500, ("Failed to download content: %s" % e),
|
||||
Codes.UNKNOWN
|
||||
)
|
||||
logger.error("Error handling downloaded %s: %r", url, e)
|
||||
# TODO: we really ought to delete the downloaded file in this
|
||||
# case, since we won't have recorded it in the db, and will
|
||||
# therefore not expire it.
|
||||
raise
|
||||
|
||||
defer.returnValue({
|
||||
"media_type": media_type,
|
||||
|
|
140
synapse/rest/media/v1/storage_provider.py
Normal file
140
synapse/rest/media/v1/storage_provider.py
Normal file
|
@ -0,0 +1,140 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer, threads
|
||||
|
||||
from .media_storage import FileResponder
|
||||
|
||||
from synapse.config._base import Config
|
||||
from synapse.util.logcontext import preserve_fn
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StorageProvider(object):
|
||||
"""A storage provider is a service that can store uploaded media and
|
||||
retrieve them.
|
||||
"""
|
||||
def store_file(self, path, file_info):
|
||||
"""Store the file described by file_info. The actual contents can be
|
||||
retrieved by reading the file in file_info.upload_path.
|
||||
|
||||
Args:
|
||||
path (str): Relative path of file in local cache
|
||||
file_info (FileInfo)
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
pass
|
||||
|
||||
def fetch(self, path, file_info):
|
||||
"""Attempt to fetch the file described by file_info and stream it
|
||||
into writer.
|
||||
|
||||
Args:
|
||||
path (str): Relative path of file in local cache
|
||||
file_info (FileInfo)
|
||||
|
||||
Returns:
|
||||
Deferred(Responder): Returns a Responder if the provider has the file,
|
||||
otherwise returns None.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class StorageProviderWrapper(StorageProvider):
|
||||
"""Wraps a storage provider and provides various config options
|
||||
|
||||
Args:
|
||||
backend (StorageProvider)
|
||||
store_local (bool): Whether to store new local files or not.
|
||||
store_synchronous (bool): Whether to wait for file to be successfully
|
||||
uploaded, or todo the upload in the backgroud.
|
||||
store_remote (bool): Whether remote media should be uploaded
|
||||
"""
|
||||
def __init__(self, backend, store_local, store_synchronous, store_remote):
|
||||
self.backend = backend
|
||||
self.store_local = store_local
|
||||
self.store_synchronous = store_synchronous
|
||||
self.store_remote = store_remote
|
||||
|
||||
def store_file(self, path, file_info):
|
||||
if not file_info.server_name and not self.store_local:
|
||||
return defer.succeed(None)
|
||||
|
||||
if file_info.server_name and not self.store_remote:
|
||||
return defer.succeed(None)
|
||||
|
||||
if self.store_synchronous:
|
||||
return self.backend.store_file(path, file_info)
|
||||
else:
|
||||
# TODO: Handle errors.
|
||||
preserve_fn(self.backend.store_file)(path, file_info)
|
||||
return defer.succeed(None)
|
||||
|
||||
def fetch(self, path, file_info):
|
||||
return self.backend.fetch(path, file_info)
|
||||
|
||||
|
||||
class FileStorageProviderBackend(StorageProvider):
|
||||
"""A storage provider that stores files in a directory on a filesystem.
|
||||
|
||||
Args:
|
||||
hs (HomeServer)
|
||||
config: The config returned by `parse_config`.
|
||||
"""
|
||||
|
||||
def __init__(self, hs, config):
|
||||
self.cache_directory = hs.config.media_store_path
|
||||
self.base_directory = config
|
||||
|
||||
def store_file(self, path, file_info):
|
||||
"""See StorageProvider.store_file"""
|
||||
|
||||
primary_fname = os.path.join(self.cache_directory, path)
|
||||
backup_fname = os.path.join(self.base_directory, path)
|
||||
|
||||
dirname = os.path.dirname(backup_fname)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
return threads.deferToThread(
|
||||
shutil.copyfile, primary_fname, backup_fname,
|
||||
)
|
||||
|
||||
def fetch(self, path, file_info):
|
||||
"""See StorageProvider.fetch"""
|
||||
|
||||
backup_fname = os.path.join(self.base_directory, path)
|
||||
if os.path.isfile(backup_fname):
|
||||
return FileResponder(open(backup_fname, "rb"))
|
||||
|
||||
@staticmethod
|
||||
def parse_config(config):
|
||||
"""Called on startup to parse config supplied. This should parse
|
||||
the config and raise if there is a problem.
|
||||
|
||||
The returned value is passed into the constructor.
|
||||
|
||||
In this case we only care about a single param, the directory, so let's
|
||||
just pull that out.
|
||||
"""
|
||||
return Config.ensure_directory(config["directory"])
|
|
@ -14,7 +14,10 @@
|
|||
# limitations under the License.
|
||||
|
||||
|
||||
from ._base import parse_media_id, respond_404, respond_with_file
|
||||
from ._base import (
|
||||
parse_media_id, respond_404, respond_with_file, FileInfo,
|
||||
respond_with_responder,
|
||||
)
|
||||
from twisted.web.resource import Resource
|
||||
from synapse.http.servlet import parse_string, parse_integer
|
||||
from synapse.http.server import request_handler, set_cors_headers
|
||||
|
@ -30,12 +33,12 @@ logger = logging.getLogger(__name__)
|
|||
class ThumbnailResource(Resource):
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs, media_repo):
|
||||
def __init__(self, hs, media_repo, media_storage):
|
||||
Resource.__init__(self)
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
self.filepaths = media_repo.filepaths
|
||||
self.media_repo = media_repo
|
||||
self.media_storage = media_storage
|
||||
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
||||
self.server_name = hs.hostname
|
||||
self.version_string = hs.version_string
|
||||
|
@ -64,6 +67,7 @@ class ThumbnailResource(Resource):
|
|||
yield self._respond_local_thumbnail(
|
||||
request, media_id, width, height, method, m_type
|
||||
)
|
||||
self.media_repo.mark_recently_accessed(None, media_id)
|
||||
else:
|
||||
if self.dynamic_thumbnails:
|
||||
yield self._select_or_generate_remote_thumbnail(
|
||||
|
@ -75,20 +79,20 @@ class ThumbnailResource(Resource):
|
|||
request, server_name, media_id,
|
||||
width, height, method, m_type
|
||||
)
|
||||
self.media_repo.mark_recently_accessed(server_name, media_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _respond_local_thumbnail(self, request, media_id, width, height,
|
||||
method, m_type):
|
||||
media_info = yield self.store.get_local_media(media_id)
|
||||
|
||||
if not media_info or media_info["quarantined_by"]:
|
||||
if not media_info:
|
||||
respond_404(request)
|
||||
return
|
||||
if media_info["quarantined_by"]:
|
||||
logger.info("Media is quarantined")
|
||||
respond_404(request)
|
||||
return
|
||||
|
||||
# if media_info["media_type"] == "image/svg+xml":
|
||||
# file_path = self.filepaths.local_media_filepath(media_id)
|
||||
# yield respond_with_file(request, media_info["media_type"], file_path)
|
||||
# return
|
||||
|
||||
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
|
||||
|
||||
|
@ -96,42 +100,39 @@ class ThumbnailResource(Resource):
|
|||
thumbnail_info = self._select_thumbnail(
|
||||
width, height, method, m_type, thumbnail_infos
|
||||
)
|
||||
t_width = thumbnail_info["thumbnail_width"]
|
||||
t_height = thumbnail_info["thumbnail_height"]
|
||||
t_type = thumbnail_info["thumbnail_type"]
|
||||
t_method = thumbnail_info["thumbnail_method"]
|
||||
|
||||
if media_info["url_cache"]:
|
||||
# TODO: Check the file still exists, if it doesn't we can redownload
|
||||
# it from the url `media_info["url_cache"]`
|
||||
file_path = self.filepaths.url_cache_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method,
|
||||
)
|
||||
else:
|
||||
file_path = self.filepaths.local_media_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method,
|
||||
)
|
||||
yield respond_with_file(request, t_type, file_path)
|
||||
|
||||
else:
|
||||
yield self._respond_default_thumbnail(
|
||||
request, media_info, width, height, method, m_type,
|
||||
file_info = FileInfo(
|
||||
server_name=None, file_id=media_id,
|
||||
url_cache=media_info["url_cache"],
|
||||
thumbnail=True,
|
||||
thumbnail_width=thumbnail_info["thumbnail_width"],
|
||||
thumbnail_height=thumbnail_info["thumbnail_height"],
|
||||
thumbnail_type=thumbnail_info["thumbnail_type"],
|
||||
thumbnail_method=thumbnail_info["thumbnail_method"],
|
||||
)
|
||||
|
||||
t_type = file_info.thumbnail_type
|
||||
t_length = thumbnail_info["thumbnail_length"]
|
||||
|
||||
responder = yield self.media_storage.fetch_media(file_info)
|
||||
yield respond_with_responder(request, responder, t_type, t_length)
|
||||
else:
|
||||
logger.info("Couldn't find any generated thumbnails")
|
||||
respond_404(request)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _select_or_generate_local_thumbnail(self, request, media_id, desired_width,
|
||||
desired_height, desired_method,
|
||||
desired_type):
|
||||
media_info = yield self.store.get_local_media(media_id)
|
||||
|
||||
if not media_info or media_info["quarantined_by"]:
|
||||
if not media_info:
|
||||
respond_404(request)
|
||||
return
|
||||
if media_info["quarantined_by"]:
|
||||
logger.info("Media is quarantined")
|
||||
respond_404(request)
|
||||
return
|
||||
|
||||
# if media_info["media_type"] == "image/svg+xml":
|
||||
# file_path = self.filepaths.local_media_filepath(media_id)
|
||||
# yield respond_with_file(request, media_info["media_type"], file_path)
|
||||
# return
|
||||
|
||||
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
|
||||
for info in thumbnail_infos:
|
||||
|
@ -141,46 +142,43 @@ class ThumbnailResource(Resource):
|
|||
t_type = info["thumbnail_type"] == desired_type
|
||||
|
||||
if t_w and t_h and t_method and t_type:
|
||||
if media_info["url_cache"]:
|
||||
# TODO: Check the file still exists, if it doesn't we can redownload
|
||||
# it from the url `media_info["url_cache"]`
|
||||
file_path = self.filepaths.url_cache_thumbnail(
|
||||
media_id, desired_width, desired_height, desired_type,
|
||||
desired_method,
|
||||
)
|
||||
else:
|
||||
file_path = self.filepaths.local_media_thumbnail(
|
||||
media_id, desired_width, desired_height, desired_type,
|
||||
desired_method,
|
||||
)
|
||||
yield respond_with_file(request, desired_type, file_path)
|
||||
return
|
||||
file_info = FileInfo(
|
||||
server_name=None, file_id=media_id,
|
||||
url_cache=media_info["url_cache"],
|
||||
thumbnail=True,
|
||||
thumbnail_width=info["thumbnail_width"],
|
||||
thumbnail_height=info["thumbnail_height"],
|
||||
thumbnail_type=info["thumbnail_type"],
|
||||
thumbnail_method=info["thumbnail_method"],
|
||||
)
|
||||
|
||||
logger.debug("We don't have a local thumbnail of that size. Generating")
|
||||
t_type = file_info.thumbnail_type
|
||||
t_length = info["thumbnail_length"]
|
||||
|
||||
responder = yield self.media_storage.fetch_media(file_info)
|
||||
if responder:
|
||||
yield respond_with_responder(request, responder, t_type, t_length)
|
||||
return
|
||||
|
||||
logger.debug("We don't have a thumbnail of that size. Generating")
|
||||
|
||||
# Okay, so we generate one.
|
||||
file_path = yield self.media_repo.generate_local_exact_thumbnail(
|
||||
media_id, desired_width, desired_height, desired_method, desired_type
|
||||
media_id, desired_width, desired_height, desired_method, desired_type,
|
||||
url_cache=media_info["url_cache"],
|
||||
)
|
||||
|
||||
if file_path:
|
||||
yield respond_with_file(request, desired_type, file_path)
|
||||
else:
|
||||
yield self._respond_default_thumbnail(
|
||||
request, media_info, desired_width, desired_height,
|
||||
desired_method, desired_type,
|
||||
)
|
||||
logger.warn("Failed to generate thumbnail")
|
||||
respond_404(request)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _select_or_generate_remote_thumbnail(self, request, server_name, media_id,
|
||||
desired_width, desired_height,
|
||||
desired_method, desired_type):
|
||||
media_info = yield self.media_repo.get_remote_media(server_name, media_id)
|
||||
|
||||
# if media_info["media_type"] == "image/svg+xml":
|
||||
# file_path = self.filepaths.remote_media_filepath(server_name, media_id)
|
||||
# yield respond_with_file(request, media_info["media_type"], file_path)
|
||||
# return
|
||||
media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
|
||||
|
||||
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
|
||||
server_name, media_id,
|
||||
|
@ -195,14 +193,24 @@ class ThumbnailResource(Resource):
|
|||
t_type = info["thumbnail_type"] == desired_type
|
||||
|
||||
if t_w and t_h and t_method and t_type:
|
||||
file_path = self.filepaths.remote_media_thumbnail(
|
||||
server_name, file_id, desired_width, desired_height,
|
||||
desired_type, desired_method,
|
||||
file_info = FileInfo(
|
||||
server_name=server_name, file_id=media_info["filesystem_id"],
|
||||
thumbnail=True,
|
||||
thumbnail_width=info["thumbnail_width"],
|
||||
thumbnail_height=info["thumbnail_height"],
|
||||
thumbnail_type=info["thumbnail_type"],
|
||||
thumbnail_method=info["thumbnail_method"],
|
||||
)
|
||||
yield respond_with_file(request, desired_type, file_path)
|
||||
return
|
||||
|
||||
logger.debug("We don't have a local thumbnail of that size. Generating")
|
||||
t_type = file_info.thumbnail_type
|
||||
t_length = info["thumbnail_length"]
|
||||
|
||||
responder = yield self.media_storage.fetch_media(file_info)
|
||||
if responder:
|
||||
yield respond_with_responder(request, responder, t_type, t_length)
|
||||
return
|
||||
|
||||
logger.debug("We don't have a thumbnail of that size. Generating")
|
||||
|
||||
# Okay, so we generate one.
|
||||
file_path = yield self.media_repo.generate_remote_exact_thumbnail(
|
||||
|
@ -213,22 +221,16 @@ class ThumbnailResource(Resource):
|
|||
if file_path:
|
||||
yield respond_with_file(request, desired_type, file_path)
|
||||
else:
|
||||
yield self._respond_default_thumbnail(
|
||||
request, media_info, desired_width, desired_height,
|
||||
desired_method, desired_type,
|
||||
)
|
||||
logger.warn("Failed to generate thumbnail")
|
||||
respond_404(request)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _respond_remote_thumbnail(self, request, server_name, media_id, width,
|
||||
height, method, m_type):
|
||||
# TODO: Don't download the whole remote file
|
||||
# We should proxy the thumbnail from the remote server instead.
|
||||
media_info = yield self.media_repo.get_remote_media(server_name, media_id)
|
||||
|
||||
# if media_info["media_type"] == "image/svg+xml":
|
||||
# file_path = self.filepaths.remote_media_filepath(server_name, media_id)
|
||||
# yield respond_with_file(request, media_info["media_type"], file_path)
|
||||
# return
|
||||
# We should proxy the thumbnail from the remote server instead of
|
||||
# downloading the remote file and generating our own thumbnails.
|
||||
media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
|
||||
|
||||
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
|
||||
server_name, media_id,
|
||||
|
@ -238,59 +240,23 @@ class ThumbnailResource(Resource):
|
|||
thumbnail_info = self._select_thumbnail(
|
||||
width, height, method, m_type, thumbnail_infos
|
||||
)
|
||||
t_width = thumbnail_info["thumbnail_width"]
|
||||
t_height = thumbnail_info["thumbnail_height"]
|
||||
t_type = thumbnail_info["thumbnail_type"]
|
||||
t_method = thumbnail_info["thumbnail_method"]
|
||||
file_id = thumbnail_info["filesystem_id"]
|
||||
file_info = FileInfo(
|
||||
server_name=server_name, file_id=media_info["filesystem_id"],
|
||||
thumbnail=True,
|
||||
thumbnail_width=thumbnail_info["thumbnail_width"],
|
||||
thumbnail_height=thumbnail_info["thumbnail_height"],
|
||||
thumbnail_type=thumbnail_info["thumbnail_type"],
|
||||
thumbnail_method=thumbnail_info["thumbnail_method"],
|
||||
)
|
||||
|
||||
t_type = file_info.thumbnail_type
|
||||
t_length = thumbnail_info["thumbnail_length"]
|
||||
|
||||
file_path = self.filepaths.remote_media_thumbnail(
|
||||
server_name, file_id, t_width, t_height, t_type, t_method,
|
||||
)
|
||||
yield respond_with_file(request, t_type, file_path, t_length)
|
||||
responder = yield self.media_storage.fetch_media(file_info)
|
||||
yield respond_with_responder(request, responder, t_type, t_length)
|
||||
else:
|
||||
yield self._respond_default_thumbnail(
|
||||
request, media_info, width, height, method, m_type,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _respond_default_thumbnail(self, request, media_info, width, height,
|
||||
method, m_type):
|
||||
# XXX: how is this meant to work? store.get_default_thumbnails
|
||||
# appears to always return [] so won't this always 404?
|
||||
media_type = media_info["media_type"]
|
||||
top_level_type = media_type.split("/")[0]
|
||||
sub_type = media_type.split("/")[-1].split(";")[0]
|
||||
thumbnail_infos = yield self.store.get_default_thumbnails(
|
||||
top_level_type, sub_type,
|
||||
)
|
||||
if not thumbnail_infos:
|
||||
thumbnail_infos = yield self.store.get_default_thumbnails(
|
||||
top_level_type, "_default",
|
||||
)
|
||||
if not thumbnail_infos:
|
||||
thumbnail_infos = yield self.store.get_default_thumbnails(
|
||||
"_default", "_default",
|
||||
)
|
||||
if not thumbnail_infos:
|
||||
logger.info("Failed to find any generated thumbnails")
|
||||
respond_404(request)
|
||||
return
|
||||
|
||||
thumbnail_info = self._select_thumbnail(
|
||||
width, height, "crop", m_type, thumbnail_infos
|
||||
)
|
||||
|
||||
t_width = thumbnail_info["thumbnail_width"]
|
||||
t_height = thumbnail_info["thumbnail_height"]
|
||||
t_type = thumbnail_info["thumbnail_type"]
|
||||
t_method = thumbnail_info["thumbnail_method"]
|
||||
t_length = thumbnail_info["thumbnail_length"]
|
||||
|
||||
file_path = self.filepaths.default_thumbnail(
|
||||
top_level_type, sub_type, t_width, t_height, t_type, t_method,
|
||||
)
|
||||
yield respond_with_file(request, t_type, file_path, t_length)
|
||||
|
||||
def _select_thumbnail(self, desired_width, desired_height, desired_method,
|
||||
desired_type, thumbnail_infos):
|
||||
|
|
|
@ -55,6 +55,7 @@ from synapse.handlers.read_marker import ReadMarkerHandler
|
|||
from synapse.handlers.user_directory import UserDirectoryHandler
|
||||
from synapse.handlers.groups_local import GroupsLocalHandler
|
||||
from synapse.handlers.profile import ProfileHandler
|
||||
from synapse.handlers.message import EventCreationHandler
|
||||
from synapse.groups.groups_server import GroupsServerHandler
|
||||
from synapse.groups.attestations import GroupAttestionRenewer, GroupAttestationSigning
|
||||
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
|
||||
|
@ -66,7 +67,7 @@ from synapse.rest.media.v1.media_repository import (
|
|||
MediaRepository,
|
||||
MediaRepositoryResource,
|
||||
)
|
||||
from synapse.state import StateHandler
|
||||
from synapse.state import StateHandler, StateResolutionHandler
|
||||
from synapse.storage import DataStore
|
||||
from synapse.streams.events import EventSources
|
||||
from synapse.util import Clock
|
||||
|
@ -102,6 +103,7 @@ class HomeServer(object):
|
|||
'v1auth',
|
||||
'auth',
|
||||
'state_handler',
|
||||
'state_resolution_handler',
|
||||
'presence_handler',
|
||||
'sync_handler',
|
||||
'typing_handler',
|
||||
|
@ -117,6 +119,7 @@ class HomeServer(object):
|
|||
'application_service_handler',
|
||||
'device_message_handler',
|
||||
'profile_handler',
|
||||
'event_creation_handler',
|
||||
'deactivate_account_handler',
|
||||
'set_password_handler',
|
||||
'notifier',
|
||||
|
@ -224,6 +227,9 @@ class HomeServer(object):
|
|||
def build_state_handler(self):
|
||||
return StateHandler(self)
|
||||
|
||||
def build_state_resolution_handler(self):
|
||||
return StateResolutionHandler(self)
|
||||
|
||||
def build_presence_handler(self):
|
||||
return PresenceHandler(self)
|
||||
|
||||
|
@ -272,6 +278,9 @@ class HomeServer(object):
|
|||
def build_profile_handler(self):
|
||||
return ProfileHandler(self)
|
||||
|
||||
def build_event_creation_handler(self):
|
||||
return EventCreationHandler(self)
|
||||
|
||||
def build_deactivate_account_handler(self):
|
||||
return DeactivateAccountHandler(self)
|
||||
|
||||
|
@ -307,6 +316,23 @@ class HomeServer(object):
|
|||
**self.db_config.get("args", {})
|
||||
)
|
||||
|
||||
def get_db_conn(self, run_new_connection=True):
|
||||
"""Makes a new connection to the database, skipping the db pool
|
||||
|
||||
Returns:
|
||||
Connection: a connection object implementing the PEP-249 spec
|
||||
"""
|
||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||
# not be passed to the database engine.
|
||||
db_params = {
|
||||
k: v for k, v in self.db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
db_conn = self.database_engine.module.connect(**db_params)
|
||||
if run_new_connection:
|
||||
self.database_engine.on_new_connection(db_conn)
|
||||
return db_conn
|
||||
|
||||
def build_media_repository_resource(self):
|
||||
# build the media repo resource. This indirects through the HomeServer
|
||||
# to ensure that we only have a single instance of
|
||||
|
|
|
@ -34,6 +34,9 @@ class HomeServer(object):
|
|||
def get_state_handler(self) -> synapse.state.StateHandler:
|
||||
pass
|
||||
|
||||
def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
|
||||
pass
|
||||
|
||||
def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
|
||||
pass
|
||||
|
||||
|
|
310
synapse/state.py
310
synapse/state.py
|
@ -58,7 +58,11 @@ class _StateCacheEntry(object):
|
|||
__slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
|
||||
|
||||
def __init__(self, state, state_group, prev_group=None, delta_ids=None):
|
||||
# dict[(str, str), str] map from (type, state_key) to event_id
|
||||
self.state = frozendict(state)
|
||||
|
||||
# the ID of a state group if one and only one is involved.
|
||||
# otherwise, None otherwise?
|
||||
self.state_group = state_group
|
||||
|
||||
self.prev_group = prev_group
|
||||
|
@ -81,31 +85,19 @@ class _StateCacheEntry(object):
|
|||
|
||||
|
||||
class StateHandler(object):
|
||||
""" Responsible for doing state conflict resolution.
|
||||
"""Fetches bits of state from the stores, and does state resolution
|
||||
where necessary
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
self.hs = hs
|
||||
|
||||
# dict of set of event_ids -> _StateCacheEntry.
|
||||
self._state_cache = None
|
||||
self.resolve_linearizer = Linearizer(name="state_resolve_lock")
|
||||
self._state_resolution_handler = hs.get_state_resolution_handler()
|
||||
|
||||
def start_caching(self):
|
||||
logger.debug("start_caching")
|
||||
|
||||
self._state_cache = ExpiringCache(
|
||||
cache_name="state_cache",
|
||||
clock=self.clock,
|
||||
max_len=SIZE_OF_CACHE,
|
||||
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
|
||||
iterable=True,
|
||||
reset_expiry_on_get=True,
|
||||
)
|
||||
|
||||
self._state_cache.start()
|
||||
# TODO: remove this shim
|
||||
self._state_resolution_handler.start_caching()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_current_state(self, room_id, event_type=None, state_key="",
|
||||
|
@ -127,7 +119,7 @@ class StateHandler(object):
|
|||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||
|
||||
logger.debug("calling resolve_state_groups from get_current_state")
|
||||
ret = yield self.resolve_state_groups(room_id, latest_event_ids)
|
||||
ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
state = ret.state
|
||||
|
||||
if event_type:
|
||||
|
@ -146,19 +138,27 @@ class StateHandler(object):
|
|||
defer.returnValue(state)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_current_state_ids(self, room_id, event_type=None, state_key="",
|
||||
latest_event_ids=None):
|
||||
def get_current_state_ids(self, room_id, latest_event_ids=None):
|
||||
"""Get the current state, or the state at a set of events, for a room
|
||||
|
||||
Args:
|
||||
room_id (str):
|
||||
|
||||
latest_event_ids (iterable[str]|None): if given, the forward
|
||||
extremities to resolve. If None, we look them up from the
|
||||
database (via a cache)
|
||||
|
||||
Returns:
|
||||
Deferred[dict[(str, str), str)]]: the state dict, mapping from
|
||||
(event_type, state_key) -> event_id
|
||||
"""
|
||||
if not latest_event_ids:
|
||||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||
|
||||
logger.debug("calling resolve_state_groups from get_current_state_ids")
|
||||
ret = yield self.resolve_state_groups(room_id, latest_event_ids)
|
||||
ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
state = ret.state
|
||||
|
||||
if event_type:
|
||||
defer.returnValue(state.get((event_type, state_key)))
|
||||
return
|
||||
|
||||
defer.returnValue(state)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -166,7 +166,7 @@ class StateHandler(object):
|
|||
if not latest_event_ids:
|
||||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||
logger.debug("calling resolve_state_groups from get_current_user_in_room")
|
||||
entry = yield self.resolve_state_groups(room_id, latest_event_ids)
|
||||
entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
|
||||
defer.returnValue(joined_users)
|
||||
|
||||
|
@ -175,7 +175,7 @@ class StateHandler(object):
|
|||
if not latest_event_ids:
|
||||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||
logger.debug("calling resolve_state_groups from get_current_hosts_in_room")
|
||||
entry = yield self.resolve_state_groups(room_id, latest_event_ids)
|
||||
entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
|
||||
defer.returnValue(joined_hosts)
|
||||
|
||||
|
@ -183,8 +183,15 @@ class StateHandler(object):
|
|||
def compute_event_context(self, event, old_state=None):
|
||||
"""Build an EventContext structure for the event.
|
||||
|
||||
This works out what the current state should be for the event, and
|
||||
generates a new state group if necessary.
|
||||
|
||||
Args:
|
||||
event (synapse.events.EventBase):
|
||||
old_state (dict|None): The state at the event if it can't be
|
||||
calculated from existing events. This is normally only specified
|
||||
when receiving an event from federation where we don't have the
|
||||
prev events for, e.g. when backfilling.
|
||||
Returns:
|
||||
synapse.events.snapshot.EventContext:
|
||||
"""
|
||||
|
@ -208,15 +215,22 @@ class StateHandler(object):
|
|||
context.current_state_ids = {}
|
||||
context.prev_state_ids = {}
|
||||
context.prev_state_events = []
|
||||
context.state_group = self.store.get_next_state_group()
|
||||
|
||||
# We don't store state for outliers, so we don't generate a state
|
||||
# froup for it.
|
||||
context.state_group = None
|
||||
|
||||
defer.returnValue(context)
|
||||
|
||||
if old_state:
|
||||
# We already have the state, so we don't need to calculate it.
|
||||
# Let's just correctly fill out the context and create a
|
||||
# new state group for it.
|
||||
|
||||
context = EventContext()
|
||||
context.prev_state_ids = {
|
||||
(s.type, s.state_key): s.event_id for s in old_state
|
||||
}
|
||||
context.state_group = self.store.get_next_state_group()
|
||||
|
||||
if event.is_state():
|
||||
key = (event.type, event.state_key)
|
||||
|
@ -229,11 +243,19 @@ class StateHandler(object):
|
|||
else:
|
||||
context.current_state_ids = context.prev_state_ids
|
||||
|
||||
context.state_group = yield self.store.store_state_group(
|
||||
event.event_id,
|
||||
event.room_id,
|
||||
prev_group=None,
|
||||
delta_ids=None,
|
||||
current_state_ids=context.current_state_ids,
|
||||
)
|
||||
|
||||
context.prev_state_events = []
|
||||
defer.returnValue(context)
|
||||
|
||||
logger.debug("calling resolve_state_groups from compute_event_context")
|
||||
entry = yield self.resolve_state_groups(
|
||||
entry = yield self.resolve_state_groups_for_events(
|
||||
event.room_id, [e for e, _ in event.prev_events],
|
||||
)
|
||||
|
||||
|
@ -242,7 +264,8 @@ class StateHandler(object):
|
|||
context = EventContext()
|
||||
context.prev_state_ids = curr_state
|
||||
if event.is_state():
|
||||
context.state_group = self.store.get_next_state_group()
|
||||
# If this is a state event then we need to create a new state
|
||||
# group for the state after this event.
|
||||
|
||||
key = (event.type, event.state_key)
|
||||
if key in context.prev_state_ids:
|
||||
|
@ -253,38 +276,57 @@ class StateHandler(object):
|
|||
context.current_state_ids[key] = event.event_id
|
||||
|
||||
if entry.state_group:
|
||||
# If the state at the event has a state group assigned then
|
||||
# we can use that as the prev group
|
||||
context.prev_group = entry.state_group
|
||||
context.delta_ids = {
|
||||
key: event.event_id
|
||||
}
|
||||
elif entry.prev_group:
|
||||
# If the state at the event only has a prev group, then we can
|
||||
# use that as a prev group too.
|
||||
context.prev_group = entry.prev_group
|
||||
context.delta_ids = dict(entry.delta_ids)
|
||||
context.delta_ids[key] = event.event_id
|
||||
else:
|
||||
if entry.state_group is None:
|
||||
entry.state_group = self.store.get_next_state_group()
|
||||
entry.state_id = entry.state_group
|
||||
|
||||
context.state_group = entry.state_group
|
||||
context.state_group = yield self.store.store_state_group(
|
||||
event.event_id,
|
||||
event.room_id,
|
||||
prev_group=context.prev_group,
|
||||
delta_ids=context.delta_ids,
|
||||
current_state_ids=context.current_state_ids,
|
||||
)
|
||||
else:
|
||||
context.current_state_ids = context.prev_state_ids
|
||||
context.prev_group = entry.prev_group
|
||||
context.delta_ids = entry.delta_ids
|
||||
|
||||
if entry.state_group is None:
|
||||
entry.state_group = yield self.store.store_state_group(
|
||||
event.event_id,
|
||||
event.room_id,
|
||||
prev_group=entry.prev_group,
|
||||
delta_ids=entry.delta_ids,
|
||||
current_state_ids=context.current_state_ids,
|
||||
)
|
||||
entry.state_id = entry.state_group
|
||||
|
||||
context.state_group = entry.state_group
|
||||
|
||||
context.prev_state_events = []
|
||||
defer.returnValue(context)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def resolve_state_groups(self, room_id, event_ids):
|
||||
def resolve_state_groups_for_events(self, room_id, event_ids):
|
||||
""" Given a list of event_ids this method fetches the state at each
|
||||
event, resolves conflicts between them and returns them.
|
||||
|
||||
Args:
|
||||
room_id (str):
|
||||
event_ids (list[str]):
|
||||
|
||||
Returns:
|
||||
a Deferred tuple of (`state_group`, `state`, `prev_state`).
|
||||
`state_group` is the name of a state group if one and only one is
|
||||
involved. `state` is a map from (type, state_key) to event, and
|
||||
`prev_state` is a list of event ids.
|
||||
Deferred[_StateCacheEntry]: resolved state
|
||||
"""
|
||||
logger.debug("resolve_state_groups event_ids %s", event_ids)
|
||||
|
||||
|
@ -295,13 +337,7 @@ class StateHandler(object):
|
|||
room_id, event_ids
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"resolve_state_groups state_groups %s",
|
||||
state_groups_ids.keys()
|
||||
)
|
||||
|
||||
group_names = frozenset(state_groups_ids.keys())
|
||||
if len(group_names) == 1:
|
||||
if len(state_groups_ids) == 1:
|
||||
name, state_list = state_groups_ids.items().pop()
|
||||
|
||||
prev_group, delta_ids = yield self.store.get_state_group_delta(name)
|
||||
|
@ -313,6 +349,92 @@ class StateHandler(object):
|
|||
delta_ids=delta_ids,
|
||||
))
|
||||
|
||||
result = yield self._state_resolution_handler.resolve_state_groups(
|
||||
room_id, state_groups_ids, self._state_map_factory,
|
||||
)
|
||||
defer.returnValue(result)
|
||||
|
||||
def _state_map_factory(self, ev_ids):
|
||||
return self.store.get_events(
|
||||
ev_ids, get_prev_content=False, check_redacted=False,
|
||||
)
|
||||
|
||||
def resolve_events(self, state_sets, event):
|
||||
logger.info(
|
||||
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
|
||||
)
|
||||
state_set_ids = [{
|
||||
(ev.type, ev.state_key): ev.event_id
|
||||
for ev in st
|
||||
} for st in state_sets]
|
||||
|
||||
state_map = {
|
||||
ev.event_id: ev
|
||||
for st in state_sets
|
||||
for ev in st
|
||||
}
|
||||
|
||||
with Measure(self.clock, "state._resolve_events"):
|
||||
new_state = resolve_events_with_state_map(state_set_ids, state_map)
|
||||
|
||||
new_state = {
|
||||
key: state_map[ev_id] for key, ev_id in new_state.items()
|
||||
}
|
||||
|
||||
return new_state
|
||||
|
||||
|
||||
class StateResolutionHandler(object):
|
||||
"""Responsible for doing state conflict resolution.
|
||||
|
||||
Note that the storage layer depends on this handler, so all functions must
|
||||
be storage-independent.
|
||||
"""
|
||||
def __init__(self, hs):
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
# dict of set of event_ids -> _StateCacheEntry.
|
||||
self._state_cache = None
|
||||
self.resolve_linearizer = Linearizer(name="state_resolve_lock")
|
||||
|
||||
def start_caching(self):
|
||||
logger.debug("start_caching")
|
||||
|
||||
self._state_cache = ExpiringCache(
|
||||
cache_name="state_cache",
|
||||
clock=self.clock,
|
||||
max_len=SIZE_OF_CACHE,
|
||||
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
|
||||
iterable=True,
|
||||
reset_expiry_on_get=True,
|
||||
)
|
||||
|
||||
self._state_cache.start()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def resolve_state_groups(self, room_id, state_groups_ids, state_map_factory):
|
||||
"""Resolves conflicts between a set of state groups
|
||||
|
||||
Always generates a new state group (unless we hit the cache), so should
|
||||
not be called for a single state group
|
||||
|
||||
Args:
|
||||
room_id (str): room we are resolving for (used for logging)
|
||||
state_groups_ids (dict[int, dict[(str, str), str]]):
|
||||
map from state group id to the state in that state group
|
||||
(where 'state' is a map from state key to event id)
|
||||
|
||||
Returns:
|
||||
Deferred[_StateCacheEntry]: resolved state
|
||||
"""
|
||||
logger.debug(
|
||||
"resolve_state_groups state_groups %s",
|
||||
state_groups_ids.keys()
|
||||
)
|
||||
|
||||
group_names = frozenset(state_groups_ids.keys())
|
||||
|
||||
with (yield self.resolve_linearizer.queue(group_names)):
|
||||
if self._state_cache is not None:
|
||||
cache = self._state_cache.get(group_names, None)
|
||||
|
@ -341,17 +463,19 @@ class StateHandler(object):
|
|||
if conflicted_state:
|
||||
logger.info("Resolving conflicted state for %r", room_id)
|
||||
with Measure(self.clock, "state._resolve_events"):
|
||||
new_state = yield resolve_events(
|
||||
new_state = yield resolve_events_with_factory(
|
||||
state_groups_ids.values(),
|
||||
state_map_factory=lambda ev_ids: self.store.get_events(
|
||||
ev_ids, get_prev_content=False, check_redacted=False,
|
||||
),
|
||||
state_map_factory=state_map_factory,
|
||||
)
|
||||
else:
|
||||
new_state = {
|
||||
key: e_ids.pop() for key, e_ids in state.items()
|
||||
}
|
||||
|
||||
# if the new state matches any of the input state groups, we can
|
||||
# use that state group again. Otherwise we will generate a state_id
|
||||
# which will be used as a cache key for future resolutions, but
|
||||
# not get persisted.
|
||||
state_group = None
|
||||
new_state_event_ids = frozenset(new_state.values())
|
||||
for sg, events in state_groups_ids.items():
|
||||
|
@ -388,30 +512,6 @@ class StateHandler(object):
|
|||
|
||||
defer.returnValue(cache)
|
||||
|
||||
def resolve_events(self, state_sets, event):
|
||||
logger.info(
|
||||
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
|
||||
)
|
||||
state_set_ids = [{
|
||||
(ev.type, ev.state_key): ev.event_id
|
||||
for ev in st
|
||||
} for st in state_sets]
|
||||
|
||||
state_map = {
|
||||
ev.event_id: ev
|
||||
for st in state_sets
|
||||
for ev in st
|
||||
}
|
||||
|
||||
with Measure(self.clock, "state._resolve_events"):
|
||||
new_state = resolve_events(state_set_ids, state_map)
|
||||
|
||||
new_state = {
|
||||
key: state_map[ev_id] for key, ev_id in new_state.items()
|
||||
}
|
||||
|
||||
return new_state
|
||||
|
||||
|
||||
def _ordered_events(events):
|
||||
def key_func(e):
|
||||
|
@ -420,19 +520,17 @@ def _ordered_events(events):
|
|||
return sorted(events, key=key_func)
|
||||
|
||||
|
||||
def resolve_events(state_sets, state_map_factory):
|
||||
def resolve_events_with_state_map(state_sets, state_map):
|
||||
"""
|
||||
Args:
|
||||
state_sets(list): List of dicts of (type, state_key) -> event_id,
|
||||
which are the different state groups to resolve.
|
||||
state_map_factory(dict|callable): If callable, then will be called
|
||||
with a list of event_ids that are needed, and should return with
|
||||
a Deferred of dict of event_id to event. Otherwise, should be
|
||||
a dict from event_id to event of all events in state_sets.
|
||||
state_map(dict): a dict from event_id to event, for all events in
|
||||
state_sets.
|
||||
|
||||
Returns
|
||||
dict[(str, str), synapse.events.FrozenEvent] is a map from
|
||||
(type, state_key) to event.
|
||||
dict[(str, str), str]:
|
||||
a map from (type, state_key) to event_id.
|
||||
"""
|
||||
if len(state_sets) == 1:
|
||||
return state_sets[0]
|
||||
|
@ -441,13 +539,6 @@ def resolve_events(state_sets, state_map_factory):
|
|||
state_sets,
|
||||
)
|
||||
|
||||
if callable(state_map_factory):
|
||||
return _resolve_with_state_fac(
|
||||
unconflicted_state, conflicted_state, state_map_factory
|
||||
)
|
||||
|
||||
state_map = state_map_factory
|
||||
|
||||
auth_events = _create_auth_events_from_maps(
|
||||
unconflicted_state, conflicted_state, state_map
|
||||
)
|
||||
|
@ -461,6 +552,21 @@ def _seperate(state_sets):
|
|||
"""Takes the state_sets and figures out which keys are conflicted and
|
||||
which aren't. i.e., which have multiple different event_ids associated
|
||||
with them in different state sets.
|
||||
|
||||
Args:
|
||||
state_sets(list[dict[(str, str), str]]):
|
||||
List of dicts of (type, state_key) -> event_id, which are the
|
||||
different state groups to resolve.
|
||||
|
||||
Returns:
|
||||
(dict[(str, str), str], dict[(str, str), set[str]]):
|
||||
A tuple of (unconflicted_state, conflicted_state), where:
|
||||
|
||||
unconflicted_state is a dict mapping (type, state_key)->event_id
|
||||
for unconflicted state keys.
|
||||
|
||||
conflicted_state is a dict mapping (type, state_key) to a set of
|
||||
event ids for conflicted state keys.
|
||||
"""
|
||||
unconflicted_state = dict(state_sets[0])
|
||||
conflicted_state = {}
|
||||
|
@ -491,8 +597,26 @@ def _seperate(state_sets):
|
|||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _resolve_with_state_fac(unconflicted_state, conflicted_state,
|
||||
state_map_factory):
|
||||
def resolve_events_with_factory(state_sets, state_map_factory):
|
||||
"""
|
||||
Args:
|
||||
state_sets(list): List of dicts of (type, state_key) -> event_id,
|
||||
which are the different state groups to resolve.
|
||||
state_map_factory(func): will be called
|
||||
with a list of event_ids that are needed, and should return with
|
||||
a Deferred of dict of event_id to event.
|
||||
|
||||
Returns
|
||||
Deferred[dict[(str, str), str]]:
|
||||
a map from (type, state_key) to event_id.
|
||||
"""
|
||||
if len(state_sets) == 1:
|
||||
defer.returnValue(state_sets[0])
|
||||
|
||||
unconflicted_state, conflicted_state = _seperate(
|
||||
state_sets,
|
||||
)
|
||||
|
||||
needed_events = set(
|
||||
event_id
|
||||
for event_ids in conflicted_state.itervalues()
|
||||
|
|
|
@ -124,7 +124,6 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
)
|
||||
|
||||
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
|
||||
self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
|
||||
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
|
||||
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
|
||||
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
|
||||
|
|
|
@ -291,33 +291,33 @@ class SQLBaseStore(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def runInteraction(self, desc, func, *args, **kwargs):
|
||||
"""Wraps the .runInteraction() method on the underlying db_pool."""
|
||||
current_context = LoggingContext.current_context()
|
||||
"""Starts a transaction on the database and runs a given function
|
||||
|
||||
start_time = time.time() * 1000
|
||||
Arguments:
|
||||
desc (str): description of the transaction, for logging and metrics
|
||||
func (func): callback function, which will be called with a
|
||||
database transaction (twisted.enterprise.adbapi.Transaction) as
|
||||
its first argument, followed by `args` and `kwargs`.
|
||||
|
||||
args (list): positional args to pass to `func`
|
||||
kwargs (dict): named args to pass to `func`
|
||||
|
||||
Returns:
|
||||
Deferred: The result of func
|
||||
"""
|
||||
current_context = LoggingContext.current_context()
|
||||
|
||||
after_callbacks = []
|
||||
final_callbacks = []
|
||||
|
||||
def inner_func(conn, *args, **kwargs):
|
||||
with LoggingContext("runInteraction") as context:
|
||||
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
|
||||
|
||||
if self.database_engine.is_connection_closed(conn):
|
||||
logger.debug("Reconnecting closed database connection")
|
||||
conn.reconnect()
|
||||
|
||||
current_context.copy_to(context)
|
||||
return self._new_transaction(
|
||||
conn, desc, after_callbacks, final_callbacks, current_context,
|
||||
func, *args, **kwargs
|
||||
)
|
||||
return self._new_transaction(
|
||||
conn, desc, after_callbacks, final_callbacks, current_context,
|
||||
func, *args, **kwargs
|
||||
)
|
||||
|
||||
try:
|
||||
with PreserveLoggingContext():
|
||||
result = yield self._db_pool.runWithConnection(
|
||||
inner_func, *args, **kwargs
|
||||
)
|
||||
result = yield self.runWithConnection(inner_func, *args, **kwargs)
|
||||
|
||||
for after_callback, after_args, after_kwargs in after_callbacks:
|
||||
after_callback(*after_args, **after_kwargs)
|
||||
|
@ -329,14 +329,27 @@ class SQLBaseStore(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def runWithConnection(self, func, *args, **kwargs):
|
||||
"""Wraps the .runInteraction() method on the underlying db_pool."""
|
||||
"""Wraps the .runWithConnection() method on the underlying db_pool.
|
||||
|
||||
Arguments:
|
||||
func (func): callback function, which will be called with a
|
||||
database connection (twisted.enterprise.adbapi.Connection) as
|
||||
its first argument, followed by `args` and `kwargs`.
|
||||
args (list): positional args to pass to `func`
|
||||
kwargs (dict): named args to pass to `func`
|
||||
|
||||
Returns:
|
||||
Deferred: The result of func
|
||||
"""
|
||||
current_context = LoggingContext.current_context()
|
||||
|
||||
start_time = time.time() * 1000
|
||||
|
||||
def inner_func(conn, *args, **kwargs):
|
||||
with LoggingContext("runWithConnection") as context:
|
||||
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
|
||||
sched_duration_ms = time.time() * 1000 - start_time
|
||||
sql_scheduling_timer.inc_by(sched_duration_ms)
|
||||
current_context.add_database_scheduled(sched_duration_ms)
|
||||
|
||||
if self.database_engine.is_connection_closed(conn):
|
||||
logger.debug("Reconnecting closed database connection")
|
||||
|
|
|
@ -62,3 +62,9 @@ class PostgresEngine(object):
|
|||
|
||||
def lock_table(self, txn, table):
|
||||
txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
|
||||
|
||||
def get_next_state_group_id(self, txn):
|
||||
"""Returns an int that can be used as a new state_group ID
|
||||
"""
|
||||
txn.execute("SELECT nextval('state_group_id_seq')")
|
||||
return txn.fetchone()[0]
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
from synapse.storage.prepare_database import prepare_database
|
||||
|
||||
import struct
|
||||
import threading
|
||||
|
||||
|
||||
class Sqlite3Engine(object):
|
||||
|
@ -24,6 +25,11 @@ class Sqlite3Engine(object):
|
|||
def __init__(self, database_module, database_config):
|
||||
self.module = database_module
|
||||
|
||||
# The current max state_group, or None if we haven't looked
|
||||
# in the DB yet.
|
||||
self._current_state_group_id = None
|
||||
self._current_state_group_id_lock = threading.Lock()
|
||||
|
||||
def check_database(self, txn):
|
||||
pass
|
||||
|
||||
|
@ -43,6 +49,19 @@ class Sqlite3Engine(object):
|
|||
def lock_table(self, txn, table):
|
||||
return
|
||||
|
||||
def get_next_state_group_id(self, txn):
|
||||
"""Returns an int that can be used as a new state_group ID
|
||||
"""
|
||||
# We do application locking here since if we're using sqlite then
|
||||
# we are a single process synapse.
|
||||
with self._current_state_group_id_lock:
|
||||
if self._current_state_group_id is None:
|
||||
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
|
||||
self._current_state_group_id = txn.fetchone()[0]
|
||||
|
||||
self._current_state_group_id += 1
|
||||
return self._current_state_group_id
|
||||
|
||||
|
||||
# Following functions taken from: https://github.com/coleifer/peewee
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ from synapse.util.logutils import log_function
|
|||
from synapse.util.metrics import Measure
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.state import resolve_events
|
||||
from synapse.state import resolve_events_with_factory
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.types import get_domain_from_id
|
||||
|
||||
|
@ -110,7 +110,7 @@ class _EventPeristenceQueue(object):
|
|||
end_item.events_and_contexts.extend(events_and_contexts)
|
||||
return end_item.deferred.observe()
|
||||
|
||||
deferred = ObservableDeferred(defer.Deferred())
|
||||
deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
|
||||
|
||||
queue.append(self._EventPersistQueueItem(
|
||||
events_and_contexts=events_and_contexts,
|
||||
|
@ -146,18 +146,25 @@ class _EventPeristenceQueue(object):
|
|||
try:
|
||||
queue = self._get_drainining_queue(room_id)
|
||||
for item in queue:
|
||||
# handle_queue_loop runs in the sentinel logcontext, so
|
||||
# there is no need to preserve_fn when running the
|
||||
# callbacks on the deferred.
|
||||
try:
|
||||
ret = yield per_item_callback(item)
|
||||
item.deferred.callback(ret)
|
||||
except Exception as e:
|
||||
item.deferred.errback(e)
|
||||
except Exception:
|
||||
item.deferred.errback()
|
||||
finally:
|
||||
queue = self._event_persist_queues.pop(room_id, None)
|
||||
if queue:
|
||||
self._event_persist_queues[room_id] = queue
|
||||
self._currently_persisting_rooms.discard(room_id)
|
||||
|
||||
preserve_fn(handle_queue_loop)()
|
||||
# set handle_queue_loop off on the background. We don't want to
|
||||
# attribute work done in it to the current request, so we drop the
|
||||
# logcontext altogether.
|
||||
with PreserveLoggingContext():
|
||||
handle_queue_loop()
|
||||
|
||||
def _get_drainining_queue(self, room_id):
|
||||
queue = self._event_persist_queues.setdefault(room_id, deque())
|
||||
|
@ -335,8 +342,20 @@ class EventsStore(SQLBaseStore):
|
|||
|
||||
# NB: Assumes that we are only persisting events for one room
|
||||
# at a time.
|
||||
|
||||
# map room_id->list[event_ids] giving the new forward
|
||||
# extremities in each room
|
||||
new_forward_extremeties = {}
|
||||
|
||||
# map room_id->(type,state_key)->event_id tracking the full
|
||||
# state in each room after adding these events
|
||||
current_state_for_room = {}
|
||||
|
||||
# map room_id->(to_delete, to_insert) where each entry is
|
||||
# a map (type,key)->event_id giving the state delta in each
|
||||
# room
|
||||
state_delta_for_room = {}
|
||||
|
||||
if not backfilled:
|
||||
with Measure(self._clock, "_calculate_state_and_extrem"):
|
||||
# Work out the new "current state" for each room.
|
||||
|
@ -379,11 +398,19 @@ class EventsStore(SQLBaseStore):
|
|||
if all_single_prev_not_state:
|
||||
continue
|
||||
|
||||
state = yield self._calculate_state_delta(
|
||||
room_id, ev_ctx_rm, new_latest_event_ids
|
||||
logger.info(
|
||||
"Calculating state delta for room %s", room_id,
|
||||
)
|
||||
if state:
|
||||
current_state_for_room[room_id] = state
|
||||
current_state = yield self._get_new_state_after_events(
|
||||
ev_ctx_rm, new_latest_event_ids,
|
||||
)
|
||||
if current_state is not None:
|
||||
current_state_for_room[room_id] = current_state
|
||||
delta = yield self._calculate_state_delta(
|
||||
room_id, current_state,
|
||||
)
|
||||
if delta is not None:
|
||||
state_delta_for_room[room_id] = delta
|
||||
|
||||
yield self.runInteraction(
|
||||
"persist_events",
|
||||
|
@ -391,7 +418,7 @@ class EventsStore(SQLBaseStore):
|
|||
events_and_contexts=chunk,
|
||||
backfilled=backfilled,
|
||||
delete_existing=delete_existing,
|
||||
current_state_for_room=current_state_for_room,
|
||||
state_delta_for_room=state_delta_for_room,
|
||||
new_forward_extremeties=new_forward_extremeties,
|
||||
)
|
||||
persist_event_counter.inc_by(len(chunk))
|
||||
|
@ -408,7 +435,7 @@ class EventsStore(SQLBaseStore):
|
|||
|
||||
event_counter.inc(event.type, origin_type, origin_entity)
|
||||
|
||||
for room_id, (_, _, new_state) in current_state_for_room.iteritems():
|
||||
for room_id, new_state in current_state_for_room.iteritems():
|
||||
self.get_current_state_ids.prefill(
|
||||
(room_id, ), new_state
|
||||
)
|
||||
|
@ -460,20 +487,22 @@ class EventsStore(SQLBaseStore):
|
|||
defer.returnValue(new_latest_event_ids)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids):
|
||||
"""Calculate the new state deltas for a room.
|
||||
def _get_new_state_after_events(self, events_context, new_latest_event_ids):
|
||||
"""Calculate the current state dict after adding some new events to
|
||||
a room
|
||||
|
||||
Assumes that we are only persisting events for one room at a time.
|
||||
Args:
|
||||
events_context (list[(EventBase, EventContext)]):
|
||||
events and contexts which are being added to the room
|
||||
|
||||
new_latest_event_ids (iterable[str]):
|
||||
the new forward extremities for the room.
|
||||
|
||||
Returns:
|
||||
3-tuple (to_delete, to_insert, new_state) where both are state dicts,
|
||||
i.e. (type, state_key) -> event_id. `to_delete` are the entries to
|
||||
first be deleted from current_state_events, `to_insert` are entries
|
||||
to insert. `new_state` is the full set of state.
|
||||
May return None if there are no changes to be applied.
|
||||
Deferred[dict[(str,str), str]|None]:
|
||||
None if there are no changes to the room state, or
|
||||
a dict of (type, state_key) -> event_id].
|
||||
"""
|
||||
# Now we need to work out the different state sets for
|
||||
# each state extremities
|
||||
state_sets = []
|
||||
state_groups = set()
|
||||
missing_event_ids = []
|
||||
|
@ -516,18 +545,23 @@ class EventsStore(SQLBaseStore):
|
|||
state_sets.extend(group_to_state.itervalues())
|
||||
|
||||
if not new_latest_event_ids:
|
||||
current_state = {}
|
||||
defer.returnValue({})
|
||||
elif was_updated:
|
||||
if len(state_sets) == 1:
|
||||
# If there is only one state set, then we know what the current
|
||||
# state is.
|
||||
current_state = state_sets[0]
|
||||
defer.returnValue(state_sets[0])
|
||||
else:
|
||||
# We work out the current state by passing the state sets to the
|
||||
# state resolution algorithm. It may ask for some events, including
|
||||
# the events we have yet to persist, so we need a slightly more
|
||||
# complicated event lookup function than simply looking the events
|
||||
# up in the db.
|
||||
|
||||
logger.info(
|
||||
"Resolving state with %i state sets", len(state_sets),
|
||||
)
|
||||
|
||||
events_map = {ev.event_id: ev for ev, _ in events_context}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -550,13 +584,26 @@ class EventsStore(SQLBaseStore):
|
|||
to_return.update(evs)
|
||||
defer.returnValue(to_return)
|
||||
|
||||
current_state = yield resolve_events(
|
||||
current_state = yield resolve_events_with_factory(
|
||||
state_sets,
|
||||
state_map_factory=get_events,
|
||||
)
|
||||
defer.returnValue(current_state)
|
||||
else:
|
||||
return
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _calculate_state_delta(self, room_id, current_state):
|
||||
"""Calculate the new state deltas for a room.
|
||||
|
||||
Assumes that we are only persisting events for one room at a time.
|
||||
|
||||
Returns:
|
||||
2-tuple (to_delete, to_insert) where both are state dicts,
|
||||
i.e. (type, state_key) -> event_id. `to_delete` are the entries to
|
||||
first be deleted from current_state_events, `to_insert` are entries
|
||||
to insert.
|
||||
"""
|
||||
existing_state = yield self.get_current_state_ids(room_id)
|
||||
|
||||
existing_events = set(existing_state.itervalues())
|
||||
|
@ -576,7 +623,7 @@ class EventsStore(SQLBaseStore):
|
|||
if ev_id in events_to_insert
|
||||
}
|
||||
|
||||
defer.returnValue((to_delete, to_insert, current_state))
|
||||
defer.returnValue((to_delete, to_insert))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_event(self, event_id, check_redacted=True,
|
||||
|
@ -636,7 +683,7 @@ class EventsStore(SQLBaseStore):
|
|||
|
||||
@log_function
|
||||
def _persist_events_txn(self, txn, events_and_contexts, backfilled,
|
||||
delete_existing=False, current_state_for_room={},
|
||||
delete_existing=False, state_delta_for_room={},
|
||||
new_forward_extremeties={}):
|
||||
"""Insert some number of room events into the necessary database tables.
|
||||
|
||||
|
@ -652,7 +699,7 @@ class EventsStore(SQLBaseStore):
|
|||
delete_existing (bool): True to purge existing table rows for the
|
||||
events from the database. This is useful when retrying due to
|
||||
IntegrityError.
|
||||
current_state_for_room (dict[str, (list[str], list[str])]):
|
||||
state_delta_for_room (dict[str, (list[str], list[str])]):
|
||||
The current-state delta for each room. For each room, a tuple
|
||||
(to_delete, to_insert), being a list of event ids to be removed
|
||||
from the current state, and a list of event ids to be added to
|
||||
|
@ -664,7 +711,7 @@ class EventsStore(SQLBaseStore):
|
|||
"""
|
||||
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
|
||||
|
||||
self._update_current_state_txn(txn, current_state_for_room, max_stream_order)
|
||||
self._update_current_state_txn(txn, state_delta_for_room, max_stream_order)
|
||||
|
||||
self._update_forward_extremities_txn(
|
||||
txn,
|
||||
|
@ -708,9 +755,8 @@ class EventsStore(SQLBaseStore):
|
|||
events_and_contexts=events_and_contexts,
|
||||
)
|
||||
|
||||
# Insert into the state_groups, state_groups_state, and
|
||||
# event_to_state_groups tables.
|
||||
self._store_mult_state_groups_txn(txn, events_and_contexts)
|
||||
# Insert into event_to_state_groups.
|
||||
self._store_event_state_mappings_txn(txn, events_and_contexts)
|
||||
|
||||
# _store_rejected_events_txn filters out any events which were
|
||||
# rejected, and returns the filtered list.
|
||||
|
@ -730,7 +776,7 @@ class EventsStore(SQLBaseStore):
|
|||
|
||||
def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
|
||||
for room_id, current_state_tuple in state_delta_by_room.iteritems():
|
||||
to_delete, to_insert, _ = current_state_tuple
|
||||
to_delete, to_insert = current_state_tuple
|
||||
txn.executemany(
|
||||
"DELETE FROM current_state_events WHERE event_id = ?",
|
||||
[(ev_id,) for ev_id in to_delete.itervalues()],
|
||||
|
@ -945,10 +991,9 @@ class EventsStore(SQLBaseStore):
|
|||
# an outlier in the database. We now have some state at that
|
||||
# so we need to update the state_groups table with that state.
|
||||
|
||||
# insert into the state_group, state_groups_state and
|
||||
# event_to_state_groups tables.
|
||||
# insert into event_to_state_groups.
|
||||
try:
|
||||
self._store_mult_state_groups_txn(txn, ((event, context),))
|
||||
self._store_event_state_mappings_txn(txn, ((event, context),))
|
||||
except Exception:
|
||||
logger.exception("")
|
||||
raise
|
||||
|
@ -2018,16 +2063,32 @@ class EventsStore(SQLBaseStore):
|
|||
)
|
||||
return self.runInteraction("get_all_new_events", get_all_new_events_txn)
|
||||
|
||||
def delete_old_state(self, room_id, topological_ordering):
|
||||
return self.runInteraction(
|
||||
"delete_old_state",
|
||||
self._delete_old_state_txn, room_id, topological_ordering
|
||||
)
|
||||
def purge_history(
|
||||
self, room_id, topological_ordering, delete_local_events,
|
||||
):
|
||||
"""Deletes room history before a certain point
|
||||
|
||||
def _delete_old_state_txn(self, txn, room_id, topological_ordering):
|
||||
"""Deletes old room state
|
||||
Args:
|
||||
room_id (str):
|
||||
|
||||
topological_ordering (int):
|
||||
minimum topo ordering to preserve
|
||||
|
||||
delete_local_events (bool):
|
||||
if True, we will delete local events as well as remote ones
|
||||
(instead of just marking them as outliers and deleting their
|
||||
state groups).
|
||||
"""
|
||||
|
||||
return self.runInteraction(
|
||||
"purge_history",
|
||||
self._purge_history_txn, room_id, topological_ordering,
|
||||
delete_local_events,
|
||||
)
|
||||
|
||||
def _purge_history_txn(
|
||||
self, txn, room_id, topological_ordering, delete_local_events,
|
||||
):
|
||||
# Tables that should be pruned:
|
||||
# event_auth
|
||||
# event_backward_extremities
|
||||
|
@ -2068,7 +2129,7 @@ class EventsStore(SQLBaseStore):
|
|||
400, "topological_ordering is greater than forward extremeties"
|
||||
)
|
||||
|
||||
logger.debug("[purge] looking for events to delete")
|
||||
logger.info("[purge] looking for events to delete")
|
||||
|
||||
txn.execute(
|
||||
"SELECT event_id, state_key FROM events"
|
||||
|
@ -2080,16 +2141,16 @@ class EventsStore(SQLBaseStore):
|
|||
|
||||
to_delete = [
|
||||
(event_id,) for event_id, state_key in event_rows
|
||||
if state_key is None and not self.hs.is_mine_id(event_id)
|
||||
if state_key is None and (
|
||||
delete_local_events or not self.hs.is_mine_id(event_id)
|
||||
)
|
||||
]
|
||||
logger.info(
|
||||
"[purge] found %i events before cutoff, of which %i are remote"
|
||||
" non-state events to delete", len(event_rows), len(to_delete))
|
||||
"[purge] found %i events before cutoff, of which %i can be deleted",
|
||||
len(event_rows), len(to_delete),
|
||||
)
|
||||
|
||||
for event_id, state_key in event_rows:
|
||||
txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
|
||||
|
||||
logger.debug("[purge] Finding new backward extremities")
|
||||
logger.info("[purge] Finding new backward extremities")
|
||||
|
||||
# We calculate the new entries for the backward extremeties by finding
|
||||
# all events that point to events that are to be purged
|
||||
|
@ -2103,7 +2164,7 @@ class EventsStore(SQLBaseStore):
|
|||
)
|
||||
new_backwards_extrems = txn.fetchall()
|
||||
|
||||
logger.debug("[purge] replacing backward extremities: %r", new_backwards_extrems)
|
||||
logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems)
|
||||
|
||||
txn.execute(
|
||||
"DELETE FROM event_backward_extremities WHERE room_id = ?",
|
||||
|
@ -2119,7 +2180,7 @@ class EventsStore(SQLBaseStore):
|
|||
]
|
||||
)
|
||||
|
||||
logger.debug("[purge] finding redundant state groups")
|
||||
logger.info("[purge] finding redundant state groups")
|
||||
|
||||
# Get all state groups that are only referenced by events that are
|
||||
# to be deleted.
|
||||
|
@ -2136,15 +2197,15 @@ class EventsStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
state_rows = txn.fetchall()
|
||||
logger.debug("[purge] found %i redundant state groups", len(state_rows))
|
||||
logger.info("[purge] found %i redundant state groups", len(state_rows))
|
||||
|
||||
# make a set of the redundant state groups, so that we can look them up
|
||||
# efficiently
|
||||
state_groups_to_delete = set([sg for sg, in state_rows])
|
||||
|
||||
# Now we get all the state groups that rely on these state groups
|
||||
logger.debug("[purge] finding state groups which depend on redundant"
|
||||
" state groups")
|
||||
logger.info("[purge] finding state groups which depend on redundant"
|
||||
" state groups")
|
||||
remaining_state_groups = []
|
||||
for i in xrange(0, len(state_rows), 100):
|
||||
chunk = [sg for sg, in state_rows[i:i + 100]]
|
||||
|
@ -2169,7 +2230,7 @@ class EventsStore(SQLBaseStore):
|
|||
# Now we turn the state groups that reference to-be-deleted state
|
||||
# groups to non delta versions.
|
||||
for sg in remaining_state_groups:
|
||||
logger.debug("[purge] de-delta-ing remaining state group %s", sg)
|
||||
logger.info("[purge] de-delta-ing remaining state group %s", sg)
|
||||
curr_state = self._get_state_groups_from_groups_txn(
|
||||
txn, [sg], types=None
|
||||
)
|
||||
|
@ -2206,7 +2267,7 @@ class EventsStore(SQLBaseStore):
|
|||
],
|
||||
)
|
||||
|
||||
logger.debug("[purge] removing redundant state groups")
|
||||
logger.info("[purge] removing redundant state groups")
|
||||
txn.executemany(
|
||||
"DELETE FROM state_groups_state WHERE state_group = ?",
|
||||
state_rows
|
||||
|
@ -2216,18 +2277,15 @@ class EventsStore(SQLBaseStore):
|
|||
state_rows
|
||||
)
|
||||
|
||||
# Delete all non-state
|
||||
logger.debug("[purge] removing events from event_to_state_groups")
|
||||
logger.info("[purge] removing events from event_to_state_groups")
|
||||
txn.executemany(
|
||||
"DELETE FROM event_to_state_groups WHERE event_id = ?",
|
||||
[(event_id,) for event_id, _ in event_rows]
|
||||
)
|
||||
|
||||
logger.debug("[purge] updating room_depth")
|
||||
txn.execute(
|
||||
"UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
|
||||
(topological_ordering, room_id,)
|
||||
)
|
||||
for event_id, _ in event_rows:
|
||||
txn.call_after(self._get_state_group_for_event.invalidate, (
|
||||
event_id,
|
||||
))
|
||||
|
||||
# Delete all remote non-state events
|
||||
for table in (
|
||||
|
@ -2245,7 +2303,8 @@ class EventsStore(SQLBaseStore):
|
|||
"event_signatures",
|
||||
"rejections",
|
||||
):
|
||||
logger.debug("[purge] removing remote non-state events from %s", table)
|
||||
logger.info("[purge] removing remote non-state events from %s",
|
||||
table)
|
||||
|
||||
txn.executemany(
|
||||
"DELETE FROM %s WHERE event_id = ?" % (table,),
|
||||
|
@ -2253,16 +2312,30 @@ class EventsStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
# Mark all state and own events as outliers
|
||||
logger.debug("[purge] marking remaining events as outliers")
|
||||
logger.info("[purge] marking remaining events as outliers")
|
||||
txn.executemany(
|
||||
"UPDATE events SET outlier = ?"
|
||||
" WHERE event_id = ?",
|
||||
[
|
||||
(True, event_id,) for event_id, state_key in event_rows
|
||||
if state_key is not None or self.hs.is_mine_id(event_id)
|
||||
if state_key is not None or (
|
||||
not delete_local_events and self.hs.is_mine_id(event_id)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# synapse tries to take out an exclusive lock on room_depth whenever it
|
||||
# persists events (because upsert), and once we run this update, we
|
||||
# will block that for the rest of our transaction.
|
||||
#
|
||||
# So, let's stick it at the end so that we don't block event
|
||||
# persistence.
|
||||
logger.info("[purge] updating room_depth")
|
||||
txn.execute(
|
||||
"UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
|
||||
(topological_ordering, room_id,)
|
||||
)
|
||||
|
||||
logger.info("[purge] done")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -29,9 +29,6 @@ class MediaRepositoryStore(BackgroundUpdateStore):
|
|||
where_clause='url_cache IS NOT NULL',
|
||||
)
|
||||
|
||||
def get_default_thumbnails(self, top_level_type, sub_type):
|
||||
return []
|
||||
|
||||
def get_local_media(self, media_id):
|
||||
"""Get the metadata for a local piece of media
|
||||
Returns:
|
||||
|
@ -176,7 +173,14 @@ class MediaRepositoryStore(BackgroundUpdateStore):
|
|||
desc="store_cached_remote_media",
|
||||
)
|
||||
|
||||
def update_cached_last_access_time(self, origin_id_tuples, time_ts):
|
||||
def update_cached_last_access_time(self, local_media, remote_media, time_ms):
|
||||
"""Updates the last access time of the given media
|
||||
|
||||
Args:
|
||||
local_media (iterable[str]): Set of media_ids
|
||||
remote_media (iterable[(str, str)]): Set of (server_name, media_id)
|
||||
time_ms: Current time in milliseconds
|
||||
"""
|
||||
def update_cache_txn(txn):
|
||||
sql = (
|
||||
"UPDATE remote_media_cache SET last_access_ts = ?"
|
||||
|
@ -184,8 +188,18 @@ class MediaRepositoryStore(BackgroundUpdateStore):
|
|||
)
|
||||
|
||||
txn.executemany(sql, (
|
||||
(time_ts, media_origin, media_id)
|
||||
for media_origin, media_id in origin_id_tuples
|
||||
(time_ms, media_origin, media_id)
|
||||
for media_origin, media_id in remote_media
|
||||
))
|
||||
|
||||
sql = (
|
||||
"UPDATE local_media_repository SET last_access_ts = ?"
|
||||
" WHERE media_id = ?"
|
||||
)
|
||||
|
||||
txn.executemany(sql, (
|
||||
(time_ms, media_id)
|
||||
for media_id in local_media
|
||||
))
|
||||
|
||||
return self.runInteraction("update_cached_last_access_time", update_cache_txn)
|
||||
|
|
|
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
# Remember to update this number every time a change is made to database
|
||||
# schema files, so the users will be informed on server restarts.
|
||||
SCHEMA_VERSION = 46
|
||||
SCHEMA_VERSION = 47
|
||||
|
||||
dir_path = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
|
|
|
@ -16,11 +16,9 @@
|
|||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.storage.search import SearchStore
|
||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from .engines import PostgresEngine, Sqlite3Engine
|
||||
|
||||
import collections
|
||||
import logging
|
||||
import ujson as json
|
||||
|
@ -40,7 +38,7 @@ RatelimitOverride = collections.namedtuple(
|
|||
)
|
||||
|
||||
|
||||
class RoomStore(SQLBaseStore):
|
||||
class RoomStore(SearchStore):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def store_room(self, room_id, room_creator_user_id, is_public):
|
||||
|
@ -263,8 +261,8 @@ class RoomStore(SQLBaseStore):
|
|||
},
|
||||
)
|
||||
|
||||
self._store_event_search_txn(
|
||||
txn, event, "content.topic", event.content["topic"]
|
||||
self.store_event_search_txn(
|
||||
txn, event, "content.topic", event.content["topic"],
|
||||
)
|
||||
|
||||
def _store_room_name_txn(self, txn, event):
|
||||
|
@ -279,14 +277,14 @@ class RoomStore(SQLBaseStore):
|
|||
}
|
||||
)
|
||||
|
||||
self._store_event_search_txn(
|
||||
txn, event, "content.name", event.content["name"]
|
||||
self.store_event_search_txn(
|
||||
txn, event, "content.name", event.content["name"],
|
||||
)
|
||||
|
||||
def _store_room_message_txn(self, txn, event):
|
||||
if hasattr(event, "content") and "body" in event.content:
|
||||
self._store_event_search_txn(
|
||||
txn, event, "content.body", event.content["body"]
|
||||
self.store_event_search_txn(
|
||||
txn, event, "content.body", event.content["body"],
|
||||
)
|
||||
|
||||
def _store_history_visibility_txn(self, txn, event):
|
||||
|
@ -308,31 +306,6 @@ class RoomStore(SQLBaseStore):
|
|||
event.content[key]
|
||||
))
|
||||
|
||||
def _store_event_search_txn(self, txn, event, key, value):
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
sql = (
|
||||
"INSERT INTO event_search"
|
||||
" (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
|
||||
" VALUES (?,?,?,to_tsvector('english', ?),?,?)"
|
||||
)
|
||||
txn.execute(
|
||||
sql,
|
||||
(
|
||||
event.event_id, event.room_id, key, value,
|
||||
event.internal_metadata.stream_ordering,
|
||||
event.origin_server_ts,
|
||||
)
|
||||
)
|
||||
elif isinstance(self.database_engine, Sqlite3Engine):
|
||||
sql = (
|
||||
"INSERT INTO event_search (event_id, room_id, key, value)"
|
||||
" VALUES (?,?,?,?)"
|
||||
)
|
||||
txn.execute(sql, (event.event_id, event.room_id, key, value,))
|
||||
else:
|
||||
# This should be unreachable.
|
||||
raise Exception("Unrecognized database engine")
|
||||
|
||||
def add_event_report(self, room_id, event_id, user_id, reason, content,
|
||||
received_ts):
|
||||
next_id = self._event_reports_id_gen.get_next()
|
||||
|
@ -533,73 +506,114 @@ class RoomStore(SQLBaseStore):
|
|||
)
|
||||
self.is_room_blocked.invalidate((room_id,))
|
||||
|
||||
def get_media_mxcs_in_room(self, room_id):
|
||||
"""Retrieves all the local and remote media MXC URIs in a given room
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
|
||||
Returns:
|
||||
The local and remote media as a lists of tuples where the key is
|
||||
the hostname and the value is the media ID.
|
||||
"""
|
||||
def _get_media_mxcs_in_room_txn(txn):
|
||||
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
|
||||
local_media_mxcs = []
|
||||
remote_media_mxcs = []
|
||||
|
||||
# Convert the IDs to MXC URIs
|
||||
for media_id in local_mxcs:
|
||||
local_media_mxcs.append("mxc://%s/%s" % (self.hostname, media_id))
|
||||
for hostname, media_id in remote_mxcs:
|
||||
remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
|
||||
|
||||
return local_media_mxcs, remote_media_mxcs
|
||||
return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn)
|
||||
|
||||
def quarantine_media_ids_in_room(self, room_id, quarantined_by):
|
||||
"""For a room loops through all events with media and quarantines
|
||||
the associated media
|
||||
"""
|
||||
def _get_media_ids_in_room(txn):
|
||||
mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
|
||||
|
||||
next_token = self.get_current_events_token() + 1
|
||||
|
||||
def _quarantine_media_in_room_txn(txn):
|
||||
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
|
||||
total_media_quarantined = 0
|
||||
|
||||
while next_token:
|
||||
sql = """
|
||||
SELECT stream_ordering, content FROM events
|
||||
WHERE room_id = ?
|
||||
AND stream_ordering < ?
|
||||
AND contains_url = ? AND outlier = ?
|
||||
ORDER BY stream_ordering DESC
|
||||
LIMIT ?
|
||||
# Now update all the tables to set the quarantined_by flag
|
||||
|
||||
txn.executemany("""
|
||||
UPDATE local_media_repository
|
||||
SET quarantined_by = ?
|
||||
WHERE media_id = ?
|
||||
""", ((quarantined_by, media_id) for media_id in local_mxcs))
|
||||
|
||||
txn.executemany(
|
||||
"""
|
||||
txn.execute(sql, (room_id, next_token, True, False, 100))
|
||||
|
||||
next_token = None
|
||||
local_media_mxcs = []
|
||||
remote_media_mxcs = []
|
||||
for stream_ordering, content_json in txn:
|
||||
next_token = stream_ordering
|
||||
content = json.loads(content_json)
|
||||
|
||||
content_url = content.get("url")
|
||||
thumbnail_url = content.get("info", {}).get("thumbnail_url")
|
||||
|
||||
for url in (content_url, thumbnail_url):
|
||||
if not url:
|
||||
continue
|
||||
matches = mxc_re.match(url)
|
||||
if matches:
|
||||
hostname = matches.group(1)
|
||||
media_id = matches.group(2)
|
||||
if hostname == self.hostname:
|
||||
local_media_mxcs.append(media_id)
|
||||
else:
|
||||
remote_media_mxcs.append((hostname, media_id))
|
||||
|
||||
# Now update all the tables to set the quarantined_by flag
|
||||
|
||||
txn.executemany("""
|
||||
UPDATE local_media_repository
|
||||
UPDATE remote_media_cache
|
||||
SET quarantined_by = ?
|
||||
WHERE media_id = ?
|
||||
""", ((quarantined_by, media_id) for media_id in local_media_mxcs))
|
||||
|
||||
txn.executemany(
|
||||
"""
|
||||
UPDATE remote_media_cache
|
||||
SET quarantined_by = ?
|
||||
WHERE media_origin AND media_id = ?
|
||||
""",
|
||||
(
|
||||
(quarantined_by, origin, media_id)
|
||||
for origin, media_id in remote_media_mxcs
|
||||
)
|
||||
WHERE media_origin = ? AND media_id = ?
|
||||
""",
|
||||
(
|
||||
(quarantined_by, origin, media_id)
|
||||
for origin, media_id in remote_mxcs
|
||||
)
|
||||
)
|
||||
|
||||
total_media_quarantined += len(local_media_mxcs)
|
||||
total_media_quarantined += len(remote_media_mxcs)
|
||||
total_media_quarantined += len(local_mxcs)
|
||||
total_media_quarantined += len(remote_mxcs)
|
||||
|
||||
return total_media_quarantined
|
||||
|
||||
return self.runInteraction("get_media_ids_in_room", _get_media_ids_in_room)
|
||||
return self.runInteraction(
|
||||
"quarantine_media_in_room",
|
||||
_quarantine_media_in_room_txn,
|
||||
)
|
||||
|
||||
def _get_media_mxcs_in_room_txn(self, txn, room_id):
|
||||
"""Retrieves all the local and remote media MXC URIs in a given room
|
||||
|
||||
Args:
|
||||
txn (cursor)
|
||||
room_id (str)
|
||||
|
||||
Returns:
|
||||
The local and remote media as a lists of tuples where the key is
|
||||
the hostname and the value is the media ID.
|
||||
"""
|
||||
mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
|
||||
|
||||
next_token = self.get_current_events_token() + 1
|
||||
local_media_mxcs = []
|
||||
remote_media_mxcs = []
|
||||
|
||||
while next_token:
|
||||
sql = """
|
||||
SELECT stream_ordering, content FROM events
|
||||
WHERE room_id = ?
|
||||
AND stream_ordering < ?
|
||||
AND contains_url = ? AND outlier = ?
|
||||
ORDER BY stream_ordering DESC
|
||||
LIMIT ?
|
||||
"""
|
||||
txn.execute(sql, (room_id, next_token, True, False, 100))
|
||||
|
||||
next_token = None
|
||||
for stream_ordering, content_json in txn:
|
||||
next_token = stream_ordering
|
||||
content = json.loads(content_json)
|
||||
|
||||
content_url = content.get("url")
|
||||
thumbnail_url = content.get("info", {}).get("thumbnail_url")
|
||||
|
||||
for url in (content_url, thumbnail_url):
|
||||
if not url:
|
||||
continue
|
||||
matches = mxc_re.match(url)
|
||||
if matches:
|
||||
hostname = matches.group(1)
|
||||
media_id = matches.group(2)
|
||||
if hostname == self.hostname:
|
||||
local_media_mxcs.append(media_id)
|
||||
else:
|
||||
remote_media_mxcs.append((hostname, media_id))
|
||||
|
||||
return local_media_mxcs, remote_media_mxcs
|
||||
|
|
16
synapse/storage/schema/delta/47/last_access_media.sql
Normal file
16
synapse/storage/schema/delta/47/last_access_media.sql
Normal file
|
@ -0,0 +1,16 @@
|
|||
/* Copyright 2018 New Vector Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
ALTER TABLE local_media_repository ADD COLUMN last_access_ts BIGINT;
|
37
synapse/storage/schema/delta/47/state_group_seq.py
Normal file
37
synapse/storage/schema/delta/47/state_group_seq.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
|
||||
|
||||
def run_create(cur, database_engine, *args, **kwargs):
|
||||
if isinstance(database_engine, PostgresEngine):
|
||||
# if we already have some state groups, we want to start making new
|
||||
# ones with a higher id.
|
||||
cur.execute("SELECT max(id) FROM state_groups")
|
||||
row = cur.fetchone()
|
||||
|
||||
if row[0] is None:
|
||||
start_val = 1
|
||||
else:
|
||||
start_val = row[0] + 1
|
||||
|
||||
cur.execute(
|
||||
"CREATE SEQUENCE state_group_id_seq START WITH %s",
|
||||
(start_val, ),
|
||||
)
|
||||
|
||||
|
||||
def run_upgrade(*args, **kwargs):
|
||||
pass
|
|
@ -13,19 +13,26 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import namedtuple
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
import ujson as json
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from .background_updates import BackgroundUpdateStore
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||
|
||||
import logging
|
||||
import re
|
||||
import ujson as json
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SearchEntry = namedtuple('SearchEntry', [
|
||||
'key', 'value', 'event_id', 'room_id', 'stream_ordering',
|
||||
'origin_server_ts',
|
||||
])
|
||||
|
||||
|
||||
class SearchStore(BackgroundUpdateStore):
|
||||
|
||||
|
@ -60,16 +67,17 @@ class SearchStore(BackgroundUpdateStore):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _background_reindex_search(self, progress, batch_size):
|
||||
# we work through the events table from highest stream id to lowest
|
||||
target_min_stream_id = progress["target_min_stream_id_inclusive"]
|
||||
max_stream_id = progress["max_stream_id_exclusive"]
|
||||
rows_inserted = progress.get("rows_inserted", 0)
|
||||
|
||||
INSERT_CLUMP_SIZE = 1000
|
||||
TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
|
||||
|
||||
def reindex_search_txn(txn):
|
||||
sql = (
|
||||
"SELECT stream_ordering, event_id, room_id, type, content FROM events"
|
||||
"SELECT stream_ordering, event_id, room_id, type, content, "
|
||||
" origin_server_ts FROM events"
|
||||
" WHERE ? <= stream_ordering AND stream_ordering < ?"
|
||||
" AND (%s)"
|
||||
" ORDER BY stream_ordering DESC"
|
||||
|
@ -78,6 +86,10 @@ class SearchStore(BackgroundUpdateStore):
|
|||
|
||||
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
|
||||
|
||||
# we could stream straight from the results into
|
||||
# store_search_entries_txn with a generator function, but that
|
||||
# would mean having two cursors open on the database at once.
|
||||
# Instead we just build a list of results.
|
||||
rows = self.cursor_to_dict(txn)
|
||||
if not rows:
|
||||
return 0
|
||||
|
@ -90,6 +102,8 @@ class SearchStore(BackgroundUpdateStore):
|
|||
event_id = row["event_id"]
|
||||
room_id = row["room_id"]
|
||||
etype = row["type"]
|
||||
stream_ordering = row["stream_ordering"]
|
||||
origin_server_ts = row["origin_server_ts"]
|
||||
try:
|
||||
content = json.loads(row["content"])
|
||||
except Exception:
|
||||
|
@ -104,6 +118,8 @@ class SearchStore(BackgroundUpdateStore):
|
|||
elif etype == "m.room.name":
|
||||
key = "content.name"
|
||||
value = content["name"]
|
||||
else:
|
||||
raise Exception("unexpected event type %s" % etype)
|
||||
except (KeyError, AttributeError):
|
||||
# If the event is missing a necessary field then
|
||||
# skip over it.
|
||||
|
@ -114,25 +130,16 @@ class SearchStore(BackgroundUpdateStore):
|
|||
# then skip over it
|
||||
continue
|
||||
|
||||
event_search_rows.append((event_id, room_id, key, value))
|
||||
event_search_rows.append(SearchEntry(
|
||||
key=key,
|
||||
value=value,
|
||||
event_id=event_id,
|
||||
room_id=room_id,
|
||||
stream_ordering=stream_ordering,
|
||||
origin_server_ts=origin_server_ts,
|
||||
))
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
sql = (
|
||||
"INSERT INTO event_search (event_id, room_id, key, vector)"
|
||||
" VALUES (?,?,?,to_tsvector('english', ?))"
|
||||
)
|
||||
elif isinstance(self.database_engine, Sqlite3Engine):
|
||||
sql = (
|
||||
"INSERT INTO event_search (event_id, room_id, key, value)"
|
||||
" VALUES (?,?,?,?)"
|
||||
)
|
||||
else:
|
||||
# This should be unreachable.
|
||||
raise Exception("Unrecognized database engine")
|
||||
|
||||
for index in range(0, len(event_search_rows), INSERT_CLUMP_SIZE):
|
||||
clump = event_search_rows[index:index + INSERT_CLUMP_SIZE]
|
||||
txn.executemany(sql, clump)
|
||||
self.store_search_entries_txn(txn, event_search_rows)
|
||||
|
||||
progress = {
|
||||
"target_min_stream_id_inclusive": target_min_stream_id,
|
||||
|
@ -276,6 +283,92 @@ class SearchStore(BackgroundUpdateStore):
|
|||
|
||||
defer.returnValue(num_rows)
|
||||
|
||||
def store_event_search_txn(self, txn, event, key, value):
|
||||
"""Add event to the search table
|
||||
|
||||
Args:
|
||||
txn (cursor):
|
||||
event (EventBase):
|
||||
key (str):
|
||||
value (str):
|
||||
"""
|
||||
self.store_search_entries_txn(
|
||||
txn,
|
||||
(SearchEntry(
|
||||
key=key,
|
||||
value=value,
|
||||
event_id=event.event_id,
|
||||
room_id=event.room_id,
|
||||
stream_ordering=event.internal_metadata.stream_ordering,
|
||||
origin_server_ts=event.origin_server_ts,
|
||||
),),
|
||||
)
|
||||
|
||||
def store_search_entries_txn(self, txn, entries):
|
||||
"""Add entries to the search table
|
||||
|
||||
Args:
|
||||
txn (cursor):
|
||||
entries (iterable[SearchEntry]):
|
||||
entries to be added to the table
|
||||
"""
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
sql = (
|
||||
"INSERT INTO event_search"
|
||||
" (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
|
||||
" VALUES (?,?,?,to_tsvector('english', ?),?,?)"
|
||||
)
|
||||
|
||||
args = ((
|
||||
entry.event_id, entry.room_id, entry.key, entry.value,
|
||||
entry.stream_ordering, entry.origin_server_ts,
|
||||
) for entry in entries)
|
||||
|
||||
# inserts to a GIN index are normally batched up into a pending
|
||||
# list, and then all committed together once the list gets to a
|
||||
# certain size. The trouble with that is that postgres (pre-9.5)
|
||||
# uses work_mem to determine the length of the list, and work_mem
|
||||
# is typically very large.
|
||||
#
|
||||
# We therefore reduce work_mem while we do the insert.
|
||||
#
|
||||
# (postgres 9.5 uses the separate gin_pending_list_limit setting,
|
||||
# so doesn't suffer the same problem, but changing work_mem will
|
||||
# be harmless)
|
||||
|
||||
txn.execute("SET work_mem='256kB'")
|
||||
try:
|
||||
txn.executemany(sql, args)
|
||||
except Exception:
|
||||
# we need to reset work_mem, but doing so may throw a new
|
||||
# exception and we want to preserve the original
|
||||
t, v, tb = sys.exc_info()
|
||||
try:
|
||||
txn.execute("RESET work_mem")
|
||||
except Exception as e:
|
||||
logger.warn(
|
||||
"exception resetting work_mem during exception "
|
||||
"handling: %r",
|
||||
e,
|
||||
)
|
||||
raise t, v, tb
|
||||
else:
|
||||
txn.execute("RESET work_mem")
|
||||
|
||||
elif isinstance(self.database_engine, Sqlite3Engine):
|
||||
sql = (
|
||||
"INSERT INTO event_search (event_id, room_id, key, value)"
|
||||
" VALUES (?,?,?,?)"
|
||||
)
|
||||
args = ((
|
||||
entry.event_id, entry.room_id, entry.key, entry.value,
|
||||
) for entry in entries)
|
||||
|
||||
txn.executemany(sql, args)
|
||||
else:
|
||||
# This should be unreachable.
|
||||
raise Exception("Unrecognized database engine")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def search_msgs(self, room_ids, search_term, keys):
|
||||
"""Performs a full text search over events with given keys.
|
||||
|
|
|
@ -42,11 +42,8 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt
|
|||
return len(self.delta_ids) if self.delta_ids else 0
|
||||
|
||||
|
||||
class StateGroupReadStore(SQLBaseStore):
|
||||
"""The read-only parts of StateGroupStore
|
||||
|
||||
None of these functions write to the state tables, so are suitable for
|
||||
including in the SlavedStores.
|
||||
class StateGroupWorkerStore(SQLBaseStore):
|
||||
"""The parts of StateGroupStore that can be called from workers.
|
||||
"""
|
||||
|
||||
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
|
||||
|
@ -54,7 +51,7 @@ class StateGroupReadStore(SQLBaseStore):
|
|||
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
|
||||
|
||||
def __init__(self, db_conn, hs):
|
||||
super(StateGroupReadStore, self).__init__(db_conn, hs)
|
||||
super(StateGroupWorkerStore, self).__init__(db_conn, hs)
|
||||
|
||||
self._state_group_cache = DictionaryCache(
|
||||
"*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
|
||||
|
@ -549,8 +546,117 @@ class StateGroupReadStore(SQLBaseStore):
|
|||
|
||||
defer.returnValue(results)
|
||||
|
||||
def store_state_group(self, event_id, room_id, prev_group, delta_ids,
|
||||
current_state_ids):
|
||||
"""Store a new set of state, returning a newly assigned state group.
|
||||
|
||||
class StateStore(StateGroupReadStore, BackgroundUpdateStore):
|
||||
Args:
|
||||
event_id (str): The event ID for which the state was calculated
|
||||
room_id (str)
|
||||
prev_group (int|None): A previous state group for the room, optional.
|
||||
delta_ids (dict|None): The delta between state at `prev_group` and
|
||||
`current_state_ids`, if `prev_group` was given. Same format as
|
||||
`current_state_ids`.
|
||||
current_state_ids (dict): The state to store. Map of (type, state_key)
|
||||
to event_id.
|
||||
|
||||
Returns:
|
||||
Deferred[int]: The state group ID
|
||||
"""
|
||||
def _store_state_group_txn(txn):
|
||||
if current_state_ids is None:
|
||||
# AFAIK, this can never happen
|
||||
raise Exception("current_state_ids cannot be None")
|
||||
|
||||
state_group = self.database_engine.get_next_state_group_id(txn)
|
||||
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="state_groups",
|
||||
values={
|
||||
"id": state_group,
|
||||
"room_id": room_id,
|
||||
"event_id": event_id,
|
||||
},
|
||||
)
|
||||
|
||||
# We persist as a delta if we can, while also ensuring the chain
|
||||
# of deltas isn't tooo long, as otherwise read performance degrades.
|
||||
if prev_group:
|
||||
is_in_db = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="state_groups",
|
||||
keyvalues={"id": prev_group},
|
||||
retcol="id",
|
||||
allow_none=True,
|
||||
)
|
||||
if not is_in_db:
|
||||
raise Exception(
|
||||
"Trying to persist state with unpersisted prev_group: %r"
|
||||
% (prev_group,)
|
||||
)
|
||||
|
||||
potential_hops = self._count_state_group_hops_txn(
|
||||
txn, prev_group
|
||||
)
|
||||
if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="state_group_edges",
|
||||
values={
|
||||
"state_group": state_group,
|
||||
"prev_state_group": prev_group,
|
||||
},
|
||||
)
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="state_groups_state",
|
||||
values=[
|
||||
{
|
||||
"state_group": state_group,
|
||||
"room_id": room_id,
|
||||
"type": key[0],
|
||||
"state_key": key[1],
|
||||
"event_id": state_id,
|
||||
}
|
||||
for key, state_id in delta_ids.iteritems()
|
||||
],
|
||||
)
|
||||
else:
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="state_groups_state",
|
||||
values=[
|
||||
{
|
||||
"state_group": state_group,
|
||||
"room_id": room_id,
|
||||
"type": key[0],
|
||||
"state_key": key[1],
|
||||
"event_id": state_id,
|
||||
}
|
||||
for key, state_id in current_state_ids.iteritems()
|
||||
],
|
||||
)
|
||||
|
||||
# Prefill the state group cache with this group.
|
||||
# It's fine to use the sequence like this as the state group map
|
||||
# is immutable. (If the map wasn't immutable then this prefill could
|
||||
# race with another update)
|
||||
txn.call_after(
|
||||
self._state_group_cache.update,
|
||||
self._state_group_cache.sequence,
|
||||
key=state_group,
|
||||
value=dict(current_state_ids),
|
||||
full=True,
|
||||
)
|
||||
|
||||
return state_group
|
||||
|
||||
return self.runInteraction("store_state_group", _store_state_group_txn)
|
||||
|
||||
|
||||
class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
|
||||
""" Keeps track of the state at a given event.
|
||||
|
||||
This is done by the concept of `state groups`. Every event is a assigned
|
||||
|
@ -591,27 +697,12 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
|
|||
where_clause="type='m.room.member'",
|
||||
)
|
||||
|
||||
def _have_persisted_state_group_txn(self, txn, state_group):
|
||||
txn.execute(
|
||||
"SELECT count(*) FROM state_groups WHERE id = ?",
|
||||
(state_group,)
|
||||
)
|
||||
row = txn.fetchone()
|
||||
return row and row[0]
|
||||
|
||||
def _store_mult_state_groups_txn(self, txn, events_and_contexts):
|
||||
def _store_event_state_mappings_txn(self, txn, events_and_contexts):
|
||||
state_groups = {}
|
||||
for event, context in events_and_contexts:
|
||||
if event.internal_metadata.is_outlier():
|
||||
continue
|
||||
|
||||
if context.current_state_ids is None:
|
||||
# AFAIK, this can never happen
|
||||
logger.error(
|
||||
"Non-outlier event %s had current_state_ids==None",
|
||||
event.event_id)
|
||||
continue
|
||||
|
||||
# if the event was rejected, just give it the same state as its
|
||||
# predecessor.
|
||||
if context.rejected:
|
||||
|
@ -620,90 +711,6 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
|
|||
|
||||
state_groups[event.event_id] = context.state_group
|
||||
|
||||
if self._have_persisted_state_group_txn(txn, context.state_group):
|
||||
continue
|
||||
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="state_groups",
|
||||
values={
|
||||
"id": context.state_group,
|
||||
"room_id": event.room_id,
|
||||
"event_id": event.event_id,
|
||||
},
|
||||
)
|
||||
|
||||
# We persist as a delta if we can, while also ensuring the chain
|
||||
# of deltas isn't tooo long, as otherwise read performance degrades.
|
||||
if context.prev_group:
|
||||
is_in_db = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="state_groups",
|
||||
keyvalues={"id": context.prev_group},
|
||||
retcol="id",
|
||||
allow_none=True,
|
||||
)
|
||||
if not is_in_db:
|
||||
raise Exception(
|
||||
"Trying to persist state with unpersisted prev_group: %r"
|
||||
% (context.prev_group,)
|
||||
)
|
||||
|
||||
potential_hops = self._count_state_group_hops_txn(
|
||||
txn, context.prev_group
|
||||
)
|
||||
if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="state_group_edges",
|
||||
values={
|
||||
"state_group": context.state_group,
|
||||
"prev_state_group": context.prev_group,
|
||||
},
|
||||
)
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="state_groups_state",
|
||||
values=[
|
||||
{
|
||||
"state_group": context.state_group,
|
||||
"room_id": event.room_id,
|
||||
"type": key[0],
|
||||
"state_key": key[1],
|
||||
"event_id": state_id,
|
||||
}
|
||||
for key, state_id in context.delta_ids.iteritems()
|
||||
],
|
||||
)
|
||||
else:
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="state_groups_state",
|
||||
values=[
|
||||
{
|
||||
"state_group": context.state_group,
|
||||
"room_id": event.room_id,
|
||||
"type": key[0],
|
||||
"state_key": key[1],
|
||||
"event_id": state_id,
|
||||
}
|
||||
for key, state_id in context.current_state_ids.iteritems()
|
||||
],
|
||||
)
|
||||
|
||||
# Prefill the state group cache with this group.
|
||||
# It's fine to use the sequence like this as the state group map
|
||||
# is immutable. (If the map wasn't immutable then this prefill could
|
||||
# race with another update)
|
||||
txn.call_after(
|
||||
self._state_group_cache.update,
|
||||
self._state_group_cache.sequence,
|
||||
key=context.state_group,
|
||||
value=dict(context.current_state_ids),
|
||||
full=True,
|
||||
)
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="event_to_state_groups",
|
||||
|
@ -763,9 +770,6 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
|
|||
|
||||
return count
|
||||
|
||||
def get_next_state_group(self):
|
||||
return self._state_groups_id_gen.get_next()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _background_deduplicate_state(self, progress, batch_size):
|
||||
"""This background update will slowly deduplicate state by reencoding
|
||||
|
|
|
@ -641,8 +641,12 @@ class UserDirectoryStore(SQLBaseStore):
|
|||
"""
|
||||
|
||||
if self.hs.config.user_directory_search_all_users:
|
||||
join_clause = ""
|
||||
where_clause = "?<>''" # naughty hack to keep the same number of binds
|
||||
# make s.user_id null to keep the ordering algorithm happy
|
||||
join_clause = """
|
||||
CROSS JOIN (SELECT NULL as user_id) AS s
|
||||
"""
|
||||
join_args = ()
|
||||
where_clause = "1=1"
|
||||
else:
|
||||
join_clause = """
|
||||
LEFT JOIN users_in_public_rooms AS p USING (user_id)
|
||||
|
@ -651,6 +655,7 @@ class UserDirectoryStore(SQLBaseStore):
|
|||
WHERE user_id = ? AND share_private
|
||||
) AS s USING (user_id)
|
||||
"""
|
||||
join_args = (user_id,)
|
||||
where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)"
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
|
@ -692,7 +697,7 @@ class UserDirectoryStore(SQLBaseStore):
|
|||
avatar_url IS NULL
|
||||
LIMIT ?
|
||||
""" % (join_clause, where_clause)
|
||||
args = (user_id, full_query, exact_query, prefix_query, limit + 1,)
|
||||
args = join_args + (full_query, exact_query, prefix_query, limit + 1,)
|
||||
elif isinstance(self.database_engine, Sqlite3Engine):
|
||||
search_query = _parse_query_sqlite(search_term)
|
||||
|
||||
|
@ -710,7 +715,7 @@ class UserDirectoryStore(SQLBaseStore):
|
|||
avatar_url IS NULL
|
||||
LIMIT ?
|
||||
""" % (join_clause, where_clause)
|
||||
args = (user_id, search_query, limit + 1)
|
||||
args = join_args + (search_query, limit + 1)
|
||||
else:
|
||||
# This should be unreachable.
|
||||
raise Exception("Unrecognized database engine")
|
||||
|
|
|
@ -75,6 +75,7 @@ class Cache(object):
|
|||
self.cache = LruCache(
|
||||
max_size=max_entries, keylen=keylen, cache_type=cache_type,
|
||||
size_callback=(lambda d: len(d)) if iterable else None,
|
||||
evicted_callback=self._on_evicted,
|
||||
)
|
||||
|
||||
self.name = name
|
||||
|
@ -83,6 +84,9 @@ class Cache(object):
|
|||
self.thread = None
|
||||
self.metrics = register_cache(name, self.cache)
|
||||
|
||||
def _on_evicted(self, evicted_count):
|
||||
self.metrics.inc_evictions(evicted_count)
|
||||
|
||||
def check_thread(self):
|
||||
expected_thread = self.thread
|
||||
if expected_thread is None:
|
||||
|
|
|
@ -79,7 +79,11 @@ class ExpiringCache(object):
|
|||
while self._max_len and len(self) > self._max_len:
|
||||
_key, value = self._cache.popitem(last=False)
|
||||
if self.iterable:
|
||||
self._size_estimate -= len(value.value)
|
||||
removed_len = len(value.value)
|
||||
self.metrics.inc_evictions(removed_len)
|
||||
self._size_estimate -= removed_len
|
||||
else:
|
||||
self.metrics.inc_evictions()
|
||||
|
||||
def __getitem__(self, key):
|
||||
try:
|
||||
|
|
|
@ -49,7 +49,24 @@ class LruCache(object):
|
|||
Can also set callbacks on objects when getting/setting which are fired
|
||||
when that key gets invalidated/evicted.
|
||||
"""
|
||||
def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None):
|
||||
def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None,
|
||||
evicted_callback=None):
|
||||
"""
|
||||
Args:
|
||||
max_size (int):
|
||||
|
||||
keylen (int):
|
||||
|
||||
cache_type (type):
|
||||
type of underlying cache to be used. Typically one of dict
|
||||
or TreeCache.
|
||||
|
||||
size_callback (func(V) -> int | None):
|
||||
|
||||
evicted_callback (func(int)|None):
|
||||
if not None, called on eviction with the size of the evicted
|
||||
entry
|
||||
"""
|
||||
cache = cache_type()
|
||||
self.cache = cache # Used for introspection.
|
||||
list_root = _Node(None, None, None, None)
|
||||
|
@ -61,8 +78,10 @@ class LruCache(object):
|
|||
def evict():
|
||||
while cache_len() > max_size:
|
||||
todelete = list_root.prev_node
|
||||
delete_node(todelete)
|
||||
evicted_len = delete_node(todelete)
|
||||
cache.pop(todelete.key, None)
|
||||
if evicted_callback:
|
||||
evicted_callback(evicted_len)
|
||||
|
||||
def synchronized(f):
|
||||
@wraps(f)
|
||||
|
@ -111,12 +130,15 @@ class LruCache(object):
|
|||
prev_node.next_node = next_node
|
||||
next_node.prev_node = prev_node
|
||||
|
||||
deleted_len = 1
|
||||
if size_callback:
|
||||
cached_cache_len[0] -= size_callback(node.value)
|
||||
deleted_len = size_callback(node.value)
|
||||
cached_cache_len[0] -= deleted_len
|
||||
|
||||
for cb in node.callbacks:
|
||||
cb()
|
||||
node.callbacks.clear()
|
||||
return deleted_len
|
||||
|
||||
@synchronized
|
||||
def cache_get(key, default=None, callbacks=[]):
|
||||
|
|
139
synapse/util/file_consumer.py
Normal file
139
synapse/util/file_consumer.py
Normal file
|
@ -0,0 +1,139 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import threads, reactor
|
||||
|
||||
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||
|
||||
import Queue
|
||||
|
||||
|
||||
class BackgroundFileConsumer(object):
|
||||
"""A consumer that writes to a file like object. Supports both push
|
||||
and pull producers
|
||||
|
||||
Args:
|
||||
file_obj (file): The file like object to write to. Closed when
|
||||
finished.
|
||||
"""
|
||||
|
||||
# For PushProducers pause if we have this many unwritten slices
|
||||
_PAUSE_ON_QUEUE_SIZE = 5
|
||||
# And resume once the size of the queue is less than this
|
||||
_RESUME_ON_QUEUE_SIZE = 2
|
||||
|
||||
def __init__(self, file_obj):
|
||||
self._file_obj = file_obj
|
||||
|
||||
# Producer we're registered with
|
||||
self._producer = None
|
||||
|
||||
# True if PushProducer, false if PullProducer
|
||||
self.streaming = False
|
||||
|
||||
# For PushProducers, indicates whether we've paused the producer and
|
||||
# need to call resumeProducing before we get more data.
|
||||
self._paused_producer = False
|
||||
|
||||
# Queue of slices of bytes to be written. When producer calls
|
||||
# unregister a final None is sent.
|
||||
self._bytes_queue = Queue.Queue()
|
||||
|
||||
# Deferred that is resolved when finished writing
|
||||
self._finished_deferred = None
|
||||
|
||||
# If the _writer thread throws an exception it gets stored here.
|
||||
self._write_exception = None
|
||||
|
||||
def registerProducer(self, producer, streaming):
|
||||
"""Part of IConsumer interface
|
||||
|
||||
Args:
|
||||
producer (IProducer)
|
||||
streaming (bool): True if push based producer, False if pull
|
||||
based.
|
||||
"""
|
||||
if self._producer:
|
||||
raise Exception("registerProducer called twice")
|
||||
|
||||
self._producer = producer
|
||||
self.streaming = streaming
|
||||
self._finished_deferred = preserve_fn(threads.deferToThread)(self._writer)
|
||||
if not streaming:
|
||||
self._producer.resumeProducing()
|
||||
|
||||
def unregisterProducer(self):
|
||||
"""Part of IProducer interface
|
||||
"""
|
||||
self._producer = None
|
||||
if not self._finished_deferred.called:
|
||||
self._bytes_queue.put_nowait(None)
|
||||
|
||||
def write(self, bytes):
|
||||
"""Part of IProducer interface
|
||||
"""
|
||||
if self._write_exception:
|
||||
raise self._write_exception
|
||||
|
||||
if self._finished_deferred.called:
|
||||
raise Exception("consumer has closed")
|
||||
|
||||
self._bytes_queue.put_nowait(bytes)
|
||||
|
||||
# If this is a PushProducer and the queue is getting behind
|
||||
# then we pause the producer.
|
||||
if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE:
|
||||
self._paused_producer = True
|
||||
self._producer.pauseProducing()
|
||||
|
||||
def _writer(self):
|
||||
"""This is run in a background thread to write to the file.
|
||||
"""
|
||||
try:
|
||||
while self._producer or not self._bytes_queue.empty():
|
||||
# If we've paused the producer check if we should resume the
|
||||
# producer.
|
||||
if self._producer and self._paused_producer:
|
||||
if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE:
|
||||
reactor.callFromThread(self._resume_paused_producer)
|
||||
|
||||
bytes = self._bytes_queue.get()
|
||||
|
||||
# If we get a None (or empty list) then that's a signal used
|
||||
# to indicate we should check if we should stop.
|
||||
if bytes:
|
||||
self._file_obj.write(bytes)
|
||||
|
||||
# If its a pull producer then we need to explicitly ask for
|
||||
# more stuff.
|
||||
if not self.streaming and self._producer:
|
||||
reactor.callFromThread(self._producer.resumeProducing)
|
||||
except Exception as e:
|
||||
self._write_exception = e
|
||||
raise
|
||||
finally:
|
||||
self._file_obj.close()
|
||||
|
||||
def wait(self):
|
||||
"""Returns a deferred that resolves when finished writing to file
|
||||
"""
|
||||
return make_deferred_yieldable(self._finished_deferred)
|
||||
|
||||
def _resume_paused_producer(self):
|
||||
"""Gets called if we should resume producing after being paused
|
||||
"""
|
||||
if self._paused_producer and self._producer:
|
||||
self._paused_producer = False
|
||||
self._producer.resumeProducing()
|
|
@ -52,13 +52,17 @@ except Exception:
|
|||
class LoggingContext(object):
|
||||
"""Additional context for log formatting. Contexts are scoped within a
|
||||
"with" block.
|
||||
|
||||
Args:
|
||||
name (str): Name for the context for debugging.
|
||||
"""
|
||||
|
||||
__slots__ = [
|
||||
"previous_context", "name", "usage_start", "usage_end", "main_thread",
|
||||
"__dict__", "tag", "alive",
|
||||
"previous_context", "name", "ru_stime", "ru_utime",
|
||||
"db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
|
||||
"usage_start", "usage_end",
|
||||
"main_thread", "alive",
|
||||
"request", "tag",
|
||||
]
|
||||
|
||||
thread_local = threading.local()
|
||||
|
@ -83,6 +87,9 @@ class LoggingContext(object):
|
|||
def add_database_transaction(self, duration_ms):
|
||||
pass
|
||||
|
||||
def add_database_scheduled(self, sched_ms):
|
||||
pass
|
||||
|
||||
def __nonzero__(self):
|
||||
return False
|
||||
|
||||
|
@ -94,9 +101,17 @@ class LoggingContext(object):
|
|||
self.ru_stime = 0.
|
||||
self.ru_utime = 0.
|
||||
self.db_txn_count = 0
|
||||
self.db_txn_duration = 0.
|
||||
|
||||
# ms spent waiting for db txns, excluding scheduling time
|
||||
self.db_txn_duration_ms = 0
|
||||
|
||||
# ms spent waiting for db txns to be scheduled
|
||||
self.db_sched_duration_ms = 0
|
||||
|
||||
self.usage_start = None
|
||||
self.usage_end = None
|
||||
self.main_thread = threading.current_thread()
|
||||
self.request = None
|
||||
self.tag = ""
|
||||
self.alive = True
|
||||
|
||||
|
@ -105,7 +120,11 @@ class LoggingContext(object):
|
|||
|
||||
@classmethod
|
||||
def current_context(cls):
|
||||
"""Get the current logging context from thread local storage"""
|
||||
"""Get the current logging context from thread local storage
|
||||
|
||||
Returns:
|
||||
LoggingContext: the current logging context
|
||||
"""
|
||||
return getattr(cls.thread_local, "current_context", cls.sentinel)
|
||||
|
||||
@classmethod
|
||||
|
@ -155,11 +174,13 @@ class LoggingContext(object):
|
|||
self.alive = False
|
||||
|
||||
def copy_to(self, record):
|
||||
"""Copy fields from this context to the record"""
|
||||
for key, value in self.__dict__.items():
|
||||
setattr(record, key, value)
|
||||
"""Copy logging fields from this context to a log record or
|
||||
another LoggingContext
|
||||
"""
|
||||
|
||||
record.ru_utime, record.ru_stime = self.get_resource_usage()
|
||||
# 'request' is the only field we currently use in the logger, so that's
|
||||
# all we need to copy
|
||||
record.request = self.request
|
||||
|
||||
def start(self):
|
||||
if threading.current_thread() is not self.main_thread:
|
||||
|
@ -194,7 +215,16 @@ class LoggingContext(object):
|
|||
|
||||
def add_database_transaction(self, duration_ms):
|
||||
self.db_txn_count += 1
|
||||
self.db_txn_duration += duration_ms / 1000.
|
||||
self.db_txn_duration_ms += duration_ms
|
||||
|
||||
def add_database_scheduled(self, sched_ms):
|
||||
"""Record a use of the database pool
|
||||
|
||||
Args:
|
||||
sched_ms (int): number of milliseconds it took us to get a
|
||||
connection
|
||||
"""
|
||||
self.db_sched_duration_ms += sched_ms
|
||||
|
||||
|
||||
class LoggingContextFilter(logging.Filter):
|
||||
|
|
|
@ -27,25 +27,62 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
metrics = synapse.metrics.get_metrics_for(__name__)
|
||||
|
||||
block_timer = metrics.register_distribution(
|
||||
"block_timer",
|
||||
labels=["block_name"]
|
||||
# total number of times we have hit this block
|
||||
block_counter = metrics.register_counter(
|
||||
"block_count",
|
||||
labels=["block_name"],
|
||||
alternative_names=(
|
||||
# the following are all deprecated aliases for the same metric
|
||||
metrics.name_prefix + x for x in (
|
||||
"_block_timer:count",
|
||||
"_block_ru_utime:count",
|
||||
"_block_ru_stime:count",
|
||||
"_block_db_txn_count:count",
|
||||
"_block_db_txn_duration:count",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
block_ru_utime = metrics.register_distribution(
|
||||
"block_ru_utime", labels=["block_name"]
|
||||
block_timer = metrics.register_counter(
|
||||
"block_time_seconds",
|
||||
labels=["block_name"],
|
||||
alternative_names=(
|
||||
metrics.name_prefix + "_block_timer:total",
|
||||
),
|
||||
)
|
||||
|
||||
block_ru_stime = metrics.register_distribution(
|
||||
"block_ru_stime", labels=["block_name"]
|
||||
block_ru_utime = metrics.register_counter(
|
||||
"block_ru_utime_seconds", labels=["block_name"],
|
||||
alternative_names=(
|
||||
metrics.name_prefix + "_block_ru_utime:total",
|
||||
),
|
||||
)
|
||||
|
||||
block_db_txn_count = metrics.register_distribution(
|
||||
"block_db_txn_count", labels=["block_name"]
|
||||
block_ru_stime = metrics.register_counter(
|
||||
"block_ru_stime_seconds", labels=["block_name"],
|
||||
alternative_names=(
|
||||
metrics.name_prefix + "_block_ru_stime:total",
|
||||
),
|
||||
)
|
||||
|
||||
block_db_txn_duration = metrics.register_distribution(
|
||||
"block_db_txn_duration", labels=["block_name"]
|
||||
block_db_txn_count = metrics.register_counter(
|
||||
"block_db_txn_count", labels=["block_name"],
|
||||
alternative_names=(
|
||||
metrics.name_prefix + "_block_db_txn_count:total",
|
||||
),
|
||||
)
|
||||
|
||||
# seconds spent waiting for db txns, excluding scheduling time, in this block
|
||||
block_db_txn_duration = metrics.register_counter(
|
||||
"block_db_txn_duration_seconds", labels=["block_name"],
|
||||
alternative_names=(
|
||||
metrics.name_prefix + "_block_db_txn_duration:total",
|
||||
),
|
||||
)
|
||||
|
||||
# seconds spent waiting for a db connection, in this block
|
||||
block_db_sched_duration = metrics.register_counter(
|
||||
"block_db_sched_duration_seconds", labels=["block_name"],
|
||||
)
|
||||
|
||||
|
||||
|
@ -64,7 +101,9 @@ def measure_func(name):
|
|||
class Measure(object):
|
||||
__slots__ = [
|
||||
"clock", "name", "start_context", "start", "new_context", "ru_utime",
|
||||
"ru_stime", "db_txn_count", "db_txn_duration", "created_context"
|
||||
"ru_stime",
|
||||
"db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
|
||||
"created_context",
|
||||
]
|
||||
|
||||
def __init__(self, clock, name):
|
||||
|
@ -84,13 +123,16 @@ class Measure(object):
|
|||
|
||||
self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
|
||||
self.db_txn_count = self.start_context.db_txn_count
|
||||
self.db_txn_duration = self.start_context.db_txn_duration
|
||||
self.db_txn_duration_ms = self.start_context.db_txn_duration_ms
|
||||
self.db_sched_duration_ms = self.start_context.db_sched_duration_ms
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if isinstance(exc_type, Exception) or not self.start_context:
|
||||
return
|
||||
|
||||
duration = self.clock.time_msec() - self.start
|
||||
|
||||
block_counter.inc(self.name)
|
||||
block_timer.inc_by(duration, self.name)
|
||||
|
||||
context = LoggingContext.current_context()
|
||||
|
@ -114,7 +156,12 @@ class Measure(object):
|
|||
context.db_txn_count - self.db_txn_count, self.name
|
||||
)
|
||||
block_db_txn_duration.inc_by(
|
||||
context.db_txn_duration - self.db_txn_duration, self.name
|
||||
(context.db_txn_duration_ms - self.db_txn_duration_ms) / 1000.,
|
||||
self.name
|
||||
)
|
||||
block_db_sched_duration.inc_by(
|
||||
(context.db_sched_duration_ms - self.db_sched_duration_ms) / 1000.,
|
||||
self.name
|
||||
)
|
||||
|
||||
if self.created_context:
|
||||
|
|
|
@ -26,6 +26,18 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class NotRetryingDestination(Exception):
|
||||
def __init__(self, retry_last_ts, retry_interval, destination):
|
||||
"""Raised by the limiter (and federation client) to indicate that we are
|
||||
are deliberately not attempting to contact a given server.
|
||||
|
||||
Args:
|
||||
retry_last_ts (int): the unix ts in milliseconds of our last attempt
|
||||
to contact the server. 0 indicates that the last attempt was
|
||||
successful or that we've never actually attempted to connect.
|
||||
retry_interval (int): the time in milliseconds to wait until the next
|
||||
attempt.
|
||||
destination (str): the domain in question
|
||||
"""
|
||||
|
||||
msg = "Not retrying server %s." % (destination,)
|
||||
super(NotRetryingDestination, self).__init__(msg)
|
||||
|
||||
|
|
48
synapse/util/threepids.py
Normal file
48
synapse/util/threepids.py
Normal file
|
@ -0,0 +1,48 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_3pid_allowed(hs, medium, address):
|
||||
"""Checks whether a given format of 3PID is allowed to be used on this HS
|
||||
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
medium (str): 3pid medium - e.g. email, msisdn
|
||||
address (str): address within that medium (e.g. "wotan@matrix.org")
|
||||
msisdns need to first have been canonicalised
|
||||
Returns:
|
||||
bool: whether the 3PID medium/address is allowed to be added to this HS
|
||||
"""
|
||||
|
||||
if hs.config.allowed_local_3pids:
|
||||
for constraint in hs.config.allowed_local_3pids:
|
||||
logger.debug(
|
||||
"Checking 3PID %s (%s) against %s (%s)",
|
||||
address, medium, constraint['pattern'], constraint['medium'],
|
||||
)
|
||||
if (
|
||||
medium == constraint['medium'] and
|
||||
re.match(constraint['pattern'], address)
|
||||
):
|
||||
return True
|
||||
else:
|
||||
return True
|
||||
|
||||
return False
|
Loading…
Add table
Add a link
Reference in a new issue