Merge branch 'develop' of github.com:matrix-org/synapse into erikj/disable_sql_bytes

This commit is contained in:
Erik Johnston 2019-10-10 13:10:57 +01:00
commit 91f43dca39
135 changed files with 2482 additions and 1700 deletions

View file

@ -179,7 +179,6 @@ class Auth(object):
def get_public_keys(self, invite_event):
return event_auth.get_public_keys(invite_event)
@opentracing.trace
@defer.inlineCallbacks
def get_user_by_req(
self, request, allow_guest=False, rights="access", allow_expired=False
@ -212,6 +211,7 @@ class Auth(object):
if user_id:
request.authenticated_entity = user_id
opentracing.set_tag("authenticated_entity", user_id)
opentracing.set_tag("appservice_id", app_service.id)
if ip_addr and self.hs.config.track_appservice_user_ips:
yield self.store.insert_client_ip(
@ -263,6 +263,8 @@ class Auth(object):
request.authenticated_entity = user.to_string()
opentracing.set_tag("authenticated_entity", user.to_string())
if device_id:
opentracing.set_tag("device_id", device_id)
return synapse.types.create_requester(
user, token_id, is_guest, device_id, app_service=app_service

View file

@ -17,6 +17,7 @@
"""Contains exceptions and error codes."""
import logging
from typing import Dict
from six import iteritems
from six.moves import http_client
@ -111,7 +112,7 @@ class ProxiedRequestError(SynapseError):
def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None):
super(ProxiedRequestError, self).__init__(code, msg, errcode)
if additional_fields is None:
self._additional_fields = {}
self._additional_fields = {} # type: Dict
else:
self._additional_fields = dict(additional_fields)

View file

@ -12,6 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
import attr
@ -102,4 +105,4 @@ KNOWN_ROOM_VERSIONS = {
RoomVersions.V4,
RoomVersions.V5,
)
} # type: dict[str, RoomVersion]
} # type: Dict[str, RoomVersion]

View file

@ -263,7 +263,9 @@ def start(hs, listeners=None):
refresh_certificate(hs)
# Start the tracer
synapse.logging.opentracing.init_tracer(hs.config)
synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa
hs.config
)
# It is now safe to start your Synapse.
hs.start_listening(listeners)

View file

@ -18,7 +18,9 @@
import argparse
import errno
import os
from collections import OrderedDict
from textwrap import dedent
from typing import Any, MutableMapping, Optional
from six import integer_types
@ -51,7 +53,56 @@ Missing mandatory `server_name` config option.
"""
def path_exists(file_path):
"""Check if a file exists
Unlike os.path.exists, this throws an exception if there is an error
checking if the file exists (for example, if there is a perms error on
the parent dir).
Returns:
bool: True if the file exists; False if not.
"""
try:
os.stat(file_path)
return True
except OSError as e:
if e.errno != errno.ENOENT:
raise e
return False
class Config(object):
"""
A configuration section, containing configuration keys and values.
Attributes:
section (str): The section title of this config object, such as
"tls" or "logger". This is used to refer to it on the root
logger (for example, `config.tls.some_option`). Must be
defined in subclasses.
"""
section = None
def __init__(self, root_config=None):
self.root = root_config
def __getattr__(self, item: str) -> Any:
"""
Try and fetch a configuration option that does not exist on this class.
This is so that existing configs that rely on `self.value`, where value
is actually from a different config section, continue to work.
"""
if item in ["generate_config_section", "read_config"]:
raise AttributeError(item)
if self.root is None:
raise AttributeError(item)
else:
return self.root._get_unclassed_config(self.section, item)
@staticmethod
def parse_size(value):
if isinstance(value, integer_types):
@ -88,22 +139,7 @@ class Config(object):
@classmethod
def path_exists(cls, file_path):
"""Check if a file exists
Unlike os.path.exists, this throws an exception if there is an error
checking if the file exists (for example, if there is a perms error on
the parent dir).
Returns:
bool: True if the file exists; False if not.
"""
try:
os.stat(file_path)
return True
except OSError as e:
if e.errno != errno.ENOENT:
raise e
return False
return path_exists(file_path)
@classmethod
def check_file(cls, file_path, config_name):
@ -136,42 +172,106 @@ class Config(object):
with open(file_path) as file_stream:
return file_stream.read()
def invoke_all(self, name, *args, **kargs):
"""Invoke all instance methods with the given name and arguments in the
class's MRO.
class RootConfig(object):
"""
Holder of an application's configuration.
What configuration this object holds is defined by `config_classes`, a list
of Config classes that will be instantiated and given the contents of a
configuration file to read. They can then be accessed on this class by their
section name, defined in the Config or dynamically set to be the name of the
class, lower-cased and with "Config" removed.
"""
config_classes = []
def __init__(self):
self._configs = OrderedDict()
for config_class in self.config_classes:
if config_class.section is None:
raise ValueError("%r requires a section name" % (config_class,))
try:
conf = config_class(self)
except Exception as e:
raise Exception("Failed making %s: %r" % (config_class.section, e))
self._configs[config_class.section] = conf
def __getattr__(self, item: str) -> Any:
"""
Redirect lookups on this object either to config objects, or values on
config objects, so that `config.tls.blah` works, as well as legacy uses
of things like `config.server_name`. It will first look up the config
section name, and then values on those config classes.
"""
if item in self._configs.keys():
return self._configs[item]
return self._get_unclassed_config(None, item)
def _get_unclassed_config(self, asking_section: Optional[str], item: str):
"""
Fetch a config value from one of the instantiated config classes that
has not been fetched directly.
Args:
name (str): Name of function to invoke
asking_section: If this check is coming from a Config child, which
one? This section will not be asked if it has the value.
item: The configuration value key.
Raises:
AttributeError if no config classes have the config key. The body
will contain what sections were checked.
"""
for key, val in self._configs.items():
if key == asking_section:
continue
if item in dir(val):
return getattr(val, item)
raise AttributeError(item, "not found in %s" % (list(self._configs.keys()),))
def invoke_all(self, func_name: str, *args, **kwargs) -> MutableMapping[str, Any]:
"""
Invoke a function on all instantiated config objects this RootConfig is
configured to use.
Args:
func_name: Name of function to invoke
*args
**kwargs
Returns:
list: The list of the return values from each method called
ordered dictionary of config section name and the result of the
function from it.
"""
results = []
for cls in type(self).mro():
if name in cls.__dict__:
results.append(getattr(cls, name)(self, *args, **kargs))
return results
res = OrderedDict()
for name, config in self._configs.items():
if hasattr(config, func_name):
res[name] = getattr(config, func_name)(*args, **kwargs)
return res
@classmethod
def invoke_all_static(cls, name, *args, **kargs):
"""Invoke all static methods with the given name and arguments in the
class's MRO.
def invoke_all_static(cls, func_name: str, *args, **kwargs):
"""
Invoke a static function on config objects this RootConfig is
configured to use.
Args:
name (str): Name of function to invoke
func_name: Name of function to invoke
*args
**kwargs
Returns:
list: The list of the return values from each method called
ordered dictionary of config section name and the result of the
function from it.
"""
results = []
for c in cls.mro():
if name in c.__dict__:
results.append(getattr(c, name)(*args, **kargs))
return results
for config in cls.config_classes:
if hasattr(config, func_name):
getattr(config, func_name)(*args, **kwargs)
def generate_config(
self,
@ -187,7 +287,8 @@ class Config(object):
tls_private_key_path=None,
acme_domain=None,
):
"""Build a default configuration file
"""
Build a default configuration file
This is used when the user explicitly asks us to generate a config file
(eg with --generate_config).
@ -242,6 +343,7 @@ class Config(object):
Returns:
str: the yaml config file
"""
return "\n\n".join(
dedent(conf)
for conf in self.invoke_all(
@ -257,7 +359,7 @@ class Config(object):
tls_certificate_path=tls_certificate_path,
tls_private_key_path=tls_private_key_path,
acme_domain=acme_domain,
)
).values()
)
@classmethod
@ -444,7 +546,7 @@ class Config(object):
)
(config_path,) = config_files
if not cls.path_exists(config_path):
if not path_exists(config_path):
print("Generating config file %s" % (config_path,))
if config_args.data_directory:
@ -469,7 +571,7 @@ class Config(object):
open_private_ports=config_args.open_private_ports,
)
if not cls.path_exists(config_dir_path):
if not path_exists(config_dir_path):
os.makedirs(config_dir_path)
with open(config_path, "w") as config_file:
config_file.write("# vim:ft=yaml\n\n")
@ -518,7 +620,7 @@ class Config(object):
return obj
def parse_config_dict(self, config_dict, config_dir_path, data_dir_path):
def parse_config_dict(self, config_dict, config_dir_path=None, data_dir_path=None):
"""Read the information from the config dict into this Config object.
Args:
@ -607,3 +709,6 @@ def find_config_files(search_paths):
else:
config_files.append(config_path)
return config_files
__all__ = ["Config", "RootConfig"]

135
synapse/config/_base.pyi Normal file
View file

@ -0,0 +1,135 @@
from typing import Any, List, Optional
from synapse.config import (
api,
appservice,
captcha,
cas,
consent_config,
database,
emailconfig,
groups,
jwt_config,
key,
logger,
metrics,
password,
password_auth_providers,
push,
ratelimiting,
registration,
repository,
room_directory,
saml2_config,
server,
server_notices_config,
spam_checker,
stats,
third_party_event_rules,
tls,
tracer,
user_directory,
voip,
workers,
)
class ConfigError(Exception): ...
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS: str
MISSING_REPORT_STATS_SPIEL: str
MISSING_SERVER_NAME: str
def path_exists(file_path: str): ...
class RootConfig:
server: server.ServerConfig
tls: tls.TlsConfig
database: database.DatabaseConfig
logging: logger.LoggingConfig
ratelimit: ratelimiting.RatelimitConfig
media: repository.ContentRepositoryConfig
captcha: captcha.CaptchaConfig
voip: voip.VoipConfig
registration: registration.RegistrationConfig
metrics: metrics.MetricsConfig
api: api.ApiConfig
appservice: appservice.AppServiceConfig
key: key.KeyConfig
saml2: saml2_config.SAML2Config
cas: cas.CasConfig
jwt: jwt_config.JWTConfig
password: password.PasswordConfig
email: emailconfig.EmailConfig
worker: workers.WorkerConfig
authproviders: password_auth_providers.PasswordAuthProviderConfig
push: push.PushConfig
spamchecker: spam_checker.SpamCheckerConfig
groups: groups.GroupsConfig
userdirectory: user_directory.UserDirectoryConfig
consent: consent_config.ConsentConfig
stats: stats.StatsConfig
servernotices: server_notices_config.ServerNoticesConfig
roomdirectory: room_directory.RoomDirectoryConfig
thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig
tracer: tracer.TracerConfig
config_classes: List = ...
def __init__(self) -> None: ...
def invoke_all(self, func_name: str, *args: Any, **kwargs: Any): ...
@classmethod
def invoke_all_static(cls, func_name: str, *args: Any, **kwargs: Any) -> None: ...
def __getattr__(self, item: str): ...
def parse_config_dict(
self,
config_dict: Any,
config_dir_path: Optional[Any] = ...,
data_dir_path: Optional[Any] = ...,
) -> None: ...
read_config: Any = ...
def generate_config(
self,
config_dir_path: str,
data_dir_path: str,
server_name: str,
generate_secrets: bool = ...,
report_stats: Optional[str] = ...,
open_private_ports: bool = ...,
listeners: Optional[Any] = ...,
database_conf: Optional[Any] = ...,
tls_certificate_path: Optional[str] = ...,
tls_private_key_path: Optional[str] = ...,
acme_domain: Optional[str] = ...,
): ...
@classmethod
def load_or_generate_config(cls, description: Any, argv: Any): ...
@classmethod
def load_config(cls, description: Any, argv: Any): ...
@classmethod
def add_arguments_to_parser(cls, config_parser: Any) -> None: ...
@classmethod
def load_config_with_parser(cls, parser: Any, argv: Any): ...
def generate_missing_files(
self, config_dict: dict, config_dir_path: str
) -> None: ...
class Config:
root: RootConfig
def __init__(self, root_config: Optional[RootConfig] = ...) -> None: ...
def __getattr__(self, item: str, from_root: bool = ...): ...
@staticmethod
def parse_size(value: Any): ...
@staticmethod
def parse_duration(value: Any): ...
@staticmethod
def abspath(file_path: Optional[str]): ...
@classmethod
def path_exists(cls, file_path: str): ...
@classmethod
def check_file(cls, file_path: str, config_name: str): ...
@classmethod
def ensure_directory(cls, dir_path: str): ...
@classmethod
def read_file(cls, file_path: str, config_name: str): ...
def read_config_files(config_files: List[str]): ...
def find_config_files(search_paths: List[str]): ...

View file

@ -18,6 +18,8 @@ from ._base import Config
class ApiConfig(Config):
section = "api"
def read_config(self, config, **kwargs):
self.room_invite_state_types = config.get(
"room_invite_state_types",

View file

@ -13,6 +13,7 @@
# limitations under the License.
import logging
from typing import Dict
from six import string_types
from six.moves.urllib import parse as urlparse
@ -29,6 +30,8 @@ logger = logging.getLogger(__name__)
class AppServiceConfig(Config):
section = "appservice"
def read_config(self, config, **kwargs):
self.app_service_config_files = config.get("app_service_config_files", [])
self.notify_appservices = config.get("notify_appservices", True)
@ -56,8 +59,8 @@ def load_appservices(hostname, config_files):
return []
# Dicts of value -> filename
seen_as_tokens = {}
seen_ids = {}
seen_as_tokens = {} # type: Dict[str, str]
seen_ids = {} # type: Dict[str, str]
appservices = []

View file

@ -16,6 +16,8 @@ from ._base import Config
class CaptchaConfig(Config):
section = "captcha"
def read_config(self, config, **kwargs):
self.recaptcha_private_key = config.get("recaptcha_private_key")
self.recaptcha_public_key = config.get("recaptcha_public_key")

View file

@ -22,6 +22,8 @@ class CasConfig(Config):
cas_server_url: URL of CAS server
"""
section = "cas"
def read_config(self, config, **kwargs):
cas_config = config.get("cas_config", None)
if cas_config:

View file

@ -73,8 +73,11 @@ DEFAULT_CONFIG = """\
class ConsentConfig(Config):
def __init__(self):
super(ConsentConfig, self).__init__()
section = "consent"
def __init__(self, *args):
super(ConsentConfig, self).__init__(*args)
self.user_consent_version = None
self.user_consent_template_dir = None

View file

@ -21,6 +21,8 @@ from ._base import Config
class DatabaseConfig(Config):
section = "database"
def read_config(self, config, **kwargs):
self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K"))

View file

@ -28,6 +28,8 @@ from ._base import Config, ConfigError
class EmailConfig(Config):
section = "email"
def read_config(self, config, **kwargs):
# TODO: We should separate better the email configuration from the notification
# and account validity config.

View file

@ -17,6 +17,8 @@ from ._base import Config
class GroupsConfig(Config):
section = "groups"
def read_config(self, config, **kwargs):
self.enable_group_creation = config.get("enable_group_creation", False)
self.group_creation_prefix = config.get("group_creation_prefix", "")

View file

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import RootConfig
from .api import ApiConfig
from .appservice import AppServiceConfig
from .captcha import CaptchaConfig
@ -46,36 +47,37 @@ from .voip import VoipConfig
from .workers import WorkerConfig
class HomeServerConfig(
ServerConfig,
TlsConfig,
DatabaseConfig,
LoggingConfig,
RatelimitConfig,
ContentRepositoryConfig,
CaptchaConfig,
VoipConfig,
RegistrationConfig,
MetricsConfig,
ApiConfig,
AppServiceConfig,
KeyConfig,
SAML2Config,
CasConfig,
JWTConfig,
PasswordConfig,
EmailConfig,
WorkerConfig,
PasswordAuthProviderConfig,
PushConfig,
SpamCheckerConfig,
GroupsConfig,
UserDirectoryConfig,
ConsentConfig,
StatsConfig,
ServerNoticesConfig,
RoomDirectoryConfig,
ThirdPartyRulesConfig,
TracerConfig,
):
pass
class HomeServerConfig(RootConfig):
config_classes = [
ServerConfig,
TlsConfig,
DatabaseConfig,
LoggingConfig,
RatelimitConfig,
ContentRepositoryConfig,
CaptchaConfig,
VoipConfig,
RegistrationConfig,
MetricsConfig,
ApiConfig,
AppServiceConfig,
KeyConfig,
SAML2Config,
CasConfig,
JWTConfig,
PasswordConfig,
EmailConfig,
WorkerConfig,
PasswordAuthProviderConfig,
PushConfig,
SpamCheckerConfig,
GroupsConfig,
UserDirectoryConfig,
ConsentConfig,
StatsConfig,
ServerNoticesConfig,
RoomDirectoryConfig,
ThirdPartyRulesConfig,
TracerConfig,
]

View file

@ -23,6 +23,8 @@ MISSING_JWT = """Missing jwt library. This is required for jwt login.
class JWTConfig(Config):
section = "jwt"
def read_config(self, config, **kwargs):
jwt_config = config.get("jwt_config", None)
if jwt_config:

View file

@ -92,6 +92,8 @@ class TrustedKeyServer(object):
class KeyConfig(Config):
section = "key"
def read_config(self, config, config_dir_path, **kwargs):
# the signing key can be specified inline or in a separate file
if "signing_key" in config:

View file

@ -84,6 +84,8 @@ root:
class LoggingConfig(Config):
section = "logging"
def read_config(self, config, **kwargs):
self.log_config = self.abspath(config.get("log_config"))
self.no_redirect_stdio = config.get("no_redirect_stdio", False)

View file

@ -34,6 +34,8 @@ class MetricsFlags(object):
class MetricsConfig(Config):
section = "metrics"
def read_config(self, config, **kwargs):
self.enable_metrics = config.get("enable_metrics", False)
self.report_stats = config.get("report_stats", None)

View file

@ -20,6 +20,8 @@ class PasswordConfig(Config):
"""Password login configuration
"""
section = "password"
def read_config(self, config, **kwargs):
password_config = config.get("password_config", {})
if password_config is None:

View file

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List
from synapse.util.module_loader import load_module
from ._base import Config
@ -21,8 +23,10 @@ LDAP_PROVIDER = "ldap_auth_provider.LdapAuthProvider"
class PasswordAuthProviderConfig(Config):
section = "authproviders"
def read_config(self, config, **kwargs):
self.password_providers = []
self.password_providers = [] # type: List[Any]
providers = []
# We want to be backwards compatible with the old `ldap_config`

