mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-08-01 12:56:05 -04:00
Revert "Revert accidental fast-forward merge from v1.49.0rc1"
This reverts commit 158d73ebdd
.
This commit is contained in:
parent
158d73ebdd
commit
4dd9ea8f4f
165 changed files with 7715 additions and 2703 deletions
|
@ -18,6 +18,7 @@ import time
|
|||
import unicodedata
|
||||
import urllib.parse
|
||||
from binascii import crc32
|
||||
from http import HTTPStatus
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
|
@ -38,6 +39,7 @@ import attr
|
|||
import bcrypt
|
||||
import pymacaroons
|
||||
import unpaddedbase64
|
||||
from pymacaroons.exceptions import MacaroonVerificationFailedException
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
|
@ -181,8 +183,11 @@ class LoginTokenAttributes:
|
|||
|
||||
user_id = attr.ib(type=str)
|
||||
|
||||
# the SSO Identity Provider that the user authenticated with, to get this token
|
||||
auth_provider_id = attr.ib(type=str)
|
||||
"""The SSO Identity Provider that the user authenticated with, to get this token."""
|
||||
|
||||
auth_provider_session_id = attr.ib(type=Optional[str])
|
||||
"""The session ID advertised by the SSO Identity Provider."""
|
||||
|
||||
|
||||
class AuthHandler:
|
||||
|
@ -756,53 +761,109 @@ class AuthHandler:
|
|||
async def refresh_token(
|
||||
self,
|
||||
refresh_token: str,
|
||||
valid_until_ms: Optional[int],
|
||||
) -> Tuple[str, str]:
|
||||
access_token_valid_until_ms: Optional[int],
|
||||
refresh_token_valid_until_ms: Optional[int],
|
||||
) -> Tuple[str, str, Optional[int]]:
|
||||
"""
|
||||
Consumes a refresh token and generate both a new access token and a new refresh token from it.
|
||||
|
||||
The consumed refresh token is considered invalid after the first use of the new access token or the new refresh token.
|
||||
|
||||
The lifetime of both the access token and refresh token will be capped so that they
|
||||
do not exceed the session's ultimate expiry time, if applicable.
|
||||
|
||||
Args:
|
||||
refresh_token: The token to consume.
|
||||
valid_until_ms: The expiration timestamp of the new access token.
|
||||
|
||||
access_token_valid_until_ms: The expiration timestamp of the new access token.
|
||||
None if the access token does not expire.
|
||||
refresh_token_valid_until_ms: The expiration timestamp of the new refresh token.
|
||||
None if the refresh token does not expire.
|
||||
Returns:
|
||||
A tuple containing the new access token and refresh token
|
||||
A tuple containing:
|
||||
- the new access token
|
||||
- the new refresh token
|
||||
- the actual expiry time of the access token, which may be earlier than
|
||||
`access_token_valid_until_ms`.
|
||||
"""
|
||||
|
||||
# Verify the token signature first before looking up the token
|
||||
if not self._verify_refresh_token(refresh_token):
|
||||
raise SynapseError(401, "invalid refresh token", Codes.UNKNOWN_TOKEN)
|
||||
raise SynapseError(
|
||||
HTTPStatus.UNAUTHORIZED, "invalid refresh token", Codes.UNKNOWN_TOKEN
|
||||
)
|
||||
|
||||
existing_token = await self.store.lookup_refresh_token(refresh_token)
|
||||
if existing_token is None:
|
||||
raise SynapseError(401, "refresh token does not exist", Codes.UNKNOWN_TOKEN)
|
||||
raise SynapseError(
|
||||
HTTPStatus.UNAUTHORIZED,
|
||||
"refresh token does not exist",
|
||||
Codes.UNKNOWN_TOKEN,
|
||||
)
|
||||
|
||||
if (
|
||||
existing_token.has_next_access_token_been_used
|
||||
or existing_token.has_next_refresh_token_been_refreshed
|
||||
):
|
||||
raise SynapseError(
|
||||
403, "refresh token isn't valid anymore", Codes.FORBIDDEN
|
||||
HTTPStatus.FORBIDDEN,
|
||||
"refresh token isn't valid anymore",
|
||||
Codes.FORBIDDEN,
|
||||
)
|
||||
|
||||
now_ms = self._clock.time_msec()
|
||||
|
||||
if existing_token.expiry_ts is not None and existing_token.expiry_ts < now_ms:
|
||||
|
||||
raise SynapseError(
|
||||
HTTPStatus.FORBIDDEN,
|
||||
"The supplied refresh token has expired",
|
||||
Codes.FORBIDDEN,
|
||||
)
|
||||
|
||||
if existing_token.ultimate_session_expiry_ts is not None:
|
||||
# This session has a bounded lifetime, even across refreshes.
|
||||
|
||||
if access_token_valid_until_ms is not None:
|
||||
access_token_valid_until_ms = min(
|
||||
access_token_valid_until_ms,
|
||||
existing_token.ultimate_session_expiry_ts,
|
||||
)
|
||||
else:
|
||||
access_token_valid_until_ms = existing_token.ultimate_session_expiry_ts
|
||||
|
||||
if refresh_token_valid_until_ms is not None:
|
||||
refresh_token_valid_until_ms = min(
|
||||
refresh_token_valid_until_ms,
|
||||
existing_token.ultimate_session_expiry_ts,
|
||||
)
|
||||
else:
|
||||
refresh_token_valid_until_ms = existing_token.ultimate_session_expiry_ts
|
||||
if existing_token.ultimate_session_expiry_ts < now_ms:
|
||||
raise SynapseError(
|
||||
HTTPStatus.FORBIDDEN,
|
||||
"The session has expired and can no longer be refreshed",
|
||||
Codes.FORBIDDEN,
|
||||
)
|
||||
|
||||
(
|
||||
new_refresh_token,
|
||||
new_refresh_token_id,
|
||||
) = await self.create_refresh_token_for_user_id(
|
||||
user_id=existing_token.user_id, device_id=existing_token.device_id
|
||||
user_id=existing_token.user_id,
|
||||
device_id=existing_token.device_id,
|
||||
expiry_ts=refresh_token_valid_until_ms,
|
||||
ultimate_session_expiry_ts=existing_token.ultimate_session_expiry_ts,
|
||||
)
|
||||
access_token = await self.create_access_token_for_user_id(
|
||||
user_id=existing_token.user_id,
|
||||
device_id=existing_token.device_id,
|
||||
valid_until_ms=valid_until_ms,
|
||||
valid_until_ms=access_token_valid_until_ms,
|
||||
refresh_token_id=new_refresh_token_id,
|
||||
)
|
||||
await self.store.replace_refresh_token(
|
||||
existing_token.token_id, new_refresh_token_id
|
||||
)
|
||||
return access_token, new_refresh_token
|
||||
return access_token, new_refresh_token, access_token_valid_until_ms
|
||||
|
||||
def _verify_refresh_token(self, token: str) -> bool:
|
||||
"""
|
||||
|
@ -836,6 +897,8 @@ class AuthHandler:
|
|||
self,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
expiry_ts: Optional[int],
|
||||
ultimate_session_expiry_ts: Optional[int],
|
||||
) -> Tuple[str, int]:
|
||||
"""
|
||||
Creates a new refresh token for the user with the given user ID.
|
||||
|
@ -843,6 +906,13 @@ class AuthHandler:
|
|||
Args:
|
||||
user_id: canonical user ID
|
||||
device_id: the device ID to associate with the token.
|
||||
expiry_ts (milliseconds since the epoch): Time after which the
|
||||
refresh token cannot be used.
|
||||
If None, the refresh token never expires until it has been used.
|
||||
ultimate_session_expiry_ts (milliseconds since the epoch):
|
||||
Time at which the session will end and can not be extended any
|
||||
further.
|
||||
If None, the session can be refreshed indefinitely.
|
||||
|
||||
Returns:
|
||||
The newly created refresh token and its ID in the database
|
||||
|
@ -852,6 +922,8 @@ class AuthHandler:
|
|||
user_id=user_id,
|
||||
token=refresh_token,
|
||||
device_id=device_id,
|
||||
expiry_ts=expiry_ts,
|
||||
ultimate_session_expiry_ts=ultimate_session_expiry_ts,
|
||||
)
|
||||
return refresh_token, refresh_token_id
|
||||
|
||||
|
@ -1582,6 +1654,7 @@ class AuthHandler:
|
|||
client_redirect_url: str,
|
||||
extra_attributes: Optional[JsonDict] = None,
|
||||
new_user: bool = False,
|
||||
auth_provider_session_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Having figured out a mxid for this user, complete the HTTP request
|
||||
|
||||
|
@ -1597,6 +1670,7 @@ class AuthHandler:
|
|||
during successful login. Must be JSON serializable.
|
||||
new_user: True if we should use wording appropriate to a user who has just
|
||||
registered.
|
||||
auth_provider_session_id: The session ID from the SSO IdP received during login.
|
||||
"""
|
||||
# If the account has been deactivated, do not proceed with the login
|
||||
# flow.
|
||||
|
@ -1617,6 +1691,7 @@ class AuthHandler:
|
|||
extra_attributes,
|
||||
new_user=new_user,
|
||||
user_profile_data=profile,
|
||||
auth_provider_session_id=auth_provider_session_id,
|
||||
)
|
||||
|
||||
def _complete_sso_login(
|
||||
|
@ -1628,6 +1703,7 @@ class AuthHandler:
|
|||
extra_attributes: Optional[JsonDict] = None,
|
||||
new_user: bool = False,
|
||||
user_profile_data: Optional[ProfileInfo] = None,
|
||||
auth_provider_session_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
The synchronous portion of complete_sso_login.
|
||||
|
@ -1649,7 +1725,9 @@ class AuthHandler:
|
|||
|
||||
# Create a login token
|
||||
login_token = self.macaroon_gen.generate_short_term_login_token(
|
||||
registered_user_id, auth_provider_id=auth_provider_id
|
||||
registered_user_id,
|
||||
auth_provider_id=auth_provider_id,
|
||||
auth_provider_session_id=auth_provider_session_id,
|
||||
)
|
||||
|
||||
# Append the login token to the original redirect URL (i.e. with its query
|
||||
|
@ -1754,6 +1832,7 @@ class MacaroonGenerator:
|
|||
self,
|
||||
user_id: str,
|
||||
auth_provider_id: str,
|
||||
auth_provider_session_id: Optional[str] = None,
|
||||
duration_in_ms: int = (2 * 60 * 1000),
|
||||
) -> str:
|
||||
macaroon = self._generate_base_macaroon(user_id)
|
||||
|
@ -1762,6 +1841,10 @@ class MacaroonGenerator:
|
|||
expiry = now + duration_in_ms
|
||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||
macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
|
||||
if auth_provider_session_id is not None:
|
||||
macaroon.add_first_party_caveat(
|
||||
"auth_provider_session_id = %s" % (auth_provider_session_id,)
|
||||
)
|
||||
return macaroon.serialize()
|
||||
|
||||
def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
|
||||
|
@ -1783,15 +1866,28 @@ class MacaroonGenerator:
|
|||
user_id = get_value_from_macaroon(macaroon, "user_id")
|
||||
auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
|
||||
|
||||
auth_provider_session_id: Optional[str] = None
|
||||
try:
|
||||
auth_provider_session_id = get_value_from_macaroon(
|
||||
macaroon, "auth_provider_session_id"
|
||||
)
|
||||
except MacaroonVerificationFailedException:
|
||||
pass
|
||||
|
||||
v = pymacaroons.Verifier()
|
||||
v.satisfy_exact("gen = 1")
|
||||
v.satisfy_exact("type = login")
|
||||
v.satisfy_general(lambda c: c.startswith("user_id = "))
|
||||
v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
|
||||
v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = "))
|
||||
satisfy_expiry(v, self.hs.get_clock().time_msec)
|
||||
v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
|
||||
|
||||
return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id)
|
||||
return LoginTokenAttributes(
|
||||
user_id=user_id,
|
||||
auth_provider_id=auth_provider_id,
|
||||
auth_provider_session_id=auth_provider_session_id,
|
||||
)
|
||||
|
||||
def generate_delete_pusher_token(self, user_id: str) -> str:
|
||||
macaroon = self._generate_base_macaroon(user_id)
|
||||
|
|
|
@ -301,6 +301,8 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
user_id: str,
|
||||
device_id: Optional[str],
|
||||
initial_device_display_name: Optional[str] = None,
|
||||
auth_provider_id: Optional[str] = None,
|
||||
auth_provider_session_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
If the given device has not been registered, register it with the
|
||||
|
@ -312,6 +314,8 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
user_id: @user:id
|
||||
device_id: device id supplied by client
|
||||
initial_device_display_name: device display name from client
|
||||
auth_provider_id: The SSO IdP the user used, if any.
|
||||
auth_provider_session_id: The session ID (sid) got from the SSO IdP.
|
||||
Returns:
|
||||
device id (generated if none was supplied)
|
||||
"""
|
||||
|
@ -323,6 +327,8 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
initial_device_display_name=initial_device_display_name,
|
||||
auth_provider_id=auth_provider_id,
|
||||
auth_provider_session_id=auth_provider_session_id,
|
||||
)
|
||||
if new_device:
|
||||
await self.notify_device_update(user_id, [device_id])
|
||||
|
@ -337,6 +343,8 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
user_id=user_id,
|
||||
device_id=new_device_id,
|
||||
initial_device_display_name=initial_device_display_name,
|
||||
auth_provider_id=auth_provider_id,
|
||||
auth_provider_session_id=auth_provider_session_id,
|
||||
)
|
||||
if new_device:
|
||||
await self.notify_device_update(user_id, [new_device_id])
|
||||
|
|
|
@ -122,9 +122,8 @@ class EventStreamHandler:
|
|||
events,
|
||||
time_now,
|
||||
as_client_event=as_client_event,
|
||||
# We don't bundle "live" events, as otherwise clients
|
||||
# will end up double counting annotations.
|
||||
bundle_relations=False,
|
||||
# Don't bundle aggregations as this is a deprecated API.
|
||||
bundle_aggregations=False,
|
||||
)
|
||||
|
||||
chunk = {
|
||||
|
|
|
@ -68,6 +68,37 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
|
||||
"""Get joined domains from state
|
||||
|
||||
Args:
|
||||
state: State map from type/state key to event.
|
||||
|
||||
Returns:
|
||||
Returns a list of servers with the lowest depth of their joins.
|
||||
Sorted by lowest depth first.
|
||||
"""
|
||||
joined_users = [
|
||||
(state_key, int(event.depth))
|
||||
for (e_type, state_key), event in state.items()
|
||||
if e_type == EventTypes.Member and event.membership == Membership.JOIN
|
||||
]
|
||||
|
||||
joined_domains: Dict[str, int] = {}
|
||||
for u, d in joined_users:
|
||||
try:
|
||||
dom = get_domain_from_id(u)
|
||||
old_d = joined_domains.get(dom)
|
||||
if old_d:
|
||||
joined_domains[dom] = min(d, old_d)
|
||||
else:
|
||||
joined_domains[dom] = d
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return sorted(joined_domains.items(), key=lambda d: d[1])
|
||||
|
||||
|
||||
class FederationHandler:
|
||||
"""Handles general incoming federation requests
|
||||
|
||||
|
@ -268,36 +299,6 @@ class FederationHandler:
|
|||
|
||||
curr_state = await self.state_handler.get_current_state(room_id)
|
||||
|
||||
def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
|
||||
"""Get joined domains from state
|
||||
|
||||
Args:
|
||||
state: State map from type/state key to event.
|
||||
|
||||
Returns:
|
||||
Returns a list of servers with the lowest depth of their joins.
|
||||
Sorted by lowest depth first.
|
||||
"""
|
||||
joined_users = [
|
||||
(state_key, int(event.depth))
|
||||
for (e_type, state_key), event in state.items()
|
||||
if e_type == EventTypes.Member and event.membership == Membership.JOIN
|
||||
]
|
||||
|
||||
joined_domains: Dict[str, int] = {}
|
||||
for u, d in joined_users:
|
||||
try:
|
||||
dom = get_domain_from_id(u)
|
||||
old_d = joined_domains.get(dom)
|
||||
if old_d:
|
||||
joined_domains[dom] = min(d, old_d)
|
||||
else:
|
||||
joined_domains[dom] = d
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return sorted(joined_domains.items(), key=lambda d: d[1])
|
||||
|
||||
curr_domains = get_domains_from_state(curr_state)
|
||||
|
||||
likely_domains = [
|
||||
|
|
|
@ -165,7 +165,11 @@ class InitialSyncHandler:
|
|||
|
||||
invite_event = await self.store.get_event(event.event_id)
|
||||
d["invite"] = await self._event_serializer.serialize_event(
|
||||
invite_event, time_now, as_client_event
|
||||
invite_event,
|
||||
time_now,
|
||||
# Don't bundle aggregations as this is a deprecated API.
|
||||
bundle_aggregations=False,
|
||||
as_client_event=as_client_event,
|
||||
)
|
||||
|
||||
rooms_ret.append(d)
|
||||
|
@ -216,7 +220,11 @@ class InitialSyncHandler:
|
|||
d["messages"] = {
|
||||
"chunk": (
|
||||
await self._event_serializer.serialize_events(
|
||||
messages, time_now=time_now, as_client_event=as_client_event
|
||||
messages,
|
||||
time_now=time_now,
|
||||
# Don't bundle aggregations as this is a deprecated API.
|
||||
bundle_aggregations=False,
|
||||
as_client_event=as_client_event,
|
||||
)
|
||||
),
|
||||
"start": await start_token.to_string(self.store),
|
||||
|
@ -226,6 +234,8 @@ class InitialSyncHandler:
|
|||
d["state"] = await self._event_serializer.serialize_events(
|
||||
current_state.values(),
|
||||
time_now=time_now,
|
||||
# Don't bundle aggregations as this is a deprecated API.
|
||||
bundle_aggregations=False,
|
||||
as_client_event=as_client_event,
|
||||
)
|
||||
|
||||
|
@ -366,14 +376,18 @@ class InitialSyncHandler:
|
|||
"room_id": room_id,
|
||||
"messages": {
|
||||
"chunk": (
|
||||
await self._event_serializer.serialize_events(messages, time_now)
|
||||
# Don't bundle aggregations as this is a deprecated API.
|
||||
await self._event_serializer.serialize_events(
|
||||
messages, time_now, bundle_aggregations=False
|
||||
)
|
||||
),
|
||||
"start": await start_token.to_string(self.store),
|
||||
"end": await end_token.to_string(self.store),
|
||||
},
|
||||
"state": (
|
||||
# Don't bundle aggregations as this is a deprecated API.
|
||||
await self._event_serializer.serialize_events(
|
||||
room_state.values(), time_now
|
||||
room_state.values(), time_now, bundle_aggregations=False
|
||||
)
|
||||
),
|
||||
"presence": [],
|
||||
|
@ -392,8 +406,9 @@ class InitialSyncHandler:
|
|||
|
||||
# TODO: These concurrently
|
||||
time_now = self.clock.time_msec()
|
||||
# Don't bundle aggregations as this is a deprecated API.
|
||||
state = await self._event_serializer.serialize_events(
|
||||
current_state.values(), time_now
|
||||
current_state.values(), time_now, bundle_aggregations=False
|
||||
)
|
||||
|
||||
now_token = self.hs.get_event_sources().get_current_token()
|
||||
|
@ -467,7 +482,10 @@ class InitialSyncHandler:
|
|||
"room_id": room_id,
|
||||
"messages": {
|
||||
"chunk": (
|
||||
await self._event_serializer.serialize_events(messages, time_now)
|
||||
# Don't bundle aggregations as this is a deprecated API.
|
||||
await self._event_serializer.serialize_events(
|
||||
messages, time_now, bundle_aggregations=False
|
||||
)
|
||||
),
|
||||
"start": await start_token.to_string(self.store),
|
||||
"end": await end_token.to_string(self.store),
|
||||
|
|
|
@ -247,13 +247,7 @@ class MessageHandler:
|
|||
room_state = room_state_events[membership_event_id]
|
||||
|
||||
now = self.clock.time_msec()
|
||||
events = await self._event_serializer.serialize_events(
|
||||
room_state.values(),
|
||||
now,
|
||||
# We don't bother bundling aggregations in when asked for state
|
||||
# events, as clients won't use them.
|
||||
bundle_relations=False,
|
||||
)
|
||||
events = await self._event_serializer.serialize_events(room_state.values(), now)
|
||||
return events
|
||||
|
||||
async def get_joined_members(self, requester: Requester, room_id: str) -> dict:
|
||||
|
|
|
@ -23,7 +23,7 @@ from authlib.common.security import generate_token
|
|||
from authlib.jose import JsonWebToken, jwt
|
||||
from authlib.oauth2.auth import ClientAuth
|
||||
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
|
||||
from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
|
||||
from authlib.oidc.core import CodeIDToken, UserInfo
|
||||
from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
|
||||
from jinja2 import Environment, Template
|
||||
from pymacaroons.exceptions import (
|
||||
|
@ -117,7 +117,8 @@ class OidcHandler:
|
|||
for idp_id, p in self._providers.items():
|
||||
try:
|
||||
await p.load_metadata()
|
||||
await p.load_jwks()
|
||||
if not p._uses_userinfo:
|
||||
await p.load_jwks()
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
"Error while initialising OIDC provider %r" % (idp_id,)
|
||||
|
@ -498,10 +499,6 @@ class OidcProvider:
|
|||
return await self._jwks.get()
|
||||
|
||||
async def _load_jwks(self) -> JWKS:
|
||||
if self._uses_userinfo:
|
||||
# We're not using jwt signing, return an empty jwk set
|
||||
return {"keys": []}
|
||||
|
||||
metadata = await self.load_metadata()
|
||||
|
||||
# Load the JWKS using the `jwks_uri` metadata.
|
||||
|
@ -663,7 +660,7 @@ class OidcProvider:
|
|||
|
||||
return UserInfo(resp)
|
||||
|
||||
async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
|
||||
async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
|
||||
"""Return an instance of UserInfo from token's ``id_token``.
|
||||
|
||||
Args:
|
||||
|
@ -673,7 +670,7 @@ class OidcProvider:
|
|||
request. This value should match the one inside the token.
|
||||
|
||||
Returns:
|
||||
An object representing the user.
|
||||
The decoded claims in the ID token.
|
||||
"""
|
||||
metadata = await self.load_metadata()
|
||||
claims_params = {
|
||||
|
@ -684,9 +681,6 @@ class OidcProvider:
|
|||
# If we got an `access_token`, there should be an `at_hash` claim
|
||||
# in the `id_token` that we can check against.
|
||||
claims_params["access_token"] = token["access_token"]
|
||||
claims_cls = CodeIDToken
|
||||
else:
|
||||
claims_cls = ImplicitIDToken
|
||||
|
||||
alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
|
||||
jwt = JsonWebToken(alg_values)
|
||||
|
@ -703,7 +697,7 @@ class OidcProvider:
|
|||
claims = jwt.decode(
|
||||
id_token,
|
||||
key=jwk_set,
|
||||
claims_cls=claims_cls,
|
||||
claims_cls=CodeIDToken,
|
||||
claims_options=claim_options,
|
||||
claims_params=claims_params,
|
||||
)
|
||||
|
@ -713,7 +707,7 @@ class OidcProvider:
|
|||
claims = jwt.decode(
|
||||
id_token,
|
||||
key=jwk_set,
|
||||
claims_cls=claims_cls,
|
||||
claims_cls=CodeIDToken,
|
||||
claims_options=claim_options,
|
||||
claims_params=claims_params,
|
||||
)
|
||||
|
@ -721,7 +715,8 @@ class OidcProvider:
|
|||
logger.debug("Decoded id_token JWT %r; validating", claims)
|
||||
|
||||
claims.validate(leeway=120) # allows 2 min of clock skew
|
||||
return UserInfo(claims)
|
||||
|
||||
return claims
|
||||
|
||||
async def handle_redirect_request(
|
||||
self,
|
||||
|
@ -837,8 +832,22 @@ class OidcProvider:
|
|||
|
||||
logger.debug("Successfully obtained OAuth2 token data: %r", token)
|
||||
|
||||
# Now that we have a token, get the userinfo, either by decoding the
|
||||
# `id_token` or by fetching the `userinfo_endpoint`.
|
||||
# If there is an id_token, it should be validated, regardless of the
|
||||
# userinfo endpoint is used or not.
|
||||
if token.get("id_token") is not None:
|
||||
try:
|
||||
id_token = await self._parse_id_token(token, nonce=session_data.nonce)
|
||||
sid = id_token.get("sid")
|
||||
except Exception as e:
|
||||
logger.exception("Invalid id_token")
|
||||
self._sso_handler.render_error(request, "invalid_token", str(e))
|
||||
return
|
||||
else:
|
||||
id_token = None
|
||||
sid = None
|
||||
|
||||
# Now that we have a token, get the userinfo either from the `id_token`
|
||||
# claims or by fetching the `userinfo_endpoint`.
|
||||
if self._uses_userinfo:
|
||||
try:
|
||||
userinfo = await self._fetch_userinfo(token)
|
||||
|
@ -846,13 +855,14 @@ class OidcProvider:
|
|||
logger.exception("Could not fetch userinfo")
|
||||
self._sso_handler.render_error(request, "fetch_error", str(e))
|
||||
return
|
||||
elif id_token is not None:
|
||||
userinfo = UserInfo(id_token)
|
||||
else:
|
||||
try:
|
||||
userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
|
||||
except Exception as e:
|
||||
logger.exception("Invalid id_token")
|
||||
self._sso_handler.render_error(request, "invalid_token", str(e))
|
||||
return
|
||||
logger.error("Missing id_token in token response")
|
||||
self._sso_handler.render_error(
|
||||
request, "invalid_token", "Missing id_token in token response"
|
||||
)
|
||||
return
|
||||
|
||||
# first check if we're doing a UIA
|
||||
if session_data.ui_auth_session_id:
|
||||
|
@ -884,7 +894,7 @@ class OidcProvider:
|
|||
# Call the mapper to register/login the user
|
||||
try:
|
||||
await self._complete_oidc_login(
|
||||
userinfo, token, request, session_data.client_redirect_url
|
||||
userinfo, token, request, session_data.client_redirect_url, sid
|
||||
)
|
||||
except MappingException as e:
|
||||
logger.exception("Could not map user")
|
||||
|
@ -896,6 +906,7 @@ class OidcProvider:
|
|||
token: Token,
|
||||
request: SynapseRequest,
|
||||
client_redirect_url: str,
|
||||
sid: Optional[str],
|
||||
) -> None:
|
||||
"""Given a UserInfo response, complete the login flow
|
||||
|
||||
|
@ -1008,6 +1019,7 @@ class OidcProvider:
|
|||
oidc_response_to_user_attributes,
|
||||
grandfather_existing_users,
|
||||
extra_attributes,
|
||||
auth_provider_session_id=sid,
|
||||
)
|
||||
|
||||
def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
|
||||
|
|
|
@ -406,9 +406,6 @@ class PaginationHandler:
|
|||
force: set true to skip checking for joined users.
|
||||
"""
|
||||
with await self.pagination_lock.write(room_id):
|
||||
# check we know about the room
|
||||
await self.store.get_room_version_id(room_id)
|
||||
|
||||
# first check that we have no users in this room
|
||||
if not force:
|
||||
joined = await self.store.is_host_joined(room_id, self._server_name)
|
||||
|
|
|
@ -421,7 +421,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
|
|||
self._on_shutdown,
|
||||
)
|
||||
|
||||
def _on_shutdown(self) -> None:
|
||||
async def _on_shutdown(self) -> None:
|
||||
if self._presence_enabled:
|
||||
self.hs.get_tcp_replication().send_command(
|
||||
ClearUserSyncsCommand(self.instance_id)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright 2014 - 2016 OpenMarket Ltd
|
||||
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -116,9 +117,13 @@ class RegistrationHandler:
|
|||
self.pusher_pool = hs.get_pusherpool()
|
||||
|
||||
self.session_lifetime = hs.config.registration.session_lifetime
|
||||
self.nonrefreshable_access_token_lifetime = (
|
||||
hs.config.registration.nonrefreshable_access_token_lifetime
|
||||
)
|
||||
self.refreshable_access_token_lifetime = (
|
||||
hs.config.registration.refreshable_access_token_lifetime
|
||||
)
|
||||
self.refresh_token_lifetime = hs.config.registration.refresh_token_lifetime
|
||||
|
||||
init_counters_for_auth_provider("")
|
||||
|
||||
|
@ -741,6 +746,7 @@ class RegistrationHandler:
|
|||
is_appservice_ghost: bool = False,
|
||||
auth_provider_id: Optional[str] = None,
|
||||
should_issue_refresh_token: bool = False,
|
||||
auth_provider_session_id: Optional[str] = None,
|
||||
) -> Tuple[str, str, Optional[int], Optional[str]]:
|
||||
"""Register a device for a user and generate an access token.
|
||||
|
||||
|
@ -751,9 +757,9 @@ class RegistrationHandler:
|
|||
device_id: The device ID to check, or None to generate a new one.
|
||||
initial_display_name: An optional display name for the device.
|
||||
is_guest: Whether this is a guest account
|
||||
auth_provider_id: The SSO IdP the user used, if any (just used for the
|
||||
prometheus metrics).
|
||||
auth_provider_id: The SSO IdP the user used, if any.
|
||||
should_issue_refresh_token: Whether it should also issue a refresh token
|
||||
auth_provider_session_id: The session ID received during login from the SSO IdP.
|
||||
Returns:
|
||||
Tuple of device ID, access token, access token expiration time and refresh token
|
||||
"""
|
||||
|
@ -764,6 +770,8 @@ class RegistrationHandler:
|
|||
is_guest=is_guest,
|
||||
is_appservice_ghost=is_appservice_ghost,
|
||||
should_issue_refresh_token=should_issue_refresh_token,
|
||||
auth_provider_id=auth_provider_id,
|
||||
auth_provider_session_id=auth_provider_session_id,
|
||||
)
|
||||
|
||||
login_counter.labels(
|
||||
|
@ -786,6 +794,8 @@ class RegistrationHandler:
|
|||
is_guest: bool = False,
|
||||
is_appservice_ghost: bool = False,
|
||||
should_issue_refresh_token: bool = False,
|
||||
auth_provider_id: Optional[str] = None,
|
||||
auth_provider_session_id: Optional[str] = None,
|
||||
) -> LoginDict:
|
||||
"""Helper for register_device
|
||||
|
||||
|
@ -793,40 +803,86 @@ class RegistrationHandler:
|
|||
class and RegisterDeviceReplicationServlet.
|
||||
"""
|
||||
assert not self.hs.config.worker.worker_app
|
||||
valid_until_ms = None
|
||||
now_ms = self.clock.time_msec()
|
||||
access_token_expiry = None
|
||||
if self.session_lifetime is not None:
|
||||
if is_guest:
|
||||
raise Exception(
|
||||
"session_lifetime is not currently implemented for guest access"
|
||||
)
|
||||
valid_until_ms = self.clock.time_msec() + self.session_lifetime
|
||||
access_token_expiry = now_ms + self.session_lifetime
|
||||
|
||||
if self.nonrefreshable_access_token_lifetime is not None:
|
||||
if access_token_expiry is not None:
|
||||
# Don't allow the non-refreshable access token to outlive the
|
||||
# session.
|
||||
access_token_expiry = min(
|
||||
now_ms + self.nonrefreshable_access_token_lifetime,
|
||||
access_token_expiry,
|
||||
)
|
||||
else:
|
||||
access_token_expiry = now_ms + self.nonrefreshable_access_token_lifetime
|
||||
|
||||
refresh_token = None
|
||||
refresh_token_id = None
|
||||
|
||||
registered_device_id = await self.device_handler.check_device_registered(
|
||||
user_id, device_id, initial_display_name
|
||||
user_id,
|
||||
device_id,
|
||||
initial_display_name,
|
||||
auth_provider_id=auth_provider_id,
|
||||
auth_provider_session_id=auth_provider_session_id,
|
||||
)
|
||||
if is_guest:
|
||||
assert valid_until_ms is None
|
||||
assert access_token_expiry is None
|
||||
access_token = self.macaroon_gen.generate_guest_access_token(user_id)
|
||||
else:
|
||||
if should_issue_refresh_token:
|
||||
# A refreshable access token lifetime must be configured
|
||||
# since we're told to issue a refresh token (the caller checks
|
||||
# that this value is set before setting this flag).
|
||||
assert self.refreshable_access_token_lifetime is not None
|
||||
|
||||
# Set the expiry time of the refreshable access token
|
||||
access_token_expiry = now_ms + self.refreshable_access_token_lifetime
|
||||
|
||||
# Set the refresh token expiry time (if configured)
|
||||
refresh_token_expiry = None
|
||||
if self.refresh_token_lifetime is not None:
|
||||
refresh_token_expiry = now_ms + self.refresh_token_lifetime
|
||||
|
||||
# Set an ultimate session expiry time (if configured)
|
||||
ultimate_session_expiry_ts = None
|
||||
if self.session_lifetime is not None:
|
||||
ultimate_session_expiry_ts = now_ms + self.session_lifetime
|
||||
|
||||
# Also ensure that the issued tokens don't outlive the
|
||||
# session.
|
||||
# (It would be weird to configure a homeserver with a shorter
|
||||
# session lifetime than token lifetime, but may as well handle
|
||||
# it.)
|
||||
access_token_expiry = min(
|
||||
access_token_expiry, ultimate_session_expiry_ts
|
||||
)
|
||||
if refresh_token_expiry is not None:
|
||||
refresh_token_expiry = min(
|
||||
refresh_token_expiry, ultimate_session_expiry_ts
|
||||
)
|
||||
|
||||
(
|
||||
refresh_token,
|
||||
refresh_token_id,
|
||||
) = await self._auth_handler.create_refresh_token_for_user_id(
|
||||
user_id,
|
||||
device_id=registered_device_id,
|
||||
)
|
||||
valid_until_ms = (
|
||||
self.clock.time_msec() + self.refreshable_access_token_lifetime
|
||||
expiry_ts=refresh_token_expiry,
|
||||
ultimate_session_expiry_ts=ultimate_session_expiry_ts,
|
||||
)
|
||||
|
||||
access_token = await self._auth_handler.create_access_token_for_user_id(
|
||||
user_id,
|
||||
device_id=registered_device_id,
|
||||
valid_until_ms=valid_until_ms,
|
||||
valid_until_ms=access_token_expiry,
|
||||
is_appservice_ghost=is_appservice_ghost,
|
||||
refresh_token_id=refresh_token_id,
|
||||
)
|
||||
|
@ -834,7 +890,7 @@ class RegistrationHandler:
|
|||
return {
|
||||
"device_id": registered_device_id,
|
||||
"access_token": access_token,
|
||||
"valid_until_ms": valid_until_ms,
|
||||
"valid_until_ms": access_token_expiry,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
|
||||
|
|
|
@ -46,6 +46,7 @@ from synapse.api.constants import (
|
|||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
Codes,
|
||||
HttpResponseException,
|
||||
LimitExceededError,
|
||||
NotFoundError,
|
||||
StoreError,
|
||||
|
@ -56,6 +57,8 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
|||
from synapse.event_auth import validate_event_for_room_version
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.utils import copy_power_levels_contents
|
||||
from synapse.federation.federation_client import InvalidResponseError
|
||||
from synapse.handlers.federation import get_domains_from_state
|
||||
from synapse.rest.admin._base import assert_user_is_admin
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.streams import EventSource
|
||||
|
@ -1220,6 +1223,147 @@ class RoomContextHandler:
|
|||
return results
|
||||
|
||||
|
||||
class TimestampLookupHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.server_name = hs.hostname
|
||||
self.store = hs.get_datastore()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
self.federation_client = hs.get_federation_client()
|
||||
|
||||
async def get_event_for_timestamp(
|
||||
self,
|
||||
requester: Requester,
|
||||
room_id: str,
|
||||
timestamp: int,
|
||||
direction: str,
|
||||
) -> Tuple[str, int]:
|
||||
"""Find the closest event to the given timestamp in the given direction.
|
||||
If we can't find an event locally or the event we have locally is next to a gap,
|
||||
it will ask other federated homeservers for an event.
|
||||
|
||||
Args:
|
||||
requester: The user making the request according to the access token
|
||||
room_id: Room to fetch the event from
|
||||
timestamp: The point in time (inclusive) we should navigate from in
|
||||
the given direction to find the closest event.
|
||||
direction: ["f"|"b"] to indicate whether we should navigate forward
|
||||
or backward from the given timestamp to find the closest event.
|
||||
|
||||
Returns:
|
||||
A tuple containing the `event_id` closest to the given timestamp in
|
||||
the given direction and the `origin_server_ts`.
|
||||
|
||||
Raises:
|
||||
SynapseError if unable to find any event locally in the given direction
|
||||
"""
|
||||
|
||||
local_event_id = await self.store.get_event_id_for_timestamp(
|
||||
room_id, timestamp, direction
|
||||
)
|
||||
logger.debug(
|
||||
"get_event_for_timestamp: locally, we found event_id=%s closest to timestamp=%s",
|
||||
local_event_id,
|
||||
timestamp,
|
||||
)
|
||||
|
||||
# Check for gaps in the history where events could be hiding in between
|
||||
# the timestamp given and the event we were able to find locally
|
||||
is_event_next_to_backward_gap = False
|
||||
is_event_next_to_forward_gap = False
|
||||
if local_event_id:
|
||||
local_event = await self.store.get_event(
|
||||
local_event_id, allow_none=False, allow_rejected=False
|
||||
)
|
||||
|
||||
if direction == "f":
|
||||
# We only need to check for a backward gap if we're looking forwards
|
||||
# to ensure there is nothing in between.
|
||||
is_event_next_to_backward_gap = (
|
||||
await self.store.is_event_next_to_backward_gap(local_event)
|
||||
)
|
||||
elif direction == "b":
|
||||
# We only need to check for a forward gap if we're looking backwards
|
||||
# to ensure there is nothing in between
|
||||
is_event_next_to_forward_gap = (
|
||||
await self.store.is_event_next_to_forward_gap(local_event)
|
||||
)
|
||||
|
||||
# If we found a gap, we should probably ask another homeserver first
|
||||
# about more history in between
|
||||
if (
|
||||
not local_event_id
|
||||
or is_event_next_to_backward_gap
|
||||
or is_event_next_to_forward_gap
|
||||
):
|
||||
logger.debug(
|
||||
"get_event_for_timestamp: locally, we found event_id=%s closest to timestamp=%s which is next to a gap in event history so we're asking other homeservers first",
|
||||
local_event_id,
|
||||
timestamp,
|
||||
)
|
||||
|
||||
# Find other homeservers from the given state in the room
|
||||
curr_state = await self.state_handler.get_current_state(room_id)
|
||||
curr_domains = get_domains_from_state(curr_state)
|
||||
likely_domains = [
|
||||
domain for domain, depth in curr_domains if domain != self.server_name
|
||||
]
|
||||
|
||||
# Loop through each homeserver candidate until we get a succesful response
|
||||
for domain in likely_domains:
|
||||
try:
|
||||
remote_response = await self.federation_client.timestamp_to_event(
|
||||
domain, room_id, timestamp, direction
|
||||
)
|
||||
logger.debug(
|
||||
"get_event_for_timestamp: response from domain(%s)=%s",
|
||||
domain,
|
||||
remote_response,
|
||||
)
|
||||
|
||||
# TODO: Do we want to persist this as an extremity?
|
||||
# TODO: I think ideally, we would try to backfill from
|
||||
# this event and run this whole
|
||||
# `get_event_for_timestamp` function again to make sure
|
||||
# they didn't give us an event from their gappy history.
|
||||
remote_event_id = remote_response.event_id
|
||||
origin_server_ts = remote_response.origin_server_ts
|
||||
|
||||
# Only return the remote event if it's closer than the local event
|
||||
if not local_event or (
|
||||
abs(origin_server_ts - timestamp)
|
||||
< abs(local_event.origin_server_ts - timestamp)
|
||||
):
|
||||
return remote_event_id, origin_server_ts
|
||||
except (HttpResponseException, InvalidResponseError) as ex:
|
||||
# Let's not put a high priority on some other homeserver
|
||||
# failing to respond or giving a random response
|
||||
logger.debug(
|
||||
"Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
|
||||
domain,
|
||||
type(ex).__name__,
|
||||
ex,
|
||||
ex.args,
|
||||
)
|
||||
except Exception as ex:
|
||||
# But we do want to see some exceptions in our code
|
||||
logger.warning(
|
||||
"Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
|
||||
domain,
|
||||
type(ex).__name__,
|
||||
ex,
|
||||
ex.args,
|
||||
)
|
||||
|
||||
if not local_event_id:
|
||||
raise SynapseError(
|
||||
404,
|
||||
"Unable to find event from %s in direction %s" % (timestamp, direction),
|
||||
errcode=Codes.NOT_FOUND,
|
||||
)
|
||||
|
||||
return local_event_id, local_event.origin_server_ts
|
||||
|
||||
|
||||
class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
|
@ -1391,20 +1535,13 @@ class RoomShutdownHandler:
|
|||
await self.store.block_room(room_id, requester_user_id)
|
||||
|
||||
if not await self.store.get_room(room_id):
|
||||
if block:
|
||||
# We allow you to block an unknown room.
|
||||
return {
|
||||
"kicked_users": [],
|
||||
"failed_to_kick_users": [],
|
||||
"local_aliases": [],
|
||||
"new_room_id": None,
|
||||
}
|
||||
else:
|
||||
# But if you don't want to preventatively block another room,
|
||||
# this function can't do anything useful.
|
||||
raise NotFoundError(
|
||||
"Cannot shut down room: unknown room id %s" % (room_id,)
|
||||
)
|
||||
# if we don't know about the room, there is nothing left to do.
|
||||
return {
|
||||
"kicked_users": [],
|
||||
"failed_to_kick_users": [],
|
||||
"local_aliases": [],
|
||||
"new_room_id": None,
|
||||
}
|
||||
|
||||
if new_room_user_id is not None:
|
||||
if not self.hs.is_mine_id(new_room_user_id):
|
||||
|
|
|
@ -36,8 +36,9 @@ from synapse.api.errors import (
|
|||
SynapseError,
|
||||
UnsupportedRoomVersionError,
|
||||
)
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import JsonDict
|
||||
from synapse.types import JsonDict, Requester
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -93,6 +94,9 @@ class RoomSummaryHandler:
|
|||
self._event_serializer = hs.get_event_client_serializer()
|
||||
self._server_name = hs.hostname
|
||||
self._federation_client = hs.get_federation_client()
|
||||
self._ratelimiter = Ratelimiter(
|
||||
store=self._store, clock=hs.get_clock(), rate_hz=5, burst_count=10
|
||||
)
|
||||
|
||||
# If a user tries to fetch the same page multiple times in quick succession,
|
||||
# only process the first attempt and return its result to subsequent requests.
|
||||
|
@ -249,7 +253,7 @@ class RoomSummaryHandler:
|
|||
|
||||
async def get_room_hierarchy(
|
||||
self,
|
||||
requester: str,
|
||||
requester: Requester,
|
||||
requested_room_id: str,
|
||||
suggested_only: bool = False,
|
||||
max_depth: Optional[int] = None,
|
||||
|
@ -276,6 +280,8 @@ class RoomSummaryHandler:
|
|||
Returns:
|
||||
The JSON hierarchy dictionary.
|
||||
"""
|
||||
await self._ratelimiter.ratelimit(requester)
|
||||
|
||||
# If a user tries to fetch the same page multiple times in quick succession,
|
||||
# only process the first attempt and return its result to subsequent requests.
|
||||
#
|
||||
|
@ -283,7 +289,7 @@ class RoomSummaryHandler:
|
|||
# to process multiple requests for the same page will result in errors.
|
||||
return await self._pagination_response_cache.wrap(
|
||||
(
|
||||
requester,
|
||||
requester.user.to_string(),
|
||||
requested_room_id,
|
||||
suggested_only,
|
||||
max_depth,
|
||||
|
@ -291,7 +297,7 @@ class RoomSummaryHandler:
|
|||
from_token,
|
||||
),
|
||||
self._get_room_hierarchy,
|
||||
requester,
|
||||
requester.user.to_string(),
|
||||
requested_room_id,
|
||||
suggested_only,
|
||||
max_depth,
|
||||
|
|
|
@ -365,6 +365,7 @@ class SsoHandler:
|
|||
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
|
||||
grandfather_existing_users: Callable[[], Awaitable[Optional[str]]],
|
||||
extra_login_attributes: Optional[JsonDict] = None,
|
||||
auth_provider_session_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Given an SSO ID, retrieve the user ID for it and possibly register the user.
|
||||
|
@ -415,6 +416,8 @@ class SsoHandler:
|
|||
extra_login_attributes: An optional dictionary of extra
|
||||
attributes to be provided to the client in the login response.
|
||||
|
||||
auth_provider_session_id: An optional session ID from the IdP.
|
||||
|
||||
Raises:
|
||||
MappingException if there was a problem mapping the response to a user.
|
||||
RedirectException: if the mapping provider needs to redirect the user
|
||||
|
@ -490,6 +493,7 @@ class SsoHandler:
|
|||
client_redirect_url,
|
||||
extra_login_attributes,
|
||||
new_user=new_user,
|
||||
auth_provider_session_id=auth_provider_session_id,
|
||||
)
|
||||
|
||||
async def _call_attribute_mapper(
|
||||
|
|
|
@ -334,6 +334,19 @@ class SyncHandler:
|
|||
full_state: bool,
|
||||
cache_context: ResponseCacheContext[SyncRequestKey],
|
||||
) -> SyncResult:
|
||||
"""The start of the machinery that produces a /sync response.
|
||||
|
||||
See https://spec.matrix.org/v1.1/client-server-api/#syncing for full details.
|
||||
|
||||
This method does high-level bookkeeping:
|
||||
- tracking the kind of sync in the logging context
|
||||
- deleting any to_device messages whose delivery has been acknowledged.
|
||||
- deciding if we should dispatch an instant or delayed response
|
||||
- marking the sync as being lazily loaded, if appropriate
|
||||
|
||||
Computing the body of the response begins in the next method,
|
||||
`current_sync_for_user`.
|
||||
"""
|
||||
if since_token is None:
|
||||
sync_type = "initial_sync"
|
||||
elif full_state:
|
||||
|
@ -363,7 +376,7 @@ class SyncHandler:
|
|||
sync_config, since_token, full_state=full_state
|
||||
)
|
||||
else:
|
||||
|
||||
# Otherwise, we wait for something to happen and report it to the user.
|
||||
async def current_sync_callback(
|
||||
before_token: StreamToken, after_token: StreamToken
|
||||
) -> SyncResult:
|
||||
|
@ -402,7 +415,12 @@ class SyncHandler:
|
|||
since_token: Optional[StreamToken] = None,
|
||||
full_state: bool = False,
|
||||
) -> SyncResult:
|
||||
"""Get the sync for client needed to match what the server has now."""
|
||||
"""Generates the response body of a sync result, represented as a SyncResult.
|
||||
|
||||
This is a wrapper around `generate_sync_result` which starts an open tracing
|
||||
span to track the sync. See `generate_sync_result` for the next part of your
|
||||
indoctrination.
|
||||
"""
|
||||
with start_active_span("current_sync_for_user"):
|
||||
log_kv({"since_token": since_token})
|
||||
sync_result = await self.generate_sync_result(
|
||||
|
@ -560,7 +578,7 @@ class SyncHandler:
|
|||
# that have happened since `since_key` up to `end_key`, so we
|
||||
# can just use `get_room_events_stream_for_room`.
|
||||
# Otherwise, we want to return the last N events in the room
|
||||
# in toplogical ordering.
|
||||
# in topological ordering.
|
||||
if since_key:
|
||||
events, end_key = await self.store.get_room_events_stream_for_room(
|
||||
room_id,
|
||||
|
@ -1042,7 +1060,18 @@ class SyncHandler:
|
|||
since_token: Optional[StreamToken] = None,
|
||||
full_state: bool = False,
|
||||
) -> SyncResult:
|
||||
"""Generates a sync result."""
|
||||
"""Generates the response body of a sync result.
|
||||
|
||||
This is represented by a `SyncResult` struct, which is built from small pieces
|
||||
using a `SyncResultBuilder`. See also
|
||||
https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync
|
||||
the `sync_result_builder` is passed as a mutable ("inout") parameter to various
|
||||
helper functions. These retrieve and process the data which forms the sync body,
|
||||
often writing to the `sync_result_builder` to store their output.
|
||||
|
||||
At the end, we transfer data from the `sync_result_builder` to a new `SyncResult`
|
||||
instance to signify that the sync calculation is complete.
|
||||
"""
|
||||
# NB: The now_token gets changed by some of the generate_sync_* methods,
|
||||
# this is due to some of the underlying streams not supporting the ability
|
||||
# to query up to a given point.
|
||||
|
@ -1344,14 +1373,22 @@ class SyncHandler:
|
|||
async def _generate_sync_entry_for_account_data(
|
||||
self, sync_result_builder: "SyncResultBuilder"
|
||||
) -> Dict[str, Dict[str, JsonDict]]:
|
||||
"""Generates the account data portion of the sync response. Populates
|
||||
`sync_result_builder` with the result.
|
||||
"""Generates the account data portion of the sync response.
|
||||
|
||||
Account data (called "Client Config" in the spec) can be set either globally
|
||||
or for a specific room. Account data consists of a list of events which
|
||||
accumulate state, much like a room.
|
||||
|
||||
This function retrieves global and per-room account data. The former is written
|
||||
to the given `sync_result_builder`. The latter is returned directly, to be
|
||||
later written to the `sync_result_builder` on a room-by-room basis.
|
||||
|
||||
Args:
|
||||
sync_result_builder
|
||||
|
||||
Returns:
|
||||
A dictionary containing the per room account data.
|
||||
A dictionary whose keys (room ids) map to the per room account data for that
|
||||
room.
|
||||
"""
|
||||
sync_config = sync_result_builder.sync_config
|
||||
user_id = sync_result_builder.sync_config.user.to_string()
|
||||
|
@ -1359,7 +1396,7 @@ class SyncHandler:
|
|||
|
||||
if since_token and not sync_result_builder.full_state:
|
||||
(
|
||||
account_data,
|
||||
global_account_data,
|
||||
account_data_by_room,
|
||||
) = await self.store.get_updated_account_data_for_user(
|
||||
user_id, since_token.account_data_key
|
||||
|
@ -1370,23 +1407,23 @@ class SyncHandler:
|
|||
)
|
||||
|
||||
if push_rules_changed:
|
||||
account_data["m.push_rules"] = await self.push_rules_for_user(
|
||||
global_account_data["m.push_rules"] = await self.push_rules_for_user(
|
||||
sync_config.user
|
||||
)
|
||||
else:
|
||||
(
|
||||
account_data,
|
||||
global_account_data,
|
||||
account_data_by_room,
|
||||
) = await self.store.get_account_data_for_user(sync_config.user.to_string())
|
||||
|
||||
account_data["m.push_rules"] = await self.push_rules_for_user(
|
||||
global_account_data["m.push_rules"] = await self.push_rules_for_user(
|
||||
sync_config.user
|
||||
)
|
||||
|
||||
account_data_for_user = await sync_config.filter_collection.filter_account_data(
|
||||
[
|
||||
{"type": account_data_type, "content": content}
|
||||
for account_data_type, content in account_data.items()
|
||||
for account_data_type, content in global_account_data.items()
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -1460,18 +1497,31 @@ class SyncHandler:
|
|||
"""Generates the rooms portion of the sync response. Populates the
|
||||
`sync_result_builder` with the result.
|
||||
|
||||
In the response that reaches the client, rooms are divided into four categories:
|
||||
`invite`, `join`, `knock`, `leave`. These aren't the same as the four sets of
|
||||
room ids returned by this function.
|
||||
|
||||
Args:
|
||||
sync_result_builder
|
||||
account_data_by_room: Dictionary of per room account data
|
||||
|
||||
Returns:
|
||||
Returns a 4-tuple of
|
||||
`(newly_joined_rooms, newly_joined_or_invited_users,
|
||||
newly_left_rooms, newly_left_users)`
|
||||
Returns a 4-tuple describing rooms the user has joined or left, and users who've
|
||||
joined or left rooms any rooms the user is in. This gets used later in
|
||||
`_generate_sync_entry_for_device_list`.
|
||||
|
||||
Its entries are:
|
||||
- newly_joined_rooms
|
||||
- newly_joined_or_invited_or_knocked_users
|
||||
- newly_left_rooms
|
||||
- newly_left_users
|
||||
"""
|
||||
since_token = sync_result_builder.since_token
|
||||
|
||||
# 1. Start by fetching all ephemeral events in rooms we've joined (if required).
|
||||
user_id = sync_result_builder.sync_config.user.to_string()
|
||||
block_all_room_ephemeral = (
|
||||
sync_result_builder.since_token is None
|
||||
since_token is None
|
||||
and sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral()
|
||||
)
|
||||
|
||||
|
@ -1485,9 +1535,8 @@ class SyncHandler:
|
|||
)
|
||||
sync_result_builder.now_token = now_token
|
||||
|
||||
# We check up front if anything has changed, if it hasn't then there is
|
||||
# 2. We check up front if anything has changed, if it hasn't then there is
|
||||
# no point in going further.
|
||||
since_token = sync_result_builder.since_token
|
||||
if not sync_result_builder.full_state:
|
||||
if since_token and not ephemeral_by_room and not account_data_by_room:
|
||||
have_changed = await self._have_rooms_changed(sync_result_builder)
|
||||
|
@ -1500,20 +1549,8 @@ class SyncHandler:
|
|||
logger.debug("no-oping sync")
|
||||
return set(), set(), set(), set()
|
||||
|
||||
ignored_account_data = (
|
||||
await self.store.get_global_account_data_by_type_for_user(
|
||||
AccountDataTypes.IGNORED_USER_LIST, user_id=user_id
|
||||
)
|
||||
)
|
||||
|
||||
# If there is ignored users account data and it matches the proper type,
|
||||
# then use it.
|
||||
ignored_users: FrozenSet[str] = frozenset()
|
||||
if ignored_account_data:
|
||||
ignored_users_data = ignored_account_data.get("ignored_users", {})
|
||||
if isinstance(ignored_users_data, dict):
|
||||
ignored_users = frozenset(ignored_users_data.keys())
|
||||
|
||||
# 3. Work out which rooms need reporting in the sync response.
|
||||
ignored_users = await self._get_ignored_users(user_id)
|
||||
if since_token:
|
||||
room_changes = await self._get_rooms_changed(
|
||||
sync_result_builder, ignored_users
|
||||
|
@ -1523,7 +1560,6 @@ class SyncHandler:
|
|||
)
|
||||
else:
|
||||
room_changes = await self._get_all_rooms(sync_result_builder, ignored_users)
|
||||
|
||||
tags_by_room = await self.store.get_tags_for_user(user_id)
|
||||
|
||||
log_kv({"rooms_changed": len(room_changes.room_entries)})
|
||||
|
@ -1534,6 +1570,8 @@ class SyncHandler:
|
|||
newly_joined_rooms = room_changes.newly_joined_rooms
|
||||
newly_left_rooms = room_changes.newly_left_rooms
|
||||
|
||||
# 4. We need to apply further processing to `room_entries` (rooms considered
|
||||
# joined or archived).
|
||||
async def handle_room_entries(room_entry: "RoomSyncResultBuilder") -> None:
|
||||
logger.debug("Generating room entry for %s", room_entry.room_id)
|
||||
await self._generate_room_entry(
|
||||
|
@ -1552,31 +1590,13 @@ class SyncHandler:
|
|||
sync_result_builder.invited.extend(invited)
|
||||
sync_result_builder.knocked.extend(knocked)
|
||||
|
||||
# Now we want to get any newly joined, invited or knocking users
|
||||
newly_joined_or_invited_or_knocked_users = set()
|
||||
newly_left_users = set()
|
||||
if since_token:
|
||||
for joined_sync in sync_result_builder.joined:
|
||||
it = itertools.chain(
|
||||
joined_sync.timeline.events, joined_sync.state.values()
|
||||
)
|
||||
for event in it:
|
||||
if event.type == EventTypes.Member:
|
||||
if (
|
||||
event.membership == Membership.JOIN
|
||||
or event.membership == Membership.INVITE
|
||||
or event.membership == Membership.KNOCK
|
||||
):
|
||||
newly_joined_or_invited_or_knocked_users.add(
|
||||
event.state_key
|
||||
)
|
||||
else:
|
||||
prev_content = event.unsigned.get("prev_content", {})
|
||||
prev_membership = prev_content.get("membership", None)
|
||||
if prev_membership == Membership.JOIN:
|
||||
newly_left_users.add(event.state_key)
|
||||
|
||||
newly_left_users -= newly_joined_or_invited_or_knocked_users
|
||||
# 5. Work out which users have joined or left rooms we're in. We use this
|
||||
# to build the device_list part of the sync response in
|
||||
# `_generate_sync_entry_for_device_list`.
|
||||
(
|
||||
newly_joined_or_invited_or_knocked_users,
|
||||
newly_left_users,
|
||||
) = sync_result_builder.calculate_user_changes()
|
||||
|
||||
return (
|
||||
set(newly_joined_rooms),
|
||||
|
@ -1585,11 +1605,36 @@ class SyncHandler:
|
|||
newly_left_users,
|
||||
)
|
||||
|
||||
async def _get_ignored_users(self, user_id: str) -> FrozenSet[str]:
|
||||
"""Retrieve the users ignored by the given user from their global account_data.
|
||||
|
||||
Returns an empty set if
|
||||
- there is no global account_data entry for ignored_users
|
||||
- there is such an entry, but it's not a JSON object.
|
||||
"""
|
||||
# TODO: Can we `SELECT ignored_user_id FROM ignored_users WHERE ignorer_user_id=?;` instead?
|
||||
ignored_account_data = (
|
||||
await self.store.get_global_account_data_by_type_for_user(
|
||||
AccountDataTypes.IGNORED_USER_LIST, user_id=user_id
|
||||
)
|
||||
)
|
||||
|
||||
# If there is ignored users account data and it matches the proper type,
|
||||
# then use it.
|
||||
ignored_users: FrozenSet[str] = frozenset()
|
||||
if ignored_account_data:
|
||||
ignored_users_data = ignored_account_data.get("ignored_users", {})
|
||||
if isinstance(ignored_users_data, dict):
|
||||
ignored_users = frozenset(ignored_users_data.keys())
|
||||
return ignored_users
|
||||
|
||||
async def _have_rooms_changed(
|
||||
self, sync_result_builder: "SyncResultBuilder"
|
||||
) -> bool:
|
||||
"""Returns whether there may be any new events that should be sent down
|
||||
the sync. Returns True if there are.
|
||||
|
||||
Does not modify the `sync_result_builder`.
|
||||
"""
|
||||
user_id = sync_result_builder.sync_config.user.to_string()
|
||||
since_token = sync_result_builder.since_token
|
||||
|
@ -1597,12 +1642,13 @@ class SyncHandler:
|
|||
|
||||
assert since_token
|
||||
|
||||
# Get a list of membership change events that have happened.
|
||||
rooms_changed = await self.store.get_membership_changes_for_user(
|
||||
# Get a list of membership change events that have happened to the user
|
||||
# requesting the sync.
|
||||
membership_changes = await self.store.get_membership_changes_for_user(
|
||||
user_id, since_token.room_key, now_token.room_key
|
||||
)
|
||||
|
||||
if rooms_changed:
|
||||
if membership_changes:
|
||||
return True
|
||||
|
||||
stream_id = since_token.room_key.stream
|
||||
|
@ -1614,7 +1660,25 @@ class SyncHandler:
|
|||
async def _get_rooms_changed(
|
||||
self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
|
||||
) -> _RoomChanges:
|
||||
"""Gets the the changes that have happened since the last sync."""
|
||||
"""Determine the changes in rooms to report to the user.
|
||||
|
||||
Ideally, we want to report all events whose stream ordering `s` lies in the
|
||||
range `since_token < s <= now_token`, where the two tokens are read from the
|
||||
sync_result_builder.
|
||||
|
||||
If there are too many events in that range to report, things get complicated.
|
||||
In this situation we return a truncated list of the most recent events, and
|
||||
indicate in the response that there is a "gap" of omitted events. Additionally:
|
||||
|
||||
- we include a "state_delta", to describe the changes in state over the gap,
|
||||
- we include all membership events applying to the user making the request,
|
||||
even those in the gap.
|
||||
|
||||
See the spec for the rationale:
|
||||
https://spec.matrix.org/v1.1/client-server-api/#syncing
|
||||
|
||||
The sync_result_builder is not modified by this function.
|
||||
"""
|
||||
user_id = sync_result_builder.sync_config.user.to_string()
|
||||
since_token = sync_result_builder.since_token
|
||||
now_token = sync_result_builder.now_token
|
||||
|
@ -1622,21 +1686,36 @@ class SyncHandler:
|
|||
|
||||
assert since_token
|
||||
|
||||
# Get a list of membership change events that have happened.
|
||||
rooms_changed = await self.store.get_membership_changes_for_user(
|
||||
# The spec
|
||||
# https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync
|
||||
# notes that membership events need special consideration:
|
||||
#
|
||||
# > When a sync is limited, the server MUST return membership events for events
|
||||
# > in the gap (between since and the start of the returned timeline), regardless
|
||||
# > as to whether or not they are redundant.
|
||||
#
|
||||
# We fetch such events here, but we only seem to use them for categorising rooms
|
||||
# as newly joined, newly left, invited or knocked.
|
||||
# TODO: we've already called this function and ran this query in
|
||||
# _have_rooms_changed. We could keep the results in memory to avoid a
|
||||
# second query, at the cost of more complicated source code.
|
||||
membership_change_events = await self.store.get_membership_changes_for_user(
|
||||
user_id, since_token.room_key, now_token.room_key
|
||||
)
|
||||
|
||||
mem_change_events_by_room_id: Dict[str, List[EventBase]] = {}
|
||||
for event in rooms_changed:
|
||||
for event in membership_change_events:
|
||||
mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
|
||||
|
||||
newly_joined_rooms = []
|
||||
newly_left_rooms = []
|
||||
room_entries = []
|
||||
invited = []
|
||||
knocked = []
|
||||
newly_joined_rooms: List[str] = []
|
||||
newly_left_rooms: List[str] = []
|
||||
room_entries: List[RoomSyncResultBuilder] = []
|
||||
invited: List[InvitedSyncResult] = []
|
||||
knocked: List[KnockedSyncResult] = []
|
||||
for room_id, events in mem_change_events_by_room_id.items():
|
||||
# The body of this loop will add this room to at least one of the five lists
|
||||
# above. Things get messy if you've e.g. joined, left, joined then left the
|
||||
# room all in the same sync period.
|
||||
logger.debug(
|
||||
"Membership changes in %s: [%s]",
|
||||
room_id,
|
||||
|
@ -1691,6 +1770,7 @@ class SyncHandler:
|
|||
|
||||
if not non_joins:
|
||||
continue
|
||||
last_non_join = non_joins[-1]
|
||||
|
||||
# Check if we have left the room. This can either be because we were
|
||||
# joined before *or* that we since joined and then left.
|
||||
|
@ -1712,18 +1792,18 @@ class SyncHandler:
|
|||
newly_left_rooms.append(room_id)
|
||||
|
||||
# Only bother if we're still currently invited
|
||||
should_invite = non_joins[-1].membership == Membership.INVITE
|
||||
should_invite = last_non_join.membership == Membership.INVITE
|
||||
if should_invite:
|
||||
if event.sender not in ignored_users:
|
||||
invite_room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
|
||||
if last_non_join.sender not in ignored_users:
|
||||
invite_room_sync = InvitedSyncResult(room_id, invite=last_non_join)
|
||||
if invite_room_sync:
|
||||
invited.append(invite_room_sync)
|
||||
|
||||
# Only bother if our latest membership in the room is knock (and we haven't
|
||||
# been accepted/rejected in the meantime).
|
||||
should_knock = non_joins[-1].membership == Membership.KNOCK
|
||||
should_knock = last_non_join.membership == Membership.KNOCK
|
||||
if should_knock:
|
||||
knock_room_sync = KnockedSyncResult(room_id, knock=non_joins[-1])
|
||||
knock_room_sync = KnockedSyncResult(room_id, knock=last_non_join)
|
||||
if knock_room_sync:
|
||||
knocked.append(knock_room_sync)
|
||||
|
||||
|
@ -1781,7 +1861,9 @@ class SyncHandler:
|
|||
|
||||
timeline_limit = sync_config.filter_collection.timeline_limit()
|
||||
|
||||
# Get all events for rooms we're currently joined to.
|
||||
# Get all events since the `from_key` in rooms we're currently joined to.
|
||||
# If there are too many, we get the most recent events only. This leaves
|
||||
# a "gap" in the timeline, as described by the spec for /sync.
|
||||
room_to_events = await self.store.get_room_events_stream_for_rooms(
|
||||
room_ids=sync_result_builder.joined_room_ids,
|
||||
from_key=since_token.room_key,
|
||||
|
@ -1842,6 +1924,10 @@ class SyncHandler:
|
|||
) -> _RoomChanges:
|
||||
"""Returns entries for all rooms for the user.
|
||||
|
||||
Like `_get_rooms_changed`, but assumes the `since_token` is `None`.
|
||||
|
||||
This function does not modify the sync_result_builder.
|
||||
|
||||
Args:
|
||||
sync_result_builder
|
||||
ignored_users: Set of users ignored by user.
|
||||
|
@ -1853,16 +1939,9 @@ class SyncHandler:
|
|||
now_token = sync_result_builder.now_token
|
||||
sync_config = sync_result_builder.sync_config
|
||||
|
||||
membership_list = (
|
||||
Membership.INVITE,
|
||||
Membership.KNOCK,
|
||||
Membership.JOIN,
|
||||
Membership.LEAVE,
|
||||
Membership.BAN,
|
||||
)
|
||||
|
||||
room_list = await self.store.get_rooms_for_local_user_where_membership_is(
|
||||
user_id=user_id, membership_list=membership_list
|
||||
user_id=user_id,
|
||||
membership_list=Membership.LIST,
|
||||
)
|
||||
|
||||
room_entries = []
|
||||
|
@ -2212,8 +2291,7 @@ def _calculate_state(
|
|||
# to only include membership events for the senders in the timeline.
|
||||
# In practice, we can do this by removing them from the p_ids list,
|
||||
# which is the list of relevant state we know we have already sent to the client.
|
||||
# see https://github.com/matrix-org/synapse/pull/2970
|
||||
# /files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809
|
||||
# see https://github.com/matrix-org/synapse/pull/2970/files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809
|
||||
|
||||
if lazy_load_members:
|
||||
p_ids.difference_update(
|
||||
|
@ -2262,6 +2340,39 @@ class SyncResultBuilder:
|
|||
groups: Optional[GroupsSyncResult] = None
|
||||
to_device: List[JsonDict] = attr.Factory(list)
|
||||
|
||||
def calculate_user_changes(self) -> Tuple[Set[str], Set[str]]:
|
||||
"""Work out which other users have joined or left rooms we are joined to.
|
||||
|
||||
This data only is only useful for an incremental sync.
|
||||
|
||||
The SyncResultBuilder is not modified by this function.
|
||||
"""
|
||||
newly_joined_or_invited_or_knocked_users = set()
|
||||
newly_left_users = set()
|
||||
if self.since_token:
|
||||
for joined_sync in self.joined:
|
||||
it = itertools.chain(
|
||||
joined_sync.timeline.events, joined_sync.state.values()
|
||||
)
|
||||
for event in it:
|
||||
if event.type == EventTypes.Member:
|
||||
if (
|
||||
event.membership == Membership.JOIN
|
||||
or event.membership == Membership.INVITE
|
||||
or event.membership == Membership.KNOCK
|
||||
):
|
||||
newly_joined_or_invited_or_knocked_users.add(
|
||||
event.state_key
|
||||
)
|
||||
else:
|
||||
prev_content = event.unsigned.get("prev_content", {})
|
||||
prev_membership = prev_content.get("membership", None)
|
||||
if prev_membership == Membership.JOIN:
|
||||
newly_left_users.add(event.state_key)
|
||||
|
||||
newly_left_users -= newly_joined_or_invited_or_knocked_users
|
||||
return newly_joined_or_invited_or_knocked_users, newly_left_users
|
||||
|
||||
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
class RoomSyncResultBuilder:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue