Merge branch 'develop' of github.com:matrix-org/synapse into erikj/patch_inner

This commit is contained in:
Erik Johnston 2019-10-10 11:51:50 +01:00
commit 9970f955ce
48 changed files with 556 additions and 163 deletions

1
changelog.d/6137.misc Normal file
View File

@ -0,0 +1 @@
Refactor configuration loading to allow better typechecking.

1
changelog.d/6155.bugfix Normal file
View File

@ -0,0 +1 @@
Fix transferring notifications and tags when joining an upgraded room that is new to your server.

1
changelog.d/6184.misc Normal file
View File

@ -0,0 +1 @@
Update `user_filters` table to have a unique index, and non-null columns. Thanks to @pik for contributing this.

1
changelog.d/6187.bugfix Normal file
View File

@ -0,0 +1 @@
Fix occasional missed updates in the room and user directories.

View File

@ -4,10 +4,6 @@ plugins=mypy_zope:plugin
follow_imports=skip follow_imports=skip
mypy_path=stubs mypy_path=stubs
[mypy-synapse.config.homeserver]
# this is a mess because of the metaclass shenanigans
ignore_errors = True
[mypy-zope] [mypy-zope]
ignore_missing_imports = True ignore_missing_imports = True
@ -52,3 +48,15 @@ ignore_missing_imports = True
[mypy-signedjson.*] [mypy-signedjson.*]
ignore_missing_imports = True ignore_missing_imports = True
[mypy-prometheus_client.*]
ignore_missing_imports = True
[mypy-service_identity.*]
ignore_missing_imports = True
[mypy-daemonize]
ignore_missing_imports = True
[mypy-sentry_sdk]
ignore_missing_imports = True

View File

@ -18,7 +18,9 @@
import argparse import argparse
import errno import errno
import os import os
from collections import OrderedDict
from textwrap import dedent from textwrap import dedent
from typing import Any, MutableMapping, Optional
from six import integer_types from six import integer_types
@ -51,7 +53,56 @@ Missing mandatory `server_name` config option.
""" """
def path_exists(file_path):
"""Check if a file exists
Unlike os.path.exists, this throws an exception if there is an error
checking if the file exists (for example, if there is a perms error on
the parent dir).
Returns:
bool: True if the file exists; False if not.
"""
try:
os.stat(file_path)
return True
except OSError as e:
if e.errno != errno.ENOENT:
raise e
return False
class Config(object): class Config(object):
"""
A configuration section, containing configuration keys and values.
Attributes:
section (str): The section title of this config object, such as
"tls" or "logger". This is used to refer to it on the root
logger (for example, `config.tls.some_option`). Must be
defined in subclasses.
"""
section = None
def __init__(self, root_config=None):
self.root = root_config
def __getattr__(self, item: str) -> Any:
"""
Try and fetch a configuration option that does not exist on this class.
This is so that existing configs that rely on `self.value`, where value
is actually from a different config section, continue to work.
"""
if item in ["generate_config_section", "read_config"]:
raise AttributeError(item)
if self.root is None:
raise AttributeError(item)
else:
return self.root._get_unclassed_config(self.section, item)
@staticmethod @staticmethod
def parse_size(value): def parse_size(value):
if isinstance(value, integer_types): if isinstance(value, integer_types):
@ -88,22 +139,7 @@ class Config(object):
@classmethod @classmethod
def path_exists(cls, file_path): def path_exists(cls, file_path):
"""Check if a file exists return path_exists(file_path)
Unlike os.path.exists, this throws an exception if there is an error
checking if the file exists (for example, if there is a perms error on
the parent dir).
Returns:
bool: True if the file exists; False if not.
"""
try:
os.stat(file_path)
return True
except OSError as e:
if e.errno != errno.ENOENT:
raise e
return False
@classmethod @classmethod
def check_file(cls, file_path, config_name): def check_file(cls, file_path, config_name):
@ -136,42 +172,106 @@ class Config(object):
with open(file_path) as file_stream: with open(file_path) as file_stream:
return file_stream.read() return file_stream.read()
def invoke_all(self, name, *args, **kargs):
"""Invoke all instance methods with the given name and arguments in the class RootConfig(object):
class's MRO. """
Holder of an application's configuration.
What configuration this object holds is defined by `config_classes`, a list
of Config classes that will be instantiated and given the contents of a
configuration file to read. They can then be accessed on this class by their
section name, defined in the Config or dynamically set to be the name of the
class, lower-cased and with "Config" removed.
"""
config_classes = []
def __init__(self):
self._configs = OrderedDict()
for config_class in self.config_classes:
if config_class.section is None:
raise ValueError("%r requires a section name" % (config_class,))
try:
conf = config_class(self)
except Exception as e:
raise Exception("Failed making %s: %r" % (config_class.section, e))
self._configs[config_class.section] = conf
def __getattr__(self, item: str) -> Any:
"""
Redirect lookups on this object either to config objects, or values on
config objects, so that `config.tls.blah` works, as well as legacy uses
of things like `config.server_name`. It will first look up the config
section name, and then values on those config classes.
"""
if item in self._configs.keys():
return self._configs[item]
return self._get_unclassed_config(None, item)
def _get_unclassed_config(self, asking_section: Optional[str], item: str):
"""
Fetch a config value from one of the instantiated config classes that
has not been fetched directly.
Args: Args:
name (str): Name of function to invoke asking_section: If this check is coming from a Config child, which
one? This section will not be asked if it has the value.
item: The configuration value key.
Raises:
AttributeError if no config classes have the config key. The body
will contain what sections were checked.
"""
for key, val in self._configs.items():
if key == asking_section:
continue
if item in dir(val):
return getattr(val, item)
raise AttributeError(item, "not found in %s" % (list(self._configs.keys()),))
def invoke_all(self, func_name: str, *args, **kwargs) -> MutableMapping[str, Any]:
"""
Invoke a function on all instantiated config objects this RootConfig is
configured to use.
Args:
func_name: Name of function to invoke
*args *args
**kwargs **kwargs
Returns: Returns:
list: The list of the return values from each method called ordered dictionary of config section name and the result of the
function from it.
""" """
results = [] res = OrderedDict()
for cls in type(self).mro():
if name in cls.__dict__: for name, config in self._configs.items():
results.append(getattr(cls, name)(self, *args, **kargs)) if hasattr(config, func_name):
return results res[name] = getattr(config, func_name)(*args, **kwargs)
return res
@classmethod @classmethod
def invoke_all_static(cls, name, *args, **kargs): def invoke_all_static(cls, func_name: str, *args, **kwargs):
"""Invoke all static methods with the given name and arguments in the """
class's MRO. Invoke a static function on config objects this RootConfig is
configured to use.
Args: Args:
name (str): Name of function to invoke func_name: Name of function to invoke
*args *args
**kwargs **kwargs
Returns: Returns:
list: The list of the return values from each method called ordered dictionary of config section name and the result of the
function from it.
""" """
results = [] for config in cls.config_classes:
for c in cls.mro(): if hasattr(config, func_name):
if name in c.__dict__: getattr(config, func_name)(*args, **kwargs)
results.append(getattr(c, name)(*args, **kargs))
return results
def generate_config( def generate_config(
self, self,
@ -187,7 +287,8 @@ class Config(object):
tls_private_key_path=None, tls_private_key_path=None,
acme_domain=None, acme_domain=None,
): ):
"""Build a default configuration file """
Build a default configuration file
This is used when the user explicitly asks us to generate a config file This is used when the user explicitly asks us to generate a config file
(eg with --generate_config). (eg with --generate_config).
@ -242,6 +343,7 @@ class Config(object):
Returns: Returns:
str: the yaml config file str: the yaml config file
""" """
return "\n\n".join( return "\n\n".join(
dedent(conf) dedent(conf)
for conf in self.invoke_all( for conf in self.invoke_all(
@ -257,7 +359,7 @@ class Config(object):
tls_certificate_path=tls_certificate_path, tls_certificate_path=tls_certificate_path,
tls_private_key_path=tls_private_key_path, tls_private_key_path=tls_private_key_path,
acme_domain=acme_domain, acme_domain=acme_domain,
) ).values()
) )
@classmethod @classmethod
@ -444,7 +546,7 @@ class Config(object):
) )
(config_path,) = config_files (config_path,) = config_files
if not cls.path_exists(config_path): if not path_exists(config_path):
print("Generating config file %s" % (config_path,)) print("Generating config file %s" % (config_path,))
if config_args.data_directory: if config_args.data_directory:
@ -469,7 +571,7 @@ class Config(object):
open_private_ports=config_args.open_private_ports, open_private_ports=config_args.open_private_ports,
) )
if not cls.path_exists(config_dir_path): if not path_exists(config_dir_path):
os.makedirs(config_dir_path) os.makedirs(config_dir_path)
with open(config_path, "w") as config_file: with open(config_path, "w") as config_file:
config_file.write("# vim:ft=yaml\n\n") config_file.write("# vim:ft=yaml\n\n")
@ -518,7 +620,7 @@ class Config(object):
return obj return obj
def parse_config_dict(self, config_dict, config_dir_path, data_dir_path): def parse_config_dict(self, config_dict, config_dir_path=None, data_dir_path=None):
"""Read the information from the config dict into this Config object. """Read the information from the config dict into this Config object.
Args: Args:
@ -607,3 +709,6 @@ def find_config_files(search_paths):
else: else:
config_files.append(config_path) config_files.append(config_path)
return config_files return config_files
__all__ = ["Config", "RootConfig"]

135
synapse/config/_base.pyi Normal file
View File

