mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-08-16 18:50:30 -04:00
Merge remote-tracking branch 'upstream/release-v1.46'
This commit is contained in:
commit
cf45cfd314
172 changed files with 5549 additions and 2350 deletions
|
@ -185,19 +185,26 @@ class ApplicationServicesHandler:
|
|||
new_token: Optional[int],
|
||||
users: Optional[Collection[Union[str, UserID]]] = None,
|
||||
) -> None:
|
||||
"""This is called by the notifier in the background
|
||||
when a ephemeral event handled by the homeserver.
|
||||
"""
|
||||
This is called by the notifier in the background when an ephemeral event is handled
|
||||
by the homeserver.
|
||||
|
||||
This will determine which appservices
|
||||
are interested in the event, and submit them.
|
||||
|
||||
Events will only be pushed to appservices
|
||||
that have opted into ephemeral events
|
||||
This will determine which appservices are interested in the event, and submit them.
|
||||
|
||||
Args:
|
||||
stream_key: The stream the event came from.
|
||||
new_token: The latest stream token
|
||||
users: The user(s) involved with the event.
|
||||
|
||||
`stream_key` can be "typing_key", "receipt_key" or "presence_key". Any other
|
||||
value for `stream_key` will cause this function to return early.
|
||||
|
||||
Ephemeral events will only be pushed to appservices that have opted into
|
||||
them.
|
||||
|
||||
Appservices will only receive ephemeral events that fall within their
|
||||
registered user and room namespaces.
|
||||
|
||||
new_token: The latest stream token.
|
||||
users: The users that should be informed of the new event, if any.
|
||||
"""
|
||||
if not self.notify_appservices:
|
||||
return
|
||||
|
@ -232,21 +239,32 @@ class ApplicationServicesHandler:
|
|||
for service in services:
|
||||
# Only handle typing if we have the latest token
|
||||
if stream_key == "typing_key" and new_token is not None:
|
||||
# Note that we don't persist the token (via set_type_stream_id_for_appservice)
|
||||
# for typing_key due to performance reasons and due to their highly
|
||||
# ephemeral nature.
|
||||
#
|
||||
# Instead we simply grab the latest typing updates in _handle_typing
|
||||
# and, if they apply to this application service, send it off.
|
||||
events = await self._handle_typing(service, new_token)
|
||||
if events:
|
||||
self.scheduler.submit_ephemeral_events_for_as(service, events)
|
||||
# We don't persist the token for typing_key for performance reasons
|
||||
|
||||
elif stream_key == "receipt_key":
|
||||
events = await self._handle_receipts(service)
|
||||
if events:
|
||||
self.scheduler.submit_ephemeral_events_for_as(service, events)
|
||||
|
||||
# Persist the latest handled stream token for this appservice
|
||||
await self.store.set_type_stream_id_for_appservice(
|
||||
service, "read_receipt", new_token
|
||||
)
|
||||
|
||||
elif stream_key == "presence_key":
|
||||
events = await self._handle_presence(service, users)
|
||||
if events:
|
||||
self.scheduler.submit_ephemeral_events_for_as(service, events)
|
||||
|
||||
# Persist the latest handled stream token for this appservice
|
||||
await self.store.set_type_stream_id_for_appservice(
|
||||
service, "presence", new_token
|
||||
)
|
||||
|
@ -254,18 +272,54 @@ class ApplicationServicesHandler:
|
|||
async def _handle_typing(
|
||||
self, service: ApplicationService, new_token: int
|
||||
) -> List[JsonDict]:
|
||||
"""
|
||||
Return the typing events since the given stream token that the given application
|
||||
service should receive.
|
||||
|
||||
First fetch all typing events between the given typing stream token (non-inclusive)
|
||||
and the latest typing event stream token (inclusive). Then return only those typing
|
||||
events that the given application service may be interested in.
|
||||
|
||||
Args:
|
||||
service: The application service to check for which events it should receive.
|
||||
new_token: A typing event stream token.
|
||||
|
||||
Returns:
|
||||
A list of JSON dictionaries containing data derived from the typing events that
|
||||
should be sent to the given application service.
|
||||
"""
|
||||
typing_source = self.event_sources.sources.typing
|
||||
# Get the typing events from just before current
|
||||
typing, _ = await typing_source.get_new_events_as(
|
||||
service=service,
|
||||
# For performance reasons, we don't persist the previous
|
||||
# token in the DB and instead fetch the latest typing information
|
||||
# token in the DB and instead fetch the latest typing event
|
||||
# for appservices.
|
||||
# TODO: It'd likely be more efficient to simply fetch the
|
||||
# typing event with the given 'new_token' stream token and
|
||||
# check if the given service was interested, rather than
|
||||
# iterating over all typing events and only grabbing the
|
||||
# latest few.
|
||||
from_key=new_token - 1,
|
||||
)
|
||||
return typing
|
||||
|
||||
async def _handle_receipts(self, service: ApplicationService) -> List[JsonDict]:
|
||||
"""
|
||||
Return the latest read receipts that the given application service should receive.
|
||||
|
||||
First fetch all read receipts between the last receipt stream token that this
|
||||
application service should have previously received (non-inclusive) and the
|
||||
latest read receipt stream token (inclusive). Then from that set, return only
|
||||
those read receipts that the given application service may be interested in.
|
||||
|
||||
Args:
|
||||
service: The application service to check for which events it should receive.
|
||||
|
||||
Returns:
|
||||
A list of JSON dictionaries containing data derived from the read receipts that
|
||||
should be sent to the given application service.
|
||||
"""
|
||||
from_key = await self.store.get_type_stream_id_for_appservice(
|
||||
service, "read_receipt"
|
||||
)
|
||||
|
@ -278,6 +332,22 @@ class ApplicationServicesHandler:
|
|||
async def _handle_presence(
|
||||
self, service: ApplicationService, users: Collection[Union[str, UserID]]
|
||||
) -> List[JsonDict]:
|
||||
"""
|
||||
Return the latest presence updates that the given application service should receive.
|
||||
|
||||
First, filter the given users list to those that the application service is
|
||||
interested in. Then retrieve the latest presence updates since the
|
||||
the last-known previously received presence stream token for the given
|
||||
application service. Return those presence updates.
|
||||
|
||||
Args:
|
||||
service: The application service that ephemeral events are being sent to.
|
||||
users: The users that should receive the presence update.
|
||||
|
||||
Returns:
|
||||
A list of json dictionaries containing data derived from the presence events
|
||||
that should be sent to the given application service.
|
||||
"""
|
||||
events: List[JsonDict] = []
|
||||
presence_source = self.event_sources.sources.presence
|
||||
from_key = await self.store.get_type_stream_id_for_appservice(
|
||||
|
@ -290,9 +360,9 @@ class ApplicationServicesHandler:
|
|||
interested = await service.is_interested_in_presence(user, self.store)
|
||||
if not interested:
|
||||
continue
|
||||
|
||||
presence_events, _ = await presence_source.get_new_events(
|
||||
user=user,
|
||||
service=service,
|
||||
from_key=from_key,
|
||||
)
|
||||
time_now = self.clock.time_msec()
|
||||
|
|
|
@ -62,7 +62,6 @@ from synapse.http.server import finish_request, respond_with_html
|
|||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import defer_to_thread
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.storage.roommember import ProfileInfo
|
||||
from synapse.types import JsonDict, Requester, UserID
|
||||
from synapse.util import stringutils as stringutils
|
||||
|
@ -73,6 +72,7 @@ from synapse.util.stringutils import base62_encode
|
|||
from synapse.util.threepids import canonicalise_email
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.rest.client.login import LoginResponse
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
@ -200,46 +200,13 @@ class AuthHandler:
|
|||
|
||||
self.bcrypt_rounds = hs.config.registration.bcrypt_rounds
|
||||
|
||||
# we can't use hs.get_module_api() here, because to do so will create an
|
||||
# import loop.
|
||||
#
|
||||
# TODO: refactor this class to separate the lower-level stuff that
|
||||
# ModuleApi can use from the higher-level stuff that uses ModuleApi, as
|
||||
# better way to break the loop
|
||||
account_handler = ModuleApi(hs, self)
|
||||
|
||||
self.password_providers = [
|
||||
PasswordProvider.load(module, config, account_handler)
|
||||
for module, config in hs.config.authproviders.password_providers
|
||||
]
|
||||
|
||||
logger.info("Extra password_providers: %s", self.password_providers)
|
||||
self.password_auth_provider = hs.get_password_auth_provider()
|
||||
|
||||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||
self.macaroon_gen = hs.get_macaroon_generator()
|
||||
self._password_enabled = hs.config.auth.password_enabled
|
||||
self._password_localdb_enabled = hs.config.auth.password_localdb_enabled
|
||||
|
||||
# start out by assuming PASSWORD is enabled; we will remove it later if not.
|
||||
login_types = set()
|
||||
if self._password_localdb_enabled:
|
||||
login_types.add(LoginType.PASSWORD)
|
||||
|
||||
for provider in self.password_providers:
|
||||
login_types.update(provider.get_supported_login_types().keys())
|
||||
|
||||
if not self._password_enabled:
|
||||
login_types.discard(LoginType.PASSWORD)
|
||||
|
||||
# Some clients just pick the first type in the list. In this case, we want
|
||||
# them to use PASSWORD (rather than token or whatever), so we want to make sure
|
||||
# that comes first, where it's present.
|
||||
self._supported_login_types = []
|
||||
if LoginType.PASSWORD in login_types:
|
||||
self._supported_login_types.append(LoginType.PASSWORD)
|
||||
login_types.remove(LoginType.PASSWORD)
|
||||
self._supported_login_types.extend(login_types)
|
||||
|
||||
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
|
||||
# as per `rc_login.failed_attempts`.
|
||||
self._failed_uia_attempts_ratelimiter = Ratelimiter(
|
||||
|
@ -427,11 +394,10 @@ class AuthHandler:
|
|||
ui_auth_types.add(LoginType.PASSWORD)
|
||||
|
||||
# also allow auth from password providers
|
||||
for provider in self.password_providers:
|
||||
for t in provider.get_supported_login_types().keys():
|
||||
if t == LoginType.PASSWORD and not self._password_enabled:
|
||||
continue
|
||||
ui_auth_types.add(t)
|
||||
for t in self.password_auth_provider.get_supported_login_types().keys():
|
||||
if t == LoginType.PASSWORD and not self._password_enabled:
|
||||
continue
|
||||
ui_auth_types.add(t)
|
||||
|
||||
# if sso is enabled, allow the user to log in via SSO iff they have a mapping
|
||||
# from sso to mxid.
|
||||
|
@ -1038,7 +1004,25 @@ class AuthHandler:
|
|||
Returns:
|
||||
login types
|
||||
"""
|
||||
return self._supported_login_types
|
||||
# Load any login types registered by modules
|
||||
# This is stored in the password_auth_provider so this doesn't trigger
|
||||
# any callbacks
|
||||
types = list(self.password_auth_provider.get_supported_login_types().keys())
|
||||
|
||||
# This list should include PASSWORD if (either _password_localdb_enabled is
|
||||
# true or if one of the modules registered it) AND _password_enabled is true
|
||||
# Also:
|
||||
# Some clients just pick the first type in the list. In this case, we want
|
||||
# them to use PASSWORD (rather than token or whatever), so we want to make sure
|
||||
# that comes first, where it's present.
|
||||
if LoginType.PASSWORD in types:
|
||||
types.remove(LoginType.PASSWORD)
|
||||
if self._password_enabled:
|
||||
types.insert(0, LoginType.PASSWORD)
|
||||
elif self._password_localdb_enabled and self._password_enabled:
|
||||
types.insert(0, LoginType.PASSWORD)
|
||||
|
||||
return types
|
||||
|
||||
async def validate_login(
|
||||
self,
|
||||
|
@ -1217,15 +1201,20 @@ class AuthHandler:
|
|||
|
||||
known_login_type = False
|
||||
|
||||
for provider in self.password_providers:
|
||||
supported_login_types = provider.get_supported_login_types()
|
||||
if login_type not in supported_login_types:
|
||||
# this password provider doesn't understand this login type
|
||||
continue
|
||||
|
||||
# Check if login_type matches a type registered by one of the modules
|
||||
# We don't need to remove LoginType.PASSWORD from the list if password login is
|
||||
# disabled, since if that were the case then by this point we know that the
|
||||
# login_type is not LoginType.PASSWORD
|
||||
supported_login_types = self.password_auth_provider.get_supported_login_types()
|
||||
# check if the login type being used is supported by a module
|
||||
if login_type in supported_login_types:
|
||||
# Make a note that this login type is supported by the server
|
||||
known_login_type = True
|
||||
# Get all the fields expected for this login types
|
||||
login_fields = supported_login_types[login_type]
|
||||
|
||||
# go through the login submission and keep track of which required fields are
|
||||
# provided/not provided
|
||||
missing_fields = []
|
||||
login_dict = {}
|
||||
for f in login_fields:
|
||||
|
@ -1233,6 +1222,7 @@ class AuthHandler:
|
|||
missing_fields.append(f)
|
||||
else:
|
||||
login_dict[f] = login_submission[f]
|
||||
# raise an error if any of the expected fields for that login type weren't provided
|
||||
if missing_fields:
|
||||
raise SynapseError(
|
||||
400,
|
||||
|
@ -1240,10 +1230,15 @@ class AuthHandler:
|
|||
% (login_type, missing_fields),
|
||||
)
|
||||
|
||||
result = await provider.check_auth(username, login_type, login_dict)
|
||||
# call all of the check_auth hooks for that login_type
|
||||
# it will return a result once the first success is found (or None otherwise)
|
||||
result = await self.password_auth_provider.check_auth(
|
||||
username, login_type, login_dict
|
||||
)
|
||||
if result:
|
||||
return result
|
||||
|
||||
# if no module managed to authenticate the user, then fallback to built in password based auth
|
||||
if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
|
||||
known_login_type = True
|
||||
|
||||
|
@ -1282,11 +1277,16 @@ class AuthHandler:
|
|||
completed login/registration, or `None`. If authentication was
|
||||
unsuccessful, `user_id` and `callback` are both `None`.
|
||||
"""
|
||||
for provider in self.password_providers:
|
||||
result = await provider.check_3pid_auth(medium, address, password)
|
||||
if result:
|
||||
return result
|
||||
# call all of the check_3pid_auth callbacks
|
||||
# Result will be from the first callback that returns something other than None
|
||||
# If all the callbacks return None, then result is also set to None
|
||||
result = await self.password_auth_provider.check_3pid_auth(
|
||||
medium, address, password
|
||||
)
|
||||
if result:
|
||||
return result
|
||||
|
||||
# if result is None then return (None, None)
|
||||
return None, None
|
||||
|
||||
async def _check_local_password(self, user_id: str, password: str) -> Optional[str]:
|
||||
|
@ -1365,13 +1365,12 @@ class AuthHandler:
|
|||
user_info = await self.auth.get_user_by_access_token(access_token)
|
||||
await self.store.delete_access_token(access_token)
|
||||
|
||||
# see if any of our auth providers want to know about this
|
||||
for provider in self.password_providers:
|
||||
await provider.on_logged_out(
|
||||
user_id=user_info.user_id,
|
||||
device_id=user_info.device_id,
|
||||
access_token=access_token,
|
||||
)
|
||||
# see if any modules want to know about this
|
||||
await self.password_auth_provider.on_logged_out(
|
||||
user_id=user_info.user_id,
|
||||
device_id=user_info.device_id,
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
# delete pushers associated with this access token
|
||||
if user_info.token_id is not None:
|
||||
|
@ -1398,12 +1397,11 @@ class AuthHandler:
|
|||
user_id, except_token_id=except_token_id, device_id=device_id
|
||||
)
|
||||
|
||||
# see if any of our auth providers want to know about this
|
||||
for provider in self.password_providers:
|
||||
for token, _, device_id in tokens_and_devices:
|
||||
await provider.on_logged_out(
|
||||
user_id=user_id, device_id=device_id, access_token=token
|
||||
)
|
||||
# see if any modules want to know about this
|
||||
for token, _, device_id in tokens_and_devices:
|
||||
await self.password_auth_provider.on_logged_out(
|
||||
user_id=user_id, device_id=device_id, access_token=token
|
||||
)
|
||||
|
||||
# delete pushers associated with the access tokens
|
||||
await self.hs.get_pusherpool().remove_pushers_by_access_token(
|
||||
|
@ -1811,40 +1809,230 @@ class MacaroonGenerator:
|
|||
return macaroon
|
||||
|
||||
|
||||
class PasswordProvider:
|
||||
"""Wrapper for a password auth provider module
|
||||
def load_legacy_password_auth_providers(hs: "HomeServer") -> None:
|
||||
module_api = hs.get_module_api()
|
||||
for module, config in hs.config.authproviders.password_providers:
|
||||
load_single_legacy_password_auth_provider(
|
||||
module=module, config=config, api=module_api
|
||||
)
|
||||
|
||||
This class abstracts out all of the backwards-compatibility hacks for
|
||||
password providers, to provide a consistent interface.
|
||||
|
||||
def load_single_legacy_password_auth_provider(
|
||||
module: Type,
|
||||
config: JsonDict,
|
||||
api: "ModuleApi",
|
||||
) -> None:
|
||||
try:
|
||||
provider = module(config=config, account_handler=api)
|
||||
except Exception as e:
|
||||
logger.error("Error while initializing %r: %s", module, e)
|
||||
raise
|
||||
|
||||
# The known hooks. If a module implements a method who's name appears in this set
|
||||
# we'll want to register it
|
||||
password_auth_provider_methods = {
|
||||
"check_3pid_auth",
|
||||
"on_logged_out",
|
||||
}
|
||||
|
||||
# All methods that the module provides should be async, but this wasn't enforced
|
||||
# in the old module system, so we wrap them if needed
|
||||
def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
|
||||
# f might be None if the callback isn't implemented by the module. In this
|
||||
# case we don't want to register a callback at all so we return None.
|
||||
if f is None:
|
||||
return None
|
||||
|
||||
# We need to wrap check_password because its old form would return a boolean
|
||||
# but we now want it to behave just like check_auth() and return the matrix id of
|
||||
# the user if authentication succeeded or None otherwise
|
||||
if f.__name__ == "check_password":
|
||||
|
||||
async def wrapped_check_password(
|
||||
username: str, login_type: str, login_dict: JsonDict
|
||||
) -> Optional[Tuple[str, Optional[Callable]]]:
|
||||
# We've already made sure f is not None above, but mypy doesn't do well
|
||||
# across function boundaries so we need to tell it f is definitely not
|
||||
# None.
|
||||
assert f is not None
|
||||
|
||||
matrix_user_id = api.get_qualified_user_id(username)
|
||||
password = login_dict["password"]
|
||||
|
||||
is_valid = await f(matrix_user_id, password)
|
||||
|
||||
if is_valid:
|
||||
return matrix_user_id, None
|
||||
|
||||
return None
|
||||
|
||||
return wrapped_check_password
|
||||
|
||||
# We need to wrap check_auth as in the old form it could return
|
||||
# just a str, but now it must return Optional[Tuple[str, Optional[Callable]]
|
||||
if f.__name__ == "check_auth":
|
||||
|
||||
async def wrapped_check_auth(
|
||||
username: str, login_type: str, login_dict: JsonDict
|
||||
) -> Optional[Tuple[str, Optional[Callable]]]:
|
||||
# We've already made sure f is not None above, but mypy doesn't do well
|
||||
# across function boundaries so we need to tell it f is definitely not
|
||||
# None.
|
||||
assert f is not None
|
||||
|
||||
result = await f(username, login_type, login_dict)
|
||||
|
||||
if isinstance(result, str):
|
||||
return result, None
|
||||
|
||||
return result
|
||||
|
||||
return wrapped_check_auth
|
||||
|
||||
# We need to wrap check_3pid_auth as in the old form it could return
|
||||
# just a str, but now it must return Optional[Tuple[str, Optional[Callable]]
|
||||
if f.__name__ == "check_3pid_auth":
|
||||
|
||||
async def wrapped_check_3pid_auth(
|
||||
medium: str, address: str, password: str
|
||||
) -> Optional[Tuple[str, Optional[Callable]]]:
|
||||
# We've already made sure f is not None above, but mypy doesn't do well
|
||||
# across function boundaries so we need to tell it f is definitely not
|
||||
# None.
|
||||
assert f is not None
|
||||
|
||||
result = await f(medium, address, password)
|
||||
|
||||
if isinstance(result, str):
|
||||
return result, None
|
||||
|
||||
return result
|
||||
|
||||
return wrapped_check_3pid_auth
|
||||
|
||||
def run(*args: Tuple, **kwargs: Dict) -> Awaitable:
|
||||
# mypy doesn't do well across function boundaries so we need to tell it
|
||||
# f is definitely not None.
|
||||
assert f is not None
|
||||
|
||||
return maybe_awaitable(f(*args, **kwargs))
|
||||
|
||||
return run
|
||||
|
||||
# populate hooks with the implemented methods, wrapped with async_wrapper
|
||||
hooks = {
|
||||
hook: async_wrapper(getattr(provider, hook, None))
|
||||
for hook in password_auth_provider_methods
|
||||
}
|
||||
|
||||
supported_login_types = {}
|
||||
# call get_supported_login_types and add that to the dict
|
||||
g = getattr(provider, "get_supported_login_types", None)
|
||||
if g is not None:
|
||||
# Note the old module style also called get_supported_login_types at loading time
|
||||
# and it is synchronous
|
||||
supported_login_types.update(g())
|
||||
|
||||
auth_checkers = {}
|
||||
# Legacy modules have a check_auth method which expects to be called with one of
|
||||
# the keys returned by get_supported_login_types. New style modules register a
|
||||
# dictionary of login_type->check_auth_method mappings
|
||||
check_auth = async_wrapper(getattr(provider, "check_auth", None))
|
||||
if check_auth is not None:
|
||||
for login_type, fields in supported_login_types.items():
|
||||
# need tuple(fields) since fields can be any Iterable type (so may not be hashable)
|
||||
auth_checkers[(login_type, tuple(fields))] = check_auth
|
||||
|
||||
# if it has a "check_password" method then it should handle all auth checks
|
||||
# with login type of LoginType.PASSWORD
|
||||
check_password = async_wrapper(getattr(provider, "check_password", None))
|
||||
if check_password is not None:
|
||||
# need to use a tuple here for ("password",) not a list since lists aren't hashable
|
||||
auth_checkers[(LoginType.PASSWORD, ("password",))] = check_password
|
||||
|
||||
api.register_password_auth_provider_callbacks(hooks, auth_checkers=auth_checkers)
|
||||
|
||||
|
||||
CHECK_3PID_AUTH_CALLBACK = Callable[
|
||||
[str, str, str],
|
||||
Awaitable[
|
||||
Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
|
||||
],
|
||||
]
|
||||
ON_LOGGED_OUT_CALLBACK = Callable[[str, Optional[str], str], Awaitable]
|
||||
CHECK_AUTH_CALLBACK = Callable[
|
||||
[str, str, JsonDict],
|
||||
Awaitable[
|
||||
Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
class PasswordAuthProvider:
|
||||
"""
|
||||
A class that the AuthHandler calls when authenticating users
|
||||
It allows modules to provide alternative methods for authentication
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls, module: Type, config: JsonDict, module_api: ModuleApi
|
||||
) -> "PasswordProvider":
|
||||
try:
|
||||
pp = module(config=config, account_handler=module_api)
|
||||
except Exception as e:
|
||||
logger.error("Error while initializing %r: %s", module, e)
|
||||
raise
|
||||
return cls(pp, module_api)
|
||||
def __init__(self) -> None:
|
||||
# lists of callbacks
|
||||
self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = []
|
||||
self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = []
|
||||
|
||||
def __init__(self, pp: "PasswordProvider", module_api: ModuleApi):
|
||||
self._pp = pp
|
||||
self._module_api = module_api
|
||||
# Mapping from login type to login parameters
|
||||
self._supported_login_types: Dict[str, Iterable[str]] = {}
|
||||
|
||||
self._supported_login_types = {}
|
||||
# Mapping from login type to auth checker callbacks
|
||||
self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {}
|
||||
|
||||
# grandfather in check_password support
|
||||
if hasattr(self._pp, "check_password"):
|
||||
self._supported_login_types[LoginType.PASSWORD] = ("password",)
|
||||
def register_password_auth_provider_callbacks(
|
||||
self,
|
||||
check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None,
|
||||
on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None,
|
||||
auth_checkers: Optional[Dict[Tuple[str, Tuple], CHECK_AUTH_CALLBACK]] = None,
|
||||
) -> None:
|
||||
# Register check_3pid_auth callback
|
||||
if check_3pid_auth is not None:
|
||||
self.check_3pid_auth_callbacks.append(check_3pid_auth)
|
||||
|
||||
g = getattr(self._pp, "get_supported_login_types", None)
|
||||
if g:
|
||||
self._supported_login_types.update(g())
|
||||
# register on_logged_out callback
|
||||
if on_logged_out is not None:
|
||||
self.on_logged_out_callbacks.append(on_logged_out)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self._pp)
|
||||
if auth_checkers is not None:
|
||||
# register a new supported login_type
|
||||
# Iterate through all of the types being registered
|
||||
for (login_type, fields), callback in auth_checkers.items():
|
||||
# Note: fields may be empty here. This would allow a modules auth checker to
|
||||
# be called with just 'login_type' and no password or other secrets
|
||||
|
||||
# Need to check that all the field names are strings or may get nasty errors later
|
||||
for f in fields:
|
||||
if not isinstance(f, str):
|
||||
raise RuntimeError(
|
||||
"A module tried to register support for login type: %s with parameters %s"
|
||||
" but all parameter names must be strings"
|
||||
% (login_type, fields)
|
||||
)
|
||||
|
||||
# 2 modules supporting the same login type must expect the same fields
|
||||
# e.g. 1 can't expect "pass" if the other expects "password"
|
||||
# so throw an exception if that happens
|
||||
if login_type not in self._supported_login_types.get(login_type, []):
|
||||
self._supported_login_types[login_type] = fields
|
||||
else:
|
||||
fields_currently_supported = self._supported_login_types.get(
|
||||
login_type
|
||||
)
|
||||
if fields_currently_supported != fields:
|
||||
raise RuntimeError(
|
||||
"A module tried to register support for login type: %s with parameters %s"
|
||||
" but another module had already registered support for that type with parameters %s"
|
||||
% (login_type, fields, fields_currently_supported)
|
||||
)
|
||||
|
||||
# Add the new method to the list of auth_checker_callbacks for this login type
|
||||
self.auth_checker_callbacks.setdefault(login_type, []).append(callback)
|
||||
|
||||
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
|
||||
"""Get the login types supported by this password provider
|
||||
|
@ -1852,20 +2040,15 @@ class PasswordProvider:
|
|||
Returns a map from a login type identifier (such as m.login.password) to an
|
||||
iterable giving the fields which must be provided by the user in the submission
|
||||
to the /login API.
|
||||
|
||||
This wrapper adds m.login.password to the list if the underlying password
|
||||
provider supports the check_password() api.
|
||||
"""
|
||||
|
||||
return self._supported_login_types
|
||||
|
||||
async def check_auth(
|
||||
self, username: str, login_type: str, login_dict: JsonDict
|
||||
) -> Optional[Tuple[str, Optional[Callable]]]:
|
||||
) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]:
|
||||
"""Check if the user has presented valid login credentials
|
||||
|
||||
This wrapper also calls check_password() if the underlying password provider
|
||||
supports the check_password() api and the login type is m.login.password.
|
||||
|
||||
Args:
|
||||
username: user id presented by the client. Either an MXID or an unqualified
|
||||
username.
|
||||
|
@ -1879,63 +2062,130 @@ class PasswordProvider:
|
|||
user, and `callback` is an optional callback which will be called with the
|
||||
result from the /login call (including access_token, device_id, etc.)
|
||||
"""
|
||||
# first grandfather in a call to check_password
|
||||
if login_type == LoginType.PASSWORD:
|
||||
check_password = getattr(self._pp, "check_password", None)
|
||||
if check_password:
|
||||
qualified_user_id = self._module_api.get_qualified_user_id(username)
|
||||
is_valid = await check_password(
|
||||
qualified_user_id, login_dict["password"]
|
||||
)
|
||||
if is_valid:
|
||||
return qualified_user_id, None
|
||||
|
||||
check_auth = getattr(self._pp, "check_auth", None)
|
||||
if not check_auth:
|
||||
return None
|
||||
result = await check_auth(username, login_type, login_dict)
|
||||
# Go through all callbacks for the login type until one returns with a value
|
||||
# other than None (i.e. until a callback returns a success)
|
||||
for callback in self.auth_checker_callbacks[login_type]:
|
||||
try:
|
||||
result = await callback(username, login_type, login_dict)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to run module API callback %s: %s", callback, e)
|
||||
continue
|
||||
|
||||
# Check if the return value is a str or a tuple
|
||||
if isinstance(result, str):
|
||||
# If it's a str, set callback function to None
|
||||
return result, None
|
||||
if result is not None:
|
||||
# Check that the callback returned a Tuple[str, Optional[Callable]]
|
||||
# "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks
|
||||
# result is always the right type, but as it is 3rd party code it might not be
|
||||
|
||||
return result
|
||||
if not isinstance(result, tuple) or len(result) != 2:
|
||||
logger.warning(
|
||||
"Wrong type returned by module API callback %s: %s, expected"
|
||||
" Optional[Tuple[str, Optional[Callable]]]",
|
||||
callback,
|
||||
result,
|
||||
)
|
||||
continue
|
||||
|
||||
# pull out the two parts of the tuple so we can do type checking
|
||||
str_result, callback_result = result
|
||||
|
||||
# the 1st item in the tuple should be a str
|
||||
if not isinstance(str_result, str):
|
||||
logger.warning( # type: ignore[unreachable]
|
||||
"Wrong type returned by module API callback %s: %s, expected"
|
||||
" Optional[Tuple[str, Optional[Callable]]]",
|
||||
callback,
|
||||
result,
|
||||
)
|
||||
continue
|
||||
|
||||
# the second should be Optional[Callable]
|
||||
if callback_result is not None:
|
||||
if not callable(callback_result):
|
||||
logger.warning( # type: ignore[unreachable]
|
||||
"Wrong type returned by module API callback %s: %s, expected"
|
||||
" Optional[Tuple[str, Optional[Callable]]]",
|
||||
callback,
|
||||
result,
|
||||
)
|
||||
continue
|
||||
|
||||
# The result is a (str, Optional[callback]) tuple so return the successful result
|
||||
return result
|
||||
|
||||
# If this point has been reached then none of the callbacks successfully authenticated
|
||||
# the user so return None
|
||||
return None
|
||||
|
||||
async def check_3pid_auth(
|
||||
self, medium: str, address: str, password: str
|
||||
) -> Optional[Tuple[str, Optional[Callable]]]:
|
||||
g = getattr(self._pp, "check_3pid_auth", None)
|
||||
if not g:
|
||||
return None
|
||||
|
||||
) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]:
|
||||
# This function is able to return a deferred that either
|
||||
# resolves None, meaning authentication failure, or upon
|
||||
# success, to a str (which is the user_id) or a tuple of
|
||||
# (user_id, callback_func), where callback_func should be run
|
||||
# after we've finished everything else
|
||||
result = await g(medium, address, password)
|
||||
|
||||
# Check if the return value is a str or a tuple
|
||||
if isinstance(result, str):
|
||||
# If it's a str, set callback function to None
|
||||
return result, None
|
||||
for callback in self.check_3pid_auth_callbacks:
|
||||
try:
|
||||
result = await callback(medium, address, password)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to run module API callback %s: %s", callback, e)
|
||||
continue
|
||||
|
||||
return result
|
||||
if result is not None:
|
||||
# Check that the callback returned a Tuple[str, Optional[Callable]]
|
||||
# "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks
|
||||
# result is always the right type, but as it is 3rd party code it might not be
|
||||
|
||||
if not isinstance(result, tuple) or len(result) != 2:
|
||||
logger.warning(
|
||||
"Wrong type returned by module API callback %s: %s, expected"
|
||||
" Optional[Tuple[str, Optional[Callable]]]",
|
||||
callback,
|
||||
result,
|
||||
)
|
||||
continue
|
||||
|
||||
# pull out the two parts of the tuple so we can do type checking
|
||||
str_result, callback_result = result
|
||||
|
||||
# the 1st item in the tuple should be a str
|
||||
if not isinstance(str_result, str):
|
||||
logger.warning( # type: ignore[unreachable]
|
||||
"Wrong type returned by module API callback %s: %s, expected"
|
||||
" Optional[Tuple[str, Optional[Callable]]]",
|
||||
callback,
|
||||
result,
|
||||
)
|
||||
continue
|
||||
|
||||
# the second should be Optional[Callable]
|
||||
if callback_result is not None:
|
||||
if not callable(callback_result):
|
||||
logger.warning( # type: ignore[unreachable]
|
||||
"Wrong type returned by module API callback %s: %s, expected"
|
||||
" Optional[Tuple[str, Optional[Callable]]]",
|
||||
callback,
|
||||
result,
|
||||
)
|
||||
continue
|
||||
|
||||
# The result is a (str, Optional[callback]) tuple so return the successful result
|
||||
return result
|
||||
|
||||
# If this point has been reached then none of the callbacks successfully authenticated
|
||||
# the user so return None
|
||||
return None
|
||||
|
||||
async def on_logged_out(
|
||||
self, user_id: str, device_id: Optional[str], access_token: str
|
||||
) -> None:
|
||||
g = getattr(self._pp, "on_logged_out", None)
|
||||
if not g:
|
||||
return
|
||||
|
||||
# This might return an awaitable, if it does block the log out
|
||||
# until it completes.
|
||||
await maybe_awaitable(
|
||||
g(
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
access_token=access_token,
|
||||
)
|
||||
)
|
||||
# call all of the on_logged_out callbacks
|
||||
for callback in self.on_logged_out_callbacks:
|
||||
try:
|
||||
callback(user_id, device_id, access_token)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to run module API callback %s: %s", callback, e)
|
||||
continue
|
||||
|
|
|
@ -131,10 +131,6 @@ class DeactivateAccountHandler:
|
|||
# delete from user directory
|
||||
await self.user_directory_handler.handle_local_user_deactivated(user_id)
|
||||
|
||||
# If the user is present in the monthly active users table
|
||||
# remove them
|
||||
await self.store.remove_deactivated_user_from_mau_table(user_id)
|
||||
|
||||
# Mark the user as erased, if they asked for that
|
||||
if erase_data:
|
||||
user = UserID.from_string(user_id)
|
||||
|
|
|
@ -14,7 +14,18 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
from synapse.api import errors
|
||||
from synapse.api.constants import EventTypes
|
||||
|
@ -443,6 +454,10 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
) -> None:
|
||||
"""Notify that a user's device(s) has changed. Pokes the notifier, and
|
||||
remote servers if the user is local.
|
||||
|
||||
Args:
|
||||
user_id: The Matrix ID of the user who's device list has been updated.
|
||||
device_ids: The device IDs that have changed.
|
||||
"""
|
||||
if not device_ids:
|
||||
# No changes to notify about, so this is a no-op.
|
||||
|
@ -595,7 +610,7 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
|
||||
|
||||
def _update_device_from_client_ips(
|
||||
device: JsonDict, client_ips: Dict[Tuple[str, str], JsonDict]
|
||||
device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
|
||||
) -> None:
|
||||
ip = client_ips.get((device["user_id"], device["device_id"]), {})
|
||||
device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
|
||||
|
|
|
@ -147,7 +147,7 @@ class DirectoryHandler:
|
|||
if not self.config.roomdirectory.is_alias_creation_allowed(
|
||||
user_id, room_id, room_alias_str
|
||||
):
|
||||
# Lets just return a generic message, as there may be all sorts of
|
||||
# Let's just return a generic message, as there may be all sorts of
|
||||
# reasons why we said no. TODO: Allow configurable error messages
|
||||
# per alias creation rule?
|
||||
raise SynapseError(403, "Not allowed to create alias")
|
||||
|
@ -463,7 +463,7 @@ class DirectoryHandler:
|
|||
if not self.config.roomdirectory.is_publishing_room_allowed(
|
||||
user_id, room_id, room_aliases
|
||||
):
|
||||
# Lets just return a generic message, as there may be all sorts of
|
||||
# Let's just return a generic message, as there may be all sorts of
|
||||
# reasons why we said no. TODO: Allow configurable error messages
|
||||
# per alias creation rule?
|
||||
raise SynapseError(403, "Not allowed to publish room")
|
||||
|
|
|
@ -55,8 +55,7 @@ class EventAuthHandler:
|
|||
"""Check an event passes the auth rules at its own auth events"""
|
||||
auth_event_ids = event.auth_event_ids()
|
||||
auth_events_by_id = await self._store.get_events(auth_event_ids)
|
||||
auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()}
|
||||
check_auth_rules_for_event(room_version_obj, event, auth_events)
|
||||
check_auth_rules_for_event(room_version_obj, event, auth_events_by_id.values())
|
||||
|
||||
def compute_auth_events(
|
||||
self,
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
|
||||
"""Contains handlers for federation events."""
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
@ -27,12 +26,7 @@ from unpaddedbase64 import decode_base64
|
|||
from twisted.internet import defer
|
||||
|
||||
from synapse import event_auth
|
||||
from synapse.api.constants import (
|
||||
EventContentFields,
|
||||
EventTypes,
|
||||
Membership,
|
||||
RejectedReason,
|
||||
)
|
||||
from synapse.api.constants import EventContentFields, EventTypes, Membership
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
CodeMessageException,
|
||||
|
@ -43,12 +37,9 @@ from synapse.api.errors import (
|
|||
RequestSendFailed,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion, RoomVersions
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
||||
from synapse.crypto.event_signing import compute_event_signature
|
||||
from synapse.event_auth import (
|
||||
check_auth_rules_for_event,
|
||||
validate_event_for_room_version,
|
||||
)
|
||||
from synapse.event_auth import validate_event_for_room_version
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.events.validator import EventValidator
|
||||
|
@ -238,18 +229,10 @@ class FederationHandler:
|
|||
)
|
||||
return False
|
||||
|
||||
logger.debug(
|
||||
"room_id: %s, backfill: current_depth: %s, max_depth: %s, extrems: %s",
|
||||
room_id,
|
||||
current_depth,
|
||||
max_depth,
|
||||
sorted_extremeties_tuple,
|
||||
)
|
||||
|
||||
# We ignore extremities that have a greater depth than our current depth
|
||||
# as:
|
||||
# 1. we don't really care about getting events that have happened
|
||||
# before our current position; and
|
||||
# after our current position; and
|
||||
# 2. we have likely previously tried and failed to backfill from that
|
||||
# extremity, so to avoid getting "stuck" requesting the same
|
||||
# backfill repeatedly we drop those extremities.
|
||||
|
@ -257,9 +240,19 @@ class FederationHandler:
|
|||
t for t in sorted_extremeties_tuple if int(t[1]) <= current_depth
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
"room_id: %s, backfill: current_depth: %s, limit: %s, max_depth: %s, extrems: %s filtered_sorted_extremeties_tuple: %s",
|
||||
room_id,
|
||||
current_depth,
|
||||
limit,
|
||||
max_depth,
|
||||
sorted_extremeties_tuple,
|
||||
filtered_sorted_extremeties_tuple,
|
||||
)
|
||||
|
||||
# However, we need to check that the filtered extremities are non-empty.
|
||||
# If they are empty then either we can a) bail or b) still attempt to
|
||||
# backill. We opt to try backfilling anyway just in case we do get
|
||||
# backfill. We opt to try backfilling anyway just in case we do get
|
||||
# relevant events.
|
||||
if filtered_sorted_extremeties_tuple:
|
||||
sorted_extremeties_tuple = filtered_sorted_extremeties_tuple
|
||||
|
@ -389,7 +382,7 @@ class FederationHandler:
|
|||
for key, state_dict in states.items()
|
||||
}
|
||||
|
||||
for e_id, _ in sorted_extremeties_tuple:
|
||||
for e_id in event_ids:
|
||||
likely_extremeties_domains = get_domains_from_state(states[e_id])
|
||||
|
||||
success = await try_backfill(
|
||||
|
@ -517,7 +510,7 @@ class FederationHandler:
|
|||
auth_events=auth_chain,
|
||||
)
|
||||
|
||||
max_stream_id = await self._persist_auth_tree(
|
||||
max_stream_id = await self._federation_event_handler.process_remote_join(
|
||||
origin, room_id, auth_chain, state, event, room_version_obj
|
||||
)
|
||||
|
||||
|
@ -1093,119 +1086,6 @@ class FederationHandler:
|
|||
else:
|
||||
return None
|
||||
|
||||
async def _persist_auth_tree(
|
||||
self,
|
||||
origin: str,
|
||||
room_id: str,
|
||||
auth_events: List[EventBase],
|
||||
state: List[EventBase],
|
||||
event: EventBase,
|
||||
room_version: RoomVersion,
|
||||
) -> int:
|
||||
"""Checks the auth chain is valid (and passes auth checks) for the
|
||||
state and event. Then persists the auth chain and state atomically.
|
||||
Persists the event separately. Notifies about the persisted events
|
||||
where appropriate.
|
||||
|
||||
Will attempt to fetch missing auth events.
|
||||
|
||||
Args:
|
||||
origin: Where the events came from
|
||||
room_id,
|
||||
auth_events
|
||||
state
|
||||
event
|
||||
room_version: The room version we expect this room to have, and
|
||||
will raise if it doesn't match the version in the create event.
|
||||
"""
|
||||
events_to_context = {}
|
||||
for e in itertools.chain(auth_events, state):
|
||||
e.internal_metadata.outlier = True
|
||||
events_to_context[e.event_id] = EventContext.for_outlier()
|
||||
|
||||
event_map = {
|
||||
e.event_id: e for e in itertools.chain(auth_events, state, [event])
|
||||
}
|
||||
|
||||
create_event = None
|
||||
for e in auth_events:
|
||||
if (e.type, e.state_key) == (EventTypes.Create, ""):
|
||||
create_event = e
|
||||
break
|
||||
|
||||
if create_event is None:
|
||||
# If the state doesn't have a create event then the room is
|
||||
# invalid, and it would fail auth checks anyway.
|
||||
raise SynapseError(400, "No create event in state")
|
||||
|
||||
room_version_id = create_event.content.get(
|
||||
"room_version", RoomVersions.V1.identifier
|
||||
)
|
||||
|
||||
if room_version.identifier != room_version_id:
|
||||
raise SynapseError(400, "Room version mismatch")
|
||||
|
||||
missing_auth_events = set()
|
||||
for e in itertools.chain(auth_events, state, [event]):
|
||||
for e_id in e.auth_event_ids():
|
||||
if e_id not in event_map:
|
||||
missing_auth_events.add(e_id)
|
||||
|
||||
for e_id in missing_auth_events:
|
||||
m_ev = await self.federation_client.get_pdu(
|
||||
[origin],
|
||||
e_id,
|
||||
room_version=room_version,
|
||||
outlier=True,
|
||||
timeout=10000,
|
||||
)
|
||||
if m_ev and m_ev.event_id == e_id:
|
||||
event_map[e_id] = m_ev
|
||||
else:
|
||||
logger.info("Failed to find auth event %r", e_id)
|
||||
|
||||
for e in itertools.chain(auth_events, state, [event]):
|
||||
auth_for_e = {
|
||||
(event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
|
||||
for e_id in e.auth_event_ids()
|
||||
if e_id in event_map
|
||||
}
|
||||
if create_event:
|
||||
auth_for_e[(EventTypes.Create, "")] = create_event
|
||||
|
||||
try:
|
||||
validate_event_for_room_version(room_version, e)
|
||||
check_auth_rules_for_event(room_version, e, auth_for_e)
|
||||
except SynapseError as err:
|
||||
# we may get SynapseErrors here as well as AuthErrors. For
|
||||
# instance, there are a couple of (ancient) events in some
|
||||
# rooms whose senders do not have the correct sigil; these
|
||||
# cause SynapseErrors in auth.check. We don't want to give up
|
||||
# the attempt to federate altogether in such cases.
|
||||
|
||||
logger.warning("Rejecting %s because %s", e.event_id, err.msg)
|
||||
|
||||
if e == event:
|
||||
raise
|
||||
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
|
||||
|
||||
if auth_events or state:
|
||||
await self._federation_event_handler.persist_events_and_notify(
|
||||
room_id,
|
||||
[
|
||||
(e, events_to_context[e.event_id])
|
||||
for e in itertools.chain(auth_events, state)
|
||||
],
|
||||
)
|
||||
|
||||
new_event_context = await self.state_handler.compute_event_context(
|
||||
event, old_state=state
|
||||
)
|
||||
|
||||
return await self._federation_event_handler.persist_events_and_notify(
|
||||
room_id, [(event, new_event_context)]
|
||||
)
|
||||
|
||||
async def on_get_missing_events(
|
||||
self,
|
||||
origin: str,
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
from http import HTTPStatus
|
||||
from typing import (
|
||||
|
@ -45,7 +46,7 @@ from synapse.api.errors import (
|
|||
RequestSendFailed,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion, RoomVersions
|
||||
from synapse.event_auth import (
|
||||
auth_types_for_event,
|
||||
check_auth_rules_for_event,
|
||||
|
@ -64,7 +65,6 @@ from synapse.replication.http.federation import (
|
|||
from synapse.state import StateResolutionStore
|
||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||
from synapse.types import (
|
||||
MutableStateMap,
|
||||
PersistedEventPosition,
|
||||
RoomStreamToken,
|
||||
StateMap,
|
||||
|
@ -214,7 +214,7 @@ class FederationEventHandler:
|
|||
|
||||
if missing_prevs:
|
||||
# We only backfill backwards to the min depth.
|
||||
min_depth = await self.get_min_depth_for_context(pdu.room_id)
|
||||
min_depth = await self._store.get_min_depth(pdu.room_id)
|
||||
logger.debug("min_depth: %d", min_depth)
|
||||
|
||||
if min_depth is not None and pdu.depth > min_depth:
|
||||
|
@ -361,6 +361,7 @@ class FederationEventHandler:
|
|||
# need to.
|
||||
await self._event_creation_handler.cache_joined_hosts_for_event(event, context)
|
||||
|
||||
await self._check_for_soft_fail(event, None, origin=origin)
|
||||
await self._run_push_actions_and_persist_event(event, context)
|
||||
return event, context
|
||||
|
||||
|
@ -390,9 +391,93 @@ class FederationEventHandler:
|
|||
prev_member_event,
|
||||
)
|
||||
|
||||
async def process_remote_join(
|
||||
self,
|
||||
origin: str,
|
||||
room_id: str,
|
||||
auth_events: List[EventBase],
|
||||
state: List[EventBase],
|
||||
event: EventBase,
|
||||
room_version: RoomVersion,
|
||||
) -> int:
|
||||
"""Persists the events returned by a send_join
|
||||
|
||||
Checks the auth chain is valid (and passes auth checks) for the
|
||||
state and event. Then persists all of the events.
|
||||
Notifies about the persisted events where appropriate.
|
||||
|
||||
Args:
|
||||
origin: Where the events came from
|
||||
room_id:
|
||||
auth_events
|
||||
state
|
||||
event
|
||||
room_version: The room version we expect this room to have, and
|
||||
will raise if it doesn't match the version in the create event.
|
||||
|
||||
Returns:
|
||||
The stream ID after which all events have been persisted.
|
||||
|
||||
Raises:
|
||||
SynapseError if the response is in some way invalid.
|
||||
"""
|
||||
for e in itertools.chain(auth_events, state):
|
||||
e.internal_metadata.outlier = True
|
||||
|
||||
event_map = {e.event_id: e for e in itertools.chain(auth_events, state)}
|
||||
|
||||
create_event = None
|
||||
for e in auth_events:
|
||||
if (e.type, e.state_key) == (EventTypes.Create, ""):
|
||||
create_event = e
|
||||
break
|
||||
|
||||
if create_event is None:
|
||||
# If the state doesn't have a create event then the room is
|
||||
# invalid, and it would fail auth checks anyway.
|
||||
raise SynapseError(400, "No create event in state")
|
||||
|
||||
room_version_id = create_event.content.get(
|
||||
"room_version", RoomVersions.V1.identifier
|
||||
)
|
||||
|
||||
if room_version.identifier != room_version_id:
|
||||
raise SynapseError(400, "Room version mismatch")
|
||||
|
||||
# filter out any events we have already seen
|
||||
seen_remotes = await self._store.have_seen_events(room_id, event_map.keys())
|
||||
for s in seen_remotes:
|
||||
event_map.pop(s, None)
|
||||
|
||||
# persist the auth chain and state events.
|
||||
#
|
||||
# any invalid events here will be marked as rejected, and we'll carry on.
|
||||
#
|
||||
# any events whose auth events are missing (ie, not in the send_join response,
|
||||
# and not already in our db) will just be ignored. This is correct behaviour,
|
||||
# because the reason that auth_events are missing might be due to us being
|
||||
# unable to validate their signatures. The fact that we can't validate their
|
||||
# signatures right now doesn't mean that we will *never* be able to, so it
|
||||
# is premature to reject them.
|
||||
#
|
||||
await self._auth_and_persist_outliers(room_id, event_map.values())
|
||||
|
||||
# and now persist the join event itself.
|
||||
logger.info("Peristing join-via-remote %s", event)
|
||||
with nested_logging_context(suffix=event.event_id):
|
||||
context = await self._state_handler.compute_event_context(
|
||||
event, old_state=state
|
||||
)
|
||||
|
||||
context = await self._check_event_auth(origin, event, context)
|
||||
if context.rejected:
|
||||
raise SynapseError(400, "Join event was rejected")
|
||||
|
||||
return await self.persist_events_and_notify(room_id, [(event, context)])
|
||||
|
||||
@log_function
|
||||
async def backfill(
|
||||
self, dest: str, room_id: str, limit: int, extremities: List[str]
|
||||
self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
|
||||
) -> None:
|
||||
"""Trigger a backfill request to `dest` for the given `room_id`
|
||||
|
||||
|
@ -861,9 +946,15 @@ class FederationEventHandler:
|
|||
) -> None:
|
||||
"""Called when we have a new non-outlier event.
|
||||
|
||||
This is called when we have a new event to add to the room DAG - either directly
|
||||
via a /send request, retrieved via get_missing_events after a /send request, or
|
||||
backfilled after a client request.
|
||||
This is called when we have a new event to add to the room DAG. This can be
|
||||
due to:
|
||||
* events received directly via a /send request
|
||||
* events retrieved via get_missing_events after a /send request
|
||||
* events backfilled after a client request.
|
||||
|
||||
It's not currently used for events received from incoming send_{join,knock,leave}
|
||||
requests (which go via on_send_membership_event), nor for joins created by a
|
||||
remote join dance (which go via process_remote_join).
|
||||
|
||||
We need to do auth checks and put it through the StateHandler.
|
||||
|
||||
|
@ -899,11 +990,19 @@ class FederationEventHandler:
|
|||
logger.exception("Unexpected AuthError from _check_event_auth")
|
||||
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
|
||||
|
||||
if not backfilled and not context.rejected:
|
||||
# For new (non-backfilled and non-outlier) events we check if the event
|
||||
# passes auth based on the current state. If it doesn't then we
|
||||
# "soft-fail" the event.
|
||||
await self._check_for_soft_fail(event, state, origin=origin)
|
||||
|
||||
await self._run_push_actions_and_persist_event(event, context, backfilled)
|
||||
|
||||
if backfilled:
|
||||
if backfilled or context.rejected:
|
||||
return
|
||||
|
||||
await self._maybe_kick_guest_users(event)
|
||||
|
||||
# For encrypted messages we check that we know about the sending device,
|
||||
# if we don't then we mark the device cache for that user as stale.
|
||||
if event.type == EventTypes.Encrypted:
|
||||
|
@ -1116,14 +1215,12 @@ class FederationEventHandler:
|
|||
|
||||
await concurrently_execute(get_event, event_ids, 5)
|
||||
logger.info("Fetched %i events of %i requested", len(events), len(event_ids))
|
||||
await self._auth_and_persist_fetched_events(destination, room_id, events)
|
||||
await self._auth_and_persist_outliers(room_id, events)
|
||||
|
||||
async def _auth_and_persist_fetched_events(
|
||||
self, origin: str, room_id: str, events: Iterable[EventBase]
|
||||
async def _auth_and_persist_outliers(
|
||||
self, room_id: str, events: Iterable[EventBase]
|
||||
) -> None:
|
||||
"""Persist the events fetched by _get_events_and_persist or _get_remote_auth_chain_for_event
|
||||
|
||||
The events to be persisted must be outliers.
|
||||
"""Persist a batch of outlier events fetched from remote servers.
|
||||
|
||||
We first sort the events to make sure that we process each event's auth_events
|
||||
before the event itself, and then auth and persist them.
|
||||
|
@ -1131,7 +1228,6 @@ class FederationEventHandler:
|
|||
Notifies about the events where appropriate.
|
||||
|
||||
Params:
|
||||
origin: where the events came from
|
||||
room_id: the room that the events are meant to be in (though this has
|
||||
not yet been checked)
|
||||
events: the events that have been fetched
|
||||
|
@ -1167,15 +1263,15 @@ class FederationEventHandler:
|
|||
shortstr(e.event_id for e in roots),
|
||||
)
|
||||
|
||||
await self._auth_and_persist_fetched_events_inner(origin, room_id, roots)
|
||||
await self._auth_and_persist_outliers_inner(room_id, roots)
|
||||
|
||||
for ev in roots:
|
||||
del event_map[ev.event_id]
|
||||
|
||||
async def _auth_and_persist_fetched_events_inner(
|
||||
self, origin: str, room_id: str, fetched_events: Collection[EventBase]
|
||||
async def _auth_and_persist_outliers_inner(
|
||||
self, room_id: str, fetched_events: Collection[EventBase]
|
||||
) -> None:
|
||||
"""Helper for _auth_and_persist_fetched_events
|
||||
"""Helper for _auth_and_persist_outliers
|
||||
|
||||
Persists a batch of events where we have (theoretically) already persisted all
|
||||
of their auth events.
|
||||
|
@ -1203,20 +1299,20 @@ class FederationEventHandler:
|
|||
|
||||
def prep(event: EventBase) -> Optional[Tuple[EventBase, EventContext]]:
|
||||
with nested_logging_context(suffix=event.event_id):
|
||||
auth = {}
|
||||
auth = []
|
||||
for auth_event_id in event.auth_event_ids():
|
||||
ae = persisted_events.get(auth_event_id)
|
||||
if not ae:
|
||||
logger.warning(
|
||||
"Event %s relies on auth_event %s, which could not be found.",
|
||||
event,
|
||||
auth_event_id,
|
||||
)
|
||||
# the fact we can't find the auth event doesn't mean it doesn't
|
||||
# exist, which means it is premature to reject `event`. Instead we
|
||||
# just ignore it for now.
|
||||
logger.warning(
|
||||
"Dropping event %s, which relies on auth_event %s, which could not be found",
|
||||
event,
|
||||
auth_event_id,
|
||||
)
|
||||
return None
|
||||
auth[(ae.type, ae.state_key)] = ae
|
||||
auth.append(ae)
|
||||
|
||||
context = EventContext.for_outlier()
|
||||
try:
|
||||
|
@ -1256,6 +1352,10 @@ class FederationEventHandler:
|
|||
|
||||
Returns:
|
||||
The updated context object.
|
||||
|
||||
Raises:
|
||||
AuthError if we were unable to find copies of the event's auth events.
|
||||
(Most other failures just cause us to set `context.rejected`.)
|
||||
"""
|
||||
# This method should only be used for non-outliers
|
||||
assert not event.internal_metadata.outlier
|
||||
|
@ -1272,7 +1372,26 @@ class FederationEventHandler:
|
|||
context.rejected = RejectedReason.AUTH_ERROR
|
||||
return context
|
||||
|
||||
# calculate what the auth events *should* be, to use as a basis for auth.
|
||||
# next, check that we have all of the event's auth events.
|
||||
#
|
||||
# Note that this can raise AuthError, which we want to propagate to the
|
||||
# caller rather than swallow with `context.rejected` (since we cannot be
|
||||
# certain that there is a permanent problem with the event).
|
||||
claimed_auth_events = await self._load_or_fetch_auth_events_for_event(
|
||||
origin, event
|
||||
)
|
||||
|
||||
# ... and check that the event passes auth at those auth events.
|
||||
try:
|
||||
check_auth_rules_for_event(room_version_obj, event, claimed_auth_events)
|
||||
except AuthError as e:
|
||||
logger.warning(
|
||||
"While checking auth of %r against auth_events: %s", event, e
|
||||
)
|
||||
context.rejected = RejectedReason.AUTH_ERROR
|
||||
return context
|
||||
|
||||
# now check auth against what we think the auth events *should* be.
|
||||
prev_state_ids = await context.get_prev_state_ids()
|
||||
auth_events_ids = self._event_auth_handler.compute_auth_events(
|
||||
event, prev_state_ids, for_verification=True
|
||||
|
@ -1283,13 +1402,8 @@ class FederationEventHandler:
|
|||
}
|
||||
|
||||
try:
|
||||
(
|
||||
context,
|
||||
auth_events_for_auth,
|
||||
) = await self._update_auth_events_and_context_for_auth(
|
||||
origin,
|
||||
updated_auth_events = await self._update_auth_events_for_auth(
|
||||
event,
|
||||
context,
|
||||
calculated_auth_event_map=calculated_auth_event_map,
|
||||
)
|
||||
except Exception:
|
||||
|
@ -1302,17 +1416,23 @@ class FederationEventHandler:
|
|||
"Ignoring failure and continuing processing of event.",
|
||||
event.event_id,
|
||||
)
|
||||
updated_auth_events = None
|
||||
|
||||
if updated_auth_events:
|
||||
context = await self._update_context_for_auth_events(
|
||||
event, context, updated_auth_events
|
||||
)
|
||||
auth_events_for_auth = updated_auth_events
|
||||
else:
|
||||
auth_events_for_auth = calculated_auth_event_map
|
||||
|
||||
try:
|
||||
check_auth_rules_for_event(room_version_obj, event, auth_events_for_auth)
|
||||
check_auth_rules_for_event(
|
||||
room_version_obj, event, auth_events_for_auth.values()
|
||||
)
|
||||
except AuthError as e:
|
||||
logger.warning("Failed auth resolution for %r because %s", event, e)
|
||||
context.rejected = RejectedReason.AUTH_ERROR
|
||||
return context
|
||||
|
||||
await self._check_for_soft_fail(event, state, backfilled, origin=origin)
|
||||
await self._maybe_kick_guest_users(event)
|
||||
|
||||
return context
|
||||
|
||||
|
@ -1332,7 +1452,6 @@ class FederationEventHandler:
|
|||
self,
|
||||
event: EventBase,
|
||||
state: Optional[Iterable[EventBase]],
|
||||
backfilled: bool,
|
||||
origin: str,
|
||||
) -> None:
|
||||
"""Checks if we should soft fail the event; if so, marks the event as
|
||||
|
@ -1341,15 +1460,8 @@ class FederationEventHandler:
|
|||
Args:
|
||||
event
|
||||
state: The state at the event if we don't have all the event's prev events
|
||||
backfilled: Whether the event is from backfill
|
||||
origin: The host the event originates from.
|
||||
"""
|
||||
# For new (non-backfilled and non-outlier) events we check if the event
|
||||
# passes auth based on the current state. If it doesn't then we
|
||||
# "soft-fail" the event.
|
||||
if backfilled or event.internal_metadata.is_outlier():
|
||||
return
|
||||
|
||||
extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id)
|
||||
extrem_ids = set(extrem_ids_list)
|
||||
prev_event_ids = set(event.prev_event_ids())
|
||||
|
@ -1403,11 +1515,9 @@ class FederationEventHandler:
|
|||
current_state_ids_list = [
|
||||
e for k, e in current_state_ids.items() if k in auth_types
|
||||
]
|
||||
|
||||
auth_events_map = await self._store.get_events(current_state_ids_list)
|
||||
current_auth_events = {
|
||||
(e.type, e.state_key): e for e in auth_events_map.values()
|
||||
}
|
||||
current_auth_events = await self._store.get_events_as_list(
|
||||
current_state_ids_list
|
||||
)
|
||||
|
||||
try:
|
||||
check_auth_rules_for_event(room_version_obj, event, current_auth_events)
|
||||
|
@ -1426,13 +1536,11 @@ class FederationEventHandler:
|
|||
soft_failed_event_counter.inc()
|
||||
event.internal_metadata.soft_failed = True
|
||||
|
||||
async def _update_auth_events_and_context_for_auth(
|
||||
async def _update_auth_events_for_auth(
|
||||
self,
|
||||
origin: str,
|
||||
event: EventBase,
|
||||
context: EventContext,
|
||||
calculated_auth_event_map: StateMap[EventBase],
|
||||
) -> Tuple[EventContext, StateMap[EventBase]]:
|
||||
) -> Optional[StateMap[EventBase]]:
|
||||
"""Helper for _check_event_auth. See there for docs.
|
||||
|
||||
Checks whether a given event has the expected auth events. If it
|
||||
|
@ -1445,93 +1553,27 @@ class FederationEventHandler:
|
|||
processing of the event.
|
||||
|
||||
Args:
|
||||
origin:
|
||||
event:
|
||||
context:
|
||||
|
||||
calculated_auth_event_map:
|
||||
Our calculated auth_events based on the state of the room
|
||||
at the event's position in the DAG.
|
||||
|
||||
Returns:
|
||||
updated context, updated auth event map
|
||||
updated auth event map, or None if no changes are needed.
|
||||
|
||||
"""
|
||||
assert not event.internal_metadata.outlier
|
||||
|
||||
# take a copy of calculated_auth_event_map before we modify it.
|
||||
auth_events: MutableStateMap[EventBase] = dict(calculated_auth_event_map)
|
||||
|
||||
# check for events which are in the event's claimed auth_events, but not
|
||||
# in our calculated event map.
|
||||
event_auth_events = set(event.auth_event_ids())
|
||||
|
||||
# missing_auth is the set of the event's auth_events which we don't yet have
|
||||
# in auth_events.
|
||||
missing_auth = event_auth_events.difference(
|
||||
e.event_id for e in auth_events.values()
|
||||
)
|
||||
|
||||
# if we have missing events, we need to fetch those events from somewhere.
|
||||
#
|
||||
# we start by checking if they are in the store, and then try calling /event_auth/.
|
||||
if missing_auth:
|
||||
have_events = await self._store.have_seen_events(
|
||||
event.room_id, missing_auth
|
||||
)
|
||||
logger.debug("Events %s are in the store", have_events)
|
||||
missing_auth.difference_update(have_events)
|
||||
|
||||
# missing_auth is now the set of event_ids which:
|
||||
# a. are listed in event.auth_events, *and*
|
||||
# b. are *not* part of our calculated auth events based on room state, *and*
|
||||
# c. are *not* yet in our database.
|
||||
|
||||
if missing_auth:
|
||||
# If we don't have all the auth events, we need to get them.
|
||||
logger.info("auth_events contains unknown events: %s", missing_auth)
|
||||
try:
|
||||
await self._get_remote_auth_chain_for_event(
|
||||
origin, event.room_id, event.event_id
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to get auth chain")
|
||||
else:
|
||||
# load any auth events we might have persisted from the database. This
|
||||
# has the side-effect of correctly setting the rejected_reason on them.
|
||||
auth_events.update(
|
||||
{
|
||||
(ae.type, ae.state_key): ae
|
||||
for ae in await self._store.get_events_as_list(
|
||||
missing_auth, allow_rejected=True
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
# auth_events now contains
|
||||
# 1. our *calculated* auth events based on the room state, plus:
|
||||
# 2. any events which:
|
||||
# a. are listed in `event.auth_events`, *and*
|
||||
# b. are not part of our calculated auth events, *and*
|
||||
# c. were not in our database before the call to /event_auth
|
||||
# d. have since been added to our database (most likely by /event_auth).
|
||||
|
||||
different_auth = event_auth_events.difference(
|
||||
e.event_id for e in auth_events.values()
|
||||
e.event_id for e in calculated_auth_event_map.values()
|
||||
)
|
||||
|
||||
# different_auth is the set of events which *are* in `event.auth_events`, but
|
||||
# which are *not* in `auth_events`. Comparing with (2.) above, this means
|
||||
# exclusively the set of `event.auth_events` which we already had in our
|
||||
# database before any call to /event_auth.
|
||||
#
|
||||
# I'm reasonably sure that the fact that events returned by /event_auth are
|
||||
# blindly added to auth_events (and hence excluded from different_auth) is a bug
|
||||
# - though it's a very long-standing one (see
|
||||
# https://github.com/matrix-org/synapse/commit/78015948a7febb18e000651f72f8f58830a55b93#diff-0bc92da3d703202f5b9be2d3f845e375f5b1a6bc6ba61705a8af9be1121f5e42R786
|
||||
# from Jan 2015 which seems to add it, though it actually just moves it from
|
||||
# elsewhere (before that, it gets lost in a mess of huge "various bug fixes"
|
||||
# PRs).
|
||||
|
||||
if not different_auth:
|
||||
return context, auth_events
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
"auth_events refers to events which are not in our calculated auth "
|
||||
|
@ -1543,27 +1585,18 @@ class FederationEventHandler:
|
|||
# necessary?
|
||||
different_events = await self._store.get_events_as_list(different_auth)
|
||||
|
||||
# double-check they're all in the same room - we should already have checked
|
||||
# this but it doesn't hurt to check again.
|
||||
for d in different_events:
|
||||
if d.room_id != event.room_id:
|
||||
logger.warning(
|
||||
"Event %s refers to auth_event %s which is in a different room",
|
||||
event.event_id,
|
||||
d.event_id,
|
||||
)
|
||||
|
||||
# don't attempt to resolve the claimed auth events against our own
|
||||
# in this case: just use our own auth events.
|
||||
#
|
||||
# XXX: should we reject the event in this case? It feels like we should,
|
||||
# but then shouldn't we also do so if we've failed to fetch any of the
|
||||
# auth events?
|
||||
return context, auth_events
|
||||
assert (
|
||||
d.room_id == event.room_id
|
||||
), f"Event {event.event_id} refers to auth_event {d.event_id} which is in a different room"
|
||||
|
||||
# now we state-resolve between our own idea of the auth events, and the remote's
|
||||
# idea of them.
|
||||
|
||||
local_state = auth_events.values()
|
||||
remote_auth_events = dict(auth_events)
|
||||
local_state = calculated_auth_event_map.values()
|
||||
remote_auth_events = dict(calculated_auth_event_map)
|
||||
remote_auth_events.update({(d.type, d.state_key): d for d in different_events})
|
||||
remote_state = remote_auth_events.values()
|
||||
|
||||
|
@ -1571,23 +1604,93 @@ class FederationEventHandler:
|
|||
new_state = await self._state_handler.resolve_events(
|
||||
room_version, (local_state, remote_state), event
|
||||
)
|
||||
different_state = {
|
||||
(d.type, d.state_key): d
|
||||
for d in new_state.values()
|
||||
if calculated_auth_event_map.get((d.type, d.state_key)) != d
|
||||
}
|
||||
if not different_state:
|
||||
logger.info("State res returned no new state")
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
"After state res: updating auth_events with new state %s",
|
||||
{
|
||||
(d.type, d.state_key): d.event_id
|
||||
for d in new_state.values()
|
||||
if auth_events.get((d.type, d.state_key)) != d
|
||||
},
|
||||
different_state.values(),
|
||||
)
|
||||
|
||||
auth_events.update(new_state)
|
||||
# take a copy of calculated_auth_event_map before we modify it.
|
||||
auth_events = dict(calculated_auth_event_map)
|
||||
auth_events.update(different_state)
|
||||
return auth_events
|
||||
|
||||
context = await self._update_context_for_auth_events(
|
||||
event, context, auth_events
|
||||
async def _load_or_fetch_auth_events_for_event(
|
||||
self, destination: str, event: EventBase
|
||||
) -> Collection[EventBase]:
|
||||
"""Fetch this event's auth_events, from database or remote
|
||||
|
||||
Loads any of the auth_events that we already have from the database/cache. If
|
||||
there are any that are missing, calls /event_auth to get the complete auth
|
||||
chain for the event (and then attempts to load the auth_events again).
|
||||
|
||||
If any of the auth_events cannot be found, raises an AuthError. This can happen
|
||||
for a number of reasons; eg: the events don't exist, or we were unable to talk
|
||||
to `destination`, or we couldn't validate the signature on the event (which
|
||||
in turn has multiple potential causes).
|
||||
|
||||
Args:
|
||||
destination: where to send the /event_auth request. Typically the server
|
||||
that sent us `event` in the first place.
|
||||
event: the event whose auth_events we want
|
||||
|
||||
Returns:
|
||||
all of the events in `event.auth_events`, after deduplication
|
||||
|
||||
Raises:
|
||||
AuthError if we were unable to fetch the auth_events for any reason.
|
||||
"""
|
||||
event_auth_event_ids = set(event.auth_event_ids())
|
||||
event_auth_events = await self._store.get_events(
|
||||
event_auth_event_ids, allow_rejected=True
|
||||
)
|
||||
missing_auth_event_ids = event_auth_event_ids.difference(
|
||||
event_auth_events.keys()
|
||||
)
|
||||
if not missing_auth_event_ids:
|
||||
return event_auth_events.values()
|
||||
|
||||
return context, auth_events
|
||||
logger.info(
|
||||
"Event %s refers to unknown auth events %s: fetching auth chain",
|
||||
event,
|
||||
missing_auth_event_ids,
|
||||
)
|
||||
try:
|
||||
await self._get_remote_auth_chain_for_event(
|
||||
destination, event.room_id, event.event_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to get auth chain for %s: %s", event, e)
|
||||
# in this case, it's very likely we still won't have all the auth
|
||||
# events - but we pick that up below.
|
||||
|
||||
# try to fetch the auth events we missed list time.
|
||||
extra_auth_events = await self._store.get_events(
|
||||
missing_auth_event_ids, allow_rejected=True
|
||||
)
|
||||
missing_auth_event_ids.difference_update(extra_auth_events.keys())
|
||||
event_auth_events.update(extra_auth_events)
|
||||
if not missing_auth_event_ids:
|
||||
return event_auth_events.values()
|
||||
|
||||
# we still don't have all the auth events.
|
||||
logger.warning(
|
||||
"Missing auth events for %s: %s",
|
||||
event,
|
||||
shortstr(missing_auth_event_ids),
|
||||
)
|
||||
# the fact we can't find the auth event doesn't mean it doesn't
|
||||
# exist, which means it is premature to store `event` as rejected.
|
||||
# instead we raise an AuthError, which will make the caller ignore it.
|
||||
raise AuthError(code=HTTPStatus.FORBIDDEN, msg="Auth events could not be found")
|
||||
|
||||
async def _get_remote_auth_chain_for_event(
|
||||
self, destination: str, room_id: str, event_id: str
|
||||
|
@ -1624,9 +1727,7 @@ class FederationEventHandler:
|
|||
for s in seen_remotes:
|
||||
remote_event_map.pop(s, None)
|
||||
|
||||
await self._auth_and_persist_fetched_events(
|
||||
destination, room_id, remote_event_map.values()
|
||||
)
|
||||
await self._auth_and_persist_outliers(room_id, remote_event_map.values())
|
||||
|
||||
async def _update_context_for_auth_events(
|
||||
self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]
|
||||
|
@ -1696,16 +1797,27 @@ class FederationEventHandler:
|
|||
# persist_events_and_notify directly.)
|
||||
assert not event.internal_metadata.outlier
|
||||
|
||||
try:
|
||||
if (
|
||||
not backfilled
|
||||
and not context.rejected
|
||||
and (await self._store.get_min_depth(event.room_id)) <= event.depth
|
||||
):
|
||||
if not backfilled and not context.rejected:
|
||||
min_depth = await self._store.get_min_depth(event.room_id)
|
||||
if min_depth is None or min_depth > event.depth:
|
||||
# XXX richvdh 2021/10/07: I don't really understand what this
|
||||
# condition is doing. I think it's trying not to send pushes
|
||||
# for events that predate our join - but that's not really what
|
||||
# min_depth means, and anyway ancient events are a more general
|
||||
# problem.
|
||||
#
|
||||
# for now I'm just going to log about it.
|
||||
logger.info(
|
||||
"Skipping push actions for old event with depth %s < %s",
|
||||
event.depth,
|
||||
min_depth,
|
||||
)
|
||||
else:
|
||||
await self._action_generator.handle_push_actions_for_event(
|
||||
event, context
|
||||
)
|
||||
|
||||
try:
|
||||
await self.persist_events_and_notify(
|
||||
event.room_id, [(event, context)], backfilled=backfilled
|
||||
)
|
||||
|
@ -1837,6 +1949,3 @@ class FederationEventHandler:
|
|||
len(ev.auth_event_ids()),
|
||||
)
|
||||
raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
|
||||
|
||||
async def get_min_depth_for_context(self, context: str) -> int:
|
||||
return await self._store.get_min_depth(context)
|
||||
|
|
|
@ -54,7 +54,9 @@ class IdentityHandler:
|
|||
self.http_client = SimpleHttpClient(hs)
|
||||
# An HTTP client for contacting identity servers specified by clients.
|
||||
self.blacklisting_http_client = SimpleHttpClient(
|
||||
hs, ip_blacklist=hs.config.server.federation_ip_range_blacklist
|
||||
hs,
|
||||
ip_blacklist=hs.config.server.federation_ip_range_blacklist,
|
||||
ip_whitelist=hs.config.server.federation_ip_range_whitelist,
|
||||
)
|
||||
self.federation_http_client = hs.get_federation_http_client()
|
||||
self.hs = hs
|
||||
|
|
|
@ -609,29 +609,6 @@ class EventCreationHandler:
|
|||
|
||||
builder.internal_metadata.historical = historical
|
||||
|
||||
# Strip down the auth_event_ids to only what we need to auth the event.
|
||||
# For example, we don't need extra m.room.member that don't match event.sender
|
||||
if auth_event_ids is not None:
|
||||
# If auth events are provided, prev events must be also.
|
||||
assert prev_event_ids is not None
|
||||
|
||||
temp_event = await builder.build(
|
||||
prev_event_ids=prev_event_ids,
|
||||
auth_event_ids=auth_event_ids,
|
||||
depth=depth,
|
||||
)
|
||||
auth_events = await self.store.get_events_as_list(auth_event_ids)
|
||||
# Create a StateMap[str]
|
||||
auth_event_state_map = {
|
||||
(e.type, e.state_key): e.event_id for e in auth_events
|
||||
}
|
||||
# Actually strip down and use the necessary auth events
|
||||
auth_event_ids = self._event_auth_handler.compute_auth_events(
|
||||
event=temp_event,
|
||||
current_state_ids=auth_event_state_map,
|
||||
for_verification=False,
|
||||
)
|
||||
|
||||
event, context = await self.create_new_client_event(
|
||||
builder=builder,
|
||||
requester=requester,
|
||||
|
@ -938,6 +915,33 @@ class EventCreationHandler:
|
|||
Tuple of created event, context
|
||||
"""
|
||||
|
||||
# Strip down the auth_event_ids to only what we need to auth the event.
|
||||
# For example, we don't need extra m.room.member that don't match event.sender
|
||||
full_state_ids_at_event = None
|
||||
if auth_event_ids is not None:
|
||||
# If auth events are provided, prev events must be also.
|
||||
assert prev_event_ids is not None
|
||||
|
||||
# Copy the full auth state before it stripped down
|
||||
full_state_ids_at_event = auth_event_ids.copy()
|
||||
|
||||
temp_event = await builder.build(
|
||||
prev_event_ids=prev_event_ids,
|
||||
auth_event_ids=auth_event_ids,
|
||||
depth=depth,
|
||||
)
|
||||
auth_events = await self.store.get_events_as_list(auth_event_ids)
|
||||
# Create a StateMap[str]
|
||||
auth_event_state_map = {
|
||||
(e.type, e.state_key): e.event_id for e in auth_events
|
||||
}
|
||||
# Actually strip down and use the necessary auth events
|
||||
auth_event_ids = self._event_auth_handler.compute_auth_events(
|
||||
event=temp_event,
|
||||
current_state_ids=auth_event_state_map,
|
||||
for_verification=False,
|
||||
)
|
||||
|
||||
if prev_event_ids is not None:
|
||||
assert (
|
||||
len(prev_event_ids) <= 10
|
||||
|
@ -967,6 +971,13 @@ class EventCreationHandler:
|
|||
if builder.internal_metadata.outlier:
|
||||
event.internal_metadata.outlier = True
|
||||
context = EventContext.for_outlier()
|
||||
elif (
|
||||
event.type == EventTypes.MSC2716_INSERTION
|
||||
and full_state_ids_at_event
|
||||
and builder.internal_metadata.is_historical()
|
||||
):
|
||||
old_state = await self.store.get_events_as_list(full_state_ids_at_event)
|
||||
context = await self.state.compute_event_context(event, old_state=old_state)
|
||||
else:
|
||||
context = await self.state.compute_event_context(event)
|
||||
|
||||
|
|
|
@ -86,19 +86,22 @@ class PaginationHandler:
|
|||
self._event_serializer = hs.get_event_client_serializer()
|
||||
|
||||
self._retention_default_max_lifetime = (
|
||||
hs.config.server.retention_default_max_lifetime
|
||||
hs.config.retention.retention_default_max_lifetime
|
||||
)
|
||||
|
||||
self._retention_allowed_lifetime_min = (
|
||||
hs.config.server.retention_allowed_lifetime_min
|
||||
hs.config.retention.retention_allowed_lifetime_min
|
||||
)
|
||||
self._retention_allowed_lifetime_max = (
|
||||
hs.config.server.retention_allowed_lifetime_max
|
||||
hs.config.retention.retention_allowed_lifetime_max
|
||||
)
|
||||
|
||||
if hs.config.worker.run_background_tasks and hs.config.server.retention_enabled:
|
||||
if (
|
||||
hs.config.worker.run_background_tasks
|
||||
and hs.config.retention.retention_enabled
|
||||
):
|
||||
# Run the purge jobs described in the configuration file.
|
||||
for job in hs.config.server.retention_purge_jobs:
|
||||
for job in hs.config.retention.retention_purge_jobs:
|
||||
logger.info("Setting up purge job with config: %s", job)
|
||||
|
||||
self.clock.looping_call(
|
||||
|
|
|
@ -52,7 +52,6 @@ import synapse.metrics
|
|||
from synapse.api.constants import EventTypes, Membership, PresenceState
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.presence import UserPresenceState
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.events.presence_router import PresenceRouter
|
||||
from synapse.logging.context import run_in_background
|
||||
from synapse.logging.utils import log_function
|
||||
|
@ -1483,13 +1482,39 @@ def should_notify(old_state: UserPresenceState, new_state: UserPresenceState) ->
|
|||
def format_user_presence_state(
|
||||
state: UserPresenceState, now: int, include_user_id: bool = True
|
||||
) -> JsonDict:
|
||||
"""Convert UserPresenceState to a format that can be sent down to clients
|
||||
"""Convert UserPresenceState to a JSON format that can be sent down to clients
|
||||
and to other servers.
|
||||
|
||||
The "user_id" is optional so that this function can be used to format presence
|
||||
updates for client /sync responses and for federation /send requests.
|
||||
Args:
|
||||
state: The user presence state to format.
|
||||
now: The current timestamp since the epoch in ms.
|
||||
include_user_id: Whether to include `user_id` in the returned dictionary.
|
||||
As this function can be used both to format presence updates for client /sync
|
||||
responses and for federation /send requests, only the latter needs the include
|
||||
the `user_id` field.
|
||||
|
||||
Returns:
|
||||
A JSON dictionary with the following keys:
|
||||
* presence: The presence state as a str.
|
||||
* user_id: Optional. Included if `include_user_id` is truthy. The canonical
|
||||
Matrix ID of the user.
|
||||
* last_active_ago: Optional. Included if `last_active_ts` is set on `state`.
|
||||
The timestamp that the user was last active.
|
||||
* status_msg: Optional. Included if `status_msg` is set on `state`. The user's
|
||||
status.
|
||||
* currently_active: Optional. Included only if `state.state` is "online".
|
||||
|
||||
Example:
|
||||
|
||||
{
|
||||
"presence": "online",
|
||||
"user_id": "@alice:example.com",
|
||||
"last_active_ago": 16783813918,
|
||||
"status_msg": "Hello world!",
|
||||
"currently_active": True
|
||||
}
|
||||
"""
|
||||
content = {"presence": state.state}
|
||||
content: JsonDict = {"presence": state.state}
|
||||
if include_user_id:
|
||||
content["user_id"] = state.user_id
|
||||
if state.last_active_ts:
|
||||
|
@ -1526,7 +1551,6 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
|
|||
is_guest: bool = False,
|
||||
explicit_room_id: Optional[str] = None,
|
||||
include_offline: bool = True,
|
||||
service: Optional[ApplicationService] = None,
|
||||
) -> Tuple[List[UserPresenceState], int]:
|
||||
# The process for getting presence events are:
|
||||
# 1. Get the rooms the user is in.
|
||||
|
|
|
@ -242,12 +242,18 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
|
|||
async def get_new_events_as(
|
||||
self, from_key: int, service: ApplicationService
|
||||
) -> Tuple[List[JsonDict], int]:
|
||||
"""Returns a set of new receipt events that an appservice
|
||||
"""Returns a set of new read receipt events that an appservice
|
||||
may be interested in.
|
||||
|
||||
Args:
|
||||
from_key: the stream position at which events should be fetched from
|
||||
service: The appservice which may be interested
|
||||
|
||||
Returns:
|
||||
A two-tuple containing the following:
|
||||
* A list of json dictionaries derived from read receipts that the
|
||||
appservice may be interested in.
|
||||
* The current read receipt stream token.
|
||||
"""
|
||||
from_key = int(from_key)
|
||||
to_key = self.get_current_key()
|
||||
|
|
|
@ -465,17 +465,35 @@ class RoomCreationHandler:
|
|||
# the room has been created
|
||||
# Calculate the minimum power level needed to clone the room
|
||||
event_power_levels = power_levels.get("events", {})
|
||||
if not isinstance(event_power_levels, dict):
|
||||
event_power_levels = {}
|
||||
state_default = power_levels.get("state_default", 50)
|
||||
try:
|
||||
state_default_int = int(state_default) # type: ignore[arg-type]
|
||||
except (TypeError, ValueError):
|
||||
state_default_int = 50
|
||||
ban = power_levels.get("ban", 50)
|
||||
needed_power_level = max(state_default, ban, max(event_power_levels.values()))
|
||||
try:
|
||||
ban = int(ban) # type: ignore[arg-type]
|
||||
except (TypeError, ValueError):
|
||||
ban = 50
|
||||
needed_power_level = max(
|
||||
state_default_int, ban, max(event_power_levels.values())
|
||||
)
|
||||
|
||||
# Get the user's current power level, this matches the logic in get_user_power_level,
|
||||
# but without the entire state map.
|
||||
user_power_levels = power_levels.setdefault("users", {})
|
||||
if not isinstance(user_power_levels, dict):
|
||||
user_power_levels = {}
|
||||
users_default = power_levels.get("users_default", 0)
|
||||
current_power_level = user_power_levels.get(user_id, users_default)
|
||||
try:
|
||||
current_power_level_int = int(current_power_level) # type: ignore[arg-type]
|
||||
except (TypeError, ValueError):
|
||||
current_power_level_int = 0
|
||||
# Raise the requester's power level in the new room if necessary
|
||||
if current_power_level < needed_power_level:
|
||||
if current_power_level_int < needed_power_level:
|
||||
user_power_levels[user_id] = needed_power_level
|
||||
|
||||
await self._send_events_for_new_room(
|
||||
|
@ -765,6 +783,15 @@ class RoomCreationHandler:
|
|||
if not allowed_by_third_party_rules:
|
||||
raise SynapseError(403, "Room visibility value not allowed.")
|
||||
|
||||
if is_public:
|
||||
if not self.config.roomdirectory.is_publishing_room_allowed(
|
||||
user_id, room_id, room_alias
|
||||
):
|
||||
# Let's just return a generic message, as there may be all sorts of
|
||||
# reasons why we said no. TODO: Allow configurable error messages
|
||||
# per alias creation rule?
|
||||
raise SynapseError(403, "Not allowed to publish room")
|
||||
|
||||
directory_handler = self.hs.get_directory_handler()
|
||||
if room_alias:
|
||||
await directory_handler.create_association(
|
||||
|
@ -775,15 +802,6 @@ class RoomCreationHandler:
|
|||
check_membership=False,
|
||||
)
|
||||
|
||||
if is_public:
|
||||
if not self.config.roomdirectory.is_publishing_room_allowed(
|
||||
user_id, room_id, room_alias
|
||||
):
|
||||
# Lets just return a generic message, as there may be all sorts of
|
||||
# reasons why we said no. TODO: Allow configurable error messages
|
||||
# per alias creation rule?
|
||||
raise SynapseError(403, "Not allowed to publish room")
|
||||
|
||||
preset_config = config.get(
|
||||
"preset",
|
||||
RoomCreationPreset.PRIVATE_CHAT
|
||||
|
|
|
@ -13,6 +13,10 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_fake_event_id() -> str:
|
||||
return "$fake_" + random_string(43)
|
||||
|
||||
|
||||
class RoomBatchHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
|
@ -180,6 +184,11 @@ class RoomBatchHandler:
|
|||
|
||||
state_event_ids_at_start = []
|
||||
auth_event_ids = initial_auth_event_ids.copy()
|
||||
|
||||
# Make the state events float off on their own so we don't have a
|
||||
# bunch of `@mxid joined the room` noise between each batch
|
||||
prev_event_id_for_state_chain = generate_fake_event_id()
|
||||
|
||||
for state_event in state_events_at_start:
|
||||
assert_params_in_dict(
|
||||
state_event, ["type", "origin_server_ts", "content", "sender"]
|
||||
|
@ -203,10 +212,6 @@ class RoomBatchHandler:
|
|||
# Mark all events as historical
|
||||
event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True
|
||||
|
||||
# Make the state events float off on their own so we don't have a
|
||||
# bunch of `@mxid joined the room` noise between each batch
|
||||
fake_prev_event_id = "$" + random_string(43)
|
||||
|
||||
# TODO: This is pretty much the same as some other code to handle inserting state in this file
|
||||
if event_dict["type"] == EventTypes.Member:
|
||||
membership = event_dict["content"].get("membership", None)
|
||||
|
@ -220,7 +225,7 @@ class RoomBatchHandler:
|
|||
action=membership,
|
||||
content=event_dict["content"],
|
||||
outlier=True,
|
||||
prev_event_ids=[fake_prev_event_id],
|
||||
prev_event_ids=[prev_event_id_for_state_chain],
|
||||
# Make sure to use a copy of this list because we modify it
|
||||
# later in the loop here. Otherwise it will be the same
|
||||
# reference and also update in the event when we append later.
|
||||
|
@ -240,7 +245,7 @@ class RoomBatchHandler:
|
|||
),
|
||||
event_dict,
|
||||
outlier=True,
|
||||
prev_event_ids=[fake_prev_event_id],
|
||||
prev_event_ids=[prev_event_id_for_state_chain],
|
||||
# Make sure to use a copy of this list because we modify it
|
||||
# later in the loop here. Otherwise it will be the same
|
||||
# reference and also update in the event when we append later.
|
||||
|
@ -250,6 +255,8 @@ class RoomBatchHandler:
|
|||
|
||||
state_event_ids_at_start.append(event_id)
|
||||
auth_event_ids.append(event_id)
|
||||
# Connect all the state in a floating chain
|
||||
prev_event_id_for_state_chain = event_id
|
||||
|
||||
return state_event_ids_at_start
|
||||
|
||||
|
@ -296,6 +303,10 @@ class RoomBatchHandler:
|
|||
for ev in events_to_create:
|
||||
assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"])
|
||||
|
||||
assert self.hs.is_mine_id(ev["sender"]), "User must be our own: %s" % (
|
||||
ev["sender"],
|
||||
)
|
||||
|
||||
event_dict = {
|
||||
"type": ev["type"],
|
||||
"origin_server_ts": ev["origin_server_ts"],
|
||||
|
@ -318,6 +329,19 @@ class RoomBatchHandler:
|
|||
historical=True,
|
||||
depth=inherited_depth,
|
||||
)
|
||||
|
||||
assert context._state_group
|
||||
|
||||
# Normally this is done when persisting the event but we have to
|
||||
# pre-emptively do it here because we create all the events first,
|
||||
# then persist them in another pass below. And we want to share
|
||||
# state_groups across the whole batch so this lookup needs to work
|
||||
# for the next event in the batch in this loop.
|
||||
await self.store.store_state_group_id_for_event_id(
|
||||
event_id=event.event_id,
|
||||
state_group_id=context._state_group,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s, auth_event_ids=%s",
|
||||
event,
|
||||
|
@ -325,10 +349,6 @@ class RoomBatchHandler:
|
|||
auth_event_ids,
|
||||
)
|
||||
|
||||
assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
|
||||
event.sender,
|
||||
)
|
||||
|
||||
events_to_persist.append((event, context))
|
||||
event_id = event.event_id
|
||||
|
||||
|
|
|
@ -465,17 +465,23 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
|
|||
may be interested in.
|
||||
|
||||
Args:
|
||||
from_key: the stream position at which events should be fetched from
|
||||
service: The appservice which may be interested
|
||||
from_key: the stream position at which events should be fetched from.
|
||||
service: The appservice which may be interested.
|
||||
|
||||
Returns:
|
||||
A two-tuple containing the following:
|
||||
* A list of json dictionaries derived from typing events that the
|
||||
appservice may be interested in.
|
||||
* The latest known room serial.
|
||||
"""
|
||||
with Measure(self.clock, "typing.get_new_events_as"):
|
||||
from_key = int(from_key)
|
||||
handler = self.get_typing_handler()
|
||||
|
||||
events = []
|
||||
for room_id in handler._room_serials.keys():
|
||||
if handler._room_serials[room_id] <= from_key:
|
||||
continue
|
||||
|
||||
if not await service.matches_user_in_member_list(
|
||||
room_id, handler.store
|
||||
):
|
||||
|
|
|
@ -196,63 +196,12 @@ class UserDirectoryHandler(StateDeltasHandler):
|
|||
room_id, prev_event_id, event_id, typ
|
||||
)
|
||||
elif typ == EventTypes.Member:
|
||||
change = await self._get_key_change(
|
||||
await self._handle_room_membership_event(
|
||||
room_id,
|
||||
prev_event_id,
|
||||
event_id,
|
||||
key_name="membership",
|
||||
public_value=Membership.JOIN,
|
||||
state_key,
|
||||
)
|
||||
|
||||
is_remote = not self.is_mine_id(state_key)
|
||||
if change is MatchChange.now_false:
|
||||
# Need to check if the server left the room entirely, if so
|
||||
# we might need to remove all the users in that room
|
||||
is_in_room = await self.store.is_host_joined(
|
||||
room_id, self.server_name
|
||||
)
|
||||
if not is_in_room:
|
||||
logger.debug("Server left room: %r", room_id)
|
||||
# Fetch all the users that we marked as being in user
|
||||
# directory due to being in the room and then check if
|
||||
# need to remove those users or not
|
||||
user_ids = await self.store.get_users_in_dir_due_to_room(
|
||||
room_id
|
||||
)
|
||||
|
||||
for user_id in user_ids:
|
||||
await self._handle_remove_user(room_id, user_id)
|
||||
continue
|
||||
else:
|
||||
logger.debug("Server is still in room: %r", room_id)
|
||||
|
||||
include_in_dir = (
|
||||
is_remote
|
||||
or await self.store.should_include_local_user_in_dir(state_key)
|
||||
)
|
||||
if include_in_dir:
|
||||
if change is MatchChange.no_change:
|
||||
# Handle any profile changes for remote users.
|
||||
# (For local users we are not forced to scan membership
|
||||
# events; instead the rest of the application calls
|
||||
# `handle_local_profile_change`.)
|
||||
if is_remote:
|
||||
await self._handle_profile_change(
|
||||
state_key, room_id, prev_event_id, event_id
|
||||
)
|
||||
continue
|
||||
|
||||
if change is MatchChange.now_true: # The user joined
|
||||
# This may be the first time we've seen a remote user. If
|
||||
# so, ensure we have a directory entry for them. (We don't
|
||||
# need to do this for local users: their directory entry
|
||||
# is created at the point of registration.
|
||||
if is_remote:
|
||||
await self._upsert_directory_entry_for_remote_user(
|
||||
state_key, event_id
|
||||
)
|
||||
await self._track_user_joined_room(room_id, state_key)
|
||||
else: # The user left
|
||||
await self._handle_remove_user(room_id, state_key)
|
||||
else:
|
||||
logger.debug("Ignoring irrelevant type: %r", typ)
|
||||
|
||||
|
@ -317,14 +266,83 @@ class UserDirectoryHandler(StateDeltasHandler):
|
|||
for user_id in users_in_room:
|
||||
await self.store.remove_user_who_share_room(user_id, room_id)
|
||||
|
||||
# Then, re-add them to the tables.
|
||||
# Then, re-add all remote users and some local users to the tables.
|
||||
# NOTE: this is not the most efficient method, as _track_user_joined_room sets
|
||||
# up local_user -> other_user and other_user_whos_local -> local_user,
|
||||
# which when ran over an entire room, will result in the same values
|
||||
# being added multiple times. The batching upserts shouldn't make this
|
||||
# too bad, though.
|
||||
for user_id in users_in_room:
|
||||
await self._track_user_joined_room(room_id, user_id)
|
||||
if not self.is_mine_id(
|
||||
user_id
|
||||
) or await self.store.should_include_local_user_in_dir(user_id):
|
||||
await self._track_user_joined_room(room_id, user_id)
|
||||
|
||||
async def _handle_room_membership_event(
|
||||
self,
|
||||
room_id: str,
|
||||
prev_event_id: str,
|
||||
event_id: str,
|
||||
state_key: str,
|
||||
) -> None:
|
||||
"""Process a single room membershp event.
|
||||
|
||||
We have to do two things:
|
||||
|
||||
1. Update the room-sharing tables.
|
||||
This applies to remote users and non-excluded local users.
|
||||
2. Update the user_directory and user_directory_search tables.
|
||||
This applies to remote users only, because we only become aware of
|
||||
the (and any profile changes) by listening to these events.
|
||||
The rest of the application knows exactly when local users are
|
||||
created or their profile changed---it will directly call methods
|
||||
on this class.
|
||||
"""
|
||||
joined = await self._get_key_change(
|
||||
prev_event_id,
|
||||
event_id,
|
||||
key_name="membership",
|
||||
public_value=Membership.JOIN,
|
||||
)
|
||||
|
||||
# Both cases ignore excluded local users, so start by discarding them.
|
||||
is_remote = not self.is_mine_id(state_key)
|
||||
if not is_remote and not await self.store.should_include_local_user_in_dir(
|
||||
state_key
|
||||
):
|
||||
return
|
||||
|
||||
if joined is MatchChange.now_false:
|
||||
# Need to check if the server left the room entirely, if so
|
||||
# we might need to remove all the users in that room
|
||||
is_in_room = await self.store.is_host_joined(room_id, self.server_name)
|
||||
if not is_in_room:
|
||||
logger.debug("Server left room: %r", room_id)
|
||||
# Fetch all the users that we marked as being in user
|
||||
# directory due to being in the room and then check if
|
||||
# need to remove those users or not
|
||||
user_ids = await self.store.get_users_in_dir_due_to_room(room_id)
|
||||
|
||||
for user_id in user_ids:
|
||||
await self._handle_remove_user(room_id, user_id)
|
||||
else:
|
||||
logger.debug("Server is still in room: %r", room_id)
|
||||
await self._handle_remove_user(room_id, state_key)
|
||||
elif joined is MatchChange.no_change:
|
||||
# Handle any profile changes for remote users.
|
||||
# (For local users the rest of the application calls
|
||||
# `handle_local_profile_change`.)
|
||||
if is_remote:
|
||||
await self._handle_possible_remote_profile_change(
|
||||
state_key, room_id, prev_event_id, event_id
|
||||
)
|
||||
elif joined is MatchChange.now_true: # The user joined
|
||||
# This may be the first time we've seen a remote user. If
|
||||
# so, ensure we have a directory entry for them. (For local users,
|
||||
# the rest of the application calls `handle_local_profile_change`.)
|
||||
if is_remote:
|
||||
await self._upsert_directory_entry_for_remote_user(state_key, event_id)
|
||||
await self._track_user_joined_room(room_id, state_key)
|
||||
|
||||
async def _upsert_directory_entry_for_remote_user(
|
||||
self, user_id: str, event_id: str
|
||||
|
@ -349,61 +367,67 @@ class UserDirectoryHandler(StateDeltasHandler):
|
|||
"""Someone's just joined a room. Update `users_in_public_rooms` or
|
||||
`users_who_share_private_rooms` as appropriate.
|
||||
|
||||
The caller is responsible for ensuring that the given user is not excluded
|
||||
from the user directory.
|
||||
The caller is responsible for ensuring that the given user should be
|
||||
included in the user directory.
|
||||
"""
|
||||
is_public = await self.store.is_room_world_readable_or_publicly_joinable(
|
||||
room_id
|
||||
)
|
||||
other_users_in_room = await self.store.get_users_in_room(room_id)
|
||||
|
||||
if is_public:
|
||||
await self.store.add_users_in_public_rooms(room_id, (user_id,))
|
||||
else:
|
||||
users_in_room = await self.store.get_users_in_room(room_id)
|
||||
other_users_in_room = [
|
||||
other
|
||||
for other in users_in_room
|
||||
if other != user_id
|
||||
and (
|
||||
not self.is_mine_id(other)
|
||||
or await self.store.should_include_local_user_in_dir(other)
|
||||
)
|
||||
]
|
||||
to_insert = set()
|
||||
|
||||
# First, if they're our user then we need to update for every user
|
||||
if self.is_mine_id(user_id):
|
||||
if await self.store.should_include_local_user_in_dir(user_id):
|
||||
for other_user_id in other_users_in_room:
|
||||
if user_id == other_user_id:
|
||||
continue
|
||||
|
||||
to_insert.add((user_id, other_user_id))
|
||||
for other_user_id in other_users_in_room:
|
||||
to_insert.add((user_id, other_user_id))
|
||||
|
||||
# Next we need to update for every local user in the room
|
||||
for other_user_id in other_users_in_room:
|
||||
if user_id == other_user_id:
|
||||
continue
|
||||
|
||||
include_other_user = self.is_mine_id(
|
||||
other_user_id
|
||||
) and await self.store.should_include_local_user_in_dir(other_user_id)
|
||||
if include_other_user:
|
||||
if self.is_mine_id(other_user_id):
|
||||
to_insert.add((other_user_id, user_id))
|
||||
|
||||
if to_insert:
|
||||
await self.store.add_users_who_share_private_room(room_id, to_insert)
|
||||
|
||||
async def _handle_remove_user(self, room_id: str, user_id: str) -> None:
|
||||
"""Called when we might need to remove user from directory
|
||||
"""Called when when someone leaves a room. The user may be local or remote.
|
||||
|
||||
(If the person who left was the last local user in this room, the server
|
||||
is no longer in the room. We call this function to forget that the remaining
|
||||
remote users are in the room, even though they haven't left. So the name is
|
||||
a little misleading!)
|
||||
|
||||
Args:
|
||||
room_id: The room ID that user left or stopped being public that
|
||||
user_id
|
||||
"""
|
||||
logger.debug("Removing user %r", user_id)
|
||||
logger.debug("Removing user %r from room %r", user_id, room_id)
|
||||
|
||||
# Remove user from sharing tables
|
||||
await self.store.remove_user_who_share_room(user_id, room_id)
|
||||
|
||||
# Are they still in any rooms? If not, remove them entirely.
|
||||
rooms_user_is_in = await self.store.get_user_dir_rooms_user_is_in(user_id)
|
||||
# Additionally, if they're a remote user and we're no longer joined
|
||||
# to any rooms they're in, remove them from the user directory.
|
||||
if not self.is_mine_id(user_id):
|
||||
rooms_user_is_in = await self.store.get_user_dir_rooms_user_is_in(user_id)
|
||||
|
||||
if len(rooms_user_is_in) == 0:
|
||||
await self.store.remove_from_user_dir(user_id)
|
||||
if len(rooms_user_is_in) == 0:
|
||||
logger.debug("Removing user %r from directory", user_id)
|
||||
await self.store.remove_from_user_dir(user_id)
|
||||
|
||||
async def _handle_profile_change(
|
||||
async def _handle_possible_remote_profile_change(
|
||||
self,
|
||||
user_id: str,
|
||||
room_id: str,
|
||||
|
@ -411,7 +435,8 @@ class UserDirectoryHandler(StateDeltasHandler):
|
|||
event_id: Optional[str],
|
||||
) -> None:
|
||||
"""Check member event changes for any profile changes and update the
|
||||
database if there are.
|
||||
database if there are. This is intended for remote users only. The caller
|
||||
is responsible for checking that the given user is remote.
|
||||
"""
|
||||
if not prev_event_id or not event_id:
|
||||
return
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue