Revert "Revert accidental fast-forward merge from v1.49.0rc1"

This reverts commit 158d73ebdd.
This commit is contained in:
Olivier Wilkinson (reivilibre) 2021-12-14 14:22:01 +00:00
parent 158d73ebdd
commit 4dd9ea8f4f
165 changed files with 7715 additions and 2703 deletions

View file

@ -18,6 +18,7 @@ import time
import unicodedata
import urllib.parse
from binascii import crc32
from http import HTTPStatus
from typing import (
TYPE_CHECKING,
Any,
@ -38,6 +39,7 @@ import attr
import bcrypt
import pymacaroons
import unpaddedbase64
from pymacaroons.exceptions import MacaroonVerificationFailedException
from twisted.web.server import Request
@ -181,8 +183,11 @@ class LoginTokenAttributes:
user_id = attr.ib(type=str)
# the SSO Identity Provider that the user authenticated with, to get this token
auth_provider_id = attr.ib(type=str)
"""The SSO Identity Provider that the user authenticated with, to get this token."""
auth_provider_session_id = attr.ib(type=Optional[str])
"""The session ID advertised by the SSO Identity Provider."""
class AuthHandler:
@ -756,53 +761,109 @@ class AuthHandler:
async def refresh_token(
self,
refresh_token: str,
valid_until_ms: Optional[int],
) -> Tuple[str, str]:
access_token_valid_until_ms: Optional[int],
refresh_token_valid_until_ms: Optional[int],
) -> Tuple[str, str, Optional[int]]:
"""
Consumes a refresh token and generate both a new access token and a new refresh token from it.
The consumed refresh token is considered invalid after the first use of the new access token or the new refresh token.
The lifetime of both the access token and refresh token will be capped so that they
do not exceed the session's ultimate expiry time, if applicable.
Args:
refresh_token: The token to consume.
valid_until_ms: The expiration timestamp of the new access token.
access_token_valid_until_ms: The expiration timestamp of the new access token.
None if the access token does not expire.
refresh_token_valid_until_ms: The expiration timestamp of the new refresh token.
None if the refresh token does not expire.
Returns:
A tuple containing the new access token and refresh token
A tuple containing:
- the new access token
- the new refresh token
- the actual expiry time of the access token, which may be earlier than
`access_token_valid_until_ms`.
"""
# Verify the token signature first before looking up the token
if not self._verify_refresh_token(refresh_token):
raise SynapseError(401, "invalid refresh token", Codes.UNKNOWN_TOKEN)
raise SynapseError(
HTTPStatus.UNAUTHORIZED, "invalid refresh token", Codes.UNKNOWN_TOKEN
)
existing_token = await self.store.lookup_refresh_token(refresh_token)
if existing_token is None:
raise SynapseError(401, "refresh token does not exist", Codes.UNKNOWN_TOKEN)
raise SynapseError(
HTTPStatus.UNAUTHORIZED,
"refresh token does not exist",
Codes.UNKNOWN_TOKEN,
)
if (
existing_token.has_next_access_token_been_used
or existing_token.has_next_refresh_token_been_refreshed
):
raise SynapseError(
403, "refresh token isn't valid anymore", Codes.FORBIDDEN
HTTPStatus.FORBIDDEN,
"refresh token isn't valid anymore",
Codes.FORBIDDEN,
)
now_ms = self._clock.time_msec()
if existing_token.expiry_ts is not None and existing_token.expiry_ts < now_ms:
raise SynapseError(
HTTPStatus.FORBIDDEN,
"The supplied refresh token has expired",
Codes.FORBIDDEN,
)
if existing_token.ultimate_session_expiry_ts is not None:
# This session has a bounded lifetime, even across refreshes.
if access_token_valid_until_ms is not None:
access_token_valid_until_ms = min(
access_token_valid_until_ms,
existing_token.ultimate_session_expiry_ts,
)
else:
access_token_valid_until_ms = existing_token.ultimate_session_expiry_ts
if refresh_token_valid_until_ms is not None:
refresh_token_valid_until_ms = min(
refresh_token_valid_until_ms,
existing_token.ultimate_session_expiry_ts,
)
else:
refresh_token_valid_until_ms = existing_token.ultimate_session_expiry_ts
if existing_token.ultimate_session_expiry_ts < now_ms:
raise SynapseError(
HTTPStatus.FORBIDDEN,
"The session has expired and can no longer be refreshed",
Codes.FORBIDDEN,
)
(
new_refresh_token,
new_refresh_token_id,
) = await self.create_refresh_token_for_user_id(
user_id=existing_token.user_id, device_id=existing_token.device_id
user_id=existing_token.user_id,
device_id=existing_token.device_id,
expiry_ts=refresh_token_valid_until_ms,
ultimate_session_expiry_ts=existing_token.ultimate_session_expiry_ts,
)
access_token = await self.create_access_token_for_user_id(
user_id=existing_token.user_id,
device_id=existing_token.device_id,
valid_until_ms=valid_until_ms,
valid_until_ms=access_token_valid_until_ms,
refresh_token_id=new_refresh_token_id,
)
await self.store.replace_refresh_token(
existing_token.token_id, new_refresh_token_id
)
return access_token, new_refresh_token
return access_token, new_refresh_token, access_token_valid_until_ms
def _verify_refresh_token(self, token: str) -> bool:
"""
@ -836,6 +897,8 @@ class AuthHandler:
self,
user_id: str,
device_id: str,
expiry_ts: Optional[int],
ultimate_session_expiry_ts: Optional[int],
) -> Tuple[str, int]:
"""
Creates a new refresh token for the user with the given user ID.
@ -843,6 +906,13 @@ class AuthHandler:
Args:
user_id: canonical user ID
device_id: the device ID to associate with the token.
expiry_ts (milliseconds since the epoch): Time after which the
refresh token cannot be used.
If None, the refresh token never expires until it has been used.
ultimate_session_expiry_ts (milliseconds since the epoch):
Time at which the session will end and can not be extended any
further.
If None, the session can be refreshed indefinitely.
Returns:
The newly created refresh token and its ID in the database
@ -852,6 +922,8 @@ class AuthHandler:
user_id=user_id,
token=refresh_token,
device_id=device_id,
expiry_ts=expiry_ts,
ultimate_session_expiry_ts=ultimate_session_expiry_ts,
)
return refresh_token, refresh_token_id
@ -1582,6 +1654,7 @@ class AuthHandler:
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
new_user: bool = False,
auth_provider_session_id: Optional[str] = None,
) -> None:
"""Having figured out a mxid for this user, complete the HTTP request
@ -1597,6 +1670,7 @@ class AuthHandler:
during successful login. Must be JSON serializable.
new_user: True if we should use wording appropriate to a user who has just
registered.
auth_provider_session_id: The session ID from the SSO IdP received during login.
"""
# If the account has been deactivated, do not proceed with the login
# flow.
@ -1617,6 +1691,7 @@ class AuthHandler:
extra_attributes,
new_user=new_user,
user_profile_data=profile,
auth_provider_session_id=auth_provider_session_id,
)
def _complete_sso_login(
@ -1628,6 +1703,7 @@ class AuthHandler:
extra_attributes: Optional[JsonDict] = None,
new_user: bool = False,
user_profile_data: Optional[ProfileInfo] = None,
auth_provider_session_id: Optional[str] = None,
) -> None:
"""
The synchronous portion of complete_sso_login.
@ -1649,7 +1725,9 @@ class AuthHandler:
# Create a login token
login_token = self.macaroon_gen.generate_short_term_login_token(
registered_user_id, auth_provider_id=auth_provider_id
registered_user_id,
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
)
# Append the login token to the original redirect URL (i.e. with its query
@ -1754,6 +1832,7 @@ class MacaroonGenerator:
self,
user_id: str,
auth_provider_id: str,
auth_provider_session_id: Optional[str] = None,
duration_in_ms: int = (2 * 60 * 1000),
) -> str:
macaroon = self._generate_base_macaroon(user_id)
@ -1762,6 +1841,10 @@ class MacaroonGenerator:
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
if auth_provider_session_id is not None:
macaroon.add_first_party_caveat(
"auth_provider_session_id = %s" % (auth_provider_session_id,)
)
return macaroon.serialize()
def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
@ -1783,15 +1866,28 @@ class MacaroonGenerator:
user_id = get_value_from_macaroon(macaroon, "user_id")
auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
auth_provider_session_id: Optional[str] = None
try:
auth_provider_session_id = get_value_from_macaroon(
macaroon, "auth_provider_session_id"
)
except MacaroonVerificationFailedException:
pass
v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1")
v.satisfy_exact("type = login")
v.satisfy_general(lambda c: c.startswith("user_id = "))
v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = "))
satisfy_expiry(v, self.hs.get_clock().time_msec)
v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id)
return LoginTokenAttributes(
user_id=user_id,
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
)
def generate_delete_pusher_token(self, user_id: str) -> str:
macaroon = self._generate_base_macaroon(user_id)

View file

@ -301,6 +301,8 @@ class DeviceHandler(DeviceWorkerHandler):
user_id: str,
device_id: Optional[str],
initial_device_display_name: Optional[str] = None,
auth_provider_id: Optional[str] = None,
auth_provider_session_id: Optional[str] = None,
) -> str:
"""
If the given device has not been registered, register it with the
@ -312,6 +314,8 @@ class DeviceHandler(DeviceWorkerHandler):
user_id: @user:id
device_id: device id supplied by client
initial_device_display_name: device display name from client
auth_provider_id: The SSO IdP the user used, if any.
auth_provider_session_id: The session ID (sid) got from the SSO IdP.
Returns:
device id (generated if none was supplied)
"""
@ -323,6 +327,8 @@ class DeviceHandler(DeviceWorkerHandler):
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
)
if new_device:
await self.notify_device_update(user_id, [device_id])
@ -337,6 +343,8 @@ class DeviceHandler(DeviceWorkerHandler):
user_id=user_id,
device_id=new_device_id,
initial_device_display_name=initial_device_display_name,
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
)
if new_device:
await self.notify_device_update(user_id, [new_device_id])

View file

@ -122,9 +122,8 @@ class EventStreamHandler:
events,
time_now,
as_client_event=as_client_event,
# We don't bundle "live" events, as otherwise clients
# will end up double counting annotations.
bundle_relations=False,
# Don't bundle aggregations as this is a deprecated API.
bundle_aggregations=False,
)
chunk = {

View file

@ -68,6 +68,37 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
"""Get joined domains from state
Args:
state: State map from type/state key to event.
Returns:
Returns a list of servers with the lowest depth of their joins.
Sorted by lowest depth first.
"""
joined_users = [
(state_key, int(event.depth))
for (e_type, state_key), event in state.items()
if e_type == EventTypes.Member and event.membership == Membership.JOIN
]
joined_domains: Dict[str, int] = {}
for u, d in joined_users:
try:
dom = get_domain_from_id(u)
old_d = joined_domains.get(dom)
if old_d:
joined_domains[dom] = min(d, old_d)
else:
joined_domains[dom] = d
except Exception:
pass
return sorted(joined_domains.items(), key=lambda d: d[1])
class FederationHandler:
"""Handles general incoming federation requests
@ -268,36 +299,6 @@ class FederationHandler:
curr_state = await self.state_handler.get_current_state(room_id)
def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
"""Get joined domains from state
Args:
state: State map from type/state key to event.
Returns:
Returns a list of servers with the lowest depth of their joins.
Sorted by lowest depth first.
"""
joined_users = [
(state_key, int(event.depth))
for (e_type, state_key), event in state.items()
if e_type == EventTypes.Member and event.membership == Membership.JOIN
]
joined_domains: Dict[str, int] = {}
for u, d in joined_users:
try:
dom = get_domain_from_id(u)
old_d = joined_domains.get(dom)
if old_d:
joined_domains[dom] = min(d, old_d)
else:
joined_domains[dom] = d
except Exception:
pass
return sorted(joined_domains.items(), key=lambda d: d[1])
curr_domains = get_domains_from_state(curr_state)
likely_domains = [

View file

@ -165,7 +165,11 @@ class InitialSyncHandler:
invite_event = await self.store.get_event(event.event_id)
d["invite"] = await self._event_serializer.serialize_event(
invite_event, time_now, as_client_event
invite_event,
time_now,
# Don't bundle aggregations as this is a deprecated API.
bundle_aggregations=False,
as_client_event=as_client_event,
)
rooms_ret.append(d)
@ -216,7 +220,11 @@ class InitialSyncHandler:
d["messages"] = {
"chunk": (
await self._event_serializer.serialize_events(
messages, time_now=time_now, as_client_event=as_client_event
messages,
time_now=time_now,
# Don't bundle aggregations as this is a deprecated API.
bundle_aggregations=False,
as_client_event=as_client_event,
)
),
"start": await start_token.to_string(self.store),
@ -226,6 +234,8 @@ class InitialSyncHandler:
d["state"] = await self._event_serializer.serialize_events(
current_state.values(),
time_now=time_now,
# Don't bundle aggregations as this is a deprecated API.
bundle_aggregations=False,
as_client_event=as_client_event,
)
@ -366,14 +376,18 @@ class InitialSyncHandler:
"room_id": room_id,
"messages": {
"chunk": (
await self._event_serializer.serialize_events(messages, time_now)
# Don't bundle aggregations as this is a deprecated API.
await self._event_serializer.serialize_events(
messages, time_now, bundle_aggregations=False
)
),
"start": await start_token.to_string(self.store),
"end": await end_token.to_string(self.store),
},
"state": (
# Don't bundle aggregations as this is a deprecated API.
await self._event_serializer.serialize_events(
room_state.values(), time_now
room_state.values(), time_now, bundle_aggregations=False
)
),
"presence": [],
@ -392,8 +406,9 @@ class InitialSyncHandler:
# TODO: These concurrently
time_now = self.clock.time_msec()
# Don't bundle aggregations as this is a deprecated API.
state = await self._event_serializer.serialize_events(
current_state.values(), time_now
current_state.values(), time_now, bundle_aggregations=False
)
now_token = self.hs.get_event_sources().get_current_token()
@ -467,7 +482,10 @@ class InitialSyncHandler:
"room_id": room_id,
"messages": {
"chunk": (
await self._event_serializer.serialize_events(messages, time_now)
# Don't bundle aggregations as this is a deprecated API.
await self._event_serializer.serialize_events(
messages, time_now, bundle_aggregations=False
)
),
"start": await start_token.to_string(self.store),
"end": await end_token.to_string(self.store),

View file

@ -247,13 +247,7 @@ class MessageHandler:
room_state = room_state_events[membership_event_id]
now = self.clock.time_msec()
events = await self._event_serializer.serialize_events(
room_state.values(),
now,
# We don't bother bundling aggregations in when asked for state
# events, as clients won't use them.
bundle_relations=False,
)
events = await self._event_serializer.serialize_events(room_state.values(), now)
return events
async def get_joined_members(self, requester: Requester, room_id: str) -> dict:

View file

@ -23,7 +23,7 @@ from authlib.common.security import generate_token
from authlib.jose import JsonWebToken, jwt
from authlib.oauth2.auth import ClientAuth
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
from authlib.oidc.core import CodeIDToken, UserInfo
from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
from jinja2 import Environment, Template
from pymacaroons.exceptions import (
@ -117,7 +117,8 @@ class OidcHandler:
for idp_id, p in self._providers.items():
try:
await p.load_metadata()
await p.load_jwks()
if not p._uses_userinfo:
await p.load_jwks()
except Exception as e:
raise Exception(
"Error while initialising OIDC provider %r" % (idp_id,)
@ -498,10 +499,6 @@ class OidcProvider:
return await self._jwks.get()
async def _load_jwks(self) -> JWKS:
if self._uses_userinfo:
# We're not using jwt signing, return an empty jwk set
return {"keys": []}
metadata = await self.load_metadata()
# Load the JWKS using the `jwks_uri` metadata.
@ -663,7 +660,7 @@ class OidcProvider:
return UserInfo(resp)
async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
"""Return an instance of UserInfo from token's ``id_token``.
Args:
@ -673,7 +670,7 @@ class OidcProvider:
request. This value should match the one inside the token.
Returns:
An object representing the user.
The decoded claims in the ID token.
"""
metadata = await self.load_metadata()
claims_params = {
@ -684,9 +681,6 @@ class OidcProvider:
# If we got an `access_token`, there should be an `at_hash` claim
# in the `id_token` that we can check against.
claims_params["access_token"] = token["access_token"]
claims_cls = CodeIDToken
else:
claims_cls = ImplicitIDToken
alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
jwt = JsonWebToken(alg_values)
@ -703,7 +697,7 @@ class OidcProvider:
claims = jwt.decode(
id_token,
key=jwk_set,
claims_cls=claims_cls,
claims_cls=CodeIDToken,
claims_options=claim_options,
claims_params=claims_params,
)
@ -713,7 +707,7 @@ class OidcProvider:
claims = jwt.decode(
id_token,
key=jwk_set,
claims_cls=claims_cls,
claims_cls=CodeIDToken,
claims_options=claim_options,
claims_params=claims_params,
)
@ -721,7 +715,8 @@ class OidcProvider:
logger.debug("Decoded id_token JWT %r; validating", claims)
claims.validate(leeway=120) # allows 2 min of clock skew
return UserInfo(claims)
return claims
async def handle_redirect_request(
self,
@ -837,8 +832,22 @@ class OidcProvider:
logger.debug("Successfully obtained OAuth2 token data: %r", token)
# Now that we have a token, get the userinfo, either by decoding the
# `id_token` or by fetching the `userinfo_endpoint`.
# If there is an id_token, it should be validated, regardless of the
# userinfo endpoint is used or not.
if token.get("id_token") is not None:
try:
id_token = await self._parse_id_token(token, nonce=session_data.nonce)
sid = id_token.get("sid")
except Exception as e:
logger.exception("Invalid id_token")
self._sso_handler.render_error(request, "invalid_token", str(e))
return
else:
id_token = None
sid = None
# Now that we have a token, get the userinfo either from the `id_token`
# claims or by fetching the `userinfo_endpoint`.
if self._uses_userinfo:
try:
userinfo = await self._fetch_userinfo(token)
@ -846,13 +855,14 @@ class OidcProvider:
logger.exception("Could not fetch userinfo")
self._sso_handler.render_error(request, "fetch_error", str(e))
return
elif id_token is not None:
userinfo = UserInfo(id_token)
else:
try:
userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
except Exception as e:
logger.exception("Invalid id_token")
self._sso_handler.render_error(request, "invalid_token", str(e))
return
logger.error("Missing id_token in token response")
self._sso_handler.render_error(
request, "invalid_token", "Missing id_token in token response"
)
return
# first check if we're doing a UIA
if session_data.ui_auth_session_id:
@ -884,7 +894,7 @@ class OidcProvider:
# Call the mapper to register/login the user
try:
await self._complete_oidc_login(
userinfo, token, request, session_data.client_redirect_url
userinfo, token, request, session_data.client_redirect_url, sid
)
except MappingException as e:
logger.exception("Could not map user")
@ -896,6 +906,7 @@ class OidcProvider:
token: Token,
request: SynapseRequest,
client_redirect_url: str,
sid: Optional[str],
) -> None:
"""Given a UserInfo response, complete the login flow
@ -1008,6 +1019,7 @@ class OidcProvider:
oidc_response_to_user_attributes,
grandfather_existing_users,
extra_attributes,
auth_provider_session_id=sid,
)
def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:

View file

@ -406,9 +406,6 @@ class PaginationHandler:
force: set true to skip checking for joined users.
"""
with await self.pagination_lock.write(room_id):
# check we know about the room
await self.store.get_room_version_id(room_id)
# first check that we have no users in this room
if not force:
joined = await self.store.is_host_joined(room_id, self._server_name)

View file

@ -421,7 +421,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
self._on_shutdown,
)
def _on_shutdown(self) -> None:
async def _on_shutdown(self) -> None:
if self._presence_enabled:
self.hs.get_tcp_replication().send_command(
ClearUserSyncsCommand(self.instance_id)

View file

@ -1,4 +1,5 @@
# Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -116,9 +117,13 @@ class RegistrationHandler:
self.pusher_pool = hs.get_pusherpool()
self.session_lifetime = hs.config.registration.session_lifetime
self.nonrefreshable_access_token_lifetime = (
hs.config.registration.nonrefreshable_access_token_lifetime
)
self.refreshable_access_token_lifetime = (
hs.config.registration.refreshable_access_token_lifetime
)
self.refresh_token_lifetime = hs.config.registration.refresh_token_lifetime
init_counters_for_auth_provider("")
@ -741,6 +746,7 @@ class RegistrationHandler:
is_appservice_ghost: bool = False,
auth_provider_id: Optional[str] = None,
should_issue_refresh_token: bool = False,
auth_provider_session_id: Optional[str] = None,
) -> Tuple[str, str, Optional[int], Optional[str]]:
"""Register a device for a user and generate an access token.
@ -751,9 +757,9 @@ class RegistrationHandler:
device_id: The device ID to check, or None to generate a new one.
initial_display_name: An optional display name for the device.
is_guest: Whether this is a guest account
auth_provider_id: The SSO IdP the user used, if any (just used for the
prometheus metrics).
auth_provider_id: The SSO IdP the user used, if any.
should_issue_refresh_token: Whether it should also issue a refresh token
auth_provider_session_id: The session ID received during login from the SSO IdP.
Returns:
Tuple of device ID, access token, access token expiration time and refresh token
"""
@ -764,6 +770,8 @@ class RegistrationHandler:
is_guest=is_guest,
is_appservice_ghost=is_appservice_ghost,
should_issue_refresh_token=should_issue_refresh_token,
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
)
login_counter.labels(
@ -786,6 +794,8 @@ class RegistrationHandler:
is_guest: bool = False,
is_appservice_ghost: bool = False,
should_issue_refresh_token: bool = False,
auth_provider_id: Optional[str] = None,
auth_provider_session_id: Optional[str] = None,
) -> LoginDict:
"""Helper for register_device
@ -793,40 +803,86 @@ class RegistrationHandler:
class and RegisterDeviceReplicationServlet.
"""
assert not self.hs.config.worker.worker_app
valid_until_ms = None
now_ms = self.clock.time_msec()
access_token_expiry = None
if self.session_lifetime is not None:
if is_guest:
raise Exception(
"session_lifetime is not currently implemented for guest access"
)
valid_until_ms = self.clock.time_msec() + self.session_lifetime
access_token_expiry = now_ms + self.session_lifetime
if self.nonrefreshable_access_token_lifetime is not None:
if access_token_expiry is not None:
# Don't allow the non-refreshable access token to outlive the
# session.
access_token_expiry = min(
now_ms + self.nonrefreshable_access_token_lifetime,
access_token_expiry,
)
else:
access_token_expiry = now_ms + self.nonrefreshable_access_token_lifetime
refresh_token = None
refresh_token_id = None
registered_device_id = await self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
user_id,
device_id,
initial_display_name,
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
)
if is_guest:
assert valid_until_ms is None
assert access_token_expiry is None
access_token = self.macaroon_gen.generate_guest_access_token(user_id)
else:
if should_issue_refresh_token:
# A refreshable access token lifetime must be configured
# since we're told to issue a refresh token (the caller checks
# that this value is set before setting this flag).
assert self.refreshable_access_token_lifetime is not None
# Set the expiry time of the refreshable access token
access_token_expiry = now_ms + self.refreshable_access_token_lifetime
# Set the refresh token expiry time (if configured)
refresh_token_expiry = None
if self.refresh_token_lifetime is not None:
refresh_token_expiry = now_ms + self.refresh_token_lifetime
# Set an ultimate session expiry time (if configured)
ultimate_session_expiry_ts = None
if self.session_lifetime is not None:
ultimate_session_expiry_ts = now_ms + self.session_lifetime
# Also ensure that the issued tokens don't outlive the
# session.
# (It would be weird to configure a homeserver with a shorter
# session lifetime than token lifetime, but may as well handle
# it.)
access_token_expiry = min(
access_token_expiry, ultimate_session_expiry_ts
)
if refresh_token_expiry is not None:
refresh_token_expiry = min(
refresh_token_expiry, ultimate_session_expiry_ts
)
(
refresh_token,
refresh_token_id,
) = await self._auth_handler.create_refresh_token_for_user_id(
user_id,
device_id=registered_device_id,
)
valid_until_ms = (
self.clock.time_msec() + self.refreshable_access_token_lifetime
expiry_ts=refresh_token_expiry,
ultimate_session_expiry_ts=ultimate_session_expiry_ts,
)
access_token = await self._auth_handler.create_access_token_for_user_id(
user_id,
device_id=registered_device_id,
valid_until_ms=valid_until_ms,
valid_until_ms=access_token_expiry,
is_appservice_ghost=is_appservice_ghost,
refresh_token_id=refresh_token_id,
)
@ -834,7 +890,7 @@ class RegistrationHandler:
return {
"device_id": registered_device_id,
"access_token": access_token,
"valid_until_ms": valid_until_ms,
"valid_until_ms": access_token_expiry,
"refresh_token": refresh_token,
}

View file

@ -46,6 +46,7 @@ from synapse.api.constants import (
from synapse.api.errors import (
AuthError,
Codes,
HttpResponseException,
LimitExceededError,
NotFoundError,
StoreError,
@ -56,6 +57,8 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase
from synapse.events.utils import copy_power_levels_contents
from synapse.federation.federation_client import InvalidResponseError
from synapse.handlers.federation import get_domains_from_state
from synapse.rest.admin._base import assert_user_is_admin
from synapse.storage.state import StateFilter
from synapse.streams import EventSource
@ -1220,6 +1223,147 @@ class RoomContextHandler:
return results
class TimestampLookupHandler:
def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self.store = hs.get_datastore()
self.state_handler = hs.get_state_handler()
self.federation_client = hs.get_federation_client()
async def get_event_for_timestamp(
self,
requester: Requester,
room_id: str,
timestamp: int,
direction: str,
) -> Tuple[str, int]:
"""Find the closest event to the given timestamp in the given direction.
If we can't find an event locally or the event we have locally is next to a gap,
it will ask other federated homeservers for an event.
Args:
requester: The user making the request according to the access token
room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event.
direction: ["f"|"b"] to indicate whether we should navigate forward
or backward from the given timestamp to find the closest event.
Returns:
A tuple containing the `event_id` closest to the given timestamp in
the given direction and the `origin_server_ts`.
Raises:
SynapseError if unable to find any event locally in the given direction
"""
local_event_id = await self.store.get_event_id_for_timestamp(
room_id, timestamp, direction
)
logger.debug(
"get_event_for_timestamp: locally, we found event_id=%s closest to timestamp=%s",
local_event_id,
timestamp,
)
# Check for gaps in the history where events could be hiding in between
# the timestamp given and the event we were able to find locally
is_event_next_to_backward_gap = False
is_event_next_to_forward_gap = False
if local_event_id:
local_event = await self.store.get_event(
local_event_id, allow_none=False, allow_rejected=False
)
if direction == "f":
# We only need to check for a backward gap if we're looking forwards
# to ensure there is nothing in between.
is_event_next_to_backward_gap = (
await self.store.is_event_next_to_backward_gap(local_event)
)
elif direction == "b":
# We only need to check for a forward gap if we're looking backwards
# to ensure there is nothing in between
is_event_next_to_forward_gap = (
await self.store.is_event_next_to_forward_gap(local_event)
)
# If we found a gap, we should probably ask another homeserver first
# about more history in between
if (
not local_event_id
or is_event_next_to_backward_gap
or is_event_next_to_forward_gap
):
logger.debug(
"get_event_for_timestamp: locally, we found event_id=%s closest to timestamp=%s which is next to a gap in event history so we're asking other homeservers first",
local_event_id,
timestamp,
)
# Find other homeservers from the given state in the room
curr_state = await self.state_handler.get_current_state(room_id)
curr_domains = get_domains_from_state(curr_state)
likely_domains = [
domain for domain, depth in curr_domains if domain != self.server_name
]
# Loop through each homeserver candidate until we get a succesful response
for domain in likely_domains:
try:
remote_response = await self.federation_client.timestamp_to_event(
domain, room_id, timestamp, direction
)
logger.debug(
"get_event_for_timestamp: response from domain(%s)=%s",
domain,
remote_response,
)
# TODO: Do we want to persist this as an extremity?
# TODO: I think ideally, we would try to backfill from
# this event and run this whole
# `get_event_for_timestamp` function again to make sure
# they didn't give us an event from their gappy history.
remote_event_id = remote_response.event_id
origin_server_ts = remote_response.origin_server_ts
# Only return the remote event if it's closer than the local event
if not local_event or (
abs(origin_server_ts - timestamp)
< abs(local_event.origin_server_ts - timestamp)
):
return remote_event_id, origin_server_ts
except (HttpResponseException, InvalidResponseError) as ex:
# Let's not put a high priority on some other homeserver
# failing to respond or giving a random response
logger.debug(
"Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
domain,
type(ex).__name__,
ex,
ex.args,
)
except Exception as ex:
# But we do want to see some exceptions in our code
logger.warning(
"Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
domain,
type(ex).__name__,
ex,
ex.args,
)
if not local_event_id:
raise SynapseError(
404,
"Unable to find event from %s in direction %s" % (timestamp, direction),
errcode=Codes.NOT_FOUND,
)
return local_event_id, local_event.origin_server_ts
class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@ -1391,20 +1535,13 @@ class RoomShutdownHandler:
await self.store.block_room(room_id, requester_user_id)
if not await self.store.get_room(room_id):
if block:
# We allow you to block an unknown room.
return {
"kicked_users": [],
"failed_to_kick_users": [],
"local_aliases": [],
"new_room_id": None,
}
else:
# But if you don't want to preventatively block another room,
# this function can't do anything useful.
raise NotFoundError(
"Cannot shut down room: unknown room id %s" % (room_id,)
)
# if we don't know about the room, there is nothing left to do.
return {
"kicked_users": [],
"failed_to_kick_users": [],
"local_aliases": [],
"new_room_id": None,
}
if new_room_user_id is not None:
if not self.hs.is_mine_id(new_room_user_id):

View file

@ -36,8 +36,9 @@ from synapse.api.errors import (
SynapseError,
UnsupportedRoomVersionError,
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.events import EventBase
from synapse.types import JsonDict
from synapse.types import JsonDict, Requester
from synapse.util.caches.response_cache import ResponseCache
if TYPE_CHECKING:
@ -93,6 +94,9 @@ class RoomSummaryHandler:
self._event_serializer = hs.get_event_client_serializer()
self._server_name = hs.hostname
self._federation_client = hs.get_federation_client()
self._ratelimiter = Ratelimiter(
store=self._store, clock=hs.get_clock(), rate_hz=5, burst_count=10
)
# If a user tries to fetch the same page multiple times in quick succession,
# only process the first attempt and return its result to subsequent requests.
@ -249,7 +253,7 @@ class RoomSummaryHandler:
async def get_room_hierarchy(
self,
requester: str,
requester: Requester,
requested_room_id: str,
suggested_only: bool = False,
max_depth: Optional[int] = None,
@ -276,6 +280,8 @@ class RoomSummaryHandler:
Returns:
The JSON hierarchy dictionary.
"""
await self._ratelimiter.ratelimit(requester)
# If a user tries to fetch the same page multiple times in quick succession,
# only process the first attempt and return its result to subsequent requests.
#
@ -283,7 +289,7 @@ class RoomSummaryHandler:
# to process multiple requests for the same page will result in errors.
return await self._pagination_response_cache.wrap(
(
requester,
requester.user.to_string(),
requested_room_id,
suggested_only,
max_depth,
@ -291,7 +297,7 @@ class RoomSummaryHandler:
from_token,
),
self._get_room_hierarchy,
requester,
requester.user.to_string(),
requested_room_id,
suggested_only,
max_depth,

View file

@ -365,6 +365,7 @@ class SsoHandler:
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
grandfather_existing_users: Callable[[], Awaitable[Optional[str]]],
extra_login_attributes: Optional[JsonDict] = None,
auth_provider_session_id: Optional[str] = None,
) -> None:
"""
Given an SSO ID, retrieve the user ID for it and possibly register the user.
@ -415,6 +416,8 @@ class SsoHandler:
extra_login_attributes: An optional dictionary of extra
attributes to be provided to the client in the login response.
auth_provider_session_id: An optional session ID from the IdP.
Raises:
MappingException if there was a problem mapping the response to a user.
RedirectException: if the mapping provider needs to redirect the user
@ -490,6 +493,7 @@ class SsoHandler:
client_redirect_url,
extra_login_attributes,
new_user=new_user,
auth_provider_session_id=auth_provider_session_id,
)
async def _call_attribute_mapper(

View file

@ -334,6 +334,19 @@ class SyncHandler:
full_state: bool,
cache_context: ResponseCacheContext[SyncRequestKey],
) -> SyncResult:
"""The start of the machinery that produces a /sync response.
See https://spec.matrix.org/v1.1/client-server-api/#syncing for full details.
This method does high-level bookkeeping:
- tracking the kind of sync in the logging context
- deleting any to_device messages whose delivery has been acknowledged.
- deciding if we should dispatch an instant or delayed response
- marking the sync as being lazily loaded, if appropriate
Computing the body of the response begins in the next method,
`current_sync_for_user`.
"""
if since_token is None:
sync_type = "initial_sync"
elif full_state:
@ -363,7 +376,7 @@ class SyncHandler:
sync_config, since_token, full_state=full_state
)
else:
# Otherwise, we wait for something to happen and report it to the user.
async def current_sync_callback(
before_token: StreamToken, after_token: StreamToken
) -> SyncResult:
@ -402,7 +415,12 @@ class SyncHandler:
since_token: Optional[StreamToken] = None,
full_state: bool = False,
) -> SyncResult:
"""Get the sync for client needed to match what the server has now."""
"""Generates the response body of a sync result, represented as a SyncResult.
This is a wrapper around `generate_sync_result` which starts an open tracing
span to track the sync. See `generate_sync_result` for the next part of your
indoctrination.
"""
with start_active_span("current_sync_for_user"):
log_kv({"since_token": since_token})
sync_result = await self.generate_sync_result(
@ -560,7 +578,7 @@ class SyncHandler:
# that have happened since `since_key` up to `end_key`, so we
# can just use `get_room_events_stream_for_room`.
# Otherwise, we want to return the last N events in the room
# in toplogical ordering.
# in topological ordering.
if since_key:
events, end_key = await self.store.get_room_events_stream_for_room(
room_id,
@ -1042,7 +1060,18 @@ class SyncHandler:
since_token: Optional[StreamToken] = None,
full_state: bool = False,
) -> SyncResult:
"""Generates a sync result."""
"""Generates the response body of a sync result.
This is represented by a `SyncResult` struct, which is built from small pieces
using a `SyncResultBuilder`. See also
https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync
the `sync_result_builder` is passed as a mutable ("inout") parameter to various
helper functions. These retrieve and process the data which forms the sync body,
often writing to the `sync_result_builder` to store their output.
At the end, we transfer data from the `sync_result_builder` to a new `SyncResult`
instance to signify that the sync calculation is complete.
"""
# NB: The now_token gets changed by some of the generate_sync_* methods,
# this is due to some of the underlying streams not supporting the ability
# to query up to a given point.
@ -1344,14 +1373,22 @@ class SyncHandler:
async def _generate_sync_entry_for_account_data(
self, sync_result_builder: "SyncResultBuilder"
) -> Dict[str, Dict[str, JsonDict]]:
"""Generates the account data portion of the sync response. Populates
`sync_result_builder` with the result.
"""Generates the account data portion of the sync response.
Account data (called "Client Config" in the spec) can be set either globally
or for a specific room. Account data consists of a list of events which
accumulate state, much like a room.
This function retrieves global and per-room account data. The former is written
to the given `sync_result_builder`. The latter is returned directly, to be
later written to the `sync_result_builder` on a room-by-room basis.
Args:
sync_result_builder
Returns:
A dictionary containing the per room account data.
A dictionary whose keys (room ids) map to the per room account data for that
room.
"""
sync_config = sync_result_builder.sync_config
user_id = sync_result_builder.sync_config.user.to_string()
@ -1359,7 +1396,7 @@ class SyncHandler:
if since_token and not sync_result_builder.full_state:
(
account_data,
global_account_data,
account_data_by_room,
) = await self.store.get_updated_account_data_for_user(
user_id, since_token.account_data_key
@ -1370,23 +1407,23 @@ class SyncHandler:
)
if push_rules_changed:
account_data["m.push_rules"] = await self.push_rules_for_user(
global_account_data["m.push_rules"] = await self.push_rules_for_user(
sync_config.user
)
else:
(
account_data,
global_account_data,
account_data_by_room,
) = await self.store.get_account_data_for_user(sync_config.user.to_string())
account_data["m.push_rules"] = await self.push_rules_for_user(
global_account_data["m.push_rules"] = await self.push_rules_for_user(
sync_config.user
)
account_data_for_user = await sync_config.filter_collection.filter_account_data(
[
{"type": account_data_type, "content": content}
for account_data_type, content in account_data.items()
for account_data_type, content in global_account_data.items()
]
)
@ -1460,18 +1497,31 @@ class SyncHandler:
"""Generates the rooms portion of the sync response. Populates the
`sync_result_builder` with the result.
In the response that reaches the client, rooms are divided into four categories:
`invite`, `join`, `knock`, `leave`. These aren't the same as the four sets of
room ids returned by this function.
Args:
sync_result_builder
account_data_by_room: Dictionary of per room account data
Returns:
Returns a 4-tuple of
`(newly_joined_rooms, newly_joined_or_invited_users,
newly_left_rooms, newly_left_users)`
Returns a 4-tuple describing rooms the user has joined or left, and users who've
joined or left rooms any rooms the user is in. This gets used later in
`_generate_sync_entry_for_device_list`.
Its entries are:
- newly_joined_rooms
- newly_joined_or_invited_or_knocked_users
- newly_left_rooms
- newly_left_users
"""
since_token = sync_result_builder.since_token
# 1. Start by fetching all ephemeral events in rooms we've joined (if required).
user_id = sync_result_builder.sync_config.user.to_string()
block_all_room_ephemeral = (
sync_result_builder.since_token is None
since_token is None
and sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral()
)
@ -1485,9 +1535,8 @@ class SyncHandler:
)
sync_result_builder.now_token = now_token
# We check up front if anything has changed, if it hasn't then there is
# 2. We check up front if anything has changed, if it hasn't then there is
# no point in going further.
since_token = sync_result_builder.since_token
if not sync_result_builder.full_state:
if since_token and not ephemeral_by_room and not account_data_by_room:
have_changed = await self._have_rooms_changed(sync_result_builder)
@ -1500,20 +1549,8 @@ class SyncHandler:
logger.debug("no-oping sync")
return set(), set(), set(), set()
ignored_account_data = (
await self.store.get_global_account_data_by_type_for_user(
AccountDataTypes.IGNORED_USER_LIST, user_id=user_id
)
)
# If there is ignored users account data and it matches the proper type,
# then use it.
ignored_users: FrozenSet[str] = frozenset()
if ignored_account_data:
ignored_users_data = ignored_account_data.get("ignored_users", {})
if isinstance(ignored_users_data, dict):
ignored_users = frozenset(ignored_users_data.keys())
# 3. Work out which rooms need reporting in the sync response.
ignored_users = await self._get_ignored_users(user_id)
if since_token:
room_changes = await self._get_rooms_changed(
sync_result_builder, ignored_users
@ -1523,7 +1560,6 @@ class SyncHandler:
)
else:
room_changes = await self._get_all_rooms(sync_result_builder, ignored_users)
tags_by_room = await self.store.get_tags_for_user(user_id)
log_kv({"rooms_changed": len(room_changes.room_entries)})
@ -1534,6 +1570,8 @@ class SyncHandler:
newly_joined_rooms = room_changes.newly_joined_rooms
newly_left_rooms = room_changes.newly_left_rooms
# 4. We need to apply further processing to `room_entries` (rooms considered
# joined or archived).
async def handle_room_entries(room_entry: "RoomSyncResultBuilder") -> None:
logger.debug("Generating room entry for %s", room_entry.room_id)
await self._generate_room_entry(
@ -1552,31 +1590,13 @@ class SyncHandler:
sync_result_builder.invited.extend(invited)
sync_result_builder.knocked.extend(knocked)
# Now we want to get any newly joined, invited or knocking users
newly_joined_or_invited_or_knocked_users = set()
newly_left_users = set()
if since_token:
for joined_sync in sync_result_builder.joined:
it = itertools.chain(
joined_sync.timeline.events, joined_sync.state.values()
)
for event in it:
if event.type == EventTypes.Member:
if (
event.membership == Membership.JOIN
or event.membership == Membership.INVITE
or event.membership == Membership.KNOCK
):
newly_joined_or_invited_or_knocked_users.add(
event.state_key
)
else:
prev_content = event.unsigned.get("prev_content", {})
prev_membership = prev_content.get("membership", None)
if prev_membership == Membership.JOIN:
newly_left_users.add(event.state_key)
newly_left_users -= newly_joined_or_invited_or_knocked_users
# 5. Work out which users have joined or left rooms we're in. We use this
# to build the device_list part of the sync response in
# `_generate_sync_entry_for_device_list`.
(
newly_joined_or_invited_or_knocked_users,
newly_left_users,
) = sync_result_builder.calculate_user_changes()
return (
set(newly_joined_rooms),
@ -1585,11 +1605,36 @@ class SyncHandler:
newly_left_users,
)
async def _get_ignored_users(self, user_id: str) -> FrozenSet[str]:
"""Retrieve the users ignored by the given user from their global account_data.
Returns an empty set if
- there is no global account_data entry for ignored_users
- there is such an entry, but it's not a JSON object.
"""
# TODO: Can we `SELECT ignored_user_id FROM ignored_users WHERE ignorer_user_id=?;` instead?
ignored_account_data = (
await self.store.get_global_account_data_by_type_for_user(
AccountDataTypes.IGNORED_USER_LIST, user_id=user_id
)
)
# If there is ignored users account data and it matches the proper type,
# then use it.
ignored_users: FrozenSet[str] = frozenset()
if ignored_account_data:
ignored_users_data = ignored_account_data.get("ignored_users", {})
if isinstance(ignored_users_data, dict):
ignored_users = frozenset(ignored_users_data.keys())
return ignored_users
async def _have_rooms_changed(
self, sync_result_builder: "SyncResultBuilder"
) -> bool:
"""Returns whether there may be any new events that should be sent down
the sync. Returns True if there are.
Does not modify the `sync_result_builder`.
"""
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
@ -1597,12 +1642,13 @@ class SyncHandler:
assert since_token
# Get a list of membership change events that have happened.
rooms_changed = await self.store.get_membership_changes_for_user(
# Get a list of membership change events that have happened to the user
# requesting the sync.
membership_changes = await self.store.get_membership_changes_for_user(
user_id, since_token.room_key, now_token.room_key
)
if rooms_changed:
if membership_changes:
return True
stream_id = since_token.room_key.stream
@ -1614,7 +1660,25 @@ class SyncHandler:
async def _get_rooms_changed(
self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
) -> _RoomChanges:
"""Gets the the changes that have happened since the last sync."""
"""Determine the changes in rooms to report to the user.
Ideally, we want to report all events whose stream ordering `s` lies in the
range `since_token < s <= now_token`, where the two tokens are read from the
sync_result_builder.
If there are too many events in that range to report, things get complicated.
In this situation we return a truncated list of the most recent events, and
indicate in the response that there is a "gap" of omitted events. Additionally:
- we include a "state_delta", to describe the changes in state over the gap,
- we include all membership events applying to the user making the request,
even those in the gap.
See the spec for the rationale:
https://spec.matrix.org/v1.1/client-server-api/#syncing
The sync_result_builder is not modified by this function.
"""
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
now_token = sync_result_builder.now_token
@ -1622,21 +1686,36 @@ class SyncHandler:
assert since_token
# Get a list of membership change events that have happened.
rooms_changed = await self.store.get_membership_changes_for_user(
# The spec
# https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync
# notes that membership events need special consideration:
#
# > When a sync is limited, the server MUST return membership events for events
# > in the gap (between since and the start of the returned timeline), regardless
# > as to whether or not they are redundant.
#
# We fetch such events here, but we only seem to use them for categorising rooms
# as newly joined, newly left, invited or knocked.
# TODO: we've already called this function and ran this query in
# _have_rooms_changed. We could keep the results in memory to avoid a
# second query, at the cost of more complicated source code.
membership_change_events = await self.store.get_membership_changes_for_user(
user_id, since_token.room_key, now_token.room_key
)
mem_change_events_by_room_id: Dict[str, List[EventBase]] = {}
for event in rooms_changed:
for event in membership_change_events:
mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
newly_joined_rooms = []
newly_left_rooms = []
room_entries = []
invited = []
knocked = []
newly_joined_rooms: List[str] = []
newly_left_rooms: List[str] = []
room_entries: List[RoomSyncResultBuilder] = []
invited: List[InvitedSyncResult] = []
knocked: List[KnockedSyncResult] = []
for room_id, events in mem_change_events_by_room_id.items():
# The body of this loop will add this room to at least one of the five lists
# above. Things get messy if you've e.g. joined, left, joined then left the
# room all in the same sync period.
logger.debug(
"Membership changes in %s: [%s]",
room_id,
@ -1691,6 +1770,7 @@ class SyncHandler:
if not non_joins:
continue
last_non_join = non_joins[-1]
# Check if we have left the room. This can either be because we were
# joined before *or* that we since joined and then left.
@ -1712,18 +1792,18 @@ class SyncHandler:
newly_left_rooms.append(room_id)
# Only bother if we're still currently invited
should_invite = non_joins[-1].membership == Membership.INVITE
should_invite = last_non_join.membership == Membership.INVITE
if should_invite:
if event.sender not in ignored_users:
invite_room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
if last_non_join.sender not in ignored_users:
invite_room_sync = InvitedSyncResult(room_id, invite=last_non_join)
if invite_room_sync:
invited.append(invite_room_sync)
# Only bother if our latest membership in the room is knock (and we haven't
# been accepted/rejected in the meantime).
should_knock = non_joins[-1].membership == Membership.KNOCK
should_knock = last_non_join.membership == Membership.KNOCK
if should_knock:
knock_room_sync = KnockedSyncResult(room_id, knock=non_joins[-1])
knock_room_sync = KnockedSyncResult(room_id, knock=last_non_join)
if knock_room_sync:
knocked.append(knock_room_sync)
@ -1781,7 +1861,9 @@ class SyncHandler:
timeline_limit = sync_config.filter_collection.timeline_limit()
# Get all events for rooms we're currently joined to.
# Get all events since the `from_key` in rooms we're currently joined to.
# If there are too many, we get the most recent events only. This leaves
# a "gap" in the timeline, as described by the spec for /sync.
room_to_events = await self.store.get_room_events_stream_for_rooms(
room_ids=sync_result_builder.joined_room_ids,
from_key=since_token.room_key,
@ -1842,6 +1924,10 @@ class SyncHandler:
) -> _RoomChanges:
"""Returns entries for all rooms for the user.
Like `_get_rooms_changed`, but assumes the `since_token` is `None`.
This function does not modify the sync_result_builder.
Args:
sync_result_builder
ignored_users: Set of users ignored by user.
@ -1853,16 +1939,9 @@ class SyncHandler:
now_token = sync_result_builder.now_token
sync_config = sync_result_builder.sync_config
membership_list = (
Membership.INVITE,
Membership.KNOCK,
Membership.JOIN,
Membership.LEAVE,
Membership.BAN,
)
room_list = await self.store.get_rooms_for_local_user_where_membership_is(
user_id=user_id, membership_list=membership_list
user_id=user_id,
membership_list=Membership.LIST,
)
room_entries = []
@ -2212,8 +2291,7 @@ def _calculate_state(
# to only include membership events for the senders in the timeline.
# In practice, we can do this by removing them from the p_ids list,
# which is the list of relevant state we know we have already sent to the client.
# see https://github.com/matrix-org/synapse/pull/2970
# /files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809
# see https://github.com/matrix-org/synapse/pull/2970/files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809
if lazy_load_members:
p_ids.difference_update(
@ -2262,6 +2340,39 @@ class SyncResultBuilder:
groups: Optional[GroupsSyncResult] = None
to_device: List[JsonDict] = attr.Factory(list)
def calculate_user_changes(self) -> Tuple[Set[str], Set[str]]:
"""Work out which other users have joined or left rooms we are joined to.
This data only is only useful for an incremental sync.
The SyncResultBuilder is not modified by this function.
"""
newly_joined_or_invited_or_knocked_users = set()
newly_left_users = set()
if self.since_token:
for joined_sync in self.joined:
it = itertools.chain(
joined_sync.timeline.events, joined_sync.state.values()
)
for event in it:
if event.type == EventTypes.Member:
if (
event.membership == Membership.JOIN
or event.membership == Membership.INVITE
or event.membership == Membership.KNOCK
):
newly_joined_or_invited_or_knocked_users.add(
event.state_key
)
else:
prev_content = event.unsigned.get("prev_content", {})
prev_membership = prev_content.get("membership", None)
if prev_membership == Membership.JOIN:
newly_left_users.add(event.state_key)
newly_left_users -= newly_joined_or_invited_or_knocked_users
return newly_joined_or_invited_or_knocked_users, newly_left_users
@attr.s(slots=True, auto_attribs=True)
class RoomSyncResultBuilder: