Merge remote-tracking branch 'upstream/release-v1.46'

This commit is contained in:
Tulir Asokan 2021-10-27 15:42:34 +03:00
commit cf45cfd314
172 changed files with 5549 additions and 2350 deletions

View file

@ -185,19 +185,26 @@ class ApplicationServicesHandler:
new_token: Optional[int],
users: Optional[Collection[Union[str, UserID]]] = None,
) -> None:
"""This is called by the notifier in the background
when a ephemeral event handled by the homeserver.
"""
This is called by the notifier in the background when an ephemeral event is handled
by the homeserver.
This will determine which appservices
are interested in the event, and submit them.
Events will only be pushed to appservices
that have opted into ephemeral events
This will determine which appservices are interested in the event, and submit them.
Args:
stream_key: The stream the event came from.
new_token: The latest stream token
users: The user(s) involved with the event.
`stream_key` can be "typing_key", "receipt_key" or "presence_key". Any other
value for `stream_key` will cause this function to return early.
Ephemeral events will only be pushed to appservices that have opted into
them.
Appservices will only receive ephemeral events that fall within their
registered user and room namespaces.
new_token: The latest stream token.
users: The users that should be informed of the new event, if any.
"""
if not self.notify_appservices:
return
@ -232,21 +239,32 @@ class ApplicationServicesHandler:
for service in services:
# Only handle typing if we have the latest token
if stream_key == "typing_key" and new_token is not None:
# Note that we don't persist the token (via set_type_stream_id_for_appservice)
# for typing_key due to performance reasons and due to their highly
# ephemeral nature.
#
# Instead we simply grab the latest typing updates in _handle_typing
# and, if they apply to this application service, send it off.
events = await self._handle_typing(service, new_token)
if events:
self.scheduler.submit_ephemeral_events_for_as(service, events)
# We don't persist the token for typing_key for performance reasons
elif stream_key == "receipt_key":
events = await self._handle_receipts(service)
if events:
self.scheduler.submit_ephemeral_events_for_as(service, events)
# Persist the latest handled stream token for this appservice
await self.store.set_type_stream_id_for_appservice(
service, "read_receipt", new_token
)
elif stream_key == "presence_key":
events = await self._handle_presence(service, users)
if events:
self.scheduler.submit_ephemeral_events_for_as(service, events)
# Persist the latest handled stream token for this appservice
await self.store.set_type_stream_id_for_appservice(
service, "presence", new_token
)
@ -254,18 +272,54 @@ class ApplicationServicesHandler:
async def _handle_typing(
self, service: ApplicationService, new_token: int
) -> List[JsonDict]:
"""
Return the typing events since the given stream token that the given application
service should receive.
First fetch all typing events between the given typing stream token (non-inclusive)
and the latest typing event stream token (inclusive). Then return only those typing
events that the given application service may be interested in.
Args:
service: The application service to check for which events it should receive.
new_token: A typing event stream token.
Returns:
A list of JSON dictionaries containing data derived from the typing events that
should be sent to the given application service.
"""
typing_source = self.event_sources.sources.typing
# Get the typing events from just before current
typing, _ = await typing_source.get_new_events_as(
service=service,
# For performance reasons, we don't persist the previous
# token in the DB and instead fetch the latest typing information
# token in the DB and instead fetch the latest typing event
# for appservices.
# TODO: It'd likely be more efficient to simply fetch the
# typing event with the given 'new_token' stream token and
# check if the given service was interested, rather than
# iterating over all typing events and only grabbing the
# latest few.
from_key=new_token - 1,
)
return typing
async def _handle_receipts(self, service: ApplicationService) -> List[JsonDict]:
"""
Return the latest read receipts that the given application service should receive.
First fetch all read receipts between the last receipt stream token that this
application service should have previously received (non-inclusive) and the
latest read receipt stream token (inclusive). Then from that set, return only
those read receipts that the given application service may be interested in.
Args:
service: The application service to check for which events it should receive.
Returns:
A list of JSON dictionaries containing data derived from the read receipts that
should be sent to the given application service.
"""
from_key = await self.store.get_type_stream_id_for_appservice(
service, "read_receipt"
)
@ -278,6 +332,22 @@ class ApplicationServicesHandler:
async def _handle_presence(
self, service: ApplicationService, users: Collection[Union[str, UserID]]
) -> List[JsonDict]:
"""
Return the latest presence updates that the given application service should receive.
First, filter the given users list to those that the application service is
interested in. Then retrieve the latest presence updates since the
the last-known previously received presence stream token for the given
application service. Return those presence updates.
Args:
service: The application service that ephemeral events are being sent to.
users: The users that should receive the presence update.
Returns:
A list of json dictionaries containing data derived from the presence events
that should be sent to the given application service.
"""
events: List[JsonDict] = []
presence_source = self.event_sources.sources.presence
from_key = await self.store.get_type_stream_id_for_appservice(
@ -290,9 +360,9 @@ class ApplicationServicesHandler:
interested = await service.is_interested_in_presence(user, self.store)
if not interested:
continue
presence_events, _ = await presence_source.get_new_events(
user=user,
service=service,
from_key=from_key,
)
time_now = self.clock.time_msec()

View file

@ -62,7 +62,6 @@ from synapse.http.server import finish_request, respond_with_html
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi
from synapse.storage.roommember import ProfileInfo
from synapse.types import JsonDict, Requester, UserID
from synapse.util import stringutils as stringutils
@ -73,6 +72,7 @@ from synapse.util.stringutils import base62_encode
from synapse.util.threepids import canonicalise_email
if TYPE_CHECKING:
from synapse.module_api import ModuleApi
from synapse.rest.client.login import LoginResponse
from synapse.server import HomeServer
@ -200,46 +200,13 @@ class AuthHandler:
self.bcrypt_rounds = hs.config.registration.bcrypt_rounds
# we can't use hs.get_module_api() here, because to do so will create an
# import loop.
#
# TODO: refactor this class to separate the lower-level stuff that
# ModuleApi can use from the higher-level stuff that uses ModuleApi, as
# better way to break the loop
account_handler = ModuleApi(hs, self)
self.password_providers = [
PasswordProvider.load(module, config, account_handler)
for module, config in hs.config.authproviders.password_providers
]
logger.info("Extra password_providers: %s", self.password_providers)
self.password_auth_provider = hs.get_password_auth_provider()
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.auth.password_enabled
self._password_localdb_enabled = hs.config.auth.password_localdb_enabled
# start out by assuming PASSWORD is enabled; we will remove it later if not.
login_types = set()
if self._password_localdb_enabled:
login_types.add(LoginType.PASSWORD)
for provider in self.password_providers:
login_types.update(provider.get_supported_login_types().keys())
if not self._password_enabled:
login_types.discard(LoginType.PASSWORD)
# Some clients just pick the first type in the list. In this case, we want
# them to use PASSWORD (rather than token or whatever), so we want to make sure
# that comes first, where it's present.
self._supported_login_types = []
if LoginType.PASSWORD in login_types:
self._supported_login_types.append(LoginType.PASSWORD)
login_types.remove(LoginType.PASSWORD)
self._supported_login_types.extend(login_types)
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`.
self._failed_uia_attempts_ratelimiter = Ratelimiter(
@ -427,11 +394,10 @@ class AuthHandler:
ui_auth_types.add(LoginType.PASSWORD)
# also allow auth from password providers
for provider in self.password_providers:
for t in provider.get_supported_login_types().keys():
if t == LoginType.PASSWORD and not self._password_enabled:
continue
ui_auth_types.add(t)
for t in self.password_auth_provider.get_supported_login_types().keys():
if t == LoginType.PASSWORD and not self._password_enabled:
continue
ui_auth_types.add(t)
# if sso is enabled, allow the user to log in via SSO iff they have a mapping
# from sso to mxid.
@ -1038,7 +1004,25 @@ class AuthHandler:
Returns:
login types
"""
return self._supported_login_types
# Load any login types registered by modules
# This is stored in the password_auth_provider so this doesn't trigger
# any callbacks
types = list(self.password_auth_provider.get_supported_login_types().keys())
# This list should include PASSWORD if (either _password_localdb_enabled is
# true or if one of the modules registered it) AND _password_enabled is true
# Also:
# Some clients just pick the first type in the list. In this case, we want
# them to use PASSWORD (rather than token or whatever), so we want to make sure
# that comes first, where it's present.
if LoginType.PASSWORD in types:
types.remove(LoginType.PASSWORD)
if self._password_enabled:
types.insert(0, LoginType.PASSWORD)
elif self._password_localdb_enabled and self._password_enabled:
types.insert(0, LoginType.PASSWORD)
return types
async def validate_login(
self,
@ -1217,15 +1201,20 @@ class AuthHandler:
known_login_type = False
for provider in self.password_providers:
supported_login_types = provider.get_supported_login_types()
if login_type not in supported_login_types:
# this password provider doesn't understand this login type
continue
# Check if login_type matches a type registered by one of the modules
# We don't need to remove LoginType.PASSWORD from the list if password login is
# disabled, since if that were the case then by this point we know that the
# login_type is not LoginType.PASSWORD
supported_login_types = self.password_auth_provider.get_supported_login_types()
# check if the login type being used is supported by a module
if login_type in supported_login_types:
# Make a note that this login type is supported by the server
known_login_type = True
# Get all the fields expected for this login types
login_fields = supported_login_types[login_type]
# go through the login submission and keep track of which required fields are
# provided/not provided
missing_fields = []
login_dict = {}
for f in login_fields:
@ -1233,6 +1222,7 @@ class AuthHandler:
missing_fields.append(f)
else:
login_dict[f] = login_submission[f]
# raise an error if any of the expected fields for that login type weren't provided
if missing_fields:
raise SynapseError(
400,
@ -1240,10 +1230,15 @@ class AuthHandler:
% (login_type, missing_fields),
)
result = await provider.check_auth(username, login_type, login_dict)
# call all of the check_auth hooks for that login_type
# it will return a result once the first success is found (or None otherwise)
result = await self.password_auth_provider.check_auth(
username, login_type, login_dict
)
if result:
return result
# if no module managed to authenticate the user, then fallback to built in password based auth
if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
known_login_type = True
@ -1282,11 +1277,16 @@ class AuthHandler:
completed login/registration, or `None`. If authentication was
unsuccessful, `user_id` and `callback` are both `None`.
"""
for provider in self.password_providers:
result = await provider.check_3pid_auth(medium, address, password)
if result:
return result
# call all of the check_3pid_auth callbacks
# Result will be from the first callback that returns something other than None
# If all the callbacks return None, then result is also set to None
result = await self.password_auth_provider.check_3pid_auth(
medium, address, password
)
if result:
return result
# if result is None then return (None, None)
return None, None
async def _check_local_password(self, user_id: str, password: str) -> Optional[str]:
@ -1365,13 +1365,12 @@ class AuthHandler:
user_info = await self.auth.get_user_by_access_token(access_token)
await self.store.delete_access_token(access_token)
# see if any of our auth providers want to know about this
for provider in self.password_providers:
await provider.on_logged_out(
user_id=user_info.user_id,
device_id=user_info.device_id,
access_token=access_token,
)
# see if any modules want to know about this
await self.password_auth_provider.on_logged_out(
user_id=user_info.user_id,
device_id=user_info.device_id,
access_token=access_token,
)
# delete pushers associated with this access token
if user_info.token_id is not None:
@ -1398,12 +1397,11 @@ class AuthHandler:
user_id, except_token_id=except_token_id, device_id=device_id
)
# see if any of our auth providers want to know about this
for provider in self.password_providers:
for token, _, device_id in tokens_and_devices:
await provider.on_logged_out(
user_id=user_id, device_id=device_id, access_token=token
)
# see if any modules want to know about this
for token, _, device_id in tokens_and_devices:
await self.password_auth_provider.on_logged_out(
user_id=user_id, device_id=device_id, access_token=token
)
# delete pushers associated with the access tokens
await self.hs.get_pusherpool().remove_pushers_by_access_token(
@ -1811,40 +1809,230 @@ class MacaroonGenerator:
return macaroon
class PasswordProvider:
"""Wrapper for a password auth provider module
def load_legacy_password_auth_providers(hs: "HomeServer") -> None:
module_api = hs.get_module_api()
for module, config in hs.config.authproviders.password_providers:
load_single_legacy_password_auth_provider(
module=module, config=config, api=module_api
)
This class abstracts out all of the backwards-compatibility hacks for
password providers, to provide a consistent interface.
def load_single_legacy_password_auth_provider(
module: Type,
config: JsonDict,
api: "ModuleApi",
) -> None:
try:
provider = module(config=config, account_handler=api)
except Exception as e:
logger.error("Error while initializing %r: %s", module, e)
raise
# The known hooks. If a module implements a method who's name appears in this set
# we'll want to register it
password_auth_provider_methods = {
"check_3pid_auth",
"on_logged_out",
}
# All methods that the module provides should be async, but this wasn't enforced
# in the old module system, so we wrap them if needed
def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
# f might be None if the callback isn't implemented by the module. In this
# case we don't want to register a callback at all so we return None.
if f is None:
return None
# We need to wrap check_password because its old form would return a boolean
# but we now want it to behave just like check_auth() and return the matrix id of
# the user if authentication succeeded or None otherwise
if f.__name__ == "check_password":
async def wrapped_check_password(
username: str, login_type: str, login_dict: JsonDict
) -> Optional[Tuple[str, Optional[Callable]]]:
# We've already made sure f is not None above, but mypy doesn't do well
# across function boundaries so we need to tell it f is definitely not
# None.
assert f is not None
matrix_user_id = api.get_qualified_user_id(username)
password = login_dict["password"]
is_valid = await f(matrix_user_id, password)
if is_valid:
return matrix_user_id, None
return None
return wrapped_check_password
# We need to wrap check_auth as in the old form it could return
# just a str, but now it must return Optional[Tuple[str, Optional[Callable]]
if f.__name__ == "check_auth":
async def wrapped_check_auth(
username: str, login_type: str, login_dict: JsonDict
) -> Optional[Tuple[str, Optional[Callable]]]:
# We've already made sure f is not None above, but mypy doesn't do well
# across function boundaries so we need to tell it f is definitely not
# None.
assert f is not None
result = await f(username, login_type, login_dict)
if isinstance(result, str):
return result, None
return result
return wrapped_check_auth
# We need to wrap check_3pid_auth as in the old form it could return
# just a str, but now it must return Optional[Tuple[str, Optional[Callable]]
if f.__name__ == "check_3pid_auth":
async def wrapped_check_3pid_auth(
medium: str, address: str, password: str
) -> Optional[Tuple[str, Optional[Callable]]]:
# We've already made sure f is not None above, but mypy doesn't do well
# across function boundaries so we need to tell it f is definitely not
# None.
assert f is not None
result = await f(medium, address, password)
if isinstance(result, str):
return result, None
return result
return wrapped_check_3pid_auth
def run(*args: Tuple, **kwargs: Dict) -> Awaitable:
# mypy doesn't do well across function boundaries so we need to tell it
# f is definitely not None.
assert f is not None
return maybe_awaitable(f(*args, **kwargs))
return run
# populate hooks with the implemented methods, wrapped with async_wrapper
hooks = {
hook: async_wrapper(getattr(provider, hook, None))
for hook in password_auth_provider_methods
}
supported_login_types = {}
# call get_supported_login_types and add that to the dict
g = getattr(provider, "get_supported_login_types", None)
if g is not None:
# Note the old module style also called get_supported_login_types at loading time
# and it is synchronous
supported_login_types.update(g())
auth_checkers = {}
# Legacy modules have a check_auth method which expects to be called with one of
# the keys returned by get_supported_login_types. New style modules register a
# dictionary of login_type->check_auth_method mappings
check_auth = async_wrapper(getattr(provider, "check_auth", None))
if check_auth is not None:
for login_type, fields in supported_login_types.items():
# need tuple(fields) since fields can be any Iterable type (so may not be hashable)
auth_checkers[(login_type, tuple(fields))] = check_auth
# if it has a "check_password" method then it should handle all auth checks
# with login type of LoginType.PASSWORD
check_password = async_wrapper(getattr(provider, "check_password", None))
if check_password is not None:
# need to use a tuple here for ("password",) not a list since lists aren't hashable
auth_checkers[(LoginType.PASSWORD, ("password",))] = check_password
api.register_password_auth_provider_callbacks(hooks, auth_checkers=auth_checkers)
CHECK_3PID_AUTH_CALLBACK = Callable[
[str, str, str],
Awaitable[
Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
],
]
ON_LOGGED_OUT_CALLBACK = Callable[[str, Optional[str], str], Awaitable]
CHECK_AUTH_CALLBACK = Callable[
[str, str, JsonDict],
Awaitable[
Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
],
]
class PasswordAuthProvider:
"""
A class that the AuthHandler calls when authenticating users
It allows modules to provide alternative methods for authentication
"""
@classmethod
def load(
cls, module: Type, config: JsonDict, module_api: ModuleApi
) -> "PasswordProvider":
try:
pp = module(config=config, account_handler=module_api)
except Exception as e:
logger.error("Error while initializing %r: %s", module, e)
raise
return cls(pp, module_api)
def __init__(self) -> None:
# lists of callbacks
self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = []
self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = []
def __init__(self, pp: "PasswordProvider", module_api: ModuleApi):
self._pp = pp
self._module_api = module_api
# Mapping from login type to login parameters
self._supported_login_types: Dict[str, Iterable[str]] = {}
self._supported_login_types = {}
# Mapping from login type to auth checker callbacks
self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {}
# grandfather in check_password support
if hasattr(self._pp, "check_password"):
self._supported_login_types[LoginType.PASSWORD] = ("password",)
def register_password_auth_provider_callbacks(
self,
check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None,
on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None,
auth_checkers: Optional[Dict[Tuple[str, Tuple], CHECK_AUTH_CALLBACK]] = None,
) -> None:
# Register check_3pid_auth callback
if check_3pid_auth is not None:
self.check_3pid_auth_callbacks.append(check_3pid_auth)
g = getattr(self._pp, "get_supported_login_types", None)
if g:
self._supported_login_types.update(g())
# register on_logged_out callback
if on_logged_out is not None:
self.on_logged_out_callbacks.append(on_logged_out)
def __str__(self) -> str:
return str(self._pp)
if auth_checkers is not None:
# register a new supported login_type
# Iterate through all of the types being registered
for (login_type, fields), callback in auth_checkers.items():
# Note: fields may be empty here. This would allow a modules auth checker to
# be called with just 'login_type' and no password or other secrets
# Need to check that all the field names are strings or may get nasty errors later
for f in fields:
if not isinstance(f, str):
raise RuntimeError(
"A module tried to register support for login type: %s with parameters %s"
" but all parameter names must be strings"
% (login_type, fields)
)
# 2 modules supporting the same login type must expect the same fields
# e.g. 1 can't expect "pass" if the other expects "password"
# so throw an exception if that happens
if login_type not in self._supported_login_types.get(login_type, []):
self._supported_login_types[login_type] = fields
else:
fields_currently_supported = self._supported_login_types.get(
login_type
)
if fields_currently_supported != fields:
raise RuntimeError(
"A module tried to register support for login type: %s with parameters %s"
" but another module had already registered support for that type with parameters %s"
% (login_type, fields, fields_currently_supported)
)
# Add the new method to the list of auth_checker_callbacks for this login type
self.auth_checker_callbacks.setdefault(login_type, []).append(callback)
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
"""Get the login types supported by this password provider
@ -1852,20 +2040,15 @@ class PasswordProvider:
Returns a map from a login type identifier (such as m.login.password) to an
iterable giving the fields which must be provided by the user in the submission
to the /login API.
This wrapper adds m.login.password to the list if the underlying password
provider supports the check_password() api.
"""
return self._supported_login_types
async def check_auth(
self, username: str, login_type: str, login_dict: JsonDict
) -> Optional[Tuple[str, Optional[Callable]]]:
) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]:
"""Check if the user has presented valid login credentials
This wrapper also calls check_password() if the underlying password provider
supports the check_password() api and the login type is m.login.password.
Args:
username: user id presented by the client. Either an MXID or an unqualified
username.
@ -1879,63 +2062,130 @@ class PasswordProvider:
user, and `callback` is an optional callback which will be called with the
result from the /login call (including access_token, device_id, etc.)
"""
# first grandfather in a call to check_password
if login_type == LoginType.PASSWORD:
check_password = getattr(self._pp, "check_password", None)
if check_password:
qualified_user_id = self._module_api.get_qualified_user_id(username)
is_valid = await check_password(
qualified_user_id, login_dict["password"]
)
if is_valid:
return qualified_user_id, None
check_auth = getattr(self._pp, "check_auth", None)
if not check_auth:
return None
result = await check_auth(username, login_type, login_dict)
# Go through all callbacks for the login type until one returns with a value
# other than None (i.e. until a callback returns a success)
for callback in self.auth_checker_callbacks[login_type]:
try:
result = await callback(username, login_type, login_dict)
except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)
continue
# Check if the return value is a str or a tuple
if isinstance(result, str):
# If it's a str, set callback function to None
return result, None
if result is not None:
# Check that the callback returned a Tuple[str, Optional[Callable]]
# "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks
# result is always the right type, but as it is 3rd party code it might not be
return result
if not isinstance(result, tuple) or len(result) != 2:
logger.warning(
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
callback,
result,
)
continue
# pull out the two parts of the tuple so we can do type checking
str_result, callback_result = result
# the 1st item in the tuple should be a str
if not isinstance(str_result, str):
logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
callback,
result,
)
continue
# the second should be Optional[Callable]
if callback_result is not None:
if not callable(callback_result):
logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
callback,
result,
)
continue
# The result is a (str, Optional[callback]) tuple so return the successful result
return result
# If this point has been reached then none of the callbacks successfully authenticated
# the user so return None
return None
async def check_3pid_auth(
self, medium: str, address: str, password: str
) -> Optional[Tuple[str, Optional[Callable]]]:
g = getattr(self._pp, "check_3pid_auth", None)
if not g:
return None
) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]:
# This function is able to return a deferred that either
# resolves None, meaning authentication failure, or upon
# success, to a str (which is the user_id) or a tuple of
# (user_id, callback_func), where callback_func should be run
# after we've finished everything else
result = await g(medium, address, password)
# Check if the return value is a str or a tuple
if isinstance(result, str):
# If it's a str, set callback function to None
return result, None
for callback in self.check_3pid_auth_callbacks:
try:
result = await callback(medium, address, password)
except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)
continue
return result
if result is not None:
# Check that the callback returned a Tuple[str, Optional[Callable]]
# "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks
# result is always the right type, but as it is 3rd party code it might not be
if not isinstance(result, tuple) or len(result) != 2:
logger.warning(
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
callback,
result,
)
continue
# pull out the two parts of the tuple so we can do type checking
str_result, callback_result = result
# the 1st item in the tuple should be a str
if not isinstance(str_result, str):
logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
callback,
result,
)
continue
# the second should be Optional[Callable]
if callback_result is not None:
if not callable(callback_result):
logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
callback,
result,
)
continue
# The result is a (str, Optional[callback]) tuple so return the successful result
return result
# If this point has been reached then none of the callbacks successfully authenticated
# the user so return None
return None
async def on_logged_out(
self, user_id: str, device_id: Optional[str], access_token: str
) -> None:
g = getattr(self._pp, "on_logged_out", None)
if not g:
return
# This might return an awaitable, if it does block the log out
# until it completes.
await maybe_awaitable(
g(
user_id=user_id,
device_id=device_id,
access_token=access_token,
)
)
# call all of the on_logged_out callbacks
for callback in self.on_logged_out_callbacks:
try:
callback(user_id, device_id, access_token)
except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)
continue

View file

@ -131,10 +131,6 @@ class DeactivateAccountHandler:
# delete from user directory
await self.user_directory_handler.handle_local_user_deactivated(user_id)
# If the user is present in the monthly active users table
# remove them
await self.store.remove_deactivated_user_from_mau_table(user_id)
# Mark the user as erased, if they asked for that
if erase_data:
user = UserID.from_string(user_id)

View file

@ -14,7 +14,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
Iterable,
List,
Mapping,
Optional,
Set,
Tuple,
)
from synapse.api import errors
from synapse.api.constants import EventTypes
@ -443,6 +454,10 @@ class DeviceHandler(DeviceWorkerHandler):
) -> None:
"""Notify that a user's device(s) has changed. Pokes the notifier, and
remote servers if the user is local.
Args:
user_id: The Matrix ID of the user who's device list has been updated.
device_ids: The device IDs that have changed.
"""
if not device_ids:
# No changes to notify about, so this is a no-op.
@ -595,7 +610,7 @@ class DeviceHandler(DeviceWorkerHandler):
def _update_device_from_client_ips(
device: JsonDict, client_ips: Dict[Tuple[str, str], JsonDict]
device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
) -> None:
ip = client_ips.get((device["user_id"], device["device_id"]), {})
device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})

View file

@ -147,7 +147,7 @@ class DirectoryHandler:
if not self.config.roomdirectory.is_alias_creation_allowed(
user_id, room_id, room_alias_str
):
# Lets just return a generic message, as there may be all sorts of
# Let's just return a generic message, as there may be all sorts of
# reasons why we said no. TODO: Allow configurable error messages
# per alias creation rule?
raise SynapseError(403, "Not allowed to create alias")
@ -463,7 +463,7 @@ class DirectoryHandler:
if not self.config.roomdirectory.is_publishing_room_allowed(
user_id, room_id, room_aliases
):
# Lets just return a generic message, as there may be all sorts of
# Let's just return a generic message, as there may be all sorts of
# reasons why we said no. TODO: Allow configurable error messages
# per alias creation rule?
raise SynapseError(403, "Not allowed to publish room")