View file

@ -18,6 +18,8 @@ from ._base import Config
class PushConfig(Config):
section = "push"
def read_config(self, config, **kwargs):
push_config = config.get("push", {})
self.push_include_content = push_config.get("include_content", True)

View file

@ -36,6 +36,8 @@ class FederationRateLimitConfig(object):
class RatelimitConfig(Config):
section = "ratelimiting"
def read_config(self, config, **kwargs):
# Load the new-style messages config if it exists. Otherwise fall back

View file

@ -24,6 +24,8 @@ from synapse.util.stringutils import random_string_with_symbols
class AccountValidityConfig(Config):
section = "accountvalidity"
def __init__(self, config, synapse_config):
self.enabled = config.get("enabled", False)
self.renew_by_email_enabled = "renew_at" in config
@ -77,6 +79,8 @@ class AccountValidityConfig(Config):
class RegistrationConfig(Config):
section = "registration"
def read_config(self, config, **kwargs):
self.enable_registration = bool(
strtobool(str(config.get("enable_registration", False)))

View file

@ -15,6 +15,7 @@
import os
from collections import namedtuple
from typing import Dict, List
from synapse.python_dependencies import DependencyException, check_requirements
from synapse.util.module_loader import load_module
@ -61,7 +62,7 @@ def parse_thumbnail_requirements(thumbnail_sizes):
Dictionary mapping from media type string to list of
ThumbnailRequirement tuples.
"""
requirements = {}
requirements = {} # type: Dict[str, List]
for size in thumbnail_sizes:
width = size["width"]
height = size["height"]
@ -77,6 +78,8 @@ def parse_thumbnail_requirements(thumbnail_sizes):
class ContentRepositoryConfig(Config):
section = "media"
def read_config(self, config, **kwargs):
# Only enable the media repo if either the media repo is enabled or the
@ -130,7 +133,7 @@ class ContentRepositoryConfig(Config):
#
# We don't create the storage providers here as not all workers need
# them to be started.
self.media_storage_providers = []
self.media_storage_providers = [] # type: List[tuple]
for provider_config in storage_providers:
# We special case the module "file_system" so as not to need to

View file

@ -19,6 +19,8 @@ from ._base import Config, ConfigError
class RoomDirectoryConfig(Config):
section = "roomdirectory"
def read_config(self, config, **kwargs):
self.enable_room_list_search = config.get("enable_room_list_search", True)

View file

@ -55,6 +55,8 @@ def _dict_merge(merge_dict, into_dict):
class SAML2Config(Config):
section = "saml2"
def read_config(self, config, **kwargs):
self.saml2_enabled = False

View file

@ -19,6 +19,7 @@ import logging
import os.path
import re
from textwrap import indent
from typing import List
import attr
import yaml
@ -57,6 +58,8 @@ on how to configure the new listener.
class ServerConfig(Config):
section = "server"
def read_config(self, config, **kwargs):
self.server_name = config["server_name"]
self.server_context = config.get("server_context", None)
@ -243,7 +246,7 @@ class ServerConfig(Config):
# events with profile information that differ from the target's global profile.
self.allow_per_room_profiles = config.get("allow_per_room_profiles", True)
self.listeners = []
self.listeners = [] # type: List[dict]
for listener in config.get("listeners", []):
if not isinstance(listener.get("port", None), int):
raise ConfigError(
@ -287,7 +290,10 @@ class ServerConfig(Config):
validator=attr.validators.instance_of(bool), default=False
)
complexity = attr.ib(
validator=attr.validators.instance_of((int, float)), default=1.0
validator=attr.validators.instance_of(
(float, int) # type: ignore[arg-type] # noqa
),
default=1.0,
)
complexity_error = attr.ib(
validator=attr.validators.instance_of(str),
@ -366,7 +372,7 @@ class ServerConfig(Config):
"cleanup_extremities_with_dummy_events", True
)
def has_tls_listener(self):
def has_tls_listener(self) -> bool:
return any(l["tls"] for l in self.listeners)
def generate_config_section(

View file

@ -59,8 +59,10 @@ class ServerNoticesConfig(Config):
None if server notices are not enabled.
"""
def __init__(self):
super(ServerNoticesConfig, self).__init__()
section = "servernotices"
def __init__(self, *args):
super(ServerNoticesConfig, self).__init__(*args)
self.server_notices_mxid = None
self.server_notices_mxid_display_name = None
self.server_notices_mxid_avatar_url = None

View file

@ -19,6 +19,8 @@ from ._base import Config
class SpamCheckerConfig(Config):
section = "spamchecker"
def read_config(self, config, **kwargs):
self.spam_checker = None

View file

@ -25,6 +25,8 @@ class StatsConfig(Config):
Configuration for the behaviour of synapse's stats engine
"""
section = "stats"
def read_config(self, config, **kwargs):
self.stats_enabled = True
self.stats_bucket_size = 86400 * 1000

View file

@ -19,6 +19,8 @@ from ._base import Config
class ThirdPartyRulesConfig(Config):
section = "thirdpartyrules"
def read_config(self, config, **kwargs):
self.third_party_event_rules = None

View file

@ -18,6 +18,7 @@ import os
import warnings
from datetime import datetime
from hashlib import sha256
from typing import List
import six
@ -33,7 +34,9 @@ logger = logging.getLogger(__name__)
class TlsConfig(Config):
def read_config(self, config, config_dir_path, **kwargs):
section = "tls"
def read_config(self, config: dict, config_dir_path: str, **kwargs):
acme_config = config.get("acme", None)
if acme_config is None:
@ -57,7 +60,7 @@ class TlsConfig(Config):
self.tls_certificate_file = self.abspath(config.get("tls_certificate_path"))
self.tls_private_key_file = self.abspath(config.get("tls_private_key_path"))
if self.has_tls_listener():
if self.root.server.has_tls_listener():
if not self.tls_certificate_file:
raise ConfigError(
"tls_certificate_path must be specified if TLS-enabled listeners are "
@ -108,7 +111,7 @@ class TlsConfig(Config):
)
# Support globs (*) in whitelist values
self.federation_certificate_verification_whitelist = []
self.federation_certificate_verification_whitelist = [] # type: List[str]
for entry in fed_whitelist_entries:
try:
entry_regex = glob_to_regex(entry.encode("ascii").decode("ascii"))

View file

@ -19,6 +19,8 @@ from ._base import Config, ConfigError
class TracerConfig(Config):
section = "tracing"
def read_config(self, config, **kwargs):
opentracing_config = config.get("opentracing")
if opentracing_config is None:

View file

@ -21,6 +21,8 @@ class UserDirectoryConfig(Config):
Configuration for the behaviour of the /user_directory API
"""
section = "userdirectory"
def read_config(self, config, **kwargs):
self.user_directory_search_enabled = True
self.user_directory_search_all_users = False

View file

@ -16,6 +16,8 @@ from ._base import Config
class VoipConfig(Config):
section = "voip"
def read_config(self, config, **kwargs):
self.turn_uris = config.get("turn_uris", [])
self.turn_shared_secret = config.get("turn_shared_secret")

View file

@ -21,6 +21,8 @@ class WorkerConfig(Config):
They have their own pid_file and listener configuration. They use the
replication_url to talk to the main synapse process."""
section = "worker"
def read_config(self, config, **kwargs):
self.worker_app = config.get("worker_app")

View file

@ -36,7 +36,6 @@ from synapse.api.errors import (
UnsupportedRoomVersionError,
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.crypto.event_signing import compute_event_signature
from synapse.events import room_version_to_event_format
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.persistence import TransactionActions
@ -322,18 +321,6 @@ class FederationServer(FederationBase):
pdus = yield self.handler.get_state_for_pdu(room_id, event_id)
auth_chain = yield self.store.get_auth_chain([pdu.event_id for pdu in pdus])
for event in auth_chain:
# We sign these again because there was a bug where we
# incorrectly signed things the first time round
if self.hs.is_mine_id(event.event_id):
event.signatures.update(
compute_event_signature(
event.get_pdu_json(),
self.hs.hostname,
self.hs.config.signing_key[0],
)
)
return {
"pdus": [pdu.get_pdu_json() for pdu in pdus],
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],

View file

@ -38,7 +38,7 @@ from synapse.metrics import (
events_processed_counter,
)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.metrics import measure_func
from synapse.util.metrics import Measure, measure_func
logger = logging.getLogger(__name__)
@ -183,8 +183,8 @@ class FederationSender(object):
# Otherwise if the last member on a server in a room is
# banned then it won't receive the event because it won't
# be in the room after the ban.
destinations = yield self.state.get_current_hosts_in_room(
event.room_id, latest_event_ids=event.prev_event_ids()
destinations = yield self.state.get_hosts_in_room_at_events(
event.room_id, event_ids=event.prev_event_ids()
)
except Exception:
logger.exception(
@ -207,8 +207,9 @@ class FederationSender(object):
@defer.inlineCallbacks
def handle_room_events(events):
for event in events:
yield handle_event(event)
with Measure(self.clock, "handle_room_events"):
for event in events:
yield handle_event(event)
events_by_room = {}
for event in events:

View file

@ -765,6 +765,10 @@ class PublicRoomList(BaseFederationServlet):
else:
network_tuple = ThirdPartyInstanceID(None, None)
if limit == 0:
# zero is a special value which corresponds to no limit.
limit = None
data = await maybeDeferred(
self.handler.get_local_public_room_list,
limit,
@ -800,6 +804,10 @@ class PublicRoomList(BaseFederationServlet):
if search_filter is None:
logger.warning("Nonefilter")
if limit == 0:
# zero is a special value which corresponds to no limit.
limit = None
data = await self.handler.get_local_public_room_list(
limit=limit,
since_token=since_token,

View file

@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
# Copyright 2018 New Vector Ltd
# Copyright 2019 Michael Telatynski <7t3chguy@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -20,16 +21,16 @@ from six import string_types
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.api.errors import Codes, SynapseError
from synapse.types import GroupID, RoomID, UserID, get_domain_from_id
from synapse.util.async_helpers import concurrently_execute
logger = logging.getLogger(__name__)
# TODO: Allow users to "knock" or simpkly join depending on rules
# TODO: Allow users to "knock" or simply join depending on rules
# TODO: Federation admin APIs
# TODO: is_priveged flag to users and is_public to users and rooms
# TODO: is_privileged flag to users and is_public to users and rooms
# TODO: Audit log for admins (profile updates, membership changes, users who tried
# to join but were rejected, etc)
# TODO: Flairs
@ -590,7 +591,18 @@ class GroupsServerHandler(object):
)
# TODO: Check if user knocked
# TODO: Check if user is already invited
invited_users = yield self.store.get_invited_users_in_group(group_id)
if user_id in invited_users:
raise SynapseError(
400, "User already invited to group", errcode=Codes.BAD_STATE
)
user_results = yield self.store.get_users_in_group(
group_id, include_private=True
)
if user_id in [user_result["user_id"] for user_result in user_results]:
raise SynapseError(400, "User already in group")
content = {
"profile": {"name": group["name"], "avatar_url": group["avatar_url"]},

View file

@ -120,6 +120,10 @@ class DeactivateAccountHandler(BaseHandler):
# parts users from rooms (if it isn't already running)
self._start_user_parting()
# Reject all pending invites for the user, so that the user doesn't show up in the
# "invited" section of rooms' members list.
yield self._reject_pending_invites_for_user(user_id)
# Remove all information on the user from the account_validity table.
if self._account_validity_enabled:
yield self.store.delete_account_validity_for_user(user_id)
@ -129,6 +133,39 @@ class DeactivateAccountHandler(BaseHandler):
return identity_server_supports_unbinding
@defer.inlineCallbacks
def _reject_pending_invites_for_user(self, user_id):
"""Reject pending invites addressed to a given user ID.
Args:
user_id (str): The user ID to reject pending invites for.
"""
user = UserID.from_string(user_id)
pending_invites = yield self.store.get_invited_rooms_for_user(user_id)
for room in pending_invites:
try:
yield self._room_member_handler.update_membership(
create_requester(user),
user,
room.room_id,
"leave",
ratelimit=False,
require_consent=False,
)
logger.info(
"Rejected invite for deactivated user %r in room %r",
user_id,
room.room_id,
)
except Exception:
logger.exception(
"Failed to reject invite for user %r in room %r:"
" ignoring and continuing",
user_id,
room.room_id,
)
def _start_user_parting(self):
"""
Start the process that goes through the table of users

View file

@ -2570,7 +2570,7 @@ class FederationHandler(BaseHandler):
)
try:
self.auth.check_from_context(room_version, event, context)
yield self.auth.check_from_context(room_version, event, context)
except AuthError as e:
logger.warn("Denying third party invite %r because %s", event, e)
raise e
@ -2599,7 +2599,12 @@ class FederationHandler(BaseHandler):
original_invite_id, allow_none=True
)
if original_invite:
display_name = original_invite.content["display_name"]
# If the m.room.third_party_invite event's content is empty, it means the
# invite has been revoked. In this case, we don't have to raise an error here
# because the auth check will fail on the invite (because it's not able to
# fetch public keys from the m.room.third_party_invite event's content, which
# is empty).
display_name = original_invite.content.get("display_name")
event_dict["content"]["third_party_invite"]["display_name"] = display_name
else:
logger.info(

View file

@ -21,11 +21,15 @@ import logging
import urllib
from canonicaljson import json
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64
from twisted.internet import defer
from twisted.internet.error import TimeoutError
from synapse.api.errors import (
AuthError,
CodeMessageException,
Codes,
HttpResponseException,
@ -33,12 +37,15 @@ from synapse.api.errors import (
)
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.http.client import SimpleHttpClient
from synapse.util.hash import sha256_and_url_safe_base64
from synapse.util.stringutils import random_string
from ._base import BaseHandler
logger = logging.getLogger(__name__)
id_server_scheme = "https://"
class IdentityHandler(BaseHandler):
def __init__(self, hs):
@ -557,6 +564,352 @@ class IdentityHandler(BaseHandler):
logger.warning("Error contacting msisdn account_threepid_delegate: %s", e)
raise SynapseError(400, "Error contacting the identity server")
@defer.inlineCallbacks
def lookup_3pid(self, id_server, medium, address, id_access_token=None):
"""Looks up a 3pid in the passed identity server.
Args:
id_server (str): The server name (including port, if required)
of the identity server to use.
medium (str): The type of the third party identifier (e.g. "email").
address (str): The third party identifier (e.g. "foo@example.com").
id_access_token (str|None): The access token to authenticate to the identity
server with
Returns:
str|None: the matrix ID of the 3pid, or None if it is not recognized.
"""
if id_access_token is not None:
try:
results = yield self._lookup_3pid_v2(
id_server, id_access_token, medium, address
)
return results
except Exception as e:
# Catch HttpResponseExcept for a non-200 response code
# Check if this identity server does not know about v2 lookups
if isinstance(e, HttpResponseException) and e.code == 404:
# This is an old identity server that does not yet support v2 lookups
logger.warning(
"Attempted v2 lookup on v1 identity server %s. Falling "
"back to v1",
id_server,
)
else:
logger.warning("Error when looking up hashing details: %s", e)
return None
return (yield self._lookup_3pid_v1(id_server, medium, address))
@defer.inlineCallbacks
def _lookup_3pid_v1(self, id_server, medium, address):
"""Looks up a 3pid in the passed identity server using v1 lookup.
Args:
id_server (str): The server name (including port, if required)
of the identity server to use.
medium (str): The type of the third party identifier (e.g. "email").
address (str): The third party identifier (e.g. "foo@example.com").
Returns:
str: the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
data = yield self.blacklisting_http_client.get_json(
"%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server),
{"medium": medium, "address": address},
)
if "mxid" in data:
if "signatures" not in data:
raise AuthError(401, "No signatures on 3pid binding")
yield self._verify_any_signature(data, id_server)
return data["mxid"]
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
except IOError as e:
logger.warning("Error from v1 identity server lookup: %s" % (e,))
return None
@defer.inlineCallbacks
def _lookup_3pid_v2(self, id_server, id_access_token, medium, address):
"""Looks up a 3pid in the passed identity server using v2 lookup.
Args:
id_server (str): The server name (including port, if required)
of the identity server to use.
id_access_token (str): The access token to authenticate to the identity server with
medium (str): The type of the third party identifier (e.g. "email").
address (str): The third party identifier (e.g. "foo@example.com").
Returns:
Deferred[str|None]: the matrix ID of the 3pid, or None if it is not recognised.
"""
# Check what hashing details are supported by this identity server
try:
hash_details = yield self.blacklisting_http_client.get_json(
"%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server),
{"access_token": id_access_token},
)
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
if not isinstance(hash_details, dict):
logger.warning(
"Got non-dict object when checking hash details of %s%s: %s",
id_server_scheme,
id_server,
hash_details,
)
raise SynapseError(
400,
"Non-dict object from %s%s during v2 hash_details request: %s"
% (id_server_scheme, id_server, hash_details),
)
# Extract information from hash_details
supported_lookup_algorithms = hash_details.get("algorithms")
lookup_pepper = hash_details.get("lookup_pepper")
if (
not supported_lookup_algorithms
or not isinstance(supported_lookup_algorithms, list)
or not lookup_pepper
or not isinstance(lookup_pepper, str)
):
raise SynapseError(
400,
"Invalid hash details received from identity server %s%s: %s"
% (id_server_scheme, id_server, hash_details),
)
# Check if any of the supported lookup algorithms are present
if LookupAlgorithm.SHA256 in supported_lookup_algorithms:
# Perform a hashed lookup
lookup_algorithm = LookupAlgorithm.SHA256
# Hash address, medium and the pepper with sha256
to_hash = "%s %s %s" % (address, medium, lookup_pepper)
lookup_value = sha256_and_url_safe_base64(to_hash)
elif LookupAlgorithm.NONE in supported_lookup_algorithms:
# Perform a non-hashed lookup
lookup_algorithm = LookupAlgorithm.NONE
# Combine together plaintext address and medium
lookup_value = "%s %s" % (address, medium)
else:
logger.warning(
"None of the provided lookup algorithms of %s are supported: %s",
id_server,
supported_lookup_algorithms,
)
raise SynapseError(
400,
"Provided identity server does not support any v2 lookup "
"algorithms that this homeserver supports.",
)
# Authenticate with identity server given the access token from the client
headers = {"Authorization": create_id_access_token_header(id_access_token)}
try:
lookup_results = yield self.blacklisting_http_client.post_json_get_json(
"%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server),
{
"addresses": [lookup_value],
"algorithm": lookup_algorithm,
"pepper": lookup_pepper,
},
headers=headers,
)
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
except Exception as e:
logger.warning("Error when performing a v2 3pid lookup: %s", e)
raise SynapseError(
500, "Unknown error occurred during identity server lookup"
)
# Check for a mapping from what we looked up to an MXID
if "mappings" not in lookup_results or not isinstance(
lookup_results["mappings"], dict
):
logger.warning("No results from 3pid lookup")
return None
# Return the MXID if it's available, or None otherwise
mxid = lookup_results["mappings"].get(lookup_value)
return mxid
@defer.inlineCallbacks
def _verify_any_signature(self, data, server_hostname):
if server_hostname not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items():
try:
key_data = yield self.blacklisting_http_client.get_json(
"%s%s/_matrix/identity/api/v1/pubkey/%s"
% (id_server_scheme, server_hostname, key_name)
)
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
if "public_key" not in key_data:
raise AuthError(
401, "No public key named %s from %s" % (key_name, server_hostname)
)
verify_signed_json(
data,
server_hostname,
decode_verify_key_bytes(
key_name, decode_base64(key_data["public_key"])
),
)
return
@defer.inlineCallbacks
def ask_id_server_for_third_party_invite(
self,
requester,
id_server,
medium,
address,
room_id,
inviter_user_id,
room_alias,
room_avatar_url,
room_join_rules,
room_name,
inviter_display_name,
inviter_avatar_url,
id_access_token=None,
):
"""
Asks an identity server for a third party invite.
Args:
requester (Requester)
id_server (str): hostname + optional port for the identity server.
medium (str): The literal string "email".
address (str): The third party address being invited.
room_id (str): The ID of the room to which the user is invited.
inviter_user_id (str): The user ID of the inviter.
room_alias (str): An alias for the room, for cosmetic notifications.
room_avatar_url (str): The URL of the room's avatar, for cosmetic
notifications.
room_join_rules (str): The join rules of the email (e.g. "public").
room_name (str): The m.room.name of the room.
inviter_display_name (str): The current display name of the
inviter.
inviter_avatar_url (str): The URL of the inviter's avatar.
id_access_token (str|None): The access token to authenticate to the identity
server with
Returns:
A deferred tuple containing:
token (str): The token which must be signed to prove authenticity.
public_keys ([{"public_key": str, "key_validity_url": str}]):
public_key is a base64-encoded ed25519 public key.
fallback_public_key: One element from public_keys.
display_name (str): A user-friendly name to represent the invited
user.
"""
invite_config = {
"medium": medium,
"address": address,
"room_id": room_id,
"room_alias": room_alias,
"room_avatar_url": room_avatar_url,
"room_join_rules": room_join_rules,
"room_name": room_name,
"sender": inviter_user_id,
"sender_display_name": inviter_display_name,
"sender_avatar_url": inviter_avatar_url,
}
# Add the identity service access token to the JSON body and use the v2
# Identity Service endpoints if id_access_token is present
data = None
base_url = "%s%s/_matrix/identity" % (id_server_scheme, id_server)
if id_access_token:
key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % (
id_server_scheme,
id_server,
)
# Attempt a v2 lookup
url = base_url + "/v2/store-invite"
try:
data = yield self.blacklisting_http_client.post_json_get_json(
url,
invite_config,
{"Authorization": create_id_access_token_header(id_access_token)},
)
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
except HttpResponseException as e:
if e.code != 404:
logger.info("Failed to POST %s with JSON: %s", url, e)
raise e
if data is None:
key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
id_server_scheme,
id_server,
)
url = base_url + "/api/v1/store-invite"
try:
data = yield self.blacklisting_http_client.post_json_get_json(
url, invite_config
)
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
except HttpResponseException as e:
logger.warning(
"Error trying to call /store-invite on %s%s: %s",
id_server_scheme,
id_server,
e,
)
if data is None:
# Some identity servers may only support application/x-www-form-urlencoded
# types. This is especially true with old instances of Sydent, see
# https://github.com/matrix-org/sydent/pull/170
try:
data = yield self.blacklisting_http_client.post_urlencoded_get_json(
url, invite_config
)
except HttpResponseException as e:
logger.warning(
"Error calling /store-invite on %s%s with fallback "
"encoding: %s",
id_server_scheme,
id_server,
e,
)
raise e
# TODO: Check for success
token = data["token"]
public_keys = data.get("public_keys", [])
if "public_key" in data:
fallback_public_key = {
"public_key": data["public_key"],
"key_validity_url": key_validity_url,
}
else:
fallback_public_key = public_keys[0]
if not public_keys:
public_keys.append(fallback_public_key)
display_name = data["display_name"]
return token, public_keys, fallback_public_key, display_name
def create_id_access_token_header(id_access_token):
"""Create an Authorization header for passing to SimpleHttpClient as the header value

View file

@ -803,17 +803,25 @@ class PresenceHandler(object):
# Loop round handling deltas until we're up to date
while True:
with Measure(self.clock, "presence_delta"):
deltas = yield self.store.get_current_state_deltas(self._event_pos)
if not deltas:
room_max_stream_ordering = self.store.get_room_max_stream_ordering()
if self._event_pos == room_max_stream_ordering:
return
logger.debug(
"Processing presence stats %s->%s",
self._event_pos,
room_max_stream_ordering,
)
max_pos, deltas = yield self.store.get_current_state_deltas(
self._event_pos, room_max_stream_ordering
)
yield self._handle_state_delta(deltas)
self._event_pos = deltas[-1]["stream_id"]
self._event_pos = max_pos
# Expose current event processing position to prometheus
synapse.metrics.event_processing_positions.labels("presence").set(
self._event_pos
max_pos
)
@defer.inlineCallbacks

View file

@ -217,10 +217,9 @@ class RegistrationHandler(BaseHandler):
else:
# autogen a sequential user ID
attempts = 0
user = None
while not user:
localpart = yield self._generate_user_id(attempts > 0)
localpart = yield self._generate_user_id()
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
yield self.check_user_id_not_appservice_exclusive(user_id)
@ -238,7 +237,6 @@ class RegistrationHandler(BaseHandler):
# if user id is taken, just generate another
user = None
user_id = None
attempts += 1
if not self.hs.config.user_consent_at_registration:
yield self._auto_join_rooms(user_id)
@ -379,10 +377,10 @@ class RegistrationHandler(BaseHandler):
)
@defer.inlineCallbacks
def _generate_user_id(self, reseed=False):
if reseed or self._next_generated_user_id is None:
def _generate_user_id(self):
if self._next_generated_user_id is None:
with (yield self._generate_user_id_linearizer.queue(())):
if reseed or self._next_generated_user_id is None:
if self._next_generated_user_id is None:
self._next_generated_user_id = (
yield self.store.find_next_generated_user_id_localpart()
)

View file

@ -28,6 +28,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID
from synapse.util import stringutils
@ -554,7 +555,8 @@ class RoomCreationHandler(BaseHandler):
invite_list = config.get("invite", [])
for i in invite_list:
try:
UserID.from_string(i)
uid = UserID.from_string(i)
parse_and_validate_server_name(uid.domain)
except Exception:
raise SynapseError(400, "Invalid user_id: %s" % (i,))

View file

@ -16,8 +16,7 @@
import logging
from collections import namedtuple
from six import PY3, iteritems
from six.moves import range
from six import iteritems
import msgpack
from unpaddedbase64 import decode_base64, encode_base64
@ -27,7 +26,6 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import Codes, HttpResponseException
from synapse.types import ThirdPartyInstanceID
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.response_cache import ResponseCache
@ -37,7 +35,6 @@ logger = logging.getLogger(__name__)
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
# This is used to indicate we should only return rooms published to the main list.
EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
@ -72,6 +69,8 @@ class RoomListHandler(BaseHandler):
This can be (None, None) to indicate the main list, or a particular
appservice and network id to use an appservice specific one.
Setting to None returns all public rooms across all lists.
from_federation (bool): true iff the request comes from the federation
API
"""
if not self.enable_room_list_search:
return defer.succeed({"chunk": [], "total_room_count_estimate": 0})
@ -89,16 +88,8 @@ class RoomListHandler(BaseHandler):
# appservice specific lists.
logger.info("Bypassing cache as search request.")
# XXX: Quick hack to stop room directory queries taking too long.
# Timeout request after 60s. Probably want a more fundamental
# solution at some point
timeout = self.clock.time() + 60
return self._get_public_room_list(
limit,
since_token,
search_filter,
network_tuple=network_tuple,
timeout=timeout,
limit, since_token, search_filter, network_tuple=network_tuple
)
key = (limit, since_token, network_tuple)
@ -119,7 +110,6 @@ class RoomListHandler(BaseHandler):
search_filter=None,
network_tuple=EMPTY_THIRD_PARTY_ID,
from_federation=False,
timeout=None,
):
"""Generate a public room list.
Args:
@ -132,240 +122,116 @@ class RoomListHandler(BaseHandler):
Setting to None returns all public rooms across all lists.
from_federation (bool): Whether this request originated from a
federating server or a client. Used for room filtering.
timeout (int|None): Amount of seconds to wait for a response before
timing out.
"""
if since_token and since_token != "END":
since_token = RoomListNextBatch.from_token(since_token)
else:
since_token = None
rooms_to_order_value = {}
rooms_to_num_joined = {}
# Pagination tokens work by storing the room ID sent in the last batch,
# plus the direction (forwards or backwards). Next batch tokens always
# go forwards, prev batch tokens always go backwards.
newly_visible = []
newly_unpublished = []
if since_token:
stream_token = since_token.stream_ordering
current_public_id = yield self.store.get_current_public_room_stream_id()
public_room_stream_id = since_token.public_room_stream_id
newly_visible, newly_unpublished = yield self.store.get_public_room_changes(
public_room_stream_id, current_public_id, network_tuple=network_tuple
)
else:
stream_token = yield self.store.get_room_max_stream_ordering()
public_room_stream_id = yield self.store.get_current_public_room_stream_id()
batch_token = RoomListNextBatch.from_token(since_token)
room_ids = yield self.store.get_public_room_ids_at_stream_id(
public_room_stream_id, network_tuple=network_tuple
bounds = (batch_token.last_joined_members, batch_token.last_room_id)
forwards = batch_token.direction_is_forward
else:
batch_token = None
bounds = None
forwards = True
# we request one more than wanted to see if there are more pages to come
probing_limit = limit + 1 if limit is not None else None
results = yield self.store.get_largest_public_rooms(
network_tuple,
search_filter,
probing_limit,
bounds=bounds,
forwards=forwards,
ignore_non_federatable=from_federation,
)
# We want to return rooms in a particular order: the number of joined
# users. We then arbitrarily use the room_id as a tie breaker.
def build_room_entry(room):
entry = {
"room_id": room["room_id"],
"name": room["name"],
"topic": room["topic"],
"canonical_alias": room["canonical_alias"],
"num_joined_members": room["joined_members"],
"avatar_url": room["avatar"],
"world_readable": room["history_visibility"] == "world_readable",
"guest_can_join": room["guest_access"] == "can_join",
}
@defer.inlineCallbacks
def get_order_for_room(room_id):
# Most of the rooms won't have changed between the since token and
# now (especially if the since token is "now"). So, we can ask what
# the current users are in a room (that will hit a cache) and then
# check if the room has changed since the since token. (We have to
# do it in that order to avoid races).
# If things have changed then fall back to getting the current state
# at the since token.
joined_users = yield self.store.get_users_in_room(room_id)
if self.store.has_room_changed_since(room_id, stream_token):
latest_event_ids = yield self.store.get_forward_extremeties_for_room(
room_id, stream_token
)
# Filter out Nones rather omit the field altogether
return {k: v for k, v in entry.items() if v is not None}
if not latest_event_ids:
return
results = [build_room_entry(r) for r in results]
joined_users = yield self.state_handler.get_current_users_in_room(
room_id, latest_event_ids
)
response = {}
num_results = len(results)
if limit is not None:
more_to_come = num_results == probing_limit
num_joined_users = len(joined_users)
rooms_to_num_joined[room_id] = num_joined_users
# Depending on direction we trim either the front or back.
if forwards:
results = results[:limit]
else:
results = results[-limit:]
else:
more_to_come = False
if num_joined_users == 0:
return
if num_results > 0:
final_entry = results[-1]
initial_entry = results[0]
# We want larger rooms to be first, hence negating num_joined_users
rooms_to_order_value[room_id] = (-num_joined_users, room_id)
if forwards:
if batch_token:
# If there was a token given then we assume that there
# must be previous results.
response["prev_batch"] = RoomListNextBatch(
last_joined_members=initial_entry["num_joined_members"],
last_room_id=initial_entry["room_id"],
direction_is_forward=False,
).to_token()
logger.info(
"Getting ordering for %i rooms since %s", len(room_ids), stream_token
if more_to_come:
response["next_batch"] = RoomListNextBatch(
last_joined_members=final_entry["num_joined_members"],
last_room_id=final_entry["room_id"],
direction_is_forward=True,
).to_token()
else:
if batch_token:
response["next_batch"] = RoomListNextBatch(
last_joined_members=final_entry["num_joined_members"],
last_room_id=final_entry["room_id"],
direction_is_forward=True,
).to_token()
if more_to_come:
response["prev_batch"] = RoomListNextBatch(
last_joined_members=initial_entry["num_joined_members"],
last_room_id=initial_entry["room_id"],
direction_is_forward=False,
).to_token()
for room in results:
# populate search result entries with additional fields, namely
# 'aliases'
room_id = room["room_id"]
aliases = yield self.store.get_aliases_for_room(room_id)
if aliases:
room["aliases"] = aliases
response["chunk"] = results
response["total_room_count_estimate"] = yield self.store.count_public_rooms(
network_tuple, ignore_non_federatable=from_federation
)
yield concurrently_execute(get_order_for_room, room_ids, 10)
sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1])
sorted_rooms = [room_id for room_id, _ in sorted_entries]
# `sorted_rooms` should now be a list of all public room ids that is
# stable across pagination. Therefore, we can use indices into this
# list as our pagination tokens.
# Filter out rooms that we don't want to return
rooms_to_scan = [
r
for r in sorted_rooms
if r not in newly_unpublished and rooms_to_num_joined[r] > 0
]
total_room_count = len(rooms_to_scan)
if since_token:
# Filter out rooms we've already returned previously
# `since_token.current_limit` is the index of the last room we
# sent down, so we exclude it and everything before/after it.
if since_token.direction_is_forward:
rooms_to_scan = rooms_to_scan[since_token.current_limit + 1 :]
else:
rooms_to_scan = rooms_to_scan[: since_token.current_limit]
rooms_to_scan.reverse()
logger.info("After sorting and filtering, %i rooms remain", len(rooms_to_scan))
# _append_room_entry_to_chunk will append to chunk but will stop if
# len(chunk) > limit
#
# Normally we will generate enough results on the first iteration here,
# but if there is a search filter, _append_room_entry_to_chunk may
# filter some results out, in which case we loop again.
#
# We don't want to scan over the entire range either as that
# would potentially waste a lot of work.
#
# XXX if there is no limit, we may end up DoSing the server with
# calls to get_current_state_ids for every single room on the
# server. Surely we should cap this somehow?
#
if limit:
step = limit + 1
else:
# step cannot be zero
step = len(rooms_to_scan) if len(rooms_to_scan) != 0 else 1
chunk = []
for i in range(0, len(rooms_to_scan), step):
if timeout and self.clock.time() > timeout:
raise Exception("Timed out searching room directory")
batch = rooms_to_scan[i : i + step]
logger.info("Processing %i rooms for result", len(batch))
yield concurrently_execute(
lambda r: self._append_room_entry_to_chunk(
r,
rooms_to_num_joined[r],
chunk,
limit,
search_filter,
from_federation=from_federation,
),
batch,
5,
)
logger.info("Now %i rooms in result", len(chunk))
if len(chunk) >= limit + 1:
break
chunk.sort(key=lambda e: (-e["num_joined_members"], e["room_id"]))
# Work out the new limit of the batch for pagination, or None if we
# know there are no more results that would be returned.
# i.e., [since_token.current_limit..new_limit] is the batch of rooms
# we've returned (or the reverse if we paginated backwards)
# We tried to pull out limit + 1 rooms above, so if we have <= limit
# then we know there are no more results to return
new_limit = None
if chunk and (not limit or len(chunk) > limit):
if not since_token or since_token.direction_is_forward:
if limit:
chunk = chunk[:limit]
last_room_id = chunk[-1]["room_id"]
else:
if limit:
chunk = chunk[-limit:]
last_room_id = chunk[0]["room_id"]
new_limit = sorted_rooms.index(last_room_id)
results = {"chunk": chunk, "total_room_count_estimate": total_room_count}
if since_token:
results["new_rooms"] = bool(newly_visible)
if not since_token or since_token.direction_is_forward:
if new_limit is not None:
results["next_batch"] = RoomListNextBatch(
stream_ordering=stream_token,
public_room_stream_id=public_room_stream_id,
current_limit=new_limit,
direction_is_forward=True,
).to_token()
if since_token:
results["prev_batch"] = since_token.copy_and_replace(
direction_is_forward=False,
current_limit=since_token.current_limit + 1,
).to_token()
else:
if new_limit is not None:
results["prev_batch"] = RoomListNextBatch(
stream_ordering=stream_token,
public_room_stream_id=public_room_stream_id,
current_limit=new_limit,
direction_is_forward=False,
).to_token()
if since_token:
results["next_batch"] = since_token.copy_and_replace(
direction_is_forward=True,
current_limit=since_token.current_limit - 1,
).to_token()
return results
@defer.inlineCallbacks
def _append_room_entry_to_chunk(
self,
room_id,
num_joined_users,
chunk,
limit,
search_filter,
from_federation=False,
):
"""Generate the entry for a room in the public room list and append it
to the `chunk` if it matches the search filter
Args:
room_id (str): The ID of the room.
num_joined_users (int): The number of joined users in the room.
chunk (list)
limit (int|None): Maximum amount of rooms to display. Function will
return if length of chunk is greater than limit + 1.
search_filter (dict|None)
from_federation (bool): Whether this request originated from a
federating server or a client. Used for room filtering.
"""
if limit and len(chunk) > limit + 1:
# We've already got enough, so lets just drop it.
return
result = yield self.generate_room_entry(room_id, num_joined_users)
if not result:
return
if from_federation and not result.get("m.federate", True):
# This is a room that other servers cannot join. Do not show them
# this room.
return
if _matches_room_entry(result, search_filter):
chunk.append(result)
return response
@cachedInlineCallbacks(num_args=1, cache_context=True)
def generate_room_entry(
@ -580,18 +446,15 @@ class RoomListNextBatch(
namedtuple(
"RoomListNextBatch",
(
"stream_ordering", # stream_ordering of the first public room list
"public_room_stream_id", # public room stream id for first public room list
"current_limit", # The number of previous rooms returned
"last_joined_members", # The count to get rooms after/before
"last_room_id", # The room_id to get rooms after/before
"direction_is_forward", # Bool if this is a next_batch, false if prev_batch
),
)
):
KEY_DICT = {
"stream_ordering": "s",
"public_room_stream_id": "p",
"current_limit": "n",
"last_joined_members": "m",
"last_room_id": "r",
"direction_is_forward": "d",
}
@ -599,13 +462,7 @@ class RoomListNextBatch(
@classmethod
def from_token(cls, token):
if PY3:
# The argument raw=False is only available on new versions of
# msgpack, and only really needed on Python 3. Gate it behind
# a PY3 check to avoid causing issues on Debian-packaged versions.
decoded = msgpack.loads(decode_base64(token), raw=False)
else:
decoded = msgpack.loads(decode_base64(token))
decoded = msgpack.loads(decode_base64(token), raw=False)
return RoomListNextBatch(
**{cls.REVERSE_KEY_DICT[key]: val for key, val in decoded.items()}
)

View file

@ -20,29 +20,19 @@ import logging
from six.moves import http_client
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64
from twisted.internet import defer
from twisted.internet.error import TimeoutError
from synapse import types
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, HttpResponseException, SynapseError
from synapse.handlers.identity import LookupAlgorithm, create_id_access_token_header
from synapse.http.client import SimpleHttpClient
from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.types import RoomID, UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room
from synapse.util.hash import sha256_and_url_safe_base64
from ._base import BaseHandler
logger = logging.getLogger(__name__)
id_server_scheme = "https://"
class RoomMemberHandler(object):
# TODO(paul): This handler currently contains a messy conflation of
@ -63,14 +53,10 @@ class RoomMemberHandler(object):
self.auth = hs.get_auth()
self.state_handler = hs.get_state_handler()
self.config = hs.config
# We create a blacklisting instance of SimpleHttpClient for contacting identity
# servers specified by clients
self.simple_http_client = SimpleHttpClient(
hs, ip_blacklist=hs.config.federation_ip_range_blacklist
)
self.federation_handler = hs.get_handlers().federation_handler
self.directory_handler = hs.get_handlers().directory_handler
self.identity_handler = hs.get_handlers().identity_handler
self.registration_handler = hs.get_registration_handler()
self.profile_handler = hs.get_profile_handler()
self.event_creation_handler = hs.get_event_creation_handler()
@ -217,23 +203,11 @@ class RoomMemberHandler(object):
prev_member_event = yield self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
# Copy over user state if we're joining an upgraded room
yield self.copy_user_state_if_room_upgrade(
room_id, requester.user.to_string()
)
yield self._user_joined_room(target, room_id)
# Copy over direct message status and room tags if this is a join
# on an upgraded room
# Check if this is an upgraded room
predecessor = yield self.store.get_room_predecessor(room_id)
if predecessor:
# It is an upgraded room. Copy over old tags
self.copy_room_tags_and_direct_to_room(
predecessor["room_id"], room_id, user_id
)
# Move over old push rules
self.store.move_push_rules_from_room_to_room_for_user(
predecessor["room_id"], room_id, user_id
)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
@ -477,10 +451,16 @@ class RoomMemberHandler(object):
if requester.is_guest:
content["kind"] = "guest"
ret = yield self._remote_join(
remote_join_response = yield self._remote_join(
requester, remote_room_hosts, room_id, target, content
)
return ret
# Copy over user state if this is a join on an remote upgraded room
yield self.copy_user_state_if_room_upgrade(
room_id, requester.user.to_string()
)
return remote_join_response
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
@ -517,6 +497,38 @@ class RoomMemberHandler(object):
)
return res
@defer.inlineCallbacks
def copy_user_state_if_room_upgrade(self, new_room_id, user_id):
"""Copy user-specific information when they join a new room if that new room is the
result of a room upgrade
Args:
new_room_id (str): The ID of the room the user is joining
user_id (str): The ID of the user
Returns:
Deferred
"""
# Check if the new room is an upgraded room
predecessor = yield self.store.get_room_predecessor(new_room_id)
if not predecessor:
return
logger.debug(
"Found predecessor for %s: %s. Copying over room tags and push " "rules",
new_room_id,
predecessor,
)
# It is an upgraded room. Copy over old tags
yield self.copy_room_tags_and_direct_to_room(
predecessor["room_id"], new_room_id, user_id
)
# Copy over push rules
yield self.store.copy_push_rules_from_room_to_room_for_user(
predecessor["room_id"], new_room_id, user_id
)
@defer.inlineCallbacks
def send_membership_event(self, requester, event, context, ratelimit=True):
"""
@ -682,7 +694,9 @@ class RoomMemberHandler(object):
403, "Looking up third-party identifiers is denied from this server"
)
invitee = yield self._lookup_3pid(id_server, medium, address, id_access_token)
invitee = yield self.identity_handler.lookup_3pid(
id_server, medium, address, id_access_token
)
if invitee:
yield self.update_membership(
@ -700,211 +714,6 @@ class RoomMemberHandler(object):
id_access_token=id_access_token,
)
@defer.inlineCallbacks
def _lookup_3pid(self, id_server, medium, address, id_access_token=None):
"""Looks up a 3pid in the passed identity server.
Args:
id_server (str): The server name (including port, if required)
of the identity server to use.
medium (str): The type of the third party identifier (e.g. "email").
address (str): The third party identifier (e.g. "foo@example.com").
id_access_token (str|None): The access token to authenticate to the identity
server with
Returns:
str|None: the matrix ID of the 3pid, or None if it is not recognized.
"""
if id_access_token is not None:
try:
results = yield self._lookup_3pid_v2(
id_server, id_access_token, medium, address
)
return results
except Exception as e:
# Catch HttpResponseExcept for a non-200 response code
# Check if this identity server does not know about v2 lookups
if isinstance(e, HttpResponseException) and e.code == 404:
# This is an old identity server that does not yet support v2 lookups
logger.warning(
"Attempted v2 lookup on v1 identity server %s. Falling "
"back to v1",
id_server,
)
else:
logger.warning("Error when looking up hashing details: %s", e)
return None
return (yield self._lookup_3pid_v1(id_server, medium, address))
@defer.inlineCallbacks
def _lookup_3pid_v1(self, id_server, medium, address):
"""Looks up a 3pid in the passed identity server using v1 lookup.
Args:
id_server (str): The server name (including port, if required)
of the identity server to use.
medium (str): The type of the third party identifier (e.g. "email").
address (str): The third party identifier (e.g. "foo@example.com").
Returns:
str: the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
data = yield self.simple_http_client.get_json(
"%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server),
{"medium": medium, "address": address},
)
if "mxid" in data:
if "signatures" not in data:
raise AuthError(401, "No signatures on 3pid binding")
yield self._verify_any_signature(data, id_server)
return data["mxid"]
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
except IOError as e:
logger.warning("Error from v1 identity server lookup: %s" % (e,))
return None
@defer.inlineCallbacks
def _lookup_3pid_v2(self, id_server, id_access_token, medium, address):
"""Looks up a 3pid in the passed identity server using v2 lookup.
Args:
id_server (str): The server name (including port, if required)
of the identity server to use.
id_access_token (str): The access token to authenticate to the identity server with
medium (str): The type of the third party identifier (e.g. "email").
address (str): The third party identifier (e.g. "foo@example.com").
Returns:
Deferred[str|None]: the matrix ID of the 3pid, or None if it is not recognised.
"""
# Check what hashing details are supported by this identity server
try:
hash_details = yield self.simple_http_client.get_json(
"%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server),
{"access_token": id_access_token},
)
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
if not isinstance(hash_details, dict):
logger.warning(
"Got non-dict object when checking hash details of %s%s: %s",
id_server_scheme,
id_server,
hash_details,
)
raise SynapseError(
400,
"Non-dict object from %s%s during v2 hash_details request: %s"
% (id_server_scheme, id_server, hash_details),
)
# Extract information from hash_details
supported_lookup_algorithms = hash_details.get("algorithms")
lookup_pepper = hash_details.get("lookup_pepper")
if (
not supported_lookup_algorithms
or not isinstance(supported_lookup_algorithms, list)
or not lookup_pepper
or not isinstance(lookup_pepper, str)
):
raise SynapseError(
400,
"Invalid hash details received from identity server %s%s: %s"
% (id_server_scheme, id_server, hash_details),
)
# Check if any of the supported lookup algorithms are present
if LookupAlgorithm.SHA256 in supported_lookup_algorithms:
# Perform a hashed lookup
lookup_algorithm = LookupAlgorithm.SHA256
# Hash address, medium and the pepper with sha256
to_hash = "%s %s %s" % (address, medium, lookup_pepper)
lookup_value = sha256_and_url_safe_base64(to_hash)
elif LookupAlgorithm.NONE in supported_lookup_algorithms:
# Perform a non-hashed lookup
lookup_algorithm = LookupAlgorithm.NONE
# Combine together plaintext address and medium
lookup_value = "%s %s" % (address, medium)
else:
logger.warning(
"None of the provided lookup algorithms of %s are supported: %s",
id_server,
supported_lookup_algorithms,
)
raise SynapseError(
400,
"Provided identity server does not support any v2 lookup "
"algorithms that this homeserver supports.",
)
# Authenticate with identity server given the access token from the client
headers = {"Authorization": create_id_access_token_header(id_access_token)}
try:
lookup_results = yield self.simple_http_client.post_json_get_json(
"%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server),
{
"addresses": [lookup_value],
"algorithm": lookup_algorithm,
"pepper": lookup_pepper,
},
headers=headers,
)
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
except Exception as e:
logger.warning("Error when performing a v2 3pid lookup: %s", e)
raise SynapseError(
500, "Unknown error occurred during identity server lookup"
)
# Check for a mapping from what we looked up to an MXID
if "mappings" not in lookup_results or not isinstance(
lookup_results["mappings"], dict
):
logger.warning("No results from 3pid lookup")
return None
# Return the MXID if it's available, or None otherwise
mxid = lookup_results["mappings"].get(lookup_value)
return mxid
@defer.inlineCallbacks
def _verify_any_signature(self, data, server_hostname):
if server_hostname not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items():
try:
key_data = yield self.simple_http_client.get_json(
"%s%s/_matrix/identity/api/v1/pubkey/%s"
% (id_server_scheme, server_hostname, key_name)
)
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
if "public_key" not in key_data:
raise AuthError(
401, "No public key named %s from %s" % (key_name, server_hostname)
)
verify_signed_json(
data,
server_hostname,
decode_verify_key_bytes(
key_name, decode_base64(key_data["public_key"])
),
)
return
@defer.inlineCallbacks
def _make_and_store_3pid_invite(
self,
@ -951,7 +760,7 @@ class RoomMemberHandler(object):
room_avatar_url = room_avatar_event.content.get("url", "")
token, public_keys, fallback_public_key, display_name = (
yield self._ask_id_server_for_third_party_invite(
yield self.identity_handler.ask_id_server_for_third_party_invite(
requester=requester,
id_server=id_server,
medium=medium,
@ -987,147 +796,6 @@ class RoomMemberHandler(object):
txn_id=txn_id,
)
@defer.inlineCallbacks
def _ask_id_server_for_third_party_invite(
self,
requester,
id_server,
medium,
address,
room_id,
inviter_user_id,
room_alias,
room_avatar_url,
room_join_rules,
room_name,
inviter_display_name,
inviter_avatar_url,
id_access_token=None,
):
"""
Asks an identity server for a third party invite.
Args:
requester (Requester)
id_server (str): hostname + optional port for the identity server.
medium (str): The literal string "email".
address (str): The third party address being invited.
room_id (str): The ID of the room to which the user is invited.
inviter_user_id (str): The user ID of the inviter.
room_alias (str): An alias for the room, for cosmetic notifications.
room_avatar_url (str): The URL of the room's avatar, for cosmetic
notifications.
room_join_rules (str): The join rules of the email (e.g. "public").
room_name (str): The m.room.name of the room.
inviter_display_name (str): The current display name of the
inviter.
inviter_avatar_url (str): The URL of the inviter's avatar.
id_access_token (str|None): The access token to authenticate to the identity
server with
Returns:
A deferred tuple containing:
token (str): The token which must be signed to prove authenticity.
public_keys ([{"public_key": str, "key_validity_url": str}]):
public_key is a base64-encoded ed25519 public key.
fallback_public_key: One element from public_keys.
display_name (str): A user-friendly name to represent the invited
user.
"""
invite_config = {
"medium": medium,
"address": address,
"room_id": room_id,
"room_alias": room_alias,
"room_avatar_url": room_avatar_url,
"room_join_rules": room_join_rules,
"room_name": room_name,
"sender": inviter_user_id,
"sender_display_name": inviter_display_name,
"sender_avatar_url": inviter_avatar_url,
}
# Add the identity service access token to the JSON body and use the v2
# Identity Service endpoints if id_access_token is present
data = None
base_url = "%s%s/_matrix/identity" % (id_server_scheme, id_server)
if id_access_token:
key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % (
id_server_scheme,
id_server,
)
# Attempt a v2 lookup
url = base_url + "/v2/store-invite"
try:
data = yield self.simple_http_client.post_json_get_json(
url,
invite_config,
{"Authorization": create_id_access_token_header(id_access_token)},
)
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
except HttpResponseException as e:
if e.code != 404:
logger.info("Failed to POST %s with JSON: %s", url, e)
raise e
if data is None:
key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
id_server_scheme,
id_server,
)
url = base_url + "/api/v1/store-invite"
try:
data = yield self.simple_http_client.post_json_get_json(
url, invite_config
)
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
except HttpResponseException as e:
logger.warning(
"Error trying to call /store-invite on %s%s: %s",
id_server_scheme,
id_server,
e,
)
if data is None:
# Some identity servers may only support application/x-www-form-urlencoded
# types. This is especially true with old instances of Sydent, see
# https://github.com/matrix-org/sydent/pull/170
try:
data = yield self.simple_http_client.post_urlencoded_get_json(
url, invite_config
)
except HttpResponseException as e:
logger.warning(
"Error calling /store-invite on %s%s with fallback "
"encoding: %s",
id_server_scheme,
id_server,
e,
)
raise e
# TODO: Check for success
token = data["token"]
public_keys = data.get("public_keys", [])
if "public_key" in data:
fallback_public_key = {
"public_key": data["public_key"],
"key_validity_url": key_validity_url,
}
else:
fallback_public_key = public_keys[0]
if not public_keys:
public_keys.append(fallback_public_key)
display_name = data["display_name"]
return token, public_keys, fallback_public_key, display_name
@defer.inlineCallbacks
def _is_host_in_room(self, current_state_ids):
# Have we just created the room, and is this about to be the very

View file

@ -87,21 +87,23 @@ class StatsHandler(StateDeltasHandler):
# Be sure to read the max stream_ordering *before* checking if there are any outstanding
# deltas, since there is otherwise a chance that we could miss updates which arrive
# after we check the deltas.
room_max_stream_ordering = yield self.store.get_room_max_stream_ordering()
room_max_stream_ordering = self.store.get_room_max_stream_ordering()
if self.pos == room_max_stream_ordering:
break
deltas = yield self.store.get_current_state_deltas(self.pos)
logger.debug(
"Processing room stats %s->%s", self.pos, room_max_stream_ordering
)
max_pos, deltas = yield self.store.get_current_state_deltas(
self.pos, room_max_stream_ordering
)
if deltas:
logger.debug("Handling %d state deltas", len(deltas))
room_deltas, user_deltas = yield self._handle_deltas(deltas)
max_pos = deltas[-1]["stream_id"]
else:
room_deltas = {}
user_deltas = {}
max_pos = room_max_stream_ordering
# Then count deltas for total_events and total_event_bytes.
room_count, user_count = yield self.store.get_changes_room_total_events_and_bytes(
@ -293,6 +295,7 @@ class StatsHandler(StateDeltasHandler):
room_state["guest_access"] = event_content.get("guest_access")
for room_id, state in room_to_state_updates.items():
logger.info("Updating room_stats_state for %s: %s", room_id, state)
yield self.store.update_room_state(room_id, state)
return room_to_stats_deltas, user_to_stats_deltas

View file

@ -138,21 +138,28 @@ class UserDirectoryHandler(StateDeltasHandler):
# Loop round handling deltas until we're up to date
while True:
with Measure(self.clock, "user_dir_delta"):
deltas = yield self.store.get_current_state_deltas(self.pos)
if not deltas:
room_max_stream_ordering = self.store.get_room_max_stream_ordering()
if self.pos == room_max_stream_ordering:
return
logger.debug(
"Processing user stats %s->%s", self.pos, room_max_stream_ordering
)
max_pos, deltas = yield self.store.get_current_state_deltas(
self.pos, room_max_stream_ordering
)
logger.info("Handling %d state deltas", len(deltas))
yield self._handle_deltas(deltas)
self.pos = deltas[-1]["stream_id"]
self.pos = max_pos
# Expose current event processing position to prometheus
synapse.metrics.event_processing_positions.labels("user_dir").set(
self.pos
max_pos
)
yield self.store.update_user_directory_stream_pos(self.pos)
yield self.store.update_user_directory_stream_pos(max_pos)
@defer.inlineCallbacks
def _handle_deltas(self, deltas):

View file

@ -327,7 +327,7 @@ class SimpleHttpClient(object):
Args:
uri (str):
args (dict[str, str|List[str]]): query params
headers (dict[str, List[str]]|None): If not None, a map from
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
Returns:
@ -371,7 +371,7 @@ class SimpleHttpClient(object):
Args:
uri (str):
post_json (object):
headers (dict[str, List[str]]|None): If not None, a map from
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
Returns:
@ -414,7 +414,7 @@ class SimpleHttpClient(object):
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
headers (dict[str, List[str]]|None): If not None, a map from
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
@ -438,7 +438,7 @@ class SimpleHttpClient(object):
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
headers (dict[str, List[str]]|None): If not None, a map from
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
@ -482,7 +482,7 @@ class SimpleHttpClient(object):
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
headers (dict[str, List[str]]|None): If not None, a map from
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
@ -516,7 +516,7 @@ class SimpleHttpClient(object):
Args:
url (str): The URL to GET
output_stream (file): File to write the response body to.
headers (dict[str, List[str]]|None): If not None, a map from
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
Returns:
A (int,dict,string,int) tuple of the file length, dict of the response

View file

@ -170,6 +170,7 @@ import inspect
import logging
import re
from functools import wraps
from typing import Dict
from canonicaljson import json
@ -547,7 +548,7 @@ def inject_active_span_twisted_headers(headers, destination, check_destination=T
return
span = opentracing.tracer.active_span
carrier = {}
carrier = {} # type: Dict[str, str]
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
for key, value in carrier.items():
@ -584,7 +585,7 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True):
span = opentracing.tracer.active_span
carrier = {}
carrier = {} # type: Dict[str, str]
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
for key, value in carrier.items():
@ -639,7 +640,7 @@ def get_active_span_text_map(destination=None):
if destination and not whitelisted_homeserver(destination):
return {}
carrier = {}
carrier = {} # type: Dict[str, str]
opentracing.tracer.inject(
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
)
@ -653,7 +654,7 @@ def active_span_context_as_string():
Returns:
The active span context encoded as a string.
"""
carrier = {}
carrier = {} # type: Dict[str, str]
if opentracing:
opentracing.tracer.inject(
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier

View file

@ -119,7 +119,11 @@ def trace_function(f):
logger = logging.getLogger(name)
level = logging.DEBUG
s = inspect.currentframe().f_back
frame = inspect.currentframe()
if frame is None:
raise Exception("Can't get current frame!")
s = frame.f_back
to_print = [
"\t%s:%s %s. Args: args=%s, kwargs=%s"
@ -144,7 +148,7 @@ def trace_function(f):
pathname=pathname,
lineno=lineno,
msg=msg,
args=None,
args=tuple(),
exc_info=None,
)
@ -157,7 +161,12 @@ def trace_function(f):
def get_previous_frames():
s = inspect.currentframe().f_back.f_back
frame = inspect.currentframe()
if frame is None:
raise Exception("Can't get current frame!")
s = frame.f_back.f_back
to_return = []
while s:
if s.f_globals["__name__"].startswith("synapse"):
@ -174,7 +183,10 @@ def get_previous_frames():
def get_previous_frame(ignore=[]):
s = inspect.currentframe().f_back.f_back
frame = inspect.currentframe()
if frame is None:
raise Exception("Can't get current frame!")
s = frame.f_back.f_back
while s:
if s.f_globals["__name__"].startswith("synapse"):

View file

@ -125,7 +125,7 @@ class InFlightGauge(object):
)
# Counts number of in flight blocks for a given set of label values
self._registrations = {}
self._registrations = {} # type: Dict
# Protects access to _registrations
self._lock = threading.Lock()
@ -226,7 +226,7 @@ class BucketCollector(object):
# Fetch the data -- this must be synchronous!
data = self.data_collector()
buckets = {}
buckets = {} # type: Dict[float, int]
res = []
for x in data.keys():