@ -0,0 +1,135 @@
from typing import Any, List, Optional
from synapse.config import (
api,
appservice,
captcha,
cas,
consent_config,
database,
emailconfig,
groups,
jwt_config,
key,
logger,
metrics,
password,
password_auth_providers,
push,
ratelimiting,
registration,
repository,
room_directory,
saml2_config,
server,
server_notices_config,
spam_checker,
stats,
third_party_event_rules,
tls,
tracer,
user_directory,
voip,
workers,
)
class ConfigError(Exception): ...
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS: str
MISSING_REPORT_STATS_SPIEL: str
MISSING_SERVER_NAME: str
def path_exists(file_path: str): ...
class RootConfig:
server: server.ServerConfig
tls: tls.TlsConfig
database: database.DatabaseConfig
logging: logger.LoggingConfig
ratelimit: ratelimiting.RatelimitConfig
media: repository.ContentRepositoryConfig
captcha: captcha.CaptchaConfig
voip: voip.VoipConfig
registration: registration.RegistrationConfig
metrics: metrics.MetricsConfig
api: api.ApiConfig
appservice: appservice.AppServiceConfig
key: key.KeyConfig
saml2: saml2_config.SAML2Config
cas: cas.CasConfig
jwt: jwt_config.JWTConfig
password: password.PasswordConfig
email: emailconfig.EmailConfig
worker: workers.WorkerConfig
authproviders: password_auth_providers.PasswordAuthProviderConfig
push: push.PushConfig
spamchecker: spam_checker.SpamCheckerConfig
groups: groups.GroupsConfig
userdirectory: user_directory.UserDirectoryConfig
consent: consent_config.ConsentConfig
stats: stats.StatsConfig
servernotices: server_notices_config.ServerNoticesConfig
roomdirectory: room_directory.RoomDirectoryConfig
thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig
tracer: tracer.TracerConfig
config_classes: List = ...
def __init__(self) -> None: ...
def invoke_all(self, func_name: str, *args: Any, **kwargs: Any): ...
@classmethod
def invoke_all_static(cls, func_name: str, *args: Any, **kwargs: Any) -> None: ...
def __getattr__(self, item: str): ...
def parse_config_dict(
self,
config_dict: Any,
config_dir_path: Optional[Any] = ...,
data_dir_path: Optional[Any] = ...,
) -> None: ...
read_config: Any = ...
def generate_config(
self,
config_dir_path: str,
data_dir_path: str,
server_name: str,
generate_secrets: bool = ...,
report_stats: Optional[str] = ...,
open_private_ports: bool = ...,
listeners: Optional[Any] = ...,
database_conf: Optional[Any] = ...,
tls_certificate_path: Optional[str] = ...,
tls_private_key_path: Optional[str] = ...,
acme_domain: Optional[str] = ...,
): ...
@classmethod
def load_or_generate_config(cls, description: Any, argv: Any): ...
@classmethod
def load_config(cls, description: Any, argv: Any): ...
@classmethod
def add_arguments_to_parser(cls, config_parser: Any) -> None: ...
@classmethod
def load_config_with_parser(cls, parser: Any, argv: Any): ...
def generate_missing_files(
self, config_dict: dict, config_dir_path: str
) -> None: ...
class Config:
root: RootConfig
def __init__(self, root_config: Optional[RootConfig] = ...) -> None: ...
def __getattr__(self, item: str, from_root: bool = ...): ...
@staticmethod
def parse_size(value: Any): ...
@staticmethod
def parse_duration(value: Any): ...
@staticmethod
def abspath(file_path: Optional[str]): ...
@classmethod
def path_exists(cls, file_path: str): ...
@classmethod
def check_file(cls, file_path: str, config_name: str): ...
@classmethod
def ensure_directory(cls, dir_path: str): ...
@classmethod
def read_file(cls, file_path: str, config_name: str): ...
def read_config_files(config_files: List[str]): ...
def find_config_files(search_paths: List[str]): ...

View File

@ -18,6 +18,8 @@ from ._base import Config
class ApiConfig(Config): class ApiConfig(Config):
section = "api"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
self.room_invite_state_types = config.get( self.room_invite_state_types = config.get(
"room_invite_state_types", "room_invite_state_types",

View File

@ -30,6 +30,8 @@ logger = logging.getLogger(__name__)
class AppServiceConfig(Config): class AppServiceConfig(Config):
section = "appservice"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
self.app_service_config_files = config.get("app_service_config_files", []) self.app_service_config_files = config.get("app_service_config_files", [])
self.notify_appservices = config.get("notify_appservices", True) self.notify_appservices = config.get("notify_appservices", True)

View File

@ -16,6 +16,8 @@ from ._base import Config
class CaptchaConfig(Config): class CaptchaConfig(Config):
section = "captcha"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
self.recaptcha_private_key = config.get("recaptcha_private_key") self.recaptcha_private_key = config.get("recaptcha_private_key")
self.recaptcha_public_key = config.get("recaptcha_public_key") self.recaptcha_public_key = config.get("recaptcha_public_key")

View File

@ -22,6 +22,8 @@ class CasConfig(Config):
cas_server_url: URL of CAS server cas_server_url: URL of CAS server
""" """
section = "cas"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
cas_config = config.get("cas_config", None) cas_config = config.get("cas_config", None)
if cas_config: if cas_config:

View File

@ -73,6 +73,9 @@ DEFAULT_CONFIG = """\
class ConsentConfig(Config): class ConsentConfig(Config):
section = "consent"
def __init__(self, *args): def __init__(self, *args):
super(ConsentConfig, self).__init__(*args) super(ConsentConfig, self).__init__(*args)

View File

@ -21,6 +21,8 @@ from ._base import Config
class DatabaseConfig(Config): class DatabaseConfig(Config):
section = "database"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K")) self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K"))

View File

@ -28,6 +28,8 @@ from ._base import Config, ConfigError
class EmailConfig(Config): class EmailConfig(Config):
section = "email"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
# TODO: We should separate better the email configuration from the notification # TODO: We should separate better the email configuration from the notification
# and account validity config. # and account validity config.

View File

@ -17,6 +17,8 @@ from ._base import Config
class GroupsConfig(Config): class GroupsConfig(Config):
section = "groups"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
self.enable_group_creation = config.get("enable_group_creation", False) self.enable_group_creation = config.get("enable_group_creation", False)
self.group_creation_prefix = config.get("group_creation_prefix", "") self.group_creation_prefix = config.get("group_creation_prefix", "")

View File

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import RootConfig
from .api import ApiConfig from .api import ApiConfig
from .appservice import AppServiceConfig from .appservice import AppServiceConfig
from .captcha import CaptchaConfig from .captcha import CaptchaConfig
@ -46,7 +47,9 @@ from .voip import VoipConfig
from .workers import WorkerConfig from .workers import WorkerConfig
class HomeServerConfig( class HomeServerConfig(RootConfig):
config_classes = [
ServerConfig, ServerConfig,
TlsConfig, TlsConfig,
DatabaseConfig, DatabaseConfig,
@ -77,5 +80,4 @@ class HomeServerConfig(
RoomDirectoryConfig, RoomDirectoryConfig,
ThirdPartyRulesConfig, ThirdPartyRulesConfig,
TracerConfig, TracerConfig,
): ]
pass

View File

@ -23,6 +23,8 @@ MISSING_JWT = """Missing jwt library. This is required for jwt login.
class JWTConfig(Config): class JWTConfig(Config):
section = "jwt"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
jwt_config = config.get("jwt_config", None) jwt_config = config.get("jwt_config", None)
if jwt_config: if jwt_config:

View File

@ -92,6 +92,8 @@ class TrustedKeyServer(object):
class KeyConfig(Config): class KeyConfig(Config):
section = "key"
def read_config(self, config, config_dir_path, **kwargs): def read_config(self, config, config_dir_path, **kwargs):
# the signing key can be specified inline or in a separate file # the signing key can be specified inline or in a separate file
if "signing_key" in config: if "signing_key" in config:

View File

@ -84,6 +84,8 @@ root:
class LoggingConfig(Config): class LoggingConfig(Config):
section = "logging"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
self.log_config = self.abspath(config.get("log_config")) self.log_config = self.abspath(config.get("log_config"))
self.no_redirect_stdio = config.get("no_redirect_stdio", False) self.no_redirect_stdio = config.get("no_redirect_stdio", False)

View File

@ -34,6 +34,8 @@ class MetricsFlags(object):
class MetricsConfig(Config): class MetricsConfig(Config):
section = "metrics"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
self.enable_metrics = config.get("enable_metrics", False) self.enable_metrics = config.get("enable_metrics", False)
self.report_stats = config.get("report_stats", None) self.report_stats = config.get("report_stats", None)

View File

@ -20,6 +20,8 @@ class PasswordConfig(Config):
"""Password login configuration """Password login configuration
""" """
section = "password"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
password_config = config.get("password_config", {}) password_config = config.get("password_config", {})
if password_config is None: if password_config is None:

View File

@ -23,6 +23,8 @@ LDAP_PROVIDER = "ldap_auth_provider.LdapAuthProvider"
class PasswordAuthProviderConfig(Config): class PasswordAuthProviderConfig(Config):
section = "authproviders"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
self.password_providers = [] # type: List[Any] self.password_providers = [] # type: List[Any]
providers = [] providers = []

View File

@ -18,6 +18,8 @@ from ._base import Config
class PushConfig(Config): class PushConfig(Config):
section = "push"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
push_config = config.get("push", {}) push_config = config.get("push", {})
self.push_include_content = push_config.get("include_content", True) self.push_include_content = push_config.get("include_content", True)

View File

@ -36,6 +36,8 @@ class FederationRateLimitConfig(object):
class RatelimitConfig(Config): class RatelimitConfig(Config):
section = "ratelimiting"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
# Load the new-style messages config if it exists. Otherwise fall back # Load the new-style messages config if it exists. Otherwise fall back

View File

@ -24,6 +24,8 @@ from synapse.util.stringutils import random_string_with_symbols
class AccountValidityConfig(Config): class AccountValidityConfig(Config):
section = "accountvalidity"
def __init__(self, config, synapse_config): def __init__(self, config, synapse_config):
self.enabled = config.get("enabled", False) self.enabled = config.get("enabled", False)
self.renew_by_email_enabled = "renew_at" in config self.renew_by_email_enabled = "renew_at" in config
@ -77,6 +79,8 @@ class AccountValidityConfig(Config):
class RegistrationConfig(Config): class RegistrationConfig(Config):
section = "registration"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
self.enable_registration = bool( self.enable_registration = bool(
strtobool(str(config.get("enable_registration", False))) strtobool(str(config.get("enable_registration", False)))

View File

@ -78,6 +78,8 @@ def parse_thumbnail_requirements(thumbnail_sizes):
class ContentRepositoryConfig(Config): class ContentRepositoryConfig(Config):
section = "media"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
# Only enable the media repo if either the media repo is enabled or the # Only enable the media repo if either the media repo is enabled or the

View File

@ -19,6 +19,8 @@ from ._base import Config, ConfigError
class RoomDirectoryConfig(Config): class RoomDirectoryConfig(Config):
section = "roomdirectory"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
self.enable_room_list_search = config.get("enable_room_list_search", True) self.enable_room_list_search = config.get("enable_room_list_search", True)

View File

@ -55,6 +55,8 @@ def _dict_merge(merge_dict, into_dict):
class SAML2Config(Config): class SAML2Config(Config):
section = "saml2"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
self.saml2_enabled = False self.saml2_enabled = False

View File

@ -58,6 +58,8 @@ on how to configure the new listener.
class ServerConfig(Config): class ServerConfig(Config):
section = "server"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
self.server_name = config["server_name"] self.server_name = config["server_name"]
self.server_context = config.get("server_context", None) self.server_context = config.get("server_context", None)

View File

@ -59,6 +59,8 @@ class ServerNoticesConfig(Config):
None if server notices are not enabled. None if server notices are not enabled.
""" """
section = "servernotices"
def __init__(self, *args): def __init__(self, *args):
super(ServerNoticesConfig, self).__init__(*args) super(ServerNoticesConfig, self).__init__(*args)
self.server_notices_mxid = None self.server_notices_mxid = None

View File

@ -19,6 +19,8 @@ from ._base import Config
class SpamCheckerConfig(Config): class SpamCheckerConfig(Config):
section = "spamchecker"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
self.spam_checker = None self.spam_checker = None

View File

@ -25,6 +25,8 @@ class StatsConfig(Config):
Configuration for the behaviour of synapse's stats engine Configuration for the behaviour of synapse's stats engine
""" """
section = "stats"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
self.stats_enabled = True self.stats_enabled = True
self.stats_bucket_size = 86400 * 1000 self.stats_bucket_size = 86400 * 1000

View File

@ -19,6 +19,8 @@ from ._base import Config
class ThirdPartyRulesConfig(Config): class ThirdPartyRulesConfig(Config):
section = "thirdpartyrules"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
self.third_party_event_rules = None self.third_party_event_rules = None

View File

@ -18,6 +18,7 @@ import os
import warnings import warnings
from datetime import datetime from datetime import datetime
from hashlib import sha256 from hashlib import sha256
from typing import List
import six import six
@ -33,7 +34,9 @@ logger = logging.getLogger(__name__)
class TlsConfig(Config): class TlsConfig(Config):
def read_config(self, config, config_dir_path, **kwargs): section = "tls"
def read_config(self, config: dict, config_dir_path: str, **kwargs):
acme_config = config.get("acme", None) acme_config = config.get("acme", None)
if acme_config is None: if acme_config is None:
@ -57,7 +60,7 @@ class TlsConfig(Config):
self.tls_certificate_file = self.abspath(config.get("tls_certificate_path")) self.tls_certificate_file = self.abspath(config.get("tls_certificate_path"))
self.tls_private_key_file = self.abspath(config.get("tls_private_key_path")) self.tls_private_key_file = self.abspath(config.get("tls_private_key_path"))
if self.has_tls_listener(): if self.root.server.has_tls_listener():
if not self.tls_certificate_file: if not self.tls_certificate_file:
raise ConfigError( raise ConfigError(
"tls_certificate_path must be specified if TLS-enabled listeners are " "tls_certificate_path must be specified if TLS-enabled listeners are "
@ -108,7 +111,7 @@ class TlsConfig(Config):
) )
# Support globs (*) in whitelist values # Support globs (*) in whitelist values
self.federation_certificate_verification_whitelist = [] self.federation_certificate_verification_whitelist = [] # type: List[str]
for entry in fed_whitelist_entries: for entry in fed_whitelist_entries:
try: try:
entry_regex = glob_to_regex(entry.encode("ascii").decode("ascii")) entry_regex = glob_to_regex(entry.encode("ascii").decode("ascii"))

View File

@ -19,6 +19,8 @@ from ._base import Config, ConfigError
class TracerConfig(Config): class TracerConfig(Config):
section = "tracing"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
opentracing_config = config.get("opentracing") opentracing_config = config.get("opentracing")
if opentracing_config is None: if opentracing_config is None:

View File

@ -21,6 +21,8 @@ class UserDirectoryConfig(Config):
Configuration for the behaviour of the /user_directory API Configuration for the behaviour of the /user_directory API
""" """
section = "userdirectory"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
self.user_directory_search_enabled = True self.user_directory_search_enabled = True
self.user_directory_search_all_users = False self.user_directory_search_all_users = False

View File

@ -16,6 +16,8 @@ from ._base import Config
class VoipConfig(Config): class VoipConfig(Config):
section = "voip"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
self.turn_uris = config.get("turn_uris", []) self.turn_uris = config.get("turn_uris", [])
self.turn_shared_secret = config.get("turn_shared_secret") self.turn_shared_secret = config.get("turn_shared_secret")

View File

@ -21,6 +21,8 @@ class WorkerConfig(Config):
They have their own pid_file and listener configuration. They use the They have their own pid_file and listener configuration. They use the
replication_url to talk to the main synapse process.""" replication_url to talk to the main synapse process."""
section = "worker"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
self.worker_app = config.get("worker_app") self.worker_app = config.get("worker_app")

View File

@ -803,17 +803,25 @@ class PresenceHandler(object):
# Loop round handling deltas until we're up to date # Loop round handling deltas until we're up to date
while True: while True:
with Measure(self.clock, "presence_delta"): with Measure(self.clock, "presence_delta"):
deltas = yield self.store.get_current_state_deltas(self._event_pos) room_max_stream_ordering = self.store.get_room_max_stream_ordering()
if not deltas: if self._event_pos == room_max_stream_ordering:
return return
logger.debug(
"Processing presence stats %s->%s",
self._event_pos,
room_max_stream_ordering,
)
max_pos, deltas = yield self.store.get_current_state_deltas(
self._event_pos, room_max_stream_ordering
)
yield self._handle_state_delta(deltas) yield self._handle_state_delta(deltas)
self._event_pos = deltas[-1]["stream_id"] self._event_pos = max_pos
# Expose current event processing position to prometheus # Expose current event processing position to prometheus
synapse.metrics.event_processing_positions.labels("presence").set( synapse.metrics.event_processing_positions.labels("presence").set(
self._event_pos max_pos
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -203,23 +203,11 @@ class RoomMemberHandler(object):
prev_member_event = yield self.store.get_event(prev_member_event_id) prev_member_event = yield self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined: if newly_joined:
# Copy over user state if we're joining an upgraded room
yield self.copy_user_state_if_room_upgrade(
room_id, requester.user.to_string()
)
yield self._user_joined_room(target, room_id) yield self._user_joined_room(target, room_id)
# Copy over direct message status and room tags if this is a join
# on an upgraded room
# Check if this is an upgraded room
predecessor = yield self.store.get_room_predecessor(room_id)
if predecessor:
# It is an upgraded room. Copy over old tags
yield self.copy_room_tags_and_direct_to_room(
predecessor["room_id"], room_id, user_id
)
# Copy over push rules
yield self.store.copy_push_rules_from_room_to_room_for_user(
predecessor["room_id"], room_id, user_id
)
elif event.membership == Membership.LEAVE: elif event.membership == Membership.LEAVE:
if prev_member_event_id: if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id) prev_member_event = yield self.store.get_event(prev_member_event_id)
@ -463,10 +451,16 @@ class RoomMemberHandler(object):
if requester.is_guest: if requester.is_guest:
content["kind"] = "guest" content["kind"] = "guest"
ret = yield self._remote_join( remote_join_response = yield self._remote_join(
requester, remote_room_hosts, room_id, target, content requester, remote_room_hosts, room_id, target, content
) )
return ret
# Copy over user state if this is a join on an remote upgraded room
yield self.copy_user_state_if_room_upgrade(
room_id, requester.user.to_string()
)
return remote_join_response
elif effective_membership_state == Membership.LEAVE: elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room: if not is_host_in_room:
@ -503,6 +497,38 @@ class RoomMemberHandler(object):
) )
return res return res
@defer.inlineCallbacks
def copy_user_state_if_room_upgrade(self, new_room_id, user_id):
"""Copy user-specific information when they join a new room if that new room is the
result of a room upgrade
Args:
new_room_id (str): The ID of the room the user is joining
user_id (str): The ID of the user
Returns:
Deferred
"""
# Check if the new room is an upgraded room
predecessor = yield self.store.get_room_predecessor(new_room_id)
if not predecessor:
return
logger.debug(
"Found predecessor for %s: %s. Copying over room tags and push " "rules",
new_room_id,
predecessor,
)
# It is an upgraded room. Copy over old tags
yield self.copy_room_tags_and_direct_to_room(
predecessor["room_id"], new_room_id, user_id
)
# Copy over push rules
yield self.store.copy_push_rules_from_room_to_room_for_user(
predecessor["room_id"], new_room_id, user_id
)
@defer.inlineCallbacks @defer.inlineCallbacks
def send_membership_event(self, requester, event, context, ratelimit=True): def send_membership_event(self, requester, event, context, ratelimit=True):
""" """

View File

@ -87,21 +87,23 @@ class StatsHandler(StateDeltasHandler):
# Be sure to read the max stream_ordering *before* checking if there are any outstanding # Be sure to read the max stream_ordering *before* checking if there are any outstanding
# deltas, since there is otherwise a chance that we could miss updates which arrive # deltas, since there is otherwise a chance that we could miss updates which arrive
# after we check the deltas. # after we check the deltas.
room_max_stream_ordering = yield self.store.get_room_max_stream_ordering() room_max_stream_ordering = self.store.get_room_max_stream_ordering()
if self.pos == room_max_stream_ordering: if self.pos == room_max_stream_ordering:
break break
deltas = yield self.store.get_current_state_deltas(self.pos) logger.debug(
"Processing room stats %s->%s", self.pos, room_max_stream_ordering
)
max_pos, deltas = yield self.store.get_current_state_deltas(
self.pos, room_max_stream_ordering
)
if deltas: if deltas:
logger.debug("Handling %d state deltas", len(deltas)) logger.debug("Handling %d state deltas", len(deltas))
room_deltas, user_deltas = yield self._handle_deltas(deltas) room_deltas, user_deltas = yield self._handle_deltas(deltas)
max_pos = deltas[-1]["stream_id"]
else: else:
room_deltas = {} room_deltas = {}
user_deltas = {} user_deltas = {}
max_pos = room_max_stream_ordering
# Then count deltas for total_events and total_event_bytes. # Then count deltas for total_events and total_event_bytes.
room_count, user_count = yield self.store.get_changes_room_total_events_and_bytes( room_count, user_count = yield self.store.get_changes_room_total_events_and_bytes(

View File

@ -138,21 +138,28 @@ class UserDirectoryHandler(StateDeltasHandler):
# Loop round handling deltas until we're up to date # Loop round handling deltas until we're up to date
while True: while True:
with Measure(self.clock, "user_dir_delta"): with Measure(self.clock, "user_dir_delta"):
deltas = yield self.store.get_current_state_deltas(self.pos) room_max_stream_ordering = self.store.get_room_max_stream_ordering()
if not deltas: if self.pos == room_max_stream_ordering:
return return
logger.debug(
"Processing user stats %s->%s", self.pos, room_max_stream_ordering
)
max_pos, deltas = yield self.store.get_current_state_deltas(
self.pos, room_max_stream_ordering
)
logger.info("Handling %d state deltas", len(deltas)) logger.info("Handling %d state deltas", len(deltas))
yield self._handle_deltas(deltas) yield self._handle_deltas(deltas)
self.pos = deltas[-1]["stream_id"] self.pos = max_pos
# Expose current event processing position to prometheus # Expose current event processing position to prometheus
synapse.metrics.event_processing_positions.labels("user_dir").set( synapse.metrics.event_processing_positions.labels("user_dir").set(
self.pos max_pos
) )
yield self.store.update_user_directory_stream_pos(self.pos) yield self.store.update_user_directory_stream_pos(max_pos)
@defer.inlineCallbacks @defer.inlineCallbacks
def _handle_deltas(self, deltas): def _handle_deltas(self, deltas):

View File

@ -5,42 +5,48 @@ from synapse.storage.engines import PostgresEngine
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
"""
This migration updates the user_filters table as follows:
- drops any (user_id, filter_id) duplicates
- makes the columns NON-NULLable
- turns the index into a UNIQUE index
"""
def run_upgrade(cur, database_engine, *args, **kwargs): def run_upgrade(cur, database_engine, *args, **kwargs):
pass
def run_create(cur, database_engine, *args, **kwargs):
if isinstance(database_engine, PostgresEngine): if isinstance(database_engine, PostgresEngine):
select_clause = """ select_clause = """
CREATE TEMPORARY TABLE user_filters_migration AS
SELECT DISTINCT ON (user_id, filter_id) user_id, filter_id, filter_json SELECT DISTINCT ON (user_id, filter_id) user_id, filter_id, filter_json
FROM user_filters; FROM user_filters
""" """
else: else:
select_clause = """ select_clause = """
CREATE TEMPORARY TABLE user_filters_migration AS SELECT * FROM user_filters GROUP BY user_id, filter_id
SELECT * FROM user_filters GROUP BY user_id, filter_id;
""" """
sql = ( sql = """
""" DROP TABLE IF EXISTS user_filters_migration;
BEGIN; DROP INDEX IF EXISTS user_filters_unique;
%s CREATE TABLE user_filters_migration (
DROP INDEX user_filters_by_user_id_filter_id; user_id TEXT NOT NULL,
DELETE FROM user_filters; filter_id BIGINT NOT NULL,
ALTER TABLE user_filters filter_json BYTEA NOT NULL
ALTER COLUMN user_id SET NOT NULL, );
ALTER COLUMN filter_id SET NOT NULL, INSERT INTO user_filters_migration (user_id, filter_id, filter_json)
ALTER COLUMN filter_json SET NOT NULL; %s;
INSERT INTO user_filters(user_id, filter_id, filter_json) CREATE UNIQUE INDEX user_filters_unique ON user_filters_migration
SELECT * FROM user_filters_migration; (user_id, filter_id);
DROP TABLE user_filters_migration; DROP TABLE user_filters;
CREATE UNIQUE INDEX user_filters_by_user_id_filter_id_unique ALTER TABLE user_filters_migration RENAME TO user_filters;
ON user_filters(user_id, filter_id); """ % (
END; select_clause,
"""
% select_clause
) )
if isinstance(database_engine, PostgresEngine): if isinstance(database_engine, PostgresEngine):
cur.execute(sql) cur.execute(sql)
else: else:
cur.executescript(sql) cur.executescript(sql)
def run_create(cur, database_engine, *args, **kwargs):
pass

View File

@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
class StateDeltasStore(SQLBaseStore): class StateDeltasStore(SQLBaseStore):
def get_current_state_deltas(self, prev_stream_id): def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int):
"""Fetch a list of room state changes since the given stream id """Fetch a list of room state changes since the given stream id
Each entry in the result contains the following fields: Each entry in the result contains the following fields:
@ -36,15 +36,27 @@ class StateDeltasStore(SQLBaseStore):
Args: Args:
prev_stream_id (int): point to get changes since (exclusive) prev_stream_id (int): point to get changes since (exclusive)
max_stream_id (int): the point that we know has been correctly persisted
- ie, an upper limit to return changes from.
Returns: Returns:
Deferred[list[dict]]: results Deferred[tuple[int, list[dict]]: A tuple consisting of:
- the stream id which these results go up to
- list of current_state_delta_stream rows. If it is empty, we are
up to date.
""" """
prev_stream_id = int(prev_stream_id) prev_stream_id = int(prev_stream_id)
# check we're not going backwards
assert prev_stream_id <= max_stream_id
if not self._curr_state_delta_stream_cache.has_any_entity_changed( if not self._curr_state_delta_stream_cache.has_any_entity_changed(
prev_stream_id prev_stream_id
): ):
return [] # if the CSDs haven't changed between prev_stream_id and now, we
# know for certain that they haven't changed between prev_stream_id and
# max_stream_id.
return max_stream_id, []
def get_current_state_deltas_txn(txn): def get_current_state_deltas_txn(txn):
# First we calculate the max stream id that will give us less than # First we calculate the max stream id that will give us less than
@ -54,21 +66,29 @@ class StateDeltasStore(SQLBaseStore):
sql = """ sql = """
SELECT stream_id, count(*) SELECT stream_id, count(*)
FROM current_state_delta_stream FROM current_state_delta_stream
WHERE stream_id > ? WHERE stream_id > ? AND stream_id <= ?
GROUP BY stream_id GROUP BY stream_id
ORDER BY stream_id ASC ORDER BY stream_id ASC
LIMIT 100 LIMIT 100
""" """
txn.execute(sql, (prev_stream_id,)) txn.execute(sql, (prev_stream_id, max_stream_id))
total = 0 total = 0
max_stream_id = prev_stream_id
for max_stream_id, count in txn: for stream_id, count in txn:
total += count total += count
if total > 100: if total > 100:
# We arbitarily limit to 100 entries to ensure we don't # We arbitarily limit to 100 entries to ensure we don't
# select toooo many. # select toooo many.
logger.debug(
"Clipping current_state_delta_stream rows to stream_id %i",
stream_id,
)
clipped_stream_id = stream_id
break break
else:
# if there's no problem, we may as well go right up to the max_stream_id
clipped_stream_id = max_stream_id
# Now actually get the deltas # Now actually get the deltas
sql = """ sql = """
@ -77,8 +97,8 @@ class StateDeltasStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ? WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC ORDER BY stream_id ASC
""" """
txn.execute(sql, (prev_stream_id, max_stream_id)) txn.execute(sql, (prev_stream_id, clipped_stream_id))
return self.cursor_to_dict(txn) return clipped_stream_id, self.cursor_to_dict(txn)
return self.runInteraction( return self.runInteraction(
"get_current_state_deltas", get_current_state_deltas_txn "get_current_state_deltas", get_current_state_deltas_txn

View File

@ -21,17 +21,24 @@ import yaml
from OpenSSL import SSL from OpenSSL import SSL
from synapse.config._base import Config, RootConfig
from synapse.config.tls import ConfigError, TlsConfig from synapse.config.tls import ConfigError, TlsConfig
from synapse.crypto.context_factory import ClientTLSOptionsFactory from synapse.crypto.context_factory import ClientTLSOptionsFactory
from tests.unittest import TestCase from tests.unittest import TestCase
class TestConfig(TlsConfig): class FakeServer(Config):
section = "server"
def has_tls_listener(self): def has_tls_listener(self):
return False return False
class TestConfig(RootConfig):
config_classes = [FakeServer, TlsConfig]
class TLSConfigTests(TestCase): class TLSConfigTests(TestCase):
def test_warn_self_signed(self): def test_warn_self_signed(self):
""" """
@ -202,13 +209,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
conf = TestConfig() conf = TestConfig()
conf.read_config( conf.read_config(
yaml.safe_load( yaml.safe_load(
TestConfig().generate_config_section( TestConfig().generate_config(
"/config_dir_path", "/config_dir_path",
"my_super_secure_server", "my_super_secure_server",
"/data_dir_path", "/data_dir_path",
"/tls_cert_path", tls_certificate_path="/tls_cert_path",
"tls_private_key", tls_private_key_path="tls_private_key",
None, # This is the acme_domain acme_domain=None, # This is the acme_domain
) )
), ),
"/config_dir_path", "/config_dir_path",
@ -223,13 +230,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
conf = TestConfig() conf = TestConfig()
conf.read_config( conf.read_config(
yaml.safe_load( yaml.safe_load(
TestConfig().generate_config_section( TestConfig().generate_config(
"/config_dir_path", "/config_dir_path",
"my_super_secure_server", "my_super_secure_server",
"/data_dir_path", "/data_dir_path",
"/tls_cert_path", tls_certificate_path="/tls_cert_path",
"tls_private_key", tls_private_key_path="tls_private_key",
"my_supe_secure_server", # This is the acme_domain acme_domain="my_supe_secure_server", # This is the acme_domain
) )
), ),
"/config_dir_path", "/config_dir_path",

View File

@ -139,7 +139,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
defer.succeed(1) defer.succeed(1)
) )
self.datastore.get_current_state_deltas.return_value = None self.datastore.get_current_state_deltas.return_value = (0, None)
self.datastore.get_to_device_stream_token = lambda: 0 self.datastore.get_to_device_stream_token = lambda: 0
self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0) self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0)

View File

@ -62,7 +62,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
self.device_handler.check_device_registered = Mock(return_value="FAKE") self.device_handler.check_device_registered = Mock(return_value="FAKE")
self.datastore = Mock(return_value=Mock()) self.datastore = Mock(return_value=Mock())
self.datastore.get_current_state_deltas = Mock(return_value=[]) self.datastore.get_current_state_deltas = Mock(return_value=(0, []))
self.secrets = Mock() self.secrets = Mock()

View File

@ -163,10 +163,9 @@ deps =
{[base]deps} {[base]deps}
mypy mypy
mypy-zope mypy-zope
typeshed
env = env =
MYPYPATH = stubs/ MYPYPATH = stubs/
extras = all extras = all
commands = mypy --show-traceback \ commands = mypy --show-traceback --check-untyped-defs --show-error-codes --follow-imports=normal \
synapse/logging/ \ synapse/logging/ \
synapse/config/ synapse/config/