View file

@ -55,8 +55,7 @@ class EventAuthHandler:
"""Check an event passes the auth rules at its own auth events"""
auth_event_ids = event.auth_event_ids()
auth_events_by_id = await self._store.get_events(auth_event_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()}
check_auth_rules_for_event(room_version_obj, event, auth_events)
check_auth_rules_for_event(room_version_obj, event, auth_events_by_id.values())
def compute_auth_events(
self,

View file

@ -15,7 +15,6 @@
"""Contains handlers for federation events."""
import itertools
import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
@ -27,12 +26,7 @@ from unpaddedbase64 import decode_base64
from twisted.internet import defer
from synapse import event_auth
from synapse.api.constants import (
EventContentFields,
EventTypes,
Membership,
RejectedReason,
)
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.api.errors import (
AuthError,
CodeMessageException,
@ -43,12 +37,9 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion, RoomVersions
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.crypto.event_signing import compute_event_signature
from synapse.event_auth import (
check_auth_rules_for_event,
validate_event_for_room_version,
)
from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
@ -238,18 +229,10 @@ class FederationHandler:
)
return False
logger.debug(
"room_id: %s, backfill: current_depth: %s, max_depth: %s, extrems: %s",
room_id,
current_depth,
max_depth,
sorted_extremeties_tuple,
)
# We ignore extremities that have a greater depth than our current depth
# as:
# 1. we don't really care about getting events that have happened
# before our current position; and
# after our current position; and
# 2. we have likely previously tried and failed to backfill from that
# extremity, so to avoid getting "stuck" requesting the same
# backfill repeatedly we drop those extremities.
@ -257,9 +240,19 @@ class FederationHandler:
t for t in sorted_extremeties_tuple if int(t[1]) <= current_depth
]
logger.debug(
"room_id: %s, backfill: current_depth: %s, limit: %s, max_depth: %s, extrems: %s filtered_sorted_extremeties_tuple: %s",
room_id,
current_depth,
limit,
max_depth,
sorted_extremeties_tuple,
filtered_sorted_extremeties_tuple,
)
# However, we need to check that the filtered extremities are non-empty.
# If they are empty then either we can a) bail or b) still attempt to
# backill. We opt to try backfilling anyway just in case we do get
# backfill. We opt to try backfilling anyway just in case we do get
# relevant events.
if filtered_sorted_extremeties_tuple:
sorted_extremeties_tuple = filtered_sorted_extremeties_tuple
@ -389,7 +382,7 @@ class FederationHandler:
for key, state_dict in states.items()
}
for e_id, _ in sorted_extremeties_tuple:
for e_id in event_ids:
likely_extremeties_domains = get_domains_from_state(states[e_id])
success = await try_backfill(
@ -517,7 +510,7 @@ class FederationHandler:
auth_events=auth_chain,
)
max_stream_id = await self._persist_auth_tree(
max_stream_id = await self._federation_event_handler.process_remote_join(
origin, room_id, auth_chain, state, event, room_version_obj
)
@ -1093,119 +1086,6 @@ class FederationHandler:
else:
return None
async def _persist_auth_tree(
self,
origin: str,
room_id: str,
auth_events: List[EventBase],
state: List[EventBase],
event: EventBase,
room_version: RoomVersion,
) -> int:
"""Checks the auth chain is valid (and passes auth checks) for the
state and event. Then persists the auth chain and state atomically.
Persists the event separately. Notifies about the persisted events
where appropriate.
Will attempt to fetch missing auth events.
Args:
origin: Where the events came from
room_id,
auth_events
state
event
room_version: The room version we expect this room to have, and
will raise if it doesn't match the version in the create event.
"""
events_to_context = {}
for e in itertools.chain(auth_events, state):
e.internal_metadata.outlier = True
events_to_context[e.event_id] = EventContext.for_outlier()
event_map = {
e.event_id: e for e in itertools.chain(auth_events, state, [event])
}
create_event = None
for e in auth_events:
if (e.type, e.state_key) == (EventTypes.Create, ""):
create_event = e
break
if create_event is None:
# If the state doesn't have a create event then the room is
# invalid, and it would fail auth checks anyway.
raise SynapseError(400, "No create event in state")
room_version_id = create_event.content.get(
"room_version", RoomVersions.V1.identifier
)
if room_version.identifier != room_version_id:
raise SynapseError(400, "Room version mismatch")
missing_auth_events = set()
for e in itertools.chain(auth_events, state, [event]):
for e_id in e.auth_event_ids():
if e_id not in event_map:
missing_auth_events.add(e_id)
for e_id in missing_auth_events:
m_ev = await self.federation_client.get_pdu(
[origin],
e_id,
room_version=room_version,
outlier=True,
timeout=10000,
)
if m_ev and m_ev.event_id == e_id:
event_map[e_id] = m_ev
else:
logger.info("Failed to find auth event %r", e_id)
for e in itertools.chain(auth_events, state, [event]):
auth_for_e = {
(event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
for e_id in e.auth_event_ids()
if e_id in event_map
}
if create_event:
auth_for_e[(EventTypes.Create, "")] = create_event
try:
validate_event_for_room_version(room_version, e)
check_auth_rules_for_event(room_version, e, auth_for_e)
except SynapseError as err:
# we may get SynapseErrors here as well as AuthErrors. For
# instance, there are a couple of (ancient) events in some
# rooms whose senders do not have the correct sigil; these
# cause SynapseErrors in auth.check. We don't want to give up
# the attempt to federate altogether in such cases.
logger.warning("Rejecting %s because %s", e.event_id, err.msg)
if e == event:
raise
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
if auth_events or state:
await self._federation_event_handler.persist_events_and_notify(
room_id,
[
(e, events_to_context[e.event_id])
for e in itertools.chain(auth_events, state)
],
)
new_event_context = await self.state_handler.compute_event_context(
event, old_state=state
)
return await self._federation_event_handler.persist_events_and_notify(
room_id, [(event, new_event_context)]
)
async def on_get_missing_events(
self,
origin: str,

View file

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import logging
from http import HTTPStatus
from typing import (
@ -45,7 +46,7 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion, RoomVersions
from synapse.event_auth import (
auth_types_for_event,
check_auth_rules_for_event,
@ -64,7 +65,6 @@ from synapse.replication.http.federation import (
from synapse.state import StateResolutionStore
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import (
MutableStateMap,
PersistedEventPosition,
RoomStreamToken,
StateMap,
@ -214,7 +214,7 @@ class FederationEventHandler:
if missing_prevs:
# We only backfill backwards to the min depth.
min_depth = await self.get_min_depth_for_context(pdu.room_id)
min_depth = await self._store.get_min_depth(pdu.room_id)
logger.debug("min_depth: %d", min_depth)
if min_depth is not None and pdu.depth > min_depth:
@ -361,6 +361,7 @@ class FederationEventHandler:
# need to.
await self._event_creation_handler.cache_joined_hosts_for_event(event, context)
await self._check_for_soft_fail(event, None, origin=origin)
await self._run_push_actions_and_persist_event(event, context)
return event, context
@ -390,9 +391,93 @@ class FederationEventHandler:
prev_member_event,
)
async def process_remote_join(
self,
origin: str,
room_id: str,
auth_events: List[EventBase],
state: List[EventBase],
event: EventBase,
room_version: RoomVersion,
) -> int:
"""Persists the events returned by a send_join
Checks the auth chain is valid (and passes auth checks) for the
state and event. Then persists all of the events.
Notifies about the persisted events where appropriate.
Args:
origin: Where the events came from
room_id:
auth_events
state
event
room_version: The room version we expect this room to have, and
will raise if it doesn't match the version in the create event.
Returns:
The stream ID after which all events have been persisted.
Raises:
SynapseError if the response is in some way invalid.
"""
for e in itertools.chain(auth_events, state):
e.internal_metadata.outlier = True
event_map = {e.event_id: e for e in itertools.chain(auth_events, state)}
create_event = None
for e in auth_events:
if (e.type, e.state_key) == (EventTypes.Create, ""):
create_event = e
break
if create_event is None:
# If the state doesn't have a create event then the room is
# invalid, and it would fail auth checks anyway.
raise SynapseError(400, "No create event in state")
room_version_id = create_event.content.get(
"room_version", RoomVersions.V1.identifier
)
if room_version.identifier != room_version_id:
raise SynapseError(400, "Room version mismatch")
# filter out any events we have already seen
seen_remotes = await self._store.have_seen_events(room_id, event_map.keys())
for s in seen_remotes:
event_map.pop(s, None)
# persist the auth chain and state events.
#
# any invalid events here will be marked as rejected, and we'll carry on.
#
# any events whose auth events are missing (ie, not in the send_join response,
# and not already in our db) will just be ignored. This is correct behaviour,
# because the reason that auth_events are missing might be due to us being
# unable to validate their signatures. The fact that we can't validate their
# signatures right now doesn't mean that we will *never* be able to, so it
# is premature to reject them.
#
await self._auth_and_persist_outliers(room_id, event_map.values())
# and now persist the join event itself.
logger.info("Peristing join-via-remote %s", event)
with nested_logging_context(suffix=event.event_id):
context = await self._state_handler.compute_event_context(
event, old_state=state
)
context = await self._check_event_auth(origin, event, context)
if context.rejected:
raise SynapseError(400, "Join event was rejected")
return await self.persist_events_and_notify(room_id, [(event, context)])
@log_function
async def backfill(
self, dest: str, room_id: str, limit: int, extremities: List[str]
self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
) -> None:
"""Trigger a backfill request to `dest` for the given `room_id`
@ -861,9 +946,15 @@ class FederationEventHandler:
) -> None:
"""Called when we have a new non-outlier event.
This is called when we have a new event to add to the room DAG - either directly
via a /send request, retrieved via get_missing_events after a /send request, or
backfilled after a client request.
This is called when we have a new event to add to the room DAG. This can be
due to:
* events received directly via a /send request
* events retrieved via get_missing_events after a /send request
* events backfilled after a client request.
It's not currently used for events received from incoming send_{join,knock,leave}
requests (which go via on_send_membership_event), nor for joins created by a
remote join dance (which go via process_remote_join).
We need to do auth checks and put it through the StateHandler.
@ -899,11 +990,19 @@ class FederationEventHandler:
logger.exception("Unexpected AuthError from _check_event_auth")
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
if not backfilled and not context.rejected:
# For new (non-backfilled and non-outlier) events we check if the event
# passes auth based on the current state. If it doesn't then we
# "soft-fail" the event.
await self._check_for_soft_fail(event, state, origin=origin)
await self._run_push_actions_and_persist_event(event, context, backfilled)
if backfilled:
if backfilled or context.rejected:
return
await self._maybe_kick_guest_users(event)
# For encrypted messages we check that we know about the sending device,
# if we don't then we mark the device cache for that user as stale.
if event.type == EventTypes.Encrypted:
@ -1116,14 +1215,12 @@ class FederationEventHandler:
await concurrently_execute(get_event, event_ids, 5)
logger.info("Fetched %i events of %i requested", len(events), len(event_ids))
await self._auth_and_persist_fetched_events(destination, room_id, events)
await self._auth_and_persist_outliers(room_id, events)
async def _auth_and_persist_fetched_events(
self, origin: str, room_id: str, events: Iterable[EventBase]
async def _auth_and_persist_outliers(
self, room_id: str, events: Iterable[EventBase]
) -> None:
"""Persist the events fetched by _get_events_and_persist or _get_remote_auth_chain_for_event
The events to be persisted must be outliers.
"""Persist a batch of outlier events fetched from remote servers.
We first sort the events to make sure that we process each event's auth_events
before the event itself, and then auth and persist them.
@ -1131,7 +1228,6 @@ class FederationEventHandler:
Notifies about the events where appropriate.
Params:
origin: where the events came from
room_id: the room that the events are meant to be in (though this has
not yet been checked)
events: the events that have been fetched
@ -1167,15 +1263,15 @@ class FederationEventHandler:
shortstr(e.event_id for e in roots),
)
await self._auth_and_persist_fetched_events_inner(origin, room_id, roots)
await self._auth_and_persist_outliers_inner(room_id, roots)
for ev in roots:
del event_map[ev.event_id]
async def _auth_and_persist_fetched_events_inner(
self, origin: str, room_id: str, fetched_events: Collection[EventBase]
async def _auth_and_persist_outliers_inner(
self, room_id: str, fetched_events: Collection[EventBase]
) -> None:
"""Helper for _auth_and_persist_fetched_events
"""Helper for _auth_and_persist_outliers
Persists a batch of events where we have (theoretically) already persisted all
of their auth events.
@ -1203,20 +1299,20 @@ class FederationEventHandler:
def prep(event: EventBase) -> Optional[Tuple[EventBase, EventContext]]:
with nested_logging_context(suffix=event.event_id):
auth = {}
auth = []
for auth_event_id in event.auth_event_ids():
ae = persisted_events.get(auth_event_id)
if not ae:
logger.warning(
"Event %s relies on auth_event %s, which could not be found.",
event,
auth_event_id,
)
# the fact we can't find the auth event doesn't mean it doesn't
# exist, which means it is premature to reject `event`. Instead we
# just ignore it for now.
logger.warning(
"Dropping event %s, which relies on auth_event %s, which could not be found",
event,
auth_event_id,
)
return None
auth[(ae.type, ae.state_key)] = ae
auth.append(ae)
context = EventContext.for_outlier()
try:
@ -1256,6 +1352,10 @@ class FederationEventHandler:
Returns:
The updated context object.
Raises:
AuthError if we were unable to find copies of the event's auth events.
(Most other failures just cause us to set `context.rejected`.)
"""
# This method should only be used for non-outliers
assert not event.internal_metadata.outlier
@ -1272,7 +1372,26 @@ class FederationEventHandler:
context.rejected = RejectedReason.AUTH_ERROR
return context
# calculate what the auth events *should* be, to use as a basis for auth.
# next, check that we have all of the event's auth events.
#
# Note that this can raise AuthError, which we want to propagate to the
# caller rather than swallow with `context.rejected` (since we cannot be
# certain that there is a permanent problem with the event).
claimed_auth_events = await self._load_or_fetch_auth_events_for_event(
origin, event
)
# ... and check that the event passes auth at those auth events.
try:
check_auth_rules_for_event(room_version_obj, event, claimed_auth_events)
except AuthError as e:
logger.warning(
"While checking auth of %r against auth_events: %s", event, e
)
context.rejected = RejectedReason.AUTH_ERROR
return context
# now check auth against what we think the auth events *should* be.
prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = self._event_auth_handler.compute_auth_events(
event, prev_state_ids, for_verification=True
@ -1283,13 +1402,8 @@ class FederationEventHandler:
}
try:
(
context,
auth_events_for_auth,
) = await self._update_auth_events_and_context_for_auth(
origin,
updated_auth_events = await self._update_auth_events_for_auth(
event,
context,
calculated_auth_event_map=calculated_auth_event_map,
)
except Exception:
@ -1302,17 +1416,23 @@ class FederationEventHandler:
"Ignoring failure and continuing processing of event.",
event.event_id,
)
updated_auth_events = None
if updated_auth_events:
context = await self._update_context_for_auth_events(
event, context, updated_auth_events
)
auth_events_for_auth = updated_auth_events
else:
auth_events_for_auth = calculated_auth_event_map
try:
check_auth_rules_for_event(room_version_obj, event, auth_events_for_auth)
check_auth_rules_for_event(
room_version_obj, event, auth_events_for_auth.values()
)
except AuthError as e:
logger.warning("Failed auth resolution for %r because %s", event, e)
context.rejected = RejectedReason.AUTH_ERROR
return context
await self._check_for_soft_fail(event, state, backfilled, origin=origin)
await self._maybe_kick_guest_users(event)
return context
@ -1332,7 +1452,6 @@ class FederationEventHandler:
self,
event: EventBase,
state: Optional[Iterable[EventBase]],
backfilled: bool,
origin: str,
) -> None:
"""Checks if we should soft fail the event; if so, marks the event as
@ -1341,15 +1460,8 @@ class FederationEventHandler:
Args:
event
state: The state at the event if we don't have all the event's prev events
backfilled: Whether the event is from backfill
origin: The host the event originates from.
"""
# For new (non-backfilled and non-outlier) events we check if the event
# passes auth based on the current state. If it doesn't then we
# "soft-fail" the event.
if backfilled or event.internal_metadata.is_outlier():
return
extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id)
extrem_ids = set(extrem_ids_list)
prev_event_ids = set(event.prev_event_ids())
@ -1403,11 +1515,9 @@ class FederationEventHandler:
current_state_ids_list = [
e for k, e in current_state_ids.items() if k in auth_types
]
auth_events_map = await self._store.get_events(current_state_ids_list)
current_auth_events = {
(e.type, e.state_key): e for e in auth_events_map.values()
}
current_auth_events = await self._store.get_events_as_list(
current_state_ids_list
)
try:
check_auth_rules_for_event(room_version_obj, event, current_auth_events)
@ -1426,13 +1536,11 @@ class FederationEventHandler:
soft_failed_event_counter.inc()
event.internal_metadata.soft_failed = True
async def _update_auth_events_and_context_for_auth(
async def _update_auth_events_for_auth(
self,
origin: str,
event: EventBase,
context: EventContext,
calculated_auth_event_map: StateMap[EventBase],
) -> Tuple[EventContext, StateMap[EventBase]]:
) -> Optional[StateMap[EventBase]]:
"""Helper for _check_event_auth. See there for docs.
Checks whether a given event has the expected auth events. If it
@ -1445,93 +1553,27 @@ class FederationEventHandler:
processing of the event.
Args:
origin:
event:
context:
calculated_auth_event_map:
Our calculated auth_events based on the state of the room
at the event's position in the DAG.
Returns:
updated context, updated auth event map
updated auth event map, or None if no changes are needed.
"""
assert not event.internal_metadata.outlier
# take a copy of calculated_auth_event_map before we modify it.
auth_events: MutableStateMap[EventBase] = dict(calculated_auth_event_map)
# check for events which are in the event's claimed auth_events, but not
# in our calculated event map.
event_auth_events = set(event.auth_event_ids())
# missing_auth is the set of the event's auth_events which we don't yet have
# in auth_events.
missing_auth = event_auth_events.difference(
e.event_id for e in auth_events.values()
)
# if we have missing events, we need to fetch those events from somewhere.
#
# we start by checking if they are in the store, and then try calling /event_auth/.
if missing_auth:
have_events = await self._store.have_seen_events(
event.room_id, missing_auth
)
logger.debug("Events %s are in the store", have_events)
missing_auth.difference_update(have_events)
# missing_auth is now the set of event_ids which:
# a. are listed in event.auth_events, *and*
# b. are *not* part of our calculated auth events based on room state, *and*
# c. are *not* yet in our database.
if missing_auth:
# If we don't have all the auth events, we need to get them.
logger.info("auth_events contains unknown events: %s", missing_auth)
try:
await self._get_remote_auth_chain_for_event(
origin, event.room_id, event.event_id
)
except Exception:
logger.exception("Failed to get auth chain")
else:
# load any auth events we might have persisted from the database. This
# has the side-effect of correctly setting the rejected_reason on them.
auth_events.update(
{
(ae.type, ae.state_key): ae
for ae in await self._store.get_events_as_list(
missing_auth, allow_rejected=True
)
}
)
# auth_events now contains
# 1. our *calculated* auth events based on the room state, plus:
# 2. any events which:
# a. are listed in `event.auth_events`, *and*
# b. are not part of our calculated auth events, *and*
# c. were not in our database before the call to /event_auth
# d. have since been added to our database (most likely by /event_auth).
different_auth = event_auth_events.difference(
e.event_id for e in auth_events.values()
e.event_id for e in calculated_auth_event_map.values()
)
# different_auth is the set of events which *are* in `event.auth_events`, but
# which are *not* in `auth_events`. Comparing with (2.) above, this means
# exclusively the set of `event.auth_events` which we already had in our
# database before any call to /event_auth.
#
# I'm reasonably sure that the fact that events returned by /event_auth are
# blindly added to auth_events (and hence excluded from different_auth) is a bug
# - though it's a very long-standing one (see
# https://github.com/matrix-org/synapse/commit/78015948a7febb18e000651f72f8f58830a55b93#diff-0bc92da3d703202f5b9be2d3f845e375f5b1a6bc6ba61705a8af9be1121f5e42R786
# from Jan 2015 which seems to add it, though it actually just moves it from
# elsewhere (before that, it gets lost in a mess of huge "various bug fixes"
# PRs).
if not different_auth:
return context, auth_events
return None
logger.info(
"auth_events refers to events which are not in our calculated auth "
@ -1543,27 +1585,18 @@ class FederationEventHandler:
# necessary?
different_events = await self._store.get_events_as_list(different_auth)
# double-check they're all in the same room - we should already have checked
# this but it doesn't hurt to check again.
for d in different_events:
if d.room_id != event.room_id:
logger.warning(
"Event %s refers to auth_event %s which is in a different room",
event.event_id,
d.event_id,
)
# don't attempt to resolve the claimed auth events against our own
# in this case: just use our own auth events.
#
# XXX: should we reject the event in this case? It feels like we should,
# but then shouldn't we also do so if we've failed to fetch any of the
# auth events?
return context, auth_events
assert (
d.room_id == event.room_id
), f"Event {event.event_id} refers to auth_event {d.event_id} which is in a different room"
# now we state-resolve between our own idea of the auth events, and the remote's
# idea of them.
local_state = auth_events.values()
remote_auth_events = dict(auth_events)
local_state = calculated_auth_event_map.values()
remote_auth_events = dict(calculated_auth_event_map)
remote_auth_events.update({(d.type, d.state_key): d for d in different_events})
remote_state = remote_auth_events.values()
@ -1571,23 +1604,93 @@ class FederationEventHandler:
new_state = await self._state_handler.resolve_events(
room_version, (local_state, remote_state), event
)
different_state = {
(d.type, d.state_key): d
for d in new_state.values()
if calculated_auth_event_map.get((d.type, d.state_key)) != d
}
if not different_state:
logger.info("State res returned no new state")
return None
logger.info(
"After state res: updating auth_events with new state %s",
{
(d.type, d.state_key): d.event_id
for d in new_state.values()
if auth_events.get((d.type, d.state_key)) != d
},
different_state.values(),
)
auth_events.update(new_state)
# take a copy of calculated_auth_event_map before we modify it.
auth_events = dict(calculated_auth_event_map)
auth_events.update(different_state)
return auth_events
context = await self._update_context_for_auth_events(
event, context, auth_events
async def _load_or_fetch_auth_events_for_event(
self, destination: str, event: EventBase
) -> Collection[EventBase]:
"""Fetch this event's auth_events, from database or remote
Loads any of the auth_events that we already have from the database/cache. If
there are any that are missing, calls /event_auth to get the complete auth
chain for the event (and then attempts to load the auth_events again).
If any of the auth_events cannot be found, raises an AuthError. This can happen
for a number of reasons; eg: the events don't exist, or we were unable to talk
to `destination`, or we couldn't validate the signature on the event (which
in turn has multiple potential causes).
Args:
destination: where to send the /event_auth request. Typically the server
that sent us `event` in the first place.
event: the event whose auth_events we want
Returns:
all of the events in `event.auth_events`, after deduplication
Raises:
AuthError if we were unable to fetch the auth_events for any reason.
"""
event_auth_event_ids = set(event.auth_event_ids())
event_auth_events = await self._store.get_events(
event_auth_event_ids, allow_rejected=True
)
missing_auth_event_ids = event_auth_event_ids.difference(
event_auth_events.keys()
)
if not missing_auth_event_ids:
return event_auth_events.values()
return context, auth_events
logger.info(
"Event %s refers to unknown auth events %s: fetching auth chain",
event,
missing_auth_event_ids,
)
try:
await self._get_remote_auth_chain_for_event(
destination, event.room_id, event.event_id
)
except Exception as e:
logger.warning("Failed to get auth chain for %s: %s", event, e)
# in this case, it's very likely we still won't have all the auth
# events - but we pick that up below.
# try to fetch the auth events we missed list time.
extra_auth_events = await self._store.get_events(
missing_auth_event_ids, allow_rejected=True
)
missing_auth_event_ids.difference_update(extra_auth_events.keys())
event_auth_events.update(extra_auth_events)
if not missing_auth_event_ids:
return event_auth_events.values()
# we still don't have all the auth events.
logger.warning(
"Missing auth events for %s: %s",
event,
shortstr(missing_auth_event_ids),
)
# the fact we can't find the auth event doesn't mean it doesn't
# exist, which means it is premature to store `event` as rejected.
# instead we raise an AuthError, which will make the caller ignore it.
raise AuthError(code=HTTPStatus.FORBIDDEN, msg="Auth events could not be found")
async def _get_remote_auth_chain_for_event(
self, destination: str, room_id: str, event_id: str
@ -1624,9 +1727,7 @@ class FederationEventHandler:
for s in seen_remotes:
remote_event_map.pop(s, None)
await self._auth_and_persist_fetched_events(
destination, room_id, remote_event_map.values()
)
await self._auth_and_persist_outliers(room_id, remote_event_map.values())
async def _update_context_for_auth_events(
self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]
@ -1696,16 +1797,27 @@ class FederationEventHandler:
# persist_events_and_notify directly.)
assert not event.internal_metadata.outlier
try:
if (
not backfilled
and not context.rejected
and (await self._store.get_min_depth(event.room_id)) <= event.depth
):
if not backfilled and not context.rejected:
min_depth = await self._store.get_min_depth(event.room_id)
if min_depth is None or min_depth > event.depth:
# XXX richvdh 2021/10/07: I don't really understand what this
# condition is doing. I think it's trying not to send pushes
# for events that predate our join - but that's not really what
# min_depth means, and anyway ancient events are a more general
# problem.
#
# for now I'm just going to log about it.
logger.info(
"Skipping push actions for old event with depth %s < %s",
event.depth,
min_depth,
)
else:
await self._action_generator.handle_push_actions_for_event(
event, context
)
try:
await self.persist_events_and_notify(
event.room_id, [(event, context)], backfilled=backfilled
)
@ -1837,6 +1949,3 @@ class FederationEventHandler:
len(ev.auth_event_ids()),
)
raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
async def get_min_depth_for_context(self, context: str) -> int:
return await self._store.get_min_depth(context)

View file

@ -54,7 +54,9 @@ class IdentityHandler:
self.http_client = SimpleHttpClient(hs)
# An HTTP client for contacting identity servers specified by clients.
self.blacklisting_http_client = SimpleHttpClient(
hs, ip_blacklist=hs.config.server.federation_ip_range_blacklist
hs,
ip_blacklist=hs.config.server.federation_ip_range_blacklist,
ip_whitelist=hs.config.server.federation_ip_range_whitelist,
)
self.federation_http_client = hs.get_federation_http_client()
self.hs = hs

View file

@ -609,29 +609,6 @@ class EventCreationHandler:
builder.internal_metadata.historical = historical
# Strip down the auth_event_ids to only what we need to auth the event.
# For example, we don't need extra m.room.member that don't match event.sender
if auth_event_ids is not None:
# If auth events are provided, prev events must be also.
assert prev_event_ids is not None
temp_event = await builder.build(
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
depth=depth,
)
auth_events = await self.store.get_events_as_list(auth_event_ids)
# Create a StateMap[str]
auth_event_state_map = {
(e.type, e.state_key): e.event_id for e in auth_events
}
# Actually strip down and use the necessary auth events
auth_event_ids = self._event_auth_handler.compute_auth_events(
event=temp_event,
current_state_ids=auth_event_state_map,
for_verification=False,
)
event, context = await self.create_new_client_event(
builder=builder,
requester=requester,
@ -938,6 +915,33 @@ class EventCreationHandler:
Tuple of created event, context
"""
# Strip down the auth_event_ids to only what we need to auth the event.
# For example, we don't need extra m.room.member that don't match event.sender
full_state_ids_at_event = None
if auth_event_ids is not None:
# If auth events are provided, prev events must be also.
assert prev_event_ids is not None
# Copy the full auth state before it stripped down
full_state_ids_at_event = auth_event_ids.copy()
temp_event = await builder.build(
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
depth=depth,
)
auth_events = await self.store.get_events_as_list(auth_event_ids)
# Create a StateMap[str]
auth_event_state_map = {
(e.type, e.state_key): e.event_id for e in auth_events
}
# Actually strip down and use the necessary auth events
auth_event_ids = self._event_auth_handler.compute_auth_events(
event=temp_event,
current_state_ids=auth_event_state_map,
for_verification=False,
)
if prev_event_ids is not None:
assert (
len(prev_event_ids) <= 10
@ -967,6 +971,13 @@ class EventCreationHandler:
if builder.internal_metadata.outlier:
event.internal_metadata.outlier = True
context = EventContext.for_outlier()
elif (
event.type == EventTypes.MSC2716_INSERTION
and full_state_ids_at_event
and builder.internal_metadata.is_historical()
):
old_state = await self.store.get_events_as_list(full_state_ids_at_event)
context = await self.state.compute_event_context(event, old_state=old_state)
else:
context = await self.state.compute_event_context(event)

View file

@ -86,19 +86,22 @@ class PaginationHandler:
self._event_serializer = hs.get_event_client_serializer()
self._retention_default_max_lifetime = (
hs.config.server.retention_default_max_lifetime
hs.config.retention.retention_default_max_lifetime
)
self._retention_allowed_lifetime_min = (
hs.config.server.retention_allowed_lifetime_min
hs.config.retention.retention_allowed_lifetime_min
)
self._retention_allowed_lifetime_max = (
hs.config.server.retention_allowed_lifetime_max
hs.config.retention.retention_allowed_lifetime_max
)
if hs.config.worker.run_background_tasks and hs.config.server.retention_enabled:
if (
hs.config.worker.run_background_tasks
and hs.config.retention.retention_enabled
):
# Run the purge jobs described in the configuration file.
for job in hs.config.server.retention_purge_jobs:
for job in hs.config.retention.retention_purge_jobs:
logger.info("Setting up purge job with config: %s", job)
self.clock.looping_call(

View file

@ -52,7 +52,6 @@ import synapse.metrics
from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.errors import SynapseError
from synapse.api.presence import UserPresenceState
from synapse.appservice import ApplicationService
from synapse.events.presence_router import PresenceRouter
from synapse.logging.context import run_in_background
from synapse.logging.utils import log_function
@ -1483,13 +1482,39 @@ def should_notify(old_state: UserPresenceState, new_state: UserPresenceState) ->
def format_user_presence_state(
state: UserPresenceState, now: int, include_user_id: bool = True
) -> JsonDict:
"""Convert UserPresenceState to a format that can be sent down to clients
"""Convert UserPresenceState to a JSON format that can be sent down to clients
and to other servers.
The "user_id" is optional so that this function can be used to format presence
updates for client /sync responses and for federation /send requests.
Args:
state: The user presence state to format.
now: The current timestamp since the epoch in ms.
include_user_id: Whether to include `user_id` in the returned dictionary.
As this function can be used both to format presence updates for client /sync
responses and for federation /send requests, only the latter needs the include
the `user_id` field.
Returns:
A JSON dictionary with the following keys:
* presence: The presence state as a str.
* user_id: Optional. Included if `include_user_id` is truthy. The canonical
Matrix ID of the user.
* last_active_ago: Optional. Included if `last_active_ts` is set on `state`.
The timestamp that the user was last active.
* status_msg: Optional. Included if `status_msg` is set on `state`. The user's
status.
* currently_active: Optional. Included only if `state.state` is "online".
Example:
{
"presence": "online",
"user_id": "@alice:example.com",
"last_active_ago": 16783813918,
"status_msg": "Hello world!",
"currently_active": True
}
"""
content = {"presence": state.state}
content: JsonDict = {"presence": state.state}
if include_user_id:
content["user_id"] = state.user_id
if state.last_active_ts:
@ -1526,7 +1551,6 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
is_guest: bool = False,
explicit_room_id: Optional[str] = None,
include_offline: bool = True,
service: Optional[ApplicationService] = None,
) -> Tuple[List[UserPresenceState], int]:
# The process for getting presence events are:
# 1. Get the rooms the user is in.

View file

@ -242,12 +242,18 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
async def get_new_events_as(
self, from_key: int, service: ApplicationService
) -> Tuple[List[JsonDict], int]:
"""Returns a set of new receipt events that an appservice
"""Returns a set of new read receipt events that an appservice
may be interested in.
Args:
from_key: the stream position at which events should be fetched from
service: The appservice which may be interested
Returns:
A two-tuple containing the following:
* A list of json dictionaries derived from read receipts that the
appservice may be interested in.
* The current read receipt stream token.
"""
from_key = int(from_key)
to_key = self.get_current_key()

View file

@ -465,17 +465,35 @@ class RoomCreationHandler:
# the room has been created
# Calculate the minimum power level needed to clone the room
event_power_levels = power_levels.get("events", {})
if not isinstance(event_power_levels, dict):
event_power_levels = {}
state_default = power_levels.get("state_default", 50)
try:
state_default_int = int(state_default) # type: ignore[arg-type]
except (TypeError, ValueError):
state_default_int = 50
ban = power_levels.get("ban", 50)
needed_power_level = max(state_default, ban, max(event_power_levels.values()))
try:
ban = int(ban) # type: ignore[arg-type]
except (TypeError, ValueError):
ban = 50
needed_power_level = max(
state_default_int, ban, max(event_power_levels.values())
)
# Get the user's current power level, this matches the logic in get_user_power_level,
# but without the entire state map.
user_power_levels = power_levels.setdefault("users", {})
if not isinstance(user_power_levels, dict):
user_power_levels = {}
users_default = power_levels.get("users_default", 0)
current_power_level = user_power_levels.get(user_id, users_default)
try:
current_power_level_int = int(current_power_level) # type: ignore[arg-type]
except (TypeError, ValueError):
current_power_level_int = 0
# Raise the requester's power level in the new room if necessary
if current_power_level < needed_power_level:
if current_power_level_int < needed_power_level:
user_power_levels[user_id] = needed_power_level
await self._send_events_for_new_room(
@ -765,6 +783,15 @@ class RoomCreationHandler:
if not allowed_by_third_party_rules:
raise SynapseError(403, "Room visibility value not allowed.")
if is_public:
if not self.config.roomdirectory.is_publishing_room_allowed(
user_id, room_id, room_alias
):
# Let's just return a generic message, as there may be all sorts of
# reasons why we said no. TODO: Allow configurable error messages
# per alias creation rule?
raise SynapseError(403, "Not allowed to publish room")
directory_handler = self.hs.get_directory_handler()
if room_alias:
await directory_handler.create_association(
@ -775,15 +802,6 @@ class RoomCreationHandler:
check_membership=False,
)
if is_public:
if not self.config.roomdirectory.is_publishing_room_allowed(
user_id, room_id, room_alias
):
# Lets just return a generic message, as there may be all sorts of
# reasons why we said no. TODO: Allow configurable error messages
# per alias creation rule?
raise SynapseError(403, "Not allowed to publish room")
preset_config = config.get(
"preset",
RoomCreationPreset.PRIVATE_CHAT

View file

@ -13,6 +13,10 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def generate_fake_event_id() -> str:
return "$fake_" + random_string(43)
class RoomBatchHandler:
def __init__(self, hs: "HomeServer"):
self.hs = hs
@ -180,6 +184,11 @@ class RoomBatchHandler:
state_event_ids_at_start = []
auth_event_ids = initial_auth_event_ids.copy()
# Make the state events float off on their own so we don't have a
# bunch of `@mxid joined the room` noise between each batch
prev_event_id_for_state_chain = generate_fake_event_id()
for state_event in state_events_at_start:
assert_params_in_dict(
state_event, ["type", "origin_server_ts", "content", "sender"]
@ -203,10 +212,6 @@ class RoomBatchHandler:
# Mark all events as historical
event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True
# Make the state events float off on their own so we don't have a
# bunch of `@mxid joined the room` noise between each batch
fake_prev_event_id = "$" + random_string(43)
# TODO: This is pretty much the same as some other code to handle inserting state in this file
if event_dict["type"] == EventTypes.Member:
membership = event_dict["content"].get("membership", None)
@ -220,7 +225,7 @@ class RoomBatchHandler:
action=membership,
content=event_dict["content"],
outlier=True,
prev_event_ids=[fake_prev_event_id],
prev_event_ids=[prev_event_id_for_state_chain],
# Make sure to use a copy of this list because we modify it
# later in the loop here. Otherwise it will be the same
# reference and also update in the event when we append later.
@ -240,7 +245,7 @@ class RoomBatchHandler:
),
event_dict,
outlier=True,
prev_event_ids=[fake_prev_event_id],
prev_event_ids=[prev_event_id_for_state_chain],
# Make sure to use a copy of this list because we modify it
# later in the loop here. Otherwise it will be the same
# reference and also update in the event when we append later.
@ -250,6 +255,8 @@ class RoomBatchHandler:
state_event_ids_at_start.append(event_id)
auth_event_ids.append(event_id)
# Connect all the state in a floating chain
prev_event_id_for_state_chain = event_id
return state_event_ids_at_start
@ -296,6 +303,10 @@ class RoomBatchHandler:
for ev in events_to_create:
assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"])
assert self.hs.is_mine_id(ev["sender"]), "User must be our own: %s" % (
ev["sender"],
)
event_dict = {
"type": ev["type"],
"origin_server_ts": ev["origin_server_ts"],
@ -318,6 +329,19 @@ class RoomBatchHandler:
historical=True,
depth=inherited_depth,
)
assert context._state_group
# Normally this is done when persisting the event but we have to
# pre-emptively do it here because we create all the events first,
# then persist them in another pass below. And we want to share
# state_groups across the whole batch so this lookup needs to work
# for the next event in the batch in this loop.
await self.store.store_state_group_id_for_event_id(
event_id=event.event_id,
state_group_id=context._state_group,
)
logger.debug(
"RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s, auth_event_ids=%s",
event,
@ -325,10 +349,6 @@ class RoomBatchHandler:
auth_event_ids,
)
assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
event.sender,
)
events_to_persist.append((event, context))
event_id = event.event_id

View file

@ -465,17 +465,23 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
may be interested in.
Args:
from_key: the stream position at which events should be fetched from
service: The appservice which may be interested
from_key: the stream position at which events should be fetched from.
service: The appservice which may be interested.
Returns:
A two-tuple containing the following:
* A list of json dictionaries derived from typing events that the
appservice may be interested in.
* The latest known room serial.
"""
with Measure(self.clock, "typing.get_new_events_as"):
from_key = int(from_key)
handler = self.get_typing_handler()
events = []
for room_id in handler._room_serials.keys():
if handler._room_serials[room_id] <= from_key:
continue
if not await service.matches_user_in_member_list(
room_id, handler.store
):

View file

@ -196,63 +196,12 @@ class UserDirectoryHandler(StateDeltasHandler):
room_id, prev_event_id, event_id, typ
)
elif typ == EventTypes.Member:
change = await self._get_key_change(
await self._handle_room_membership_event(
room_id,
prev_event_id,
event_id,
key_name="membership",
public_value=Membership.JOIN,
state_key,
)
is_remote = not self.is_mine_id(state_key)
if change is MatchChange.now_false:
# Need to check if the server left the room entirely, if so
# we might need to remove all the users in that room
is_in_room = await self.store.is_host_joined(
room_id, self.server_name
)
if not is_in_room:
logger.debug("Server left room: %r", room_id)
# Fetch all the users that we marked as being in user
# directory due to being in the room and then check if
# need to remove those users or not
user_ids = await self.store.get_users_in_dir_due_to_room(
room_id
)
for user_id in user_ids:
await self._handle_remove_user(room_id, user_id)
continue
else:
logger.debug("Server is still in room: %r", room_id)
include_in_dir = (
is_remote
or await self.store.should_include_local_user_in_dir(state_key)
)
if include_in_dir:
if change is MatchChange.no_change:
# Handle any profile changes for remote users.
# (For local users we are not forced to scan membership
# events; instead the rest of the application calls
# `handle_local_profile_change`.)
if is_remote:
await self._handle_profile_change(
state_key, room_id, prev_event_id, event_id
)
continue
if change is MatchChange.now_true: # The user joined
# This may be the first time we've seen a remote user. If
# so, ensure we have a directory entry for them. (We don't
# need to do this for local users: their directory entry
# is created at the point of registration.
if is_remote:
await self._upsert_directory_entry_for_remote_user(
state_key, event_id
)
await self._track_user_joined_room(room_id, state_key)
else: # The user left
await self._handle_remove_user(room_id, state_key)
else:
logger.debug("Ignoring irrelevant type: %r", typ)
@ -317,14 +266,83 @@ class UserDirectoryHandler(StateDeltasHandler):
for user_id in users_in_room:
await self.store.remove_user_who_share_room(user_id, room_id)
# Then, re-add them to the tables.
# Then, re-add all remote users and some local users to the tables.
# NOTE: this is not the most efficient method, as _track_user_joined_room sets
# up local_user -> other_user and other_user_whos_local -> local_user,
# which when ran over an entire room, will result in the same values
# being added multiple times. The batching upserts shouldn't make this
# too bad, though.
for user_id in users_in_room:
await self._track_user_joined_room(room_id, user_id)
if not self.is_mine_id(
user_id
) or await self.store.should_include_local_user_in_dir(user_id):
await self._track_user_joined_room(room_id, user_id)
async def _handle_room_membership_event(
self,
room_id: str,
prev_event_id: str,
event_id: str,
state_key: str,
) -> None:
"""Process a single room membershp event.
We have to do two things:
1. Update the room-sharing tables.
This applies to remote users and non-excluded local users.
2. Update the user_directory and user_directory_search tables.
This applies to remote users only, because we only become aware of
the (and any profile changes) by listening to these events.
The rest of the application knows exactly when local users are
created or their profile changed---it will directly call methods
on this class.
"""
joined = await self._get_key_change(
prev_event_id,
event_id,
key_name="membership",
public_value=Membership.JOIN,
)
# Both cases ignore excluded local users, so start by discarding them.
is_remote = not self.is_mine_id(state_key)
if not is_remote and not await self.store.should_include_local_user_in_dir(
state_key
):
return
if joined is MatchChange.now_false:
# Need to check if the server left the room entirely, if so
# we might need to remove all the users in that room
is_in_room = await self.store.is_host_joined(room_id, self.server_name)
if not is_in_room:
logger.debug("Server left room: %r", room_id)
# Fetch all the users that we marked as being in user
# directory due to being in the room and then check if
# need to remove those users or not
user_ids = await self.store.get_users_in_dir_due_to_room(room_id)
for user_id in user_ids:
await self._handle_remove_user(room_id, user_id)
else:
logger.debug("Server is still in room: %r", room_id)
await self._handle_remove_user(room_id, state_key)
elif joined is MatchChange.no_change:
# Handle any profile changes for remote users.
# (For local users the rest of the application calls
# `handle_local_profile_change`.)
if is_remote:
await self._handle_possible_remote_profile_change(
state_key, room_id, prev_event_id, event_id
)
elif joined is MatchChange.now_true: # The user joined
# This may be the first time we've seen a remote user. If
# so, ensure we have a directory entry for them. (For local users,
# the rest of the application calls `handle_local_profile_change`.)
if is_remote:
await self._upsert_directory_entry_for_remote_user(state_key, event_id)
await self._track_user_joined_room(room_id, state_key)
async def _upsert_directory_entry_for_remote_user(
self, user_id: str, event_id: str
@ -349,61 +367,67 @@ class UserDirectoryHandler(StateDeltasHandler):
"""Someone's just joined a room. Update `users_in_public_rooms` or
`users_who_share_private_rooms` as appropriate.
The caller is responsible for ensuring that the given user is not excluded
from the user directory.
The caller is responsible for ensuring that the given user should be
included in the user directory.
"""
is_public = await self.store.is_room_world_readable_or_publicly_joinable(
room_id
)
other_users_in_room = await self.store.get_users_in_room(room_id)
if is_public:
await self.store.add_users_in_public_rooms(room_id, (user_id,))
else:
users_in_room = await self.store.get_users_in_room(room_id)
other_users_in_room = [
other
for other in users_in_room
if other != user_id
and (
not self.is_mine_id(other)
or await self.store.should_include_local_user_in_dir(other)
)
]
to_insert = set()
# First, if they're our user then we need to update for every user
if self.is_mine_id(user_id):
if await self.store.should_include_local_user_in_dir(user_id):
for other_user_id in other_users_in_room:
if user_id == other_user_id:
continue
to_insert.add((user_id, other_user_id))
for other_user_id in other_users_in_room:
to_insert.add((user_id, other_user_id))
# Next we need to update for every local user in the room
for other_user_id in other_users_in_room:
if user_id == other_user_id:
continue
include_other_user = self.is_mine_id(
other_user_id
) and await self.store.should_include_local_user_in_dir(other_user_id)
if include_other_user:
if self.is_mine_id(other_user_id):
to_insert.add((other_user_id, user_id))
if to_insert:
await self.store.add_users_who_share_private_room(room_id, to_insert)
async def _handle_remove_user(self, room_id: str, user_id: str) -> None:
"""Called when we might need to remove user from directory
"""Called when when someone leaves a room. The user may be local or remote.
(If the person who left was the last local user in this room, the server
is no longer in the room. We call this function to forget that the remaining
remote users are in the room, even though they haven't left. So the name is
a little misleading!)
Args:
room_id: The room ID that user left or stopped being public that
user_id
"""
logger.debug("Removing user %r", user_id)
logger.debug("Removing user %r from room %r", user_id, room_id)
# Remove user from sharing tables
await self.store.remove_user_who_share_room(user_id, room_id)
# Are they still in any rooms? If not, remove them entirely.
rooms_user_is_in = await self.store.get_user_dir_rooms_user_is_in(user_id)
# Additionally, if they're a remote user and we're no longer joined
# to any rooms they're in, remove them from the user directory.
if not self.is_mine_id(user_id):
rooms_user_is_in = await self.store.get_user_dir_rooms_user_is_in(user_id)
if len(rooms_user_is_in) == 0:
await self.store.remove_from_user_dir(user_id)
if len(rooms_user_is_in) == 0:
logger.debug("Removing user %r from directory", user_id)
await self.store.remove_from_user_dir(user_id)
async def _handle_profile_change(
async def _handle_possible_remote_profile_change(
self,
user_id: str,
room_id: str,
@ -411,7 +435,8 @@ class UserDirectoryHandler(StateDeltasHandler):
event_id: Optional[str],
) -> None:
"""Check member event changes for any profile changes and update the
database if there are.
database if there are. This is intended for remote users only. The caller
is responsible for checking that the given user is remote.
"""
if not prev_event_id or not event_id:
return