View file

@ -36,9 +36,9 @@ from twisted.web.resource import Resource
try:
from prometheus_client.samples import Sample
except ImportError:
Sample = namedtuple(
Sample = namedtuple( # type: ignore[no-redef] # noqa
"Sample", ["name", "labels", "value", "timestamp", "exemplar"]
) # type: ignore
)
CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8")

View file

@ -15,7 +15,7 @@
# limitations under the License.
import logging
from typing import Set
from typing import List, Set
from pkg_resources import (
DistributionNotFound,
@ -73,6 +73,7 @@ REQUIREMENTS = [
"netaddr>=0.7.18",
"Jinja2>=2.9",
"bleach>=1.4.3",
"typing-extensions>=3.7.4",
]
CONDITIONAL_REQUIREMENTS = {
@ -144,7 +145,11 @@ def check_requirements(for_feature=None):
deps_needed.append(dependency)
errors.append(
"Needed %s, got %s==%s"
% (dependency, e.dist.project_name, e.dist.version)
% (
dependency,
e.dist.project_name, # type: ignore[attr-defined] # noqa
e.dist.version, # type: ignore[attr-defined] # noqa
)
)
except DistributionNotFound:
deps_needed.append(dependency)
@ -159,7 +164,7 @@ def check_requirements(for_feature=None):
if not for_feature:
# Check the optional dependencies are up to date. We allow them to not be
# installed.
OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), [])
OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), []) # type: List[str]
for dependency in OPTS:
try:
@ -168,7 +173,11 @@ def check_requirements(for_feature=None):
deps_needed.append(dependency)
errors.append(
"Needed optional %s, got %s==%s"
% (dependency, e.dist.project_name, e.dist.version)
% (
dependency,
e.dist.project_name, # type: ignore[attr-defined] # noqa
e.dist.version, # type: ignore[attr-defined] # noqa
)
)
except DistributionNotFound:
# If it's not found, we don't care

View file

@ -39,6 +39,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
from synapse.logging.opentracing import set_tag
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.storage.state import StateFilter
@ -81,6 +82,7 @@ class RoomCreateRestServlet(TransactionRestServlet):
)
def on_PUT(self, request, txn_id):
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(request, self.on_POST, request)
@defer.inlineCallbacks
@ -181,6 +183,9 @@ class RoomStateEventRestServlet(TransactionRestServlet):
def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
requester = yield self.auth.get_user_by_req(request)
if txn_id:
set_tag("txn_id", txn_id)
content = parse_json_object_from_request(request)
event_dict = {
@ -209,6 +214,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
ret = {}
if event:
set_tag("event_id", event.event_id)
ret = {"event_id": event.event_id}
return 200, ret
@ -244,12 +250,15 @@ class RoomSendEventRestServlet(TransactionRestServlet):
requester, event_dict, txn_id=txn_id
)
set_tag("event_id", event.event_id)
return 200, {"event_id": event.event_id}
def on_GET(self, request, room_id, event_type, txn_id):
return 200, "Not implemented"
def on_PUT(self, request, room_id, event_type, txn_id):
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, event_type, txn_id
)
@ -310,6 +319,8 @@ class JoinRoomAliasServlet(TransactionRestServlet):
return 200, {"room_id": room_id}
def on_PUT(self, request, room_identifier, txn_id):
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_identifier, txn_id
)
@ -350,6 +361,10 @@ class PublicRoomListRestServlet(TransactionRestServlet):
limit = parse_integer(request, "limit", 0)
since_token = parse_string(request, "since", None)
if limit == 0:
# zero is a special value which corresponds to no limit.
limit = None
handler = self.hs.get_room_list_handler()
if server:
data = yield handler.get_remote_public_room_list(
@ -387,6 +402,10 @@ class PublicRoomListRestServlet(TransactionRestServlet):
else:
network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id)
if limit == 0:
# zero is a special value which corresponds to no limit.
limit = None
handler = self.hs.get_room_list_handler()
if server:
data = yield handler.get_remote_public_room_list(
@ -655,6 +674,8 @@ class RoomForgetRestServlet(TransactionRestServlet):
return 200, {}
def on_PUT(self, request, room_id, txn_id):
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, txn_id
)
@ -738,6 +759,8 @@ class RoomMembershipRestServlet(TransactionRestServlet):
return True
def on_PUT(self, request, room_id, membership_action, txn_id):
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, membership_action, txn_id
)
@ -771,9 +794,12 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
txn_id=txn_id,
)
set_tag("event_id", event.event_id)
return 200, {"event_id": event.event_id}
def on_PUT(self, request, room_id, event_id, txn_id):
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, event_id, txn_id
)

View file

@ -129,66 +129,6 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
return 200, ret
class MsisdnPasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/password/msisdn/requestToken$")
def __init__(self, hs):
super(MsisdnPasswordRequestTokenRestServlet, self).__init__()
self.hs = hs
self.datastore = self.hs.get_datastore()
self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
assert_params_in_dict(
body, ["client_secret", "country", "phone_number", "send_attempt"]
)
client_secret = body["client_secret"]
country = body["country"]
phone_number = body["phone_number"]
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
msisdn = phone_number_to_msisdn(country, phone_number)
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError(
403,
"Account phone numbers are not authorized on this server",
Codes.THREEPID_DENIED,
)
existing_user_id = yield self.datastore.get_user_id_by_threepid(
"msisdn", msisdn
)
if existing_user_id is None:
raise SynapseError(400, "MSISDN not found", Codes.THREEPID_NOT_FOUND)
if not self.hs.config.account_threepid_delegate_msisdn:
logger.warn(
"No upstream msisdn account_threepid_delegate configured on the server to "
"handle this request"
)
raise SynapseError(
400,
"Password reset by phone number is not supported on this homeserver",
)
ret = yield self.identity_handler.requestMsisdnToken(
self.hs.config.account_threepid_delegate_msisdn,
country,
phone_number,
client_secret,
send_attempt,
next_link,
)
return 200, ret
class PasswordResetSubmitTokenServlet(RestServlet):
"""Handles 3PID validation token submission"""
@ -301,9 +241,7 @@ class PasswordRestServlet(RestServlet):
else:
requester = None
result, params, _ = yield self.auth_handler.check_auth(
[[LoginType.EMAIL_IDENTITY], [LoginType.MSISDN]],
body,
self.hs.get_ip_from_request(request),
[[LoginType.EMAIL_IDENTITY]], body, self.hs.get_ip_from_request(request)
)
if LoginType.EMAIL_IDENTITY in result:
@ -843,7 +781,6 @@ class WhoamiRestServlet(RestServlet):
def register_servlets(hs, http_server):
EmailPasswordRequestTokenRestServlet(hs).register(http_server)
MsisdnPasswordRequestTokenRestServlet(hs).register(http_server)
PasswordResetSubmitTokenServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)

View file

@ -17,7 +17,7 @@ import logging
from twisted.internet import defer
from synapse.api.errors import AuthError, Codes, StoreError, SynapseError
from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID
@ -52,13 +52,15 @@ class GetFilterRestServlet(RestServlet):
raise SynapseError(400, "Invalid filter_id")
try:
filter = yield self.filtering.get_user_filter(
filter_collection = yield self.filtering.get_user_filter(
user_localpart=target_user.localpart, filter_id=filter_id
)
except StoreError as e:
if e.code != 404:
raise
raise NotFoundError("No such filter")
return 200, filter.get_filter_json()
except (KeyError, StoreError):
raise SynapseError(400, "No such filter", errcode=Codes.NOT_FOUND)
return 200, filter_collection.get_filter_json()
class CreateFilterRestServlet(RestServlet):

View file

@ -21,7 +21,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.api.constants import PresenceState
from synapse.api.errors import SynapseError
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
from synapse.events.utils import (
format_event_for_client_v2_without_room_id,
@ -119,25 +119,32 @@ class SyncRestServlet(RestServlet):
request_key = (user, timeout, since, filter_id, full_state, device_id)
if filter_id:
if filter_id.startswith("{"):
try:
filter_object = json.loads(filter_id)
set_timeline_upper_limit(
filter_object, self.hs.config.filter_timeline_limit
)
except Exception:
raise SynapseError(400, "Invalid filter JSON")
self.filtering.check_valid_filter(filter_object)
filter = FilterCollection(filter_object)
else:
filter = yield self.filtering.get_user_filter(user.localpart, filter_id)
if filter_id is None:
filter_collection = DEFAULT_FILTER_COLLECTION
elif filter_id.startswith("{"):
try:
filter_object = json.loads(filter_id)
set_timeline_upper_limit(
filter_object, self.hs.config.filter_timeline_limit
)
except Exception:
raise SynapseError(400, "Invalid filter JSON")
self.filtering.check_valid_filter(filter_object)
filter_collection = FilterCollection(filter_object)
else:
filter = DEFAULT_FILTER_COLLECTION
try:
filter_collection = yield self.filtering.get_user_filter(
user.localpart, filter_id
)
except StoreError as err:
if err.code != 404:
raise
# fix up the description and errcode to be more useful
raise SynapseError(400, "No such filter", errcode=Codes.INVALID_PARAM)
sync_config = SyncConfig(
user=user,
filter_collection=filter,
filter_collection=filter_collection,
is_guest=requester.is_guest,
request_key=request_key,
device_id=device_id,
@ -171,7 +178,7 @@ class SyncRestServlet(RestServlet):
time_now = self.clock.time_msec()
response_content = yield self.encode_response(
time_now, sync_result, requester.access_token_id, filter
time_now, sync_result, requester.access_token_id, filter_collection
)
return 200, response_content

View file

@ -195,7 +195,7 @@ def respond_with_responder(request, responder, media_type, file_size, upload_nam
respond_404(request)
return
logger.debug("Responding to media request with responder %s")
logger.debug("Responding to media request with responder %s", responder)
add_file_headers(request, media_type, file_size, upload_name)
try:
with responder:

View file

@ -82,13 +82,21 @@ class Thumbnailer(object):
else:
return (max_height * self.width) // self.height, max_height
def _resize(self, width, height):
# 1-bit or 8-bit color palette images need converting to RGB
# otherwise they will be scaled using nearest neighbour which
# looks awful
if self.image.mode in ["1", "P"]:
self.image = self.image.convert("RGB")
return self.image.resize((width, height), Image.ANTIALIAS)
def scale(self, width, height, output_type):
"""Rescales the image to the given dimensions.
Returns:
BytesIO: the bytes of the encoded image ready to be written to disk
"""
scaled = self.image.resize((width, height), Image.ANTIALIAS)
scaled = self._resize(width, height)
return self._encode_image(scaled, output_type)
def crop(self, width, height, output_type):
@ -107,13 +115,13 @@ class Thumbnailer(object):
"""
if width * self.height > height * self.width:
scaled_height = (width * self.height) // self.width
scaled_image = self.image.resize((width, scaled_height), Image.ANTIALIAS)
scaled_image = self._resize(width, scaled_height)
crop_top = (scaled_height - height) // 2
crop_bottom = height + crop_top
cropped = scaled_image.crop((0, crop_top, width, crop_bottom))
else:
scaled_width = (height * self.width) // self.height
scaled_image = self.image.resize((scaled_width, height), Image.ANTIALIAS)
scaled_image = self._resize(scaled_width, height)
crop_left = (scaled_width - width) // 2
crop_right = width + crop_left
cropped = scaled_image.crop((crop_left, 0, crop_right, height))

View file

@ -17,7 +17,7 @@ import logging
from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import SynapseError
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import (
DirectServeResource,
respond_with_json,
@ -56,7 +56,11 @@ class UploadResource(DirectServeResource):
if content_length is None:
raise SynapseError(msg="Request must specify a Content-Length", code=400)
if int(content_length) > self.max_upload_size:
raise SynapseError(msg="Upload request body is too large", code=413)
raise SynapseError(
msg="Upload request body is too large",
code=413,
errcode=Codes.TOO_LARGE,
)
upload_name = parse_string(request, b"filename", encoding=None)
if upload_name:

View file

@ -33,7 +33,7 @@ from synapse.state import v1, v2
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import Measure
from synapse.util.metrics import Measure, measure_func
logger = logging.getLogger(__name__)
@ -191,11 +191,22 @@ class StateHandler(object):
return joined_users
@defer.inlineCallbacks
def get_current_hosts_in_room(self, room_id, latest_event_ids=None):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_hosts_in_room")
entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
def get_current_hosts_in_room(self, room_id):
event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
return (yield self.get_hosts_in_room_at_events(room_id, event_ids))
@defer.inlineCallbacks
def get_hosts_in_room_at_events(self, room_id, event_ids):
"""Get the hosts that were in a room at the given event ids
Args:
room_id (str):
event_ids (list[str]):
Returns:
Deferred[list[str]]: the hosts in the room at the given events
"""
entry = yield self.resolve_state_groups_for_events(room_id, event_ids)
joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
return joined_hosts
@ -344,6 +355,7 @@ class StateHandler(object):
return context
@measure_func()
@defer.inlineCallbacks
def resolve_state_groups_for_events(self, room_id, event_ids):
""" Given a list of event_ids this method fetches the state at each

View file

@ -33,16 +33,9 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 120 * 1000
class ClientIpStore(background_updates.BackgroundUpdateStore):
class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
def __init__(self, db_conn, hs):
self.client_ip_last_seen = Cache(
name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
)
super(ClientIpStore, self).__init__(db_conn, hs)
self.user_ips_max_age = hs.config.user_ips_max_age
super(ClientIpBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_index_update(
"user_ips_device_index",
@ -92,19 +85,6 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"devices_last_seen", self._devices_last_seen_update
)
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
self._batch_row_update = {}
self._client_ip_looper = self._clock.looping_call(
self._update_client_ips_batch, 5 * 1000
)
self.hs.get_reactor().addSystemEventTrigger(
"before", "shutdown", self._update_client_ips_batch
)
if self.user_ips_max_age:
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
@defer.inlineCallbacks
def _remove_user_ip_nonunique(self, progress, batch_size):
def f(conn):
@ -303,6 +283,110 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
return batch_size
@defer.inlineCallbacks
def _devices_last_seen_update(self, progress, batch_size):
"""Background update to insert last seen info into devices table
"""
last_user_id = progress.get("last_user_id", "")
last_device_id = progress.get("last_device_id", "")
def _devices_last_seen_update_txn(txn):
# This consists of two queries:
#
# 1. The sub-query searches for the next N devices and joins
# against user_ips to find the max last_seen associated with
# that device.
# 2. The outer query then joins again against user_ips on
# user/device/last_seen. This *should* hopefully only
# return one row, but if it does return more than one then
# we'll just end up updating the same device row multiple
# times, which is fine.
if self.database_engine.supports_tuple_comparison:
where_clause = "(user_id, device_id) > (?, ?)"
where_args = [last_user_id, last_device_id]
else:
# We explicitly do a `user_id >= ? AND (...)` here to ensure
# that an index is used, as doing `user_id > ? OR (user_id = ? AND ...)`
# makes it hard for query optimiser to tell that it can use the
# index on user_id
where_clause = "user_id >= ? AND (user_id > ? OR device_id > ?)"
where_args = [last_user_id, last_user_id, last_device_id]
sql = """
SELECT
last_seen, ip, user_agent, user_id, device_id
FROM (
SELECT
user_id, device_id, MAX(u.last_seen) AS last_seen
FROM devices
INNER JOIN user_ips AS u USING (user_id, device_id)
WHERE %(where_clause)s
GROUP BY user_id, device_id
ORDER BY user_id ASC, device_id ASC
LIMIT ?
) c
INNER JOIN user_ips AS u USING (user_id, device_id, last_seen)
""" % {
"where_clause": where_clause
}
txn.execute(sql, where_args + [batch_size])
rows = txn.fetchall()
if not rows:
return 0
sql = """
UPDATE devices
SET last_seen = ?, ip = ?, user_agent = ?
WHERE user_id = ? AND device_id = ?
"""
txn.execute_batch(sql, rows)
_, _, _, user_id, device_id = rows[-1]
self._background_update_progress_txn(
txn,
"devices_last_seen",
{"last_user_id": user_id, "last_device_id": device_id},
)
return len(rows)
updated = yield self.runInteraction(
"_devices_last_seen_update", _devices_last_seen_update_txn
)
if not updated:
yield self._end_background_update("devices_last_seen")
return updated
class ClientIpStore(ClientIpBackgroundUpdateStore):
def __init__(self, db_conn, hs):
self.client_ip_last_seen = Cache(
name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
)
super(ClientIpStore, self).__init__(db_conn, hs)
self.user_ips_max_age = hs.config.user_ips_max_age
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
self._batch_row_update = {}
self._client_ip_looper = self._clock.looping_call(
self._update_client_ips_batch, 5 * 1000
)
self.hs.get_reactor().addSystemEventTrigger(
"before", "shutdown", self._update_client_ips_batch
)
if self.user_ips_max_age:
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
@defer.inlineCallbacks
def insert_client_ip(
self, user_id, access_token, ip, user_agent, device_id, now=None
@ -454,85 +538,6 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
for (access_token, ip), (user_agent, last_seen) in iteritems(results)
)
@defer.inlineCallbacks
def _devices_last_seen_update(self, progress, batch_size):
"""Background update to insert last seen info into devices table
"""
last_user_id = progress.get("last_user_id", "")
last_device_id = progress.get("last_device_id", "")
def _devices_last_seen_update_txn(txn):
# This consists of two queries:
#
# 1. The sub-query searches for the next N devices and joins
# against user_ips to find the max last_seen associated with
# that device.
# 2. The outer query then joins again against user_ips on
# user/device/last_seen. This *should* hopefully only
# return one row, but if it does return more than one then
# we'll just end up updating the same device row multiple
# times, which is fine.
if self.database_engine.supports_tuple_comparison:
where_clause = "(user_id, device_id) > (?, ?)"
where_args = [last_user_id, last_device_id]
else:
# We explicitly do a `user_id >= ? AND (...)` here to ensure
# that an index is used, as doing `user_id > ? OR (user_id = ? AND ...)`
# makes it hard for query optimiser to tell that it can use the
# index on user_id
where_clause = "user_id >= ? AND (user_id > ? OR device_id > ?)"
where_args = [last_user_id, last_user_id, last_device_id]
sql = """
SELECT
last_seen, ip, user_agent, user_id, device_id
FROM (
SELECT
user_id, device_id, MAX(u.last_seen) AS last_seen
FROM devices
INNER JOIN user_ips AS u USING (user_id, device_id)
WHERE %(where_clause)s
GROUP BY user_id, device_id
ORDER BY user_id ASC, device_id ASC
LIMIT ?
) c
INNER JOIN user_ips AS u USING (user_id, device_id, last_seen)
""" % {
"where_clause": where_clause
}
txn.execute(sql, where_args + [batch_size])
rows = txn.fetchall()
if not rows:
return 0
sql = """
UPDATE devices
SET last_seen = ?, ip = ?, user_agent = ?
WHERE user_id = ? AND device_id = ?
"""
txn.execute_batch(sql, rows)
_, _, _, user_id, device_id = rows[-1]
self._background_update_progress_txn(
txn,
"devices_last_seen",
{"last_user_id": user_id, "last_device_id": device_id},
)
return len(rows)
updated = yield self.runInteraction(
"_devices_last_seen_update", _devices_last_seen_update_txn
)
if not updated:
yield self._end_background_update("devices_last_seen")
return updated
@wrap_as_background_process("prune_old_user_ips")
async def _prune_old_user_ips(self):
"""Removes entries in user IPs older than the configured period.

View file

@ -208,11 +208,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
class DeviceInboxBackgroundUpdateStore(BackgroundUpdateStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, db_conn, hs):
super(DeviceInboxStore, self).__init__(db_conn, hs)
super(DeviceInboxBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_index_update(
"device_inbox_stream_index",
@ -225,6 +225,26 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
)
@defer.inlineCallbacks
def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
txn.close()
yield self.runWithConnection(reindex_txn)
yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
return 1
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, db_conn, hs):
super(DeviceInboxStore, self).__init__(db_conn, hs)
# Map of (user_id, device_id) to the last stream_id that has been
# deleted up to. This is so that we can no op deletions.
self._last_device_delete_cache = ExpiringCache(
@ -435,16 +455,3 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
return self.runInteraction(
"get_all_new_device_messages", get_all_new_device_messages_txn
)
@defer.inlineCallbacks
def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
txn.close()
yield self.runWithConnection(reindex_txn)
yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
return 1

View file

@ -512,17 +512,9 @@ class DeviceWorkerStore(SQLBaseStore):
return results
class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
class DeviceBackgroundUpdateStore(BackgroundUpdateStore):
def __init__(self, db_conn, hs):
super(DeviceStore, self).__init__(db_conn, hs)
# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
self.device_id_exists_cache = Cache(
name="device_id_exists", keylen=2, max_entries=10000
)
self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
super(DeviceBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_index_update(
"device_lists_stream_idx",
@ -555,6 +547,31 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
self._drop_device_list_streams_non_unique_indexes,
)
@defer.inlineCallbacks
def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
def f(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
txn.close()
yield self.runWithConnection(f)
yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES)
return 1
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def __init__(self, db_conn, hs):
super(DeviceStore, self).__init__(db_conn, hs)
# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
self.device_id_exists_cache = Cache(
name="device_id_exists", keylen=2, max_entries=10000
)
self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
@defer.inlineCallbacks
def store_device(self, user_id, device_id, initial_device_display_name):
"""Ensure the given device is known; add it to the store if not
@ -910,15 +927,3 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
"_prune_old_outbound_device_pokes",
_prune_txn,
)
@defer.inlineCallbacks
def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
def f(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
txn.close()
yield self.runWithConnection(f)
yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES)
return 1

View file

@ -15,11 +15,9 @@
from synapse.storage.background_updates import BackgroundUpdateStore
class MediaRepositoryStore(BackgroundUpdateStore):
"""Persistence for attachments and avatars"""
class MediaRepositoryBackgroundUpdateStore(BackgroundUpdateStore):
def __init__(self, db_conn, hs):
super(MediaRepositoryStore, self).__init__(db_conn, hs)
super(MediaRepositoryBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_index_update(
update_name="local_media_repository_url_idx",
@ -29,6 +27,13 @@ class MediaRepositoryStore(BackgroundUpdateStore):
where_clause="url_cache IS NOT NULL",
)
class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"""Persistence for attachments and avatars"""
def __init__(self, db_conn, hs):
super(MediaRepositoryStore, self).__init__(db_conn, hs)
def get_local_media(self, media_id):
"""Get the metadata for a local piece of media
Returns:

View file

@ -183,8 +183,8 @@ class PushRulesWorkerStore(
return results
@defer.inlineCallbacks
def move_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
"""Move a single push rule from one room to another for a specific user.
def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
"""Copy a single push rule from one room to another for a specific user.
Args:
new_room_id (str): ID of the new room.
@ -209,14 +209,11 @@ class PushRulesWorkerStore(
actions=rule["actions"],
)
# Delete push rule for the old room
yield self.delete_push_rule(user_id, rule["rule_id"])
@defer.inlineCallbacks
def move_push_rules_from_room_to_room_for_user(
def copy_push_rules_from_room_to_room_for_user(
self, old_room_id, new_room_id, user_id
):
"""Move all of the push rules from one room to another for a specific
"""Copy all of the push rules from one room to another for a specific
user.
Args:
@ -227,15 +224,14 @@ class PushRulesWorkerStore(
# Retrieve push rules for this user
user_push_rules = yield self.get_push_rules_for_user(user_id)
# Get rules relating to the old room, move them to the new room, then
# delete them from the old room
# Get rules relating to the old room and copy them to the new room
for rule in user_push_rules:
conditions = rule.get("conditions", [])
if any(
(c.get("key") == "room_id" and c.get("pattern") == old_room_id)
for c in conditions
):
self.move_push_rule_from_room_to_room(new_room_id, user_id, rule)
yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
@defer.inlineCallbacks
def bulk_get_push_rules_for_room(self, event, context):

View file

@ -493,7 +493,9 @@ class RegistrationWorkerStore(SQLBaseStore):
"""
def _find_next_generated_user_id(txn):
txn.execute("SELECT name FROM users")
# We bound between '@1' and '@a' to avoid pulling the entire table
# out.
txn.execute("SELECT name FROM users WHERE '@1' <= name AND name < '@a'")
regex = re.compile(r"^@(\d+):")
@ -785,13 +787,14 @@ class RegistrationWorkerStore(SQLBaseStore):
)
class RegistrationStore(
class RegistrationBackgroundUpdateStore(
RegistrationWorkerStore, background_updates.BackgroundUpdateStore
):
def __init__(self, db_conn, hs):
super(RegistrationStore, self).__init__(db_conn, hs)
super(RegistrationBackgroundUpdateStore, self).__init__(db_conn, hs)
self.clock = hs.get_clock()
self.config = hs.config
self.register_background_index_update(
"access_tokens_device_index",
@ -807,8 +810,6 @@ class RegistrationStore(
columns=["creation_ts"],
)
self._account_validity = hs.config.account_validity
# we no longer use refresh tokens, but it's possible that some people
# might have a background update queued to build this index. Just
# clear the background update.
@ -822,17 +823,6 @@ class RegistrationStore(
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
)
# Create a background job for culling expired 3PID validity tokens
def start_cull():
# run as a background process to make sure that the database transactions
# have a logcontext to report to
return run_as_background_process(
"cull_expired_threepid_validation_tokens",
self.cull_expired_threepid_validation_tokens,
)
hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
@defer.inlineCallbacks
def _background_update_set_deactivated_flag(self, progress, batch_size):
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
@ -894,6 +884,54 @@ class RegistrationStore(
return nb_processed
@defer.inlineCallbacks
def _bg_user_threepids_grandfather(self, progress, batch_size):
"""We now track which identity servers a user binds their 3PID to, so
we need to handle the case of existing bindings where we didn't track
this.
We do this by grandfathering in existing user threepids assuming that
they used one of the server configured trusted identity servers.
"""
id_servers = set(self.config.trusted_third_party_id_servers)
def _bg_user_threepids_grandfather_txn(txn):
sql = """
INSERT INTO user_threepid_id_server
(user_id, medium, address, id_server)
SELECT user_id, medium, address, ?
FROM user_threepids
"""
txn.executemany(sql, [(id_server,) for id_server in id_servers])
if id_servers:
yield self.runInteraction(
"_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
)
yield self._end_background_update("user_threepids_grandfather")
return 1
class RegistrationStore(RegistrationBackgroundUpdateStore):
def __init__(self, db_conn, hs):
super(RegistrationStore, self).__init__(db_conn, hs)
self._account_validity = hs.config.account_validity
# Create a background job for culling expired 3PID validity tokens
def start_cull():
# run as a background process to make sure that the database transactions
# have a logcontext to report to
return run_as_background_process(
"cull_expired_threepid_validation_tokens",
self.cull_expired_threepid_validation_tokens,
)
hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
@defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
"""Adds an access token for the given user.
@ -1242,36 +1280,6 @@ class RegistrationStore(
desc="get_users_pending_deactivation",
)
@defer.inlineCallbacks
def _bg_user_threepids_grandfather(self, progress, batch_size):
"""We now track which identity servers a user binds their 3PID to, so
we need to handle the case of existing bindings where we didn't track
this.
We do this by grandfathering in existing user threepids assuming that
they used one of the server configured trusted identity servers.
"""
id_servers = set(self.config.trusted_third_party_id_servers)
def _bg_user_threepids_grandfather_txn(txn):
sql = """
INSERT INTO user_threepid_id_server
(user_id, medium, address, id_server)
SELECT user_id, medium, address, ?
FROM user_threepids
"""
txn.executemany(sql, [(id_server,) for id_server in id_servers])
if id_servers:
yield self.runInteraction(
"_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
)
yield self._end_background_update("user_threepids_grandfather")
return 1
def validate_threepid_session(self, session_id, client_secret, token, current_ts):
"""Attempt to validate a threepid session using a token
@ -1463,17 +1471,6 @@ class RegistrationStore(
self.clock.time_msec(),
)
def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
self._simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
updatevalues={"deactivated": 1 if deactivated else 0},
)
self._invalidate_cache_and_stream(
txn, self.get_user_deactivated_status, (user_id,)
)
@defer.inlineCallbacks
def set_user_deactivated_status(self, user_id, deactivated):
"""Set the `deactivated` property for the provided user to the provided value.
@ -1489,3 +1486,14 @@ class RegistrationStore(
user_id,
deactivated,
)
def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
self._simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
updatevalues={"deactivated": 1 if deactivated else 0},
)
self._invalidate_cache_and_stream(
txn, self.get_user_deactivated_status, (user_id,)
)

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -16,6 +17,7 @@
import collections
import logging
import re
from typing import Optional, Tuple
from canonicaljson import json
@ -24,6 +26,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.storage.search import SearchStore
from synapse.types import ThirdPartyInstanceID
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
logger = logging.getLogger(__name__)
@ -63,103 +66,196 @@ class RoomWorkerStore(SQLBaseStore):
desc="get_public_room_ids",
)
@cached(num_args=2, max_entries=100)
def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
"""Get pulbic rooms for a particular list, or across all lists.
def count_public_rooms(self, network_tuple, ignore_non_federatable):
"""Counts the number of public rooms as tracked in the room_stats_current
and room_stats_state table.
Args:
stream_id (int)
network_tuple (ThirdPartyInstanceID): The list to use (None, None)
means the main list, None means all lsits.
network_tuple (ThirdPartyInstanceID|None)
ignore_non_federatable (bool): If true filters out non-federatable rooms
"""
return self.runInteraction(
"get_public_room_ids_at_stream_id",
self.get_public_room_ids_at_stream_id_txn,
stream_id,
network_tuple=network_tuple,
)
def get_public_room_ids_at_stream_id_txn(self, txn, stream_id, network_tuple):
return {
rm
for rm, vis in self.get_published_at_stream_id_txn(
txn, stream_id, network_tuple=network_tuple
).items()
if vis
}
def _count_public_rooms_txn(txn):
query_args = []
def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
if network_tuple:
# We want to get from a particular list. No aggregation required.
sql = """
SELECT room_id, visibility FROM public_room_list_stream
INNER JOIN (
SELECT room_id, max(stream_id) AS stream_id
FROM public_room_list_stream
WHERE stream_id <= ? %s
GROUP BY room_id
) grouped USING (room_id, stream_id)
if network_tuple:
if network_tuple.appservice_id:
published_sql = """
SELECT room_id from appservice_room_list
WHERE appservice_id = ? AND network_id = ?
"""
query_args.append(network_tuple.appservice_id)
query_args.append(network_tuple.network_id)
else:
published_sql = """
SELECT room_id FROM rooms WHERE is_public
"""
else:
published_sql = """
SELECT room_id FROM rooms WHERE is_public
UNION SELECT room_id from appservice_room_list
"""
if network_tuple.appservice_id is not None:
txn.execute(
sql % ("AND appservice_id = ? AND network_id = ?",),
(stream_id, network_tuple.appservice_id, network_tuple.network_id),
sql = """
SELECT
COALESCE(COUNT(*), 0)
FROM (
%(published_sql)s
) published
INNER JOIN room_stats_state USING (room_id)
INNER JOIN room_stats_current USING (room_id)
WHERE
(
join_rules = 'public' OR history_visibility = 'world_readable'
)
AND joined_members > 0
""" % {
"published_sql": published_sql
}
txn.execute(sql, query_args)
return txn.fetchone()[0]
return self.runInteraction("count_public_rooms", _count_public_rooms_txn)
@defer.inlineCallbacks
def get_largest_public_rooms(
self,
network_tuple: Optional[ThirdPartyInstanceID],
search_filter: Optional[dict],
limit: Optional[int],
bounds: Optional[Tuple[int, str]],
forwards: bool,
ignore_non_federatable: bool = False,
):
"""Gets the largest public rooms (where largest is in terms of joined
members, as tracked in the statistics table).
Args:
network_tuple
search_filter
limit: Maxmimum number of rows to return, unlimited otherwise.
bounds: An uppoer or lower bound to apply to result set if given,
consists of a joined member count and room_id (these are
excluded from result set).
forwards: true iff going forwards, going backwards otherwise
ignore_non_federatable: If true filters out non-federatable rooms.
Returns:
Rooms in order: biggest number of joined users first.
We then arbitrarily use the room_id as a tie breaker.
"""
where_clauses = []
query_args = []
if network_tuple:
if network_tuple.appservice_id:
published_sql = """
SELECT room_id from appservice_room_list
WHERE appservice_id = ? AND network_id = ?
"""
query_args.append(network_tuple.appservice_id)
query_args.append(network_tuple.network_id)
else:
published_sql = """
SELECT room_id FROM rooms WHERE is_public
"""
else:
published_sql = """
SELECT room_id FROM rooms WHERE is_public
UNION SELECT room_id from appservice_room_list
"""
# Work out the bounds if we're given them, these bounds look slightly
# odd, but are designed to help query planner use indices by pulling
# out a common bound.
if bounds:
last_joined_members, last_room_id = bounds
if forwards:
where_clauses.append(
"""
joined_members <= ? AND (
joined_members < ? OR room_id < ?
)
"""
)
else:
txn.execute(sql % ("AND appservice_id IS NULL",), (stream_id,))
return dict(txn)
else:
# We want to get from all lists, so we need to aggregate the results
where_clauses.append(
"""
joined_members >= ? AND (
joined_members > ? OR room_id > ?
)
"""
)
logger.info("Executing full list")
query_args += [last_joined_members, last_joined_members, last_room_id]
sql = """
SELECT room_id, visibility
FROM public_room_list_stream
INNER JOIN (
SELECT
room_id, max(stream_id) AS stream_id, appservice_id,
network_id
FROM public_room_list_stream
WHERE stream_id <= ?
GROUP BY room_id, appservice_id, network_id
) grouped USING (room_id, stream_id)
if ignore_non_federatable:
where_clauses.append("is_federatable")
if search_filter and search_filter.get("generic_search_term", None):
search_term = "%" + search_filter["generic_search_term"] + "%"
where_clauses.append(
"""
(
name LIKE ?
OR topic LIKE ?
OR canonical_alias LIKE ?
)
"""
)
query_args += [search_term, search_term, search_term]
where_clause = ""
if where_clauses:
where_clause = " AND " + " AND ".join(where_clauses)
sql = """
SELECT
room_id, name, topic, canonical_alias, joined_members,
avatar, history_visibility, joined_members, guest_access
FROM (
%(published_sql)s
) published
INNER JOIN room_stats_state USING (room_id)
INNER JOIN room_stats_current USING (room_id)
WHERE
(
join_rules = 'public' OR history_visibility = 'world_readable'
)
AND joined_members > 0
%(where_clause)s
ORDER BY joined_members %(dir)s, room_id %(dir)s
""" % {
"published_sql": published_sql,
"where_clause": where_clause,
"dir": "DESC" if forwards else "ASC",
}
if limit is not None:
query_args.append(limit)
sql += """
LIMIT ?
"""
txn.execute(sql, (stream_id,))
def _get_largest_public_rooms_txn(txn):
txn.execute(sql, query_args)
results = {}
# A room is visible if its visible on any list.
for room_id, visibility in txn:
results[room_id] = bool(visibility) or results.get(room_id, False)
results = self.cursor_to_dict(txn)
if not forwards:
results.reverse()
return results
def get_public_room_changes(self, prev_stream_id, new_stream_id, network_tuple):
def get_public_room_changes_txn(txn):
then_rooms = self.get_public_room_ids_at_stream_id_txn(
txn, prev_stream_id, network_tuple
)
now_rooms_dict = self.get_published_at_stream_id_txn(
txn, new_stream_id, network_tuple
)
now_rooms_visible = set(rm for rm, vis in now_rooms_dict.items() if vis)
now_rooms_not_visible = set(
rm for rm, vis in now_rooms_dict.items() if not vis
)
newly_visible = now_rooms_visible - then_rooms
newly_unpublished = now_rooms_not_visible & then_rooms
return newly_visible, newly_unpublished
return self.runInteraction(
"get_public_room_changes", get_public_room_changes_txn
ret_val = yield self.runInteraction(
"get_largest_public_rooms", _get_largest_public_rooms_txn
)
defer.returnValue(ret_val)
@cached(max_entries=10000)
def is_room_blocked(self, room_id):

View file

@ -27,12 +27,14 @@ from synapse.api.constants import EventTypes, Membership
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.engines import Sqlite3Engine
from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from synapse.util.metrics import Measure
from synapse.util.stringutils import to_ascii
logger = logging.getLogger(__name__)
@ -483,6 +485,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
return result
@defer.inlineCallbacks
def get_joined_users_from_state(self, room_id, state_entry):
state_group = state_entry.state_group
if not state_group:
@ -492,9 +495,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# To do this we set the state_group to a new object as object() != object()
state_group = object()
return self._get_joined_users_from_context(
room_id, state_group, state_entry.state, context=state_entry
)
with Measure(self._clock, "get_joined_users_from_state"):
return (
yield self._get_joined_users_from_context(
room_id, state_group, state_entry.state, context=state_entry
)
)
@cachedInlineCallbacks(
num_args=2, cache_context=True, iterable=True, max_entries=100000
@ -567,25 +573,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
missing_member_event_ids.append(event_id)
if missing_member_event_ids:
rows = yield self._simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=missing_member_event_ids,
retcols=("user_id", "display_name", "avatar_url"),
keyvalues={"membership": Membership.JOIN},
batch_size=500,
desc="_get_joined_users_from_context",
)
users_in_room.update(
{
to_ascii(row["user_id"]): ProfileInfo(
avatar_url=to_ascii(row["avatar_url"]),
display_name=to_ascii(row["display_name"]),
)
for row in rows
}
event_to_memberships = yield self._get_joined_profiles_from_event_ids(
missing_member_event_ids
)
users_in_room.update((row for row in event_to_memberships.values() if row))
if event is not None and event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
@ -597,6 +588,47 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return users_in_room
@cached(max_entries=10000)
def _get_joined_profile_from_event_id(self, event_id):
raise NotImplementedError()
@cachedList(
cached_method_name="_get_joined_profile_from_event_id",
list_name="event_ids",
inlineCallbacks=True,
)
def _get_joined_profiles_from_event_ids(self, event_ids):
"""For given set of member event_ids check if they point to a join
event and if so return the associated user and profile info.
Args:
event_ids (Iterable[str]): The member event IDs to lookup
Returns:
Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
to `user_id` and ProfileInfo (or None if not join event).
"""
rows = yield self._simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=event_ids,
retcols=("user_id", "display_name", "avatar_url", "event_id"),
keyvalues={"membership": Membership.JOIN},
batch_size=500,
desc="_get_membership_from_event_ids",
)
return {
row["event_id"]: (
row["user_id"],
ProfileInfo(
avatar_url=row["avatar_url"], display_name=row["display_name"]
),
)
for row in rows
}
@cachedInlineCallbacks(max_entries=10000)
def is_host_joined(self, room_id, host):
if "%" in host or "_" in host:
@ -669,6 +701,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return True
@defer.inlineCallbacks
def get_joined_hosts(self, room_id, state_entry):
state_group = state_entry.state_group
if not state_group:
@ -678,9 +711,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# To do this we set the state_group to a new object as object() != object()
state_group = object()
return self._get_joined_hosts(
room_id, state_group, state_entry.state, state_entry=state_entry
)
with Measure(self._clock, "get_joined_hosts"):
return (
yield self._get_joined_hosts(
room_id, state_group, state_entry.state, state_entry=state_entry
)
)
@cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
# @defer.inlineCallbacks
@ -785,9 +821,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return set(room_ids)
class RoomMemberStore(RoomMemberWorkerStore):
class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
def __init__(self, db_conn, hs):
super(RoomMemberStore, self).__init__(db_conn, hs)
super(RoomMemberBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
)
@ -803,112 +839,6 @@ class RoomMemberStore(RoomMemberWorkerStore):
where_clause="forgotten = 1",
)
def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database.
"""
self._simple_insert_many_txn(
txn,
table="room_memberships",
values=[
{
"event_id": event.event_id,
"user_id": event.state_key,
"sender": event.user_id,
"room_id": event.room_id,
"membership": event.membership,
"display_name": event.content.get("displayname", None),
"avatar_url": event.content.get("avatar_url", None),
}
for event in events
],
)
for event in events:
txn.call_after(
self._membership_stream_cache.entity_has_changed,
event.state_key,
event.internal_metadata.stream_ordering,
)
txn.call_after(
self.get_invited_rooms_for_user.invalidate, (event.state_key,)
)
# We update the local_invites table only if the event is "current",
# i.e., its something that has just happened. If the event is an
# outlier it is only current if its an "out of band membership",
# like a remote invite or a rejection of a remote invite.
is_new_state = not backfilled and (
not event.internal_metadata.is_outlier()
or event.internal_metadata.is_out_of_band_membership()
)
is_mine = self.hs.is_mine_id(event.state_key)
if is_new_state and is_mine:
if event.membership == Membership.INVITE:
self._simple_insert_txn(
txn,
table="local_invites",
values={
"event_id": event.event_id,
"invitee": event.state_key,
"inviter": event.sender,
"room_id": event.room_id,
"stream_id": event.internal_metadata.stream_ordering,
},
)
else:
sql = (
"UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
txn.execute(
sql,
(
event.internal_metadata.stream_ordering,
event.event_id,
event.room_id,
event.state_key,
),
)
@defer.inlineCallbacks
def locally_reject_invite(self, user_id, room_id):
sql = (
"UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
def f(txn, stream_ordering):
txn.execute(sql, (stream_ordering, True, room_id, user_id))
with self._stream_id_gen.get_next() as stream_ordering:
yield self.runInteraction("locally_reject_invite", f, stream_ordering)
def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
sql = (
"UPDATE"
" room_memberships"
" SET"
" forgotten = 1"
" WHERE"
" user_id = ?"
" AND"
" room_id = ?"
)
txn.execute(sql, (user_id, room_id))
self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id))
self._invalidate_cache_and_stream(
txn, self.get_forgotten_rooms_for_user, (user_id,)
)
return self.runInteraction("forget_membership", f)
@defer.inlineCallbacks
def _background_add_membership_profile(self, progress, batch_size):
target_min_stream_id = progress.get(
@ -1043,6 +973,117 @@ class RoomMemberStore(RoomMemberWorkerStore):
return row_count
class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def __init__(self, db_conn, hs):
super(RoomMemberStore, self).__init__(db_conn, hs)
def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database.
"""
self._simple_insert_many_txn(
txn,
table="room_memberships",
values=[
{
"event_id": event.event_id,
"user_id": event.state_key,
"sender": event.user_id,
"room_id": event.room_id,
"membership": event.membership,
"display_name": event.content.get("displayname", None),
"avatar_url": event.content.get("avatar_url", None),
}
for event in events
],
)
for event in events:
txn.call_after(
self._membership_stream_cache.entity_has_changed,
event.state_key,
event.internal_metadata.stream_ordering,
)
txn.call_after(
self.get_invited_rooms_for_user.invalidate, (event.state_key,)
)
# We update the local_invites table only if the event is "current",
# i.e., its something that has just happened. If the event is an
# outlier it is only current if its an "out of band membership",
# like a remote invite or a rejection of a remote invite.
is_new_state = not backfilled and (
not event.internal_metadata.is_outlier()
or event.internal_metadata.is_out_of_band_membership()
)
is_mine = self.hs.is_mine_id(event.state_key)
if is_new_state and is_mine:
if event.membership == Membership.INVITE:
self._simple_insert_txn(
txn,
table="local_invites",
values={
"event_id": event.event_id,
"invitee": event.state_key,
"inviter": event.sender,
"room_id": event.room_id,
"stream_id": event.internal_metadata.stream_ordering,
},
)
else:
sql = (
"UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
txn.execute(
sql,
(
event.internal_metadata.stream_ordering,
event.event_id,
event.room_id,
event.state_key,
),
)
@defer.inlineCallbacks
def locally_reject_invite(self, user_id, room_id):
sql = (
"UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
def f(txn, stream_ordering):
txn.execute(sql, (stream_ordering, True, room_id, user_id))
with self._stream_id_gen.get_next() as stream_ordering:
yield self.runInteraction("locally_reject_invite", f, stream_ordering)
def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
sql = (
"UPDATE"
" room_memberships"
" SET"
" forgotten = 1"
" WHERE"
" user_id = ?"
" AND"
" room_id = ?"
)
txn.execute(sql, (user_id, room_id))
self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id))
self._invalidate_cache_and_stream(
txn, self.get_forgotten_rooms_for_user, (user_id,)
)
return self.runInteraction("forget_membership", f)
class _JoinedHostsCache(object):
"""Cache for joined hosts in a room that is optimised to handle updates
via state deltas.

View file

@ -0,0 +1,20 @@
/* Copyright 2019 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- these tables are never used.
DROP TABLE IF EXISTS room_names;
DROP TABLE IF EXISTS topics;
DROP TABLE IF EXISTS history_visibility;
DROP TABLE IF EXISTS guest_access;

View file

@ -0,0 +1,16 @@
/* Copyright 2019 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE INDEX public_room_list_stream_network ON public_room_list_stream (appservice_id, network_id, room_id);

View file

@ -0,0 +1,52 @@
import logging
from synapse.storage.engines import PostgresEngine
logger = logging.getLogger(__name__)
"""
This migration updates the user_filters table as follows:
- drops any (user_id, filter_id) duplicates
- makes the columns NON-NULLable
- turns the index into a UNIQUE index
"""
def run_upgrade(cur, database_engine, *args, **kwargs):
pass
def run_create(cur, database_engine, *args, **kwargs):
if isinstance(database_engine, PostgresEngine):
select_clause = """
SELECT DISTINCT ON (user_id, filter_id) user_id, filter_id, filter_json
FROM user_filters
"""
else:
select_clause = """
SELECT * FROM user_filters GROUP BY user_id, filter_id
"""
sql = """
DROP TABLE IF EXISTS user_filters_migration;
DROP INDEX IF EXISTS user_filters_unique;
CREATE TABLE user_filters_migration (
user_id TEXT NOT NULL,
filter_id BIGINT NOT NULL,
filter_json BYTEA NOT NULL
);
INSERT INTO user_filters_migration (user_id, filter_id, filter_json)
%s;
CREATE UNIQUE INDEX user_filters_unique ON user_filters_migration
(user_id, filter_id);
DROP TABLE user_filters;
ALTER TABLE user_filters_migration RENAME TO user_filters;
""" % (
select_clause,
)
if isinstance(database_engine, PostgresEngine):
cur.execute(sql)
else:
cur.executescript(sql)

View file

@ -36,7 +36,7 @@ SearchEntry = namedtuple(
)
class SearchStore(BackgroundUpdateStore):
class SearchBackgroundUpdateStore(BackgroundUpdateStore):
EVENT_SEARCH_UPDATE_NAME = "event_search"
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
@ -44,7 +44,7 @@ class SearchStore(BackgroundUpdateStore):
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
def __init__(self, db_conn, hs):
super(SearchStore, self).__init__(db_conn, hs)
super(SearchBackgroundUpdateStore, self).__init__(db_conn, hs)
if not hs.config.enable_search:
return
@ -289,29 +289,6 @@ class SearchStore(BackgroundUpdateStore):
return num_rows
def store_event_search_txn(self, txn, event, key, value):
"""Add event to the search table
Args:
txn (cursor):
event (EventBase):
key (str):
value (str):
"""
self.store_search_entries_txn(
txn,
(
SearchEntry(
key=key,
value=value,
event_id=event.event_id,
room_id=event.room_id,
stream_ordering=event.internal_metadata.stream_ordering,
origin_server_ts=event.origin_server_ts,
),
),
)
def store_search_entries_txn(self, txn, entries):
"""Add entries to the search table
@ -358,6 +335,34 @@ class SearchStore(BackgroundUpdateStore):
# This should be unreachable.
raise Exception("Unrecognized database engine")
class SearchStore(SearchBackgroundUpdateStore):
def __init__(self, db_conn, hs):
super(SearchStore, self).__init__(db_conn, hs)
def store_event_search_txn(self, txn, event, key, value):
"""Add event to the search table
Args:
txn (cursor):
event (EventBase):
key (str):
value (str):
"""
self.store_search_entries_txn(
txn,
(
SearchEntry(
key=key,
value=value,
event_id=event.event_id,
room_id=event.room_id,
stream_ordering=event.internal_metadata.stream_ordering,
origin_server_ts=event.origin_server_ts,
),
),
)
@defer.inlineCallbacks
def search_msgs(self, room_ids, search_term, keys):
"""Performs a full text search over events with given keys.

View file

@ -353,8 +353,158 @@ class StateFilter(object):
return member_filter, non_member_filter
class StateGroupBackgroundUpdateStore(SQLBaseStore):
"""Defines functions related to state groups needed to run the state backgroud
updates.
"""
def _count_state_group_hops_txn(self, txn, state_group):
"""Given a state group, count how many hops there are in the tree.
This is used to ensure the delta chains don't get too long.
"""
if isinstance(self.database_engine, PostgresEngine):
sql = """
WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
SELECT prev_state_group FROM state_group_edges e, state s
WHERE s.state_group = e.state_group
)
SELECT count(*) FROM state;
"""
txn.execute(sql, (state_group,))
row = txn.fetchone()
if row and row[0]:
return row[0]
else:
return 0
else:
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
next_group = state_group
count = 0
while next_group:
next_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
retcol="prev_state_group",
allow_none=True,
)
if next_group:
count += 1
return count
def _get_state_groups_from_groups_txn(
self, txn, groups, state_filter=StateFilter.all()
):
results = {group: {} for group in groups}
where_clause, where_args = state_filter.make_sql_filter_clause()
# Unless the filter clause is empty, we're going to append it after an
# existing where clause
if where_clause:
where_clause = " AND (%s)" % (where_clause,)
if isinstance(self.database_engine, PostgresEngine):
# Temporarily disable sequential scans in this transaction. This is
# a temporary hack until we can add the right indices in
txn.execute("SET LOCAL enable_seqscan=off")
# The below query walks the state_group tree so that the "state"
# table includes all state_groups in the tree. It then joins
# against `state_groups_state` to fetch the latest state.
# It assumes that previous state groups are always numerically
# lesser.
# The PARTITION is used to get the event_id in the greatest state
# group for the given type, state_key.
# This may return multiple rows per (type, state_key), but last_value
# should be the same.
sql = """
WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
SELECT prev_state_group FROM state_group_edges e, state s
WHERE s.state_group = e.state_group
)
SELECT DISTINCT type, state_key, last_value(event_id) OVER (
PARTITION BY type, state_key ORDER BY state_group ASC
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS event_id FROM state_groups_state
WHERE state_group IN (
SELECT state_group FROM state
)
"""
for group in groups:
args = [group]
args.extend(where_args)
txn.execute(sql + where_clause, args)
for row in txn:
typ, state_key, event_id = row
key = (typ, state_key)
results[group][key] = event_id
else:
max_entries_returned = state_filter.max_entries_returned()
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
for group in groups:
next_group = group
while next_group:
# We did this before by getting the list of group ids, and
# then passing that list to sqlite to get latest event for
# each (type, state_key). However, that was terribly slow
# without the right indices (which we can't add until
# after we finish deduping state, which requires this func)
args = [next_group]
args.extend(where_args)
txn.execute(
"SELECT type, state_key, event_id FROM state_groups_state"
" WHERE state_group = ? " + where_clause,
args,
)
results[group].update(
((typ, state_key), event_id)
for typ, state_key, event_id in txn
if (typ, state_key) not in results[group]
)
# If the number of entries in the (type,state_key)->event_id dict
# matches the number of (type,state_keys) types we were searching
# for, then we must have found them all, so no need to go walk
# further down the tree... UNLESS our types filter contained
# wildcards (i.e. Nones) in which case we have to do an exhaustive
# search
if (
max_entries_returned is not None
and len(results[group]) == max_entries_returned
):
break
next_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
retcol="prev_state_group",
allow_none=True,
)
return results
# this inherits from EventsWorkerStore because it calls self.get_events
class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
class StateGroupWorkerStore(
EventsWorkerStore, StateGroupBackgroundUpdateStore, SQLBaseStore
):
"""The parts of StateGroupStore that can be called from workers.
"""
@ -694,107 +844,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return results
def _get_state_groups_from_groups_txn(
self, txn, groups, state_filter=StateFilter.all()
):
results = {group: {} for group in groups}
where_clause, where_args = state_filter.make_sql_filter_clause()
# Unless the filter clause is empty, we're going to append it after an
# existing where clause
if where_clause:
where_clause = " AND (%s)" % (where_clause,)
if isinstance(self.database_engine, PostgresEngine):
# Temporarily disable sequential scans in this transaction. This is
# a temporary hack until we can add the right indices in
txn.execute("SET LOCAL enable_seqscan=off")
# The below query walks the state_group tree so that the "state"
# table includes all state_groups in the tree. It then joins
# against `state_groups_state` to fetch the latest state.
# It assumes that previous state groups are always numerically
# lesser.
# The PARTITION is used to get the event_id in the greatest state
# group for the given type, state_key.
# This may return multiple rows per (type, state_key), but last_value
# should be the same.
sql = """
WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
SELECT prev_state_group FROM state_group_edges e, state s
WHERE s.state_group = e.state_group
)
SELECT DISTINCT type, state_key, last_value(event_id) OVER (
PARTITION BY type, state_key ORDER BY state_group ASC
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS event_id FROM state_groups_state
WHERE state_group IN (
SELECT state_group FROM state
)
"""
for group in groups:
args = [group]
args.extend(where_args)
txn.execute(sql + where_clause, args)
for row in txn:
typ, state_key, event_id = row
key = (typ, state_key)
results[group][key] = event_id
else:
max_entries_returned = state_filter.max_entries_returned()
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
for group in groups:
next_group = group
while next_group:
# We did this before by getting the list of group ids, and
# then passing that list to sqlite to get latest event for
# each (type, state_key). However, that was terribly slow
# without the right indices (which we can't add until
# after we finish deduping state, which requires this func)
args = [next_group]
args.extend(where_args)
txn.execute(
"SELECT type, state_key, event_id FROM state_groups_state"
" WHERE state_group = ? " + where_clause,
args,
)
results[group].update(
((typ, state_key), event_id)
for typ, state_key, event_id in txn
if (typ, state_key) not in results[group]
)
# If the number of entries in the (type,state_key)->event_id dict
# matches the number of (type,state_keys) types we were searching
# for, then we must have found them all, so no need to go walk
# further down the tree... UNLESS our types filter contained
# wildcards (i.e. Nones) in which case we have to do an exhaustive
# search
if (
max_entries_returned is not None
and len(results[group]) == max_entries_returned
):
break
next_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
retcol="prev_state_group",
allow_none=True,
)
return results
@defer.inlineCallbacks
def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
"""Given a list of event_ids and type tuples, return a list of state
@ -1238,66 +1287,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return self.runInteraction("store_state_group", _store_state_group_txn)
def _count_state_group_hops_txn(self, txn, state_group):
"""Given a state group, count how many hops there are in the tree.
This is used to ensure the delta chains don't get too long.
"""
if isinstance(self.database_engine, PostgresEngine):
sql = """
WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
SELECT prev_state_group FROM state_group_edges e, state s
WHERE s.state_group = e.state_group
)
SELECT count(*) FROM state;
"""
txn.execute(sql, (state_group,))
row = txn.fetchone()
if row and row[0]:
return row[0]
else:
return 0
else:
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
next_group = state_group
count = 0
while next_group:
next_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
retcol="prev_state_group",
allow_none=True,
)
if next_group:
count += 1
return count
class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
""" Keeps track of the state at a given event.
This is done by the concept of `state groups`. Every event is a assigned
a state group (identified by an arbitrary string), which references a
collection of state events. The current state of an event is then the
collection of state events referenced by the event's state group.
Hence, every change in the current state causes a new state group to be
generated. However, if no change happens (e.g., if we get a message event
with only one parent it inherits the state group from its parent.)
There are three tables:
* `state_groups`: Stores group name, first event with in the group and
room id.
* `event_to_state_groups`: Maps events to state groups.
* `state_groups_state`: Maps state group to state events.
"""
class StateBackgroundUpdateStore(
StateGroupBackgroundUpdateStore, BackgroundUpdateStore
):
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
@ -1305,7 +1298,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
def __init__(self, db_conn, hs):
super(StateStore, self).__init__(db_conn, hs)
super(StateBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
self._background_deduplicate_state,
@ -1327,34 +1320,6 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
columns=["state_group"],
)
def _store_event_state_mappings_txn(self, txn, events_and_contexts):
state_groups = {}
for event, context in events_and_contexts:
if event.internal_metadata.is_outlier():
continue
# if the event was rejected, just give it the same state as its
# predecessor.
if context.rejected:
state_groups[event.event_id] = context.prev_group
continue
state_groups[event.event_id] = context.state_group
self._simple_insert_many_txn(
txn,
table="event_to_state_groups",
values=[
{"state_group": state_group_id, "event_id": event_id}
for event_id, state_group_id in iteritems(state_groups)
],
)
for event_id, state_group_id in iteritems(state_groups):
txn.call_after(
self._get_state_group_for_event.prefill, (event_id,), state_group_id
)
@defer.inlineCallbacks
def _background_deduplicate_state(self, progress, batch_size):
"""This background update will slowly deduplicate state by reencoding
@ -1527,3 +1492,54 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
return 1
class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore):
""" Keeps track of the state at a given event.
This is done by the concept of `state groups`. Every event is a assigned
a state group (identified by an arbitrary string), which references a
collection of state events. The current state of an event is then the
collection of state events referenced by the event's state group.
Hence, every change in the current state causes a new state group to be
generated. However, if no change happens (e.g., if we get a message event
with only one parent it inherits the state group from its parent.)
There are three tables:
* `state_groups`: Stores group name, first event with in the group and
room id.
* `event_to_state_groups`: Maps events to state groups.
* `state_groups_state`: Maps state group to state events.
"""
def __init__(self, db_conn, hs):
super(StateStore, self).__init__(db_conn, hs)
def _store_event_state_mappings_txn(self, txn, events_and_contexts):
state_groups = {}
for event, context in events_and_contexts:
if event.internal_metadata.is_outlier():
continue
# if the event was rejected, just give it the same state as its
# predecessor.
if context.rejected:
state_groups[event.event_id] = context.prev_group
continue
state_groups[event.event_id] = context.state_group
self._simple_insert_many_txn(
txn,
table="event_to_state_groups",
values=[
{"state_group": state_group_id, "event_id": event_id}
for event_id, state_group_id in iteritems(state_groups)
],
)
for event_id, state_group_id in iteritems(state_groups):
txn.call_after(
self._get_state_group_for_event.prefill, (event_id,), state_group_id
)

View file

@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
class StateDeltasStore(SQLBaseStore):
def get_current_state_deltas(self, prev_stream_id):
def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int):
"""Fetch a list of room state changes since the given stream id
Each entry in the result contains the following fields:
@ -36,15 +36,27 @@ class StateDeltasStore(SQLBaseStore):
Args:
prev_stream_id (int): point to get changes since (exclusive)
max_stream_id (int): the point that we know has been correctly persisted
- ie, an upper limit to return changes from.
Returns:
Deferred[list[dict]]: results
Deferred[tuple[int, list[dict]]: A tuple consisting of:
- the stream id which these results go up to
- list of current_state_delta_stream rows. If it is empty, we are
up to date.
"""
prev_stream_id = int(prev_stream_id)
# check we're not going backwards
assert prev_stream_id <= max_stream_id
if not self._curr_state_delta_stream_cache.has_any_entity_changed(
prev_stream_id
):
return []
# if the CSDs haven't changed between prev_stream_id and now, we
# know for certain that they haven't changed between prev_stream_id and
# max_stream_id.
return max_stream_id, []
def get_current_state_deltas_txn(txn):
# First we calculate the max stream id that will give us less than
@ -54,21 +66,29 @@ class StateDeltasStore(SQLBaseStore):
sql = """
SELECT stream_id, count(*)
FROM current_state_delta_stream
WHERE stream_id > ?
WHERE stream_id > ? AND stream_id <= ?
GROUP BY stream_id
ORDER BY stream_id ASC
LIMIT 100
"""
txn.execute(sql, (prev_stream_id,))
txn.execute(sql, (prev_stream_id, max_stream_id))
total = 0
max_stream_id = prev_stream_id
for max_stream_id, count in txn:
for stream_id, count in txn:
total += count
if total > 100:
# We arbitarily limit to 100 entries to ensure we don't
# select toooo many.
logger.debug(
"Clipping current_state_delta_stream rows to stream_id %i",
stream_id,
)
clipped_stream_id = stream_id
break
else:
# if there's no problem, we may as well go right up to the max_stream_id
clipped_stream_id = max_stream_id
# Now actually get the deltas
sql = """
@ -77,8 +97,8 @@ class StateDeltasStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
"""
txn.execute(sql, (prev_stream_id, max_stream_id))
return self.cursor_to_dict(txn)
txn.execute(sql, (prev_stream_id, clipped_stream_id))
return clipped_stream_id, self.cursor_to_dict(txn)
return self.runInteraction(
"get_current_state_deltas", get_current_state_deltas_txn

View file

@ -332,6 +332,9 @@ class StatsStore(StateDeltasStore):
def _bulk_update_stats_delta_txn(txn):
for stats_type, stats_updates in updates.items():
for stats_id, fields in stats_updates.items():
logger.info(
"Updating %s stats for %s: %s", stats_type, stats_id, fields
)
self._update_stats_delta_txn(
txn,
ts=ts,

View file

@ -32,14 +32,14 @@ logger = logging.getLogger(__name__)
TEMP_TABLE = "_temp_populate_user_directory"
class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore):
# How many records do we calculate before sending it to
# add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500
def __init__(self, db_conn, hs):
super(UserDirectoryStore, self).__init__(db_conn, hs)
super(UserDirectoryBackgroundUpdateStore, self).__init__(db_conn, hs)
self.server_name = hs.hostname
@ -452,55 +452,6 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
"update_profile_in_user_dir", _update_profile_in_user_dir_txn
)
def remove_from_user_dir(self, user_id):
def _remove_from_user_dir_txn(txn):
self._simple_delete_txn(
txn, table="user_directory", keyvalues={"user_id": user_id}
)
self._simple_delete_txn(
txn, table="user_directory_search", keyvalues={"user_id": user_id}
)
self._simple_delete_txn(
txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
)
self._simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"user_id": user_id},
)
self._simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"other_user_id": user_id},
)
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
@defer.inlineCallbacks
def get_users_in_dir_due_to_room(self, room_id):
"""Get all user_ids that are in the room directory because they're
in the given room_id
"""
user_ids_share_pub = yield self._simple_select_onecol(
table="users_in_public_rooms",
keyvalues={"room_id": room_id},
retcol="user_id",
desc="get_users_in_dir_due_to_room",
)
user_ids_share_priv = yield self._simple_select_onecol(
table="users_who_share_private_rooms",
keyvalues={"room_id": room_id},
retcol="other_user_id",
desc="get_users_in_dir_due_to_room",
)
user_ids = set(user_ids_share_pub)
user_ids.update(user_ids_share_priv)
return user_ids
def add_users_who_share_private_room(self, room_id, user_id_tuples):
"""Insert entries into the users_who_share_private_rooms table. The first
user should be a local user.
@ -551,6 +502,98 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
"add_users_in_public_rooms", _add_users_in_public_rooms_txn
)
def delete_all_from_user_dir(self):
"""Delete the entire user directory
"""
def _delete_all_from_user_dir_txn(txn):
txn.execute("DELETE FROM user_directory")
txn.execute("DELETE FROM user_directory_search")
txn.execute("DELETE FROM users_in_public_rooms")
txn.execute("DELETE FROM users_who_share_private_rooms")
txn.call_after(self.get_user_in_directory.invalidate_all)
return self.runInteraction(
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
)
@cached()
def get_user_in_directory(self, user_id):
return self._simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
retcols=("display_name", "avatar_url"),
allow_none=True,
desc="get_user_in_directory",
)
def update_user_directory_stream_pos(self, stream_id):
return self._simple_update_one(
table="user_directory_stream_pos",
keyvalues={},
updatevalues={"stream_id": stream_id},
desc="update_user_directory_stream_pos",
)
class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# How many records do we calculate before sending it to
# add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500
def __init__(self, db_conn, hs):
super(UserDirectoryStore, self).__init__(db_conn, hs)
def remove_from_user_dir(self, user_id):
def _remove_from_user_dir_txn(txn):
self._simple_delete_txn(
txn, table="user_directory", keyvalues={"user_id": user_id}
)
self._simple_delete_txn(
txn, table="user_directory_search", keyvalues={"user_id": user_id}
)
self._simple_delete_txn(
txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
)
self._simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"user_id": user_id},
)
self._simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"other_user_id": user_id},
)
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
@defer.inlineCallbacks
def get_users_in_dir_due_to_room(self, room_id):
"""Get all user_ids that are in the room directory because they're
in the given room_id
"""
user_ids_share_pub = yield self._simple_select_onecol(
table="users_in_public_rooms",
keyvalues={"room_id": room_id},
retcol="user_id",
desc="get_users_in_dir_due_to_room",
)
user_ids_share_priv = yield self._simple_select_onecol(
table="users_who_share_private_rooms",
keyvalues={"room_id": room_id},
retcol="other_user_id",
desc="get_users_in_dir_due_to_room",
)
user_ids = set(user_ids_share_pub)
user_ids.update(user_ids_share_priv)
return user_ids
def remove_user_who_share_room(self, user_id, room_id):
"""
Deletes entries in the users_who_share_*_rooms table. The first
@ -637,31 +680,6 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
return [room_id for room_id, in rows]
def delete_all_from_user_dir(self):
"""Delete the entire user directory
"""
def _delete_all_from_user_dir_txn(txn):
txn.execute("DELETE FROM user_directory")
txn.execute("DELETE FROM user_directory_search")
txn.execute("DELETE FROM users_in_public_rooms")
txn.execute("DELETE FROM users_who_share_private_rooms")
txn.call_after(self.get_user_in_directory.invalidate_all)
return self.runInteraction(
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
)
@cached()
def get_user_in_directory(self, user_id):
return self._simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
retcols=("display_name", "avatar_url"),
allow_none=True,
desc="get_user_in_directory",
)
def get_user_directory_stream_pos(self):
return self._simple_select_one_onecol(
table="user_directory_stream_pos",
@ -670,14 +688,6 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
desc="get_user_directory_stream_pos",
)
def update_user_directory_stream_pos(self, stream_id):
return self._simple_update_one(
table="user_directory_stream_pos",
keyvalues={},
updatevalues={"stream_id": stream_id},
desc="update_user_directory_stream_pos",
)
@defer.inlineCallbacks
def search_user_dir(self, user_id, search_term, limit):
"""Searches for users in directory

View file

@ -318,6 +318,7 @@ class StreamToken(
)
):
_SEPARATOR = "_"
START = None # type: StreamToken
@classmethod
def from_string(cls, string):
@ -402,7 +403,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
followed by the "stream_ordering" id of the event it comes after.
"""
__slots__ = []
__slots__ = [] # type: list
@classmethod
def parse(cls, string):

View file

@ -13,9 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import logging
from contextlib import contextmanager
from typing import Dict, Sequence, Set, Union
from six.moves import range
@ -213,7 +215,9 @@ class Linearizer(object):
# the first element is the number of things executing, and
# the second element is an OrderedDict, where the keys are deferreds for the
# things blocked from executing.
self.key_to_defer = {}
self.key_to_defer = (
{}
) # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]]
def queue(self, key):
# we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
@ -340,10 +344,10 @@ class ReadWriteLock(object):
def __init__(self):
# Latest readers queued
self.key_to_current_readers = {}
self.key_to_current_readers = {} # type: Dict[str, Set[defer.Deferred]]
# Latest writer queued
self.key_to_current_writer = {}
self.key_to_current_writer = {} # type: Dict[str, defer.Deferred]
@defer.inlineCallbacks
def read(self, key):

View file

@ -16,6 +16,7 @@
import logging
import os
from typing import Dict
import six
from six.moves import intern
@ -37,7 +38,7 @@ def get_cache_factor_for(cache_name):
caches_by_name = {}
collectors_by_name = {}
collectors_by_name = {} # type: Dict
cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])

View file

@ -18,10 +18,12 @@ import inspect
import logging
import threading
from collections import namedtuple
from typing import Any, cast
from six import itervalues
from prometheus_client import Gauge
from typing_extensions import Protocol
from twisted.internet import defer
@ -37,6 +39,18 @@ from . import register_cache
logger = logging.getLogger(__name__)
class _CachedFunction(Protocol):
invalidate = None # type: Any
invalidate_all = None # type: Any
invalidate_many = None # type: Any
prefill = None # type: Any
cache = None # type: Any
num_args = None # type: Any
def __name__(self):
...
cache_pending_metric = Gauge(
"synapse_util_caches_cache_pending",
"Number of lookups currently pending for this cache",
@ -245,7 +259,9 @@ class Cache(object):
class _CacheDescriptorBase(object):
def __init__(self, orig, num_args, inlineCallbacks, cache_context=False):
def __init__(
self, orig: _CachedFunction, num_args, inlineCallbacks, cache_context=False
):
self.orig = orig
if inlineCallbacks:
@ -404,7 +420,7 @@ class CacheDescriptor(_CacheDescriptorBase):
return tuple(get_cache_key_gen(args, kwargs))
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
def _wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
@ -440,6 +456,8 @@ class CacheDescriptor(_CacheDescriptorBase):
return make_deferred_yieldable(observer)
wrapped = cast(_CachedFunction, _wrapped)
if self.num_args == 1:
wrapped.invalidate = lambda key: cache.invalidate(key[0])
wrapped.prefill = lambda key, val: cache.prefill(key[0], val)

View file

@ -1,3 +1,5 @@
from typing import Dict
from six import itervalues
SENTINEL = object()
@ -12,7 +14,7 @@ class TreeCache(object):
def __init__(self):
self.size = 0
self.root = {}
self.root = {} # type: Dict
def __setitem__(self, key, value):
return self.set(key, value)

View file

@ -60,12 +60,14 @@ in_flight = InFlightGauge(
)
def measure_func(name):
def measure_func(name=None):
def wrapper(func):
block_name = func.__name__ if name is None else name
@wraps(func)
@defer.inlineCallbacks
def measured_func(self, *args, **kwargs):
with Measure(self.clock, name):
with Measure(self.clock, block_name):
r = yield func(self, *args, **kwargs)
return r

View file

@ -54,5 +54,5 @@ def load_python_module(location: str):
if spec is None:
raise Exception("Unable to load module at %s" % (location,))
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
spec.loader.exec_module(mod) # type: ignore
return mod