mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-08-15 15:50:11 -04:00
Merge remote-tracking branch 'upstream/release-v1.71'
This commit is contained in:
commit
2337ca829d
135 changed files with 5192 additions and 2356 deletions
|
@ -100,6 +100,7 @@ class AdminHandler:
|
|||
user_info_dict["avatar_url"] = profile.avatar_url
|
||||
user_info_dict["threepids"] = threepids
|
||||
user_info_dict["external_ids"] = external_ids
|
||||
user_info_dict["erased"] = await self.store.is_user_erased(user.to_string())
|
||||
|
||||
return user_info_dict
|
||||
|
||||
|
|
|
@ -38,6 +38,7 @@ from typing import (
|
|||
import attr
|
||||
import bcrypt
|
||||
import unpaddedbase64
|
||||
from prometheus_client import Counter
|
||||
|
||||
from twisted.internet.defer import CancelledError
|
||||
from twisted.web.server import Request
|
||||
|
@ -48,6 +49,7 @@ from synapse.api.errors import (
|
|||
Codes,
|
||||
InteractiveAuthIncompleteError,
|
||||
LoginError,
|
||||
NotFoundError,
|
||||
StoreError,
|
||||
SynapseError,
|
||||
UserDeactivatedError,
|
||||
|
@ -63,10 +65,14 @@ from synapse.http.server import finish_request, respond_with_html
|
|||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import defer_to_thread
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.databases.main.registration import (
|
||||
LoginTokenExpired,
|
||||
LoginTokenLookupResult,
|
||||
LoginTokenReused,
|
||||
)
|
||||
from synapse.types import JsonDict, Requester, UserID
|
||||
from synapse.util import stringutils as stringutils
|
||||
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
|
||||
from synapse.util.macaroons import LoginTokenAttributes
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
from synapse.util.stringutils import base62_encode
|
||||
from synapse.util.threepids import canonicalise_email
|
||||
|
@ -80,6 +86,12 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
INVALID_USERNAME_OR_PASSWORD = "Invalid username or password"
|
||||
|
||||
invalid_login_token_counter = Counter(
|
||||
"synapse_user_login_invalid_login_tokens",
|
||||
"Counts the number of rejected m.login.token on /login",
|
||||
["reason"],
|
||||
)
|
||||
|
||||
|
||||
def convert_client_dict_legacy_fields_to_identifier(
|
||||
submission: JsonDict,
|
||||
|
@ -883,6 +895,25 @@ class AuthHandler:
|
|||
|
||||
return True
|
||||
|
||||
async def create_login_token_for_user_id(
|
||||
self,
|
||||
user_id: str,
|
||||
duration_ms: int = (2 * 60 * 1000),
|
||||
auth_provider_id: Optional[str] = None,
|
||||
auth_provider_session_id: Optional[str] = None,
|
||||
) -> str:
|
||||
login_token = self.generate_login_token()
|
||||
now = self._clock.time_msec()
|
||||
expiry_ts = now + duration_ms
|
||||
await self.store.add_login_token_to_user(
|
||||
user_id=user_id,
|
||||
token=login_token,
|
||||
expiry_ts=expiry_ts,
|
||||
auth_provider_id=auth_provider_id,
|
||||
auth_provider_session_id=auth_provider_session_id,
|
||||
)
|
||||
return login_token
|
||||
|
||||
async def create_refresh_token_for_user_id(
|
||||
self,
|
||||
user_id: str,
|
||||
|
@ -1401,6 +1432,18 @@ class AuthHandler:
|
|||
return None
|
||||
return user_id
|
||||
|
||||
def generate_login_token(self) -> str:
|
||||
"""Generates an opaque string, for use as an short-term login token"""
|
||||
|
||||
# we use the following format for access tokens:
|
||||
# syl_<random string>_<base62 crc check>
|
||||
|
||||
random_string = stringutils.random_string(20)
|
||||
base = f"syl_{random_string}"
|
||||
|
||||
crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
|
||||
return f"{base}_{crc}"
|
||||
|
||||
def generate_access_token(self, for_user: UserID) -> str:
|
||||
"""Generates an opaque string, for use as an access token"""
|
||||
|
||||
|
@ -1427,16 +1470,17 @@ class AuthHandler:
|
|||
crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
|
||||
return f"{base}_{crc}"
|
||||
|
||||
async def validate_short_term_login_token(
|
||||
self, login_token: str
|
||||
) -> LoginTokenAttributes:
|
||||
async def consume_login_token(self, login_token: str) -> LoginTokenLookupResult:
|
||||
try:
|
||||
res = self.macaroon_gen.verify_short_term_login_token(login_token)
|
||||
except Exception:
|
||||
raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN)
|
||||
return await self.store.consume_login_token(login_token)
|
||||
except LoginTokenExpired:
|
||||
invalid_login_token_counter.labels("expired").inc()
|
||||
except LoginTokenReused:
|
||||
invalid_login_token_counter.labels("reused").inc()
|
||||
except NotFoundError:
|
||||
invalid_login_token_counter.labels("not found").inc()
|
||||
|
||||
await self.auth_blocking.check_auth_blocking(res.user_id)
|
||||
return res
|
||||
raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN)
|
||||
|
||||
async def delete_access_token(self, access_token: str) -> None:
|
||||
"""Invalidate a single access token
|
||||
|
@ -1711,7 +1755,7 @@ class AuthHandler:
|
|||
)
|
||||
|
||||
# Create a login token
|
||||
login_token = self.macaroon_gen.generate_short_term_login_token(
|
||||
login_token = await self.create_login_token_for_user_id(
|
||||
registered_user_id,
|
||||
auth_provider_id=auth_provider_id,
|
||||
auth_provider_session_id=auth_provider_session_id,
|
||||
|
|
|
@ -49,6 +49,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class E2eKeysHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.config = hs.config
|
||||
self.store = hs.get_datastores().main
|
||||
self.federation = hs.get_federation_client()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
@ -431,13 +432,17 @@ class E2eKeysHandler:
|
|||
@trace
|
||||
@cancellable
|
||||
async def query_local_devices(
|
||||
self, query: Mapping[str, Optional[List[str]]]
|
||||
self,
|
||||
query: Mapping[str, Optional[List[str]]],
|
||||
include_displaynames: bool = True,
|
||||
) -> Dict[str, Dict[str, dict]]:
|
||||
"""Get E2E device keys for local users
|
||||
|
||||
Args:
|
||||
query: map from user_id to a list
|
||||
of devices to query (None for all devices)
|
||||
include_displaynames: Whether to include device displaynames in the returned
|
||||
device details.
|
||||
|
||||
Returns:
|
||||
A map from user_id -> device_id -> device details
|
||||
|
@ -469,7 +474,9 @@ class E2eKeysHandler:
|
|||
# make sure that each queried user appears in the result dict
|
||||
result_dict[user_id] = {}
|
||||
|
||||
results = await self.store.get_e2e_device_keys_for_cs_api(local_query)
|
||||
results = await self.store.get_e2e_device_keys_for_cs_api(
|
||||
local_query, include_displaynames
|
||||
)
|
||||
|
||||
# Build the result structure
|
||||
for user_id, device_keys in results.items():
|
||||
|
@ -482,11 +489,33 @@ class E2eKeysHandler:
|
|||
async def on_federation_query_client_keys(
|
||||
self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
|
||||
) -> JsonDict:
|
||||
"""Handle a device key query from a federated server"""
|
||||
"""Handle a device key query from a federated server:
|
||||
|
||||
Handles the path: GET /_matrix/federation/v1/users/keys/query
|
||||
|
||||
Args:
|
||||
query_body: The body of the query request. Should contain a key
|
||||
"device_keys" that map to a dictionary of user ID's -> list of
|
||||
device IDs. If the list of device IDs is empty, all devices of
|
||||
that user will be queried.
|
||||
|
||||
Returns:
|
||||
A json dictionary containing the following:
|
||||
- device_keys: A dictionary containing the requested device information.
|
||||
- master_keys: An optional dictionary of user ID -> master cross-signing
|
||||
key info.
|
||||
- self_signing_key: An optional dictionary of user ID -> self-signing
|
||||
key info.
|
||||
"""
|
||||
device_keys_query: Dict[str, Optional[List[str]]] = query_body.get(
|
||||
"device_keys", {}
|
||||
)
|
||||
res = await self.query_local_devices(device_keys_query)
|
||||
res = await self.query_local_devices(
|
||||
device_keys_query,
|
||||
include_displaynames=(
|
||||
self.config.federation.allow_device_name_lookup_over_federation
|
||||
),
|
||||
)
|
||||
ret = {"device_keys": res}
|
||||
|
||||
# add in the cross-signing keys
|
||||
|
|
|
@ -442,6 +442,15 @@ class FederationHandler:
|
|||
# appropriate stuff.
|
||||
# TODO: We can probably do something more intelligent here.
|
||||
return True
|
||||
except NotRetryingDestination as e:
|
||||
logger.info("_maybe_backfill_inner: %s", e)
|
||||
continue
|
||||
except FederationDeniedError:
|
||||
logger.info(
|
||||
"_maybe_backfill_inner: Not attempting to backfill from %s because the homeserver is not on our federation whitelist",
|
||||
dom,
|
||||
)
|
||||
continue
|
||||
except (SynapseError, InvalidResponseError) as e:
|
||||
logger.info("Failed to backfill from %s because %s", dom, e)
|
||||
continue
|
||||
|
@ -477,15 +486,9 @@ class FederationHandler:
|
|||
|
||||
logger.info("Failed to backfill from %s because %s", dom, e)
|
||||
continue
|
||||
except NotRetryingDestination as e:
|
||||
logger.info(str(e))
|
||||
continue
|
||||
except RequestSendFailed as e:
|
||||
logger.info("Failed to get backfill from %s because %s", dom, e)
|
||||
continue
|
||||
except FederationDeniedError as e:
|
||||
logger.info(e)
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.exception("Failed to backfill from %s because %s", dom, e)
|
||||
continue
|
||||
|
@ -1017,7 +1020,9 @@ class FederationHandler:
|
|||
|
||||
context = EventContext.for_outlier(self._storage_controllers)
|
||||
|
||||
await self._bulk_push_rule_evaluator.action_for_event_by_user(event, context)
|
||||
await self._bulk_push_rule_evaluator.action_for_events_by_user(
|
||||
[(event, context)]
|
||||
)
|
||||
try:
|
||||
await self._federation_event_handler.persist_events_and_notify(
|
||||
event.room_id, [(event, context)]
|
||||
|
|
|
@ -58,7 +58,7 @@ from synapse.event_auth import (
|
|||
)
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.federation.federation_client import InvalidResponseError
|
||||
from synapse.federation.federation_client import InvalidResponseError, PulledPduInfo
|
||||
from synapse.logging.context import nested_logging_context
|
||||
from synapse.logging.opentracing import (
|
||||
SynapseTags,
|
||||
|
@ -1517,8 +1517,8 @@ class FederationEventHandler:
|
|||
)
|
||||
|
||||
async def backfill_event_id(
|
||||
self, destination: str, room_id: str, event_id: str
|
||||
) -> EventBase:
|
||||
self, destinations: List[str], room_id: str, event_id: str
|
||||
) -> PulledPduInfo:
|
||||
"""Backfill a single event and persist it as a non-outlier which means
|
||||
we also pull in all of the state and auth events necessary for it.
|
||||
|
||||
|
@ -1530,24 +1530,21 @@ class FederationEventHandler:
|
|||
Raises:
|
||||
FederationError if we are unable to find the event from the destination
|
||||
"""
|
||||
logger.info(
|
||||
"backfill_event_id: event_id=%s from destination=%s", event_id, destination
|
||||
)
|
||||
logger.info("backfill_event_id: event_id=%s", event_id)
|
||||
|
||||
room_version = await self._store.get_room_version(room_id)
|
||||
|
||||
event_from_response = await self._federation_client.get_pdu(
|
||||
[destination],
|
||||
pulled_pdu_info = await self._federation_client.get_pdu(
|
||||
destinations,
|
||||
event_id,
|
||||
room_version,
|
||||
)
|
||||
|
||||
if not event_from_response:
|
||||
if not pulled_pdu_info:
|
||||
raise FederationError(
|
||||
"ERROR",
|
||||
404,
|
||||
"Unable to find event_id=%s from destination=%s to backfill."
|
||||
% (event_id, destination),
|
||||
f"Unable to find event_id={event_id} from remote servers to backfill.",
|
||||
affected=event_id,
|
||||
)
|
||||
|
||||
|
@ -1555,13 +1552,13 @@ class FederationEventHandler:
|
|||
# and auth events to de-outlier it. This also sets up the necessary
|
||||
# `state_groups` for the event.
|
||||
await self._process_pulled_events(
|
||||
destination,
|
||||
[event_from_response],
|
||||
pulled_pdu_info.pull_origin,
|
||||
[pulled_pdu_info.pdu],
|
||||
# Prevent notifications going to clients
|
||||
backfilled=True,
|
||||
)
|
||||
|
||||
return event_from_response
|
||||
return pulled_pdu_info
|
||||
|
||||
@trace
|
||||
@tag_args
|
||||
|
@ -1584,19 +1581,19 @@ class FederationEventHandler:
|
|||
async def get_event(event_id: str) -> None:
|
||||
with nested_logging_context(event_id):
|
||||
try:
|
||||
event = await self._federation_client.get_pdu(
|
||||
pulled_pdu_info = await self._federation_client.get_pdu(
|
||||
[destination],
|
||||
event_id,
|
||||
room_version,
|
||||
)
|
||||
if event is None:
|
||||
if pulled_pdu_info is None:
|
||||
logger.warning(
|
||||
"Server %s didn't return event %s",
|
||||
destination,
|
||||
event_id,
|
||||
)
|
||||
return
|
||||
events.append(event)
|
||||
events.append(pulled_pdu_info.pdu)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
|
@ -2171,8 +2168,8 @@ class FederationEventHandler:
|
|||
min_depth,
|
||||
)
|
||||
else:
|
||||
await self._bulk_push_rule_evaluator.action_for_event_by_user(
|
||||
event, context
|
||||
await self._bulk_push_rule_evaluator.action_for_events_by_user(
|
||||
[(event, context)]
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
|
@ -1442,17 +1442,9 @@ class EventCreationHandler:
|
|||
a room that has been un-partial stated.
|
||||
"""
|
||||
|
||||
for event, context in events_and_context:
|
||||
# Skip push notification actions for historical messages
|
||||
# because we don't want to notify people about old history back in time.
|
||||
# The historical messages also do not have the proper `context.current_state_ids`
|
||||
# and `state_groups` because they have `prev_events` that aren't persisted yet
|
||||
# (historical messages persisted in reverse-chronological order).
|
||||
if not event.internal_metadata.is_historical() and not event.content.get(EventContentFields.MSC2716_HISTORICAL):
|
||||
with opentracing.start_active_span("calculate_push_actions"):
|
||||
await self._bulk_push_rule_evaluator.action_for_event_by_user(
|
||||
event, context
|
||||
)
|
||||
await self._bulk_push_rule_evaluator.action_for_events_by_user(
|
||||
events_and_context
|
||||
)
|
||||
|
||||
try:
|
||||
# If we're a worker we need to hit out to the master.
|
||||
|
|
|
@ -12,14 +12,28 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import binascii
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
import attr
|
||||
import unpaddedbase64
|
||||
from authlib.common.security import generate_token
|
||||
from authlib.jose import JsonWebToken, jwt
|
||||
from authlib.jose import JsonWebToken, JWTClaims
|
||||
from authlib.jose.errors import InvalidClaimError, JoseError, MissingClaimError
|
||||
from authlib.oauth2.auth import ClientAuth
|
||||
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
|
||||
from authlib.oidc.core import CodeIDToken, UserInfo
|
||||
|
@ -35,9 +49,12 @@ from typing_extensions import TypedDict
|
|||
from twisted.web.client import readBody
|
||||
from twisted.web.http_headers import Headers
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.config import ConfigError
|
||||
from synapse.config.oidc import OidcProviderClientSecretJwtKey, OidcProviderConfig
|
||||
from synapse.handlers.sso import MappingException, UserAttributes
|
||||
from synapse.http.server import finish_request
|
||||
from synapse.http.servlet import parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
|
||||
|
@ -88,6 +105,8 @@ class Token(TypedDict):
|
|||
#: there is no real point of doing this in our case.
|
||||
JWK = Dict[str, str]
|
||||
|
||||
C = TypeVar("C")
|
||||
|
||||
|
||||
#: A JWK Set, as per RFC7517 sec 5.
|
||||
class JWKS(TypedDict):
|
||||
|
@ -247,6 +266,80 @@ class OidcHandler:
|
|||
|
||||
await oidc_provider.handle_oidc_callback(request, session_data, code)
|
||||
|
||||
async def handle_backchannel_logout(self, request: SynapseRequest) -> None:
|
||||
"""Handle an incoming request to /_synapse/client/oidc/backchannel_logout
|
||||
|
||||
This extracts the logout_token from the request and tries to figure out
|
||||
which OpenID Provider it is comming from. This works by matching the iss claim
|
||||
with the issuer and the aud claim with the client_id.
|
||||
|
||||
Since at this point we don't know who signed the JWT, we can't just
|
||||
decode it using authlib since it will always verifies the signature. We
|
||||
have to decode it manually without validating the signature. The actual JWT
|
||||
verification is done in the `OidcProvider.handler_backchannel_logout` method,
|
||||
once we figured out which provider sent the request.
|
||||
|
||||
Args:
|
||||
request: the incoming request from the browser.
|
||||
"""
|
||||
logout_token = parse_string(request, "logout_token")
|
||||
if logout_token is None:
|
||||
raise SynapseError(400, "Missing logout_token in request")
|
||||
|
||||
# A JWT looks like this:
|
||||
# header.payload.signature
|
||||
# where all parts are encoded with urlsafe base64.
|
||||
# The aud and iss claims we care about are in the payload part, which
|
||||
# is a JSON object.
|
||||
try:
|
||||
# By destructuring the list after splitting, we ensure that we have
|
||||
# exactly 3 segments
|
||||
_, payload, _ = logout_token.split(".")
|
||||
except ValueError:
|
||||
raise SynapseError(400, "Invalid logout_token in request")
|
||||
|
||||
try:
|
||||
payload_bytes = unpaddedbase64.decode_base64(payload)
|
||||
claims = json_decoder.decode(payload_bytes.decode("utf-8"))
|
||||
except (json.JSONDecodeError, binascii.Error, UnicodeError):
|
||||
raise SynapseError(400, "Invalid logout_token payload in request")
|
||||
|
||||
try:
|
||||
# Let's extract the iss and aud claims
|
||||
iss = claims["iss"]
|
||||
aud = claims["aud"]
|
||||
# The aud claim can be either a string or a list of string. Here we
|
||||
# normalize it as a list of strings.
|
||||
if isinstance(aud, str):
|
||||
aud = [aud]
|
||||
|
||||
# Check that we have the right types for the aud and the iss claims
|
||||
if not isinstance(iss, str) or not isinstance(aud, list):
|
||||
raise TypeError()
|
||||
for a in aud:
|
||||
if not isinstance(a, str):
|
||||
raise TypeError()
|
||||
|
||||
# At this point we properly checked both claims types
|
||||
issuer: str = iss
|
||||
audience: List[str] = aud
|
||||
except (TypeError, KeyError):
|
||||
raise SynapseError(400, "Invalid issuer/audience in logout_token")
|
||||
|
||||
# Now that we know the audience and the issuer, we can figure out from
|
||||
# what provider it is coming from
|
||||
oidc_provider: Optional[OidcProvider] = None
|
||||
for provider in self._providers.values():
|
||||
if provider.issuer == issuer and provider.client_id in audience:
|
||||
oidc_provider = provider
|
||||
break
|
||||
|
||||
if oidc_provider is None:
|
||||
raise SynapseError(400, "Could not find the OP that issued this event")
|
||||
|
||||
# Ask the provider to handle the logout request.
|
||||
await oidc_provider.handle_backchannel_logout(request, logout_token)
|
||||
|
||||
|
||||
class OidcError(Exception):
|
||||
"""Used to catch errors when calling the token_endpoint"""
|
||||
|
@ -275,6 +368,7 @@ class OidcProvider:
|
|||
provider: OidcProviderConfig,
|
||||
):
|
||||
self._store = hs.get_datastores().main
|
||||
self._clock = hs.get_clock()
|
||||
|
||||
self._macaroon_generaton = macaroon_generator
|
||||
|
||||
|
@ -341,6 +435,7 @@ class OidcProvider:
|
|||
self.idp_brand = provider.idp_brand
|
||||
|
||||
self._sso_handler = hs.get_sso_handler()
|
||||
self._device_handler = hs.get_device_handler()
|
||||
|
||||
self._sso_handler.register_identity_provider(self)
|
||||
|
||||
|
@ -399,6 +494,41 @@ class OidcProvider:
|
|||
# If we're not using userinfo, we need a valid jwks to validate the ID token
|
||||
m.validate_jwks_uri()
|
||||
|
||||
if self._config.backchannel_logout_enabled:
|
||||
if not m.get("backchannel_logout_supported", False):
|
||||
logger.warning(
|
||||
"OIDC Back-Channel Logout is enabled for issuer %r"
|
||||
"but it does not advertise support for it",
|
||||
self.issuer,
|
||||
)
|
||||
|
||||
elif not m.get("backchannel_logout_session_supported", False):
|
||||
logger.warning(
|
||||
"OIDC Back-Channel Logout is enabled and supported "
|
||||
"by issuer %r but it might not send a session ID with "
|
||||
"logout tokens, which is required for the logouts to work",
|
||||
self.issuer,
|
||||
)
|
||||
|
||||
if not self._config.backchannel_logout_ignore_sub:
|
||||
# If OIDC backchannel logouts are enabled, the provider mapping provider
|
||||
# should use the `sub` claim. We verify that by mapping a dumb user and
|
||||
# see if we get back the sub claim
|
||||
user = UserInfo({"sub": "thisisasubject"})
|
||||
try:
|
||||
subject = self._user_mapping_provider.get_remote_user_id(user)
|
||||
if subject != user["sub"]:
|
||||
raise ValueError("Unexpected subject")
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"OIDC Back-Channel Logout is enabled for issuer {self.issuer!r} "
|
||||
"but it looks like the configured `user_mapping_provider` "
|
||||
"does not use the `sub` claim as subject. If it is the case, "
|
||||
"and you want Synapse to ignore the `sub` claim in OIDC "
|
||||
"Back-Channel Logouts, set `backchannel_logout_ignore_sub` "
|
||||
"to `true` in the issuer config."
|
||||
)
|
||||
|
||||
@property
|
||||
def _uses_userinfo(self) -> bool:
|
||||
"""Returns True if the ``userinfo_endpoint`` should be used.
|
||||
|
@ -414,6 +544,16 @@ class OidcProvider:
|
|||
or self._user_profile_method == "userinfo_endpoint"
|
||||
)
|
||||
|
||||
@property
|
||||
def issuer(self) -> str:
|
||||
"""The issuer identifying this provider."""
|
||||
return self._config.issuer
|
||||
|
||||
@property
|
||||
def client_id(self) -> str:
|
||||
"""The client_id used when interacting with this provider."""
|
||||
return self._config.client_id
|
||||
|
||||
async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata:
|
||||
"""Return the provider metadata.
|
||||
|
||||
|
@ -661,6 +801,59 @@ class OidcProvider:
|
|||
|
||||
return UserInfo(resp)
|
||||
|
||||
async def _verify_jwt(
|
||||
self,
|
||||
alg_values: List[str],
|
||||
token: str,
|
||||
claims_cls: Type[C],
|
||||
claims_options: Optional[dict] = None,
|
||||
claims_params: Optional[dict] = None,
|
||||
) -> C:
|
||||
"""Decode and validate a JWT, re-fetching the JWKS as needed.
|
||||
|
||||
Args:
|
||||
alg_values: list of `alg` values allowed when verifying the JWT.
|
||||
token: the JWT.
|
||||
claims_cls: the JWTClaims class to use to validate the claims.
|
||||
claims_options: dict of options passed to the `claims_cls` constructor.
|
||||
claims_params: dict of params passed to the `claims_cls` constructor.
|
||||
|
||||
Returns:
|
||||
The decoded claims in the JWT.
|
||||
"""
|
||||
jwt = JsonWebToken(alg_values)
|
||||
|
||||
logger.debug("Attempting to decode JWT (%s) %r", claims_cls.__name__, token)
|
||||
|
||||
# Try to decode the keys in cache first, then retry by forcing the keys
|
||||
# to be reloaded
|
||||
jwk_set = await self.load_jwks()
|
||||
try:
|
||||
claims = jwt.decode(
|
||||
token,
|
||||
key=jwk_set,
|
||||
claims_cls=claims_cls,
|
||||
claims_options=claims_options,
|
||||
claims_params=claims_params,
|
||||
)
|
||||
except ValueError:
|
||||
logger.info("Reloading JWKS after decode error")
|
||||
jwk_set = await self.load_jwks(force=True) # try reloading the jwks
|
||||
claims = jwt.decode(
|
||||
token,
|
||||
key=jwk_set,
|
||||
claims_cls=claims_cls,
|
||||
claims_options=claims_options,
|
||||
claims_params=claims_params,
|
||||
)
|
||||
|
||||
logger.debug("Decoded JWT (%s) %r; validating", claims_cls.__name__, claims)
|
||||
|
||||
claims.validate(
|
||||
now=self._clock.time(), leeway=120
|
||||
) # allows 2 min of clock skew
|
||||
return claims
|
||||
|
||||
async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
|
||||
"""Return an instance of UserInfo from token's ``id_token``.
|
||||
|
||||
|
@ -673,7 +866,14 @@ class OidcProvider:
|
|||
Returns:
|
||||
The decoded claims in the ID token.
|
||||
"""
|
||||
id_token = token.get("id_token")
|
||||
|
||||
# That has been theoritically been checked by the caller, so even though
|
||||
# assertion are not enabled in production, it is mainly here to appease mypy
|
||||
assert id_token is not None
|
||||
|
||||
metadata = await self.load_metadata()
|
||||
|
||||
claims_params = {
|
||||
"nonce": nonce,
|
||||
"client_id": self._client_auth.client_id,
|
||||
|
@ -683,39 +883,17 @@ class OidcProvider:
|
|||
# in the `id_token` that we can check against.
|
||||
claims_params["access_token"] = token["access_token"]
|
||||
|
||||
claims_options = {"iss": {"values": [metadata["issuer"]]}}
|
||||
|
||||
alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
|
||||
jwt = JsonWebToken(alg_values)
|
||||
|
||||
claim_options = {"iss": {"values": [metadata["issuer"]]}}
|
||||
|
||||
id_token = token["id_token"]
|
||||
logger.debug("Attempting to decode JWT id_token %r", id_token)
|
||||
|
||||
# Try to decode the keys in cache first, then retry by forcing the keys
|
||||
# to be reloaded
|
||||
jwk_set = await self.load_jwks()
|
||||
try:
|
||||
claims = jwt.decode(
|
||||
id_token,
|
||||
key=jwk_set,
|
||||
claims_cls=CodeIDToken,
|
||||
claims_options=claim_options,
|
||||
claims_params=claims_params,
|
||||
)
|
||||
except ValueError:
|
||||
logger.info("Reloading JWKS after decode error")
|
||||
jwk_set = await self.load_jwks(force=True) # try reloading the jwks
|
||||
claims = jwt.decode(
|
||||
id_token,
|
||||
key=jwk_set,
|
||||
claims_cls=CodeIDToken,
|
||||
claims_options=claim_options,
|
||||
claims_params=claims_params,
|
||||
)
|
||||
|
||||
logger.debug("Decoded id_token JWT %r; validating", claims)
|
||||
|
||||
claims.validate(leeway=120) # allows 2 min of clock skew
|
||||
claims = await self._verify_jwt(
|
||||
alg_values=alg_values,
|
||||
token=id_token,
|
||||
claims_cls=CodeIDToken,
|
||||
claims_options=claims_options,
|
||||
claims_params=claims_params,
|
||||
)
|
||||
|
||||
return claims
|
||||
|
||||
|
@ -1036,6 +1214,146 @@ class OidcProvider:
|
|||
# to be strings.
|
||||
return str(remote_user_id)
|
||||
|
||||
async def handle_backchannel_logout(
|
||||
self, request: SynapseRequest, logout_token: str
|
||||
) -> None:
|
||||
"""Handle an incoming request to /_synapse/client/oidc/backchannel_logout
|
||||
|
||||
The OIDC Provider posts a logout token to this endpoint when a user
|
||||
session ends. That token is a JWT signed with the same keys as
|
||||
ID tokens. The OpenID Connect Back-Channel Logout draft explains how to
|
||||
validate the JWT and figure out what session to end.
|
||||
|
||||
Args:
|
||||
request: The request to respond to
|
||||
logout_token: The logout token (a JWT) extracted from the request body
|
||||
"""
|
||||
# Back-Channel Logout can be disabled in the config, hence this check.
|
||||
# This is not that important for now since Synapse is registered
|
||||
# manually to the OP, so not specifying the backchannel-logout URI is
|
||||
# as effective than disabling it here. It might make more sense if we
|
||||
# support dynamic registration in Synapse at some point.
|
||||
if not self._config.backchannel_logout_enabled:
|
||||
logger.warning(
|
||||
f"Received an OIDC Back-Channel Logout request from issuer {self.issuer!r} but it is disabled in config"
|
||||
)
|
||||
|
||||
# TODO: this responds with a 400 status code, which is what the OIDC
|
||||
# Back-Channel Logout spec expects, but spec also suggests answering with
|
||||
# a JSON object, with the `error` and `error_description` fields set, which
|
||||
# we are not doing here.
|
||||
# See https://openid.net/specs/openid-connect-backchannel-1_0.html#BCResponse
|
||||
raise SynapseError(
|
||||
400, "OpenID Connect Back-Channel Logout is disabled for this provider"
|
||||
)
|
||||
|
||||
metadata = await self.load_metadata()
|
||||
|
||||
# As per OIDC Back-Channel Logout 1.0 sec. 2.4:
|
||||
# A Logout Token MUST be signed and MAY also be encrypted. The same
|
||||
# keys are used to sign and encrypt Logout Tokens as are used for ID
|
||||
# Tokens. If the Logout Token is encrypted, it SHOULD replicate the
|
||||
# iss (issuer) claim in the JWT Header Parameters, as specified in
|
||||
# Section 5.3 of [JWT].
|
||||
alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
|
||||
|
||||
# As per sec. 2.6:
|
||||
# 3. Validate the iss, aud, and iat Claims in the same way they are
|
||||
# validated in ID Tokens.
|
||||
# Which means the audience should contain Synapse's client_id and the
|
||||
# issuer should be the IdP issuer
|
||||
claims_options = {
|
||||
"iss": {"values": [metadata["issuer"]]},
|
||||
"aud": {"values": [self.client_id]},
|
||||
}
|
||||
|
||||
try:
|
||||
claims = await self._verify_jwt(
|
||||
alg_values=alg_values,
|
||||
token=logout_token,
|
||||
claims_cls=LogoutToken,
|
||||
claims_options=claims_options,
|
||||
)
|
||||
except JoseError:
|
||||
logger.exception("Invalid logout_token")
|
||||
raise SynapseError(400, "Invalid logout_token")
|
||||
|
||||
# As per sec. 2.6:
|
||||
# 4. Verify that the Logout Token contains a sub Claim, a sid Claim,
|
||||
# or both.
|
||||
# 5. Verify that the Logout Token contains an events Claim whose
|
||||
# value is JSON object containing the member name
|
||||
# http://schemas.openid.net/event/backchannel-logout.
|
||||
# 6. Verify that the Logout Token does not contain a nonce Claim.
|
||||
# This is all verified by the LogoutToken claims class, so at this
|
||||
# point the `sid` claim exists and is a string.
|
||||
sid: str = claims.get("sid")
|
||||
|
||||
# If the `sub` claim was included in the logout token, we check that it matches
|
||||
# that it matches the right user. We can have cases where the `sub` claim is not
|
||||
# the ID saved in database, so we let admins disable this check in config.
|
||||
sub: Optional[str] = claims.get("sub")
|
||||
expected_user_id: Optional[str] = None
|
||||
if sub is not None and not self._config.backchannel_logout_ignore_sub:
|
||||
expected_user_id = await self._store.get_user_by_external_id(
|
||||
self.idp_id, sub
|
||||
)
|
||||
|
||||
# Invalidate any running user-mapping sessions, in-flight login tokens and
|
||||
# active devices
|
||||
await self._sso_handler.revoke_sessions_for_provider_session_id(
|
||||
auth_provider_id=self.idp_id,
|
||||
auth_provider_session_id=sid,
|
||||
expected_user_id=expected_user_id,
|
||||
)
|
||||
|
||||
request.setResponseCode(200)
|
||||
request.setHeader(b"Cache-Control", b"no-cache, no-store")
|
||||
request.setHeader(b"Pragma", b"no-cache")
|
||||
finish_request(request)
|
||||
|
||||
|
||||
class LogoutToken(JWTClaims):
|
||||
"""
|
||||
Holds and verify claims of a logout token, as per
|
||||
https://openid.net/specs/openid-connect-backchannel-1_0.html#LogoutToken
|
||||
"""
|
||||
|
||||
REGISTERED_CLAIMS = ["iss", "sub", "aud", "iat", "jti", "events", "sid"]
|
||||
|
||||
def validate(self, now: Optional[int] = None, leeway: int = 0) -> None:
|
||||
"""Validate everything in claims payload."""
|
||||
super().validate(now, leeway)
|
||||
self.validate_sid()
|
||||
self.validate_events()
|
||||
self.validate_nonce()
|
||||
|
||||
def validate_sid(self) -> None:
|
||||
"""Ensure the sid claim is present"""
|
||||
sid = self.get("sid")
|
||||
if not sid:
|
||||
raise MissingClaimError("sid")
|
||||
|
||||
if not isinstance(sid, str):
|
||||
raise InvalidClaimError("sid")
|
||||
|
||||
def validate_nonce(self) -> None:
|
||||
"""Ensure the nonce claim is absent"""
|
||||
if "nonce" in self:
|
||||
raise InvalidClaimError("nonce")
|
||||
|
||||
def validate_events(self) -> None:
|
||||
"""Ensure the events claim is present and with the right value"""
|
||||
events = self.get("events")
|
||||
if not events:
|
||||
raise MissingClaimError("events")
|
||||
|
||||
if not isinstance(events, dict):
|
||||
raise InvalidClaimError("events")
|
||||
|
||||
if "http://schemas.openid.net/event/backchannel-logout" not in events:
|
||||
raise InvalidClaimError("events")
|
||||
|
||||
|
||||
# number of seconds a newly-generated client secret should be valid for
|
||||
CLIENT_SECRET_VALIDITY_SECONDS = 3600
|
||||
|
@ -1105,6 +1423,7 @@ class JwtClientSecret:
|
|||
logger.info(
|
||||
"Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload
|
||||
)
|
||||
jwt = JsonWebToken(header["alg"])
|
||||
self._cached_secret = jwt.encode(header, payload, self._key.key)
|
||||
self._cached_secret_replacement_time = (
|
||||
expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS
|
||||
|
@ -1119,9 +1438,6 @@ class UserAttributeDict(TypedDict):
|
|||
emails: List[str]
|
||||
|
||||
|
||||
C = TypeVar("C")
|
||||
|
||||
|
||||
class OidcMappingProvider(Generic[C]):
|
||||
"""A mapping provider maps a UserInfo object to user attributes.
|
||||
|
||||
|
|
|
@ -307,7 +307,11 @@ class ProfileHandler:
|
|||
if not self.max_avatar_size and not self.allowed_avatar_mimetypes:
|
||||
return True
|
||||
|
||||
server_name, _, media_id = parse_and_validate_mxc_uri(mxc)
|
||||
host, port, media_id = parse_and_validate_mxc_uri(mxc)
|
||||
if port is not None:
|
||||
server_name = host + ":" + str(port)
|
||||
else:
|
||||
server_name = host
|
||||
|
||||
if server_name == self.server_name:
|
||||
media_info = await self.store.get_local_media(media_id)
|
||||
|
|
|
@ -49,7 +49,6 @@ from synapse.api.constants import (
|
|||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
Codes,
|
||||
HttpResponseException,
|
||||
LimitExceededError,
|
||||
NotFoundError,
|
||||
StoreError,
|
||||
|
@ -60,7 +59,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
|||
from synapse.event_auth import validate_event_for_room_version
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.utils import copy_and_fixup_power_levels_contents
|
||||
from synapse.federation.federation_client import InvalidResponseError
|
||||
from synapse.handlers.relations import BundledAggregations
|
||||
from synapse.module_api import NOT_SPAM
|
||||
from synapse.rest.admin._base import assert_user_is_admin
|
||||
|
@ -1070,9 +1068,6 @@ class RoomCreationHandler:
|
|||
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
|
||||
depth = 1
|
||||
|
||||
# the last event sent/persisted to the db
|
||||
last_sent_event_id: Optional[str] = None
|
||||
|
||||
# the most recently created event
|
||||
prev_event: List[str] = []
|
||||
# a map of event types, state keys -> event_ids. We collect these mappings this as events are
|
||||
|
@ -1117,26 +1112,6 @@ class RoomCreationHandler:
|
|||
|
||||
return new_event, new_context
|
||||
|
||||
async def send(
|
||||
event: EventBase,
|
||||
context: synapse.events.snapshot.EventContext,
|
||||
creator: Requester,
|
||||
) -> int:
|
||||
nonlocal last_sent_event_id
|
||||
|
||||
ev = await self.event_creation_handler.handle_new_client_event(
|
||||
requester=creator,
|
||||
events_and_context=[(event, context)],
|
||||
ratelimit=False,
|
||||
ignore_shadow_ban=True,
|
||||
)
|
||||
|
||||
last_sent_event_id = ev.event_id
|
||||
|
||||
# we know it was persisted, so must have a stream ordering
|
||||
assert ev.internal_metadata.stream_ordering
|
||||
return ev.internal_metadata.stream_ordering
|
||||
|
||||
try:
|
||||
config = self._presets_dict[preset_config]
|
||||
except KeyError:
|
||||
|
@ -1150,10 +1125,14 @@ class RoomCreationHandler:
|
|||
)
|
||||
|
||||
logger.debug("Sending %s in new room", EventTypes.Member)
|
||||
await send(creation_event, creation_context, creator)
|
||||
ev = await self.event_creation_handler.handle_new_client_event(
|
||||
requester=creator,
|
||||
events_and_context=[(creation_event, creation_context)],
|
||||
ratelimit=False,
|
||||
ignore_shadow_ban=True,
|
||||
)
|
||||
last_sent_event_id = ev.event_id
|
||||
|
||||
# Room create event must exist at this point
|
||||
assert last_sent_event_id is not None
|
||||
member_event_id, _ = await self.room_member_handler.update_membership(
|
||||
creator,
|
||||
creator.user,
|
||||
|
@ -1172,6 +1151,7 @@ class RoomCreationHandler:
|
|||
depth += 1
|
||||
state_map[(EventTypes.Member, creator.user.to_string())] = member_event_id
|
||||
|
||||
events_to_send = []
|
||||
# We treat the power levels override specially as this needs to be one
|
||||
# of the first events that get sent into a room.
|
||||
pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None)
|
||||
|
@ -1180,7 +1160,7 @@ class RoomCreationHandler:
|
|||
EventTypes.PowerLevels, pl_content, False
|
||||
)
|
||||
current_state_group = power_context._state_group
|
||||
await send(power_event, power_context, creator)
|
||||
events_to_send.append((power_event, power_context))
|
||||
else:
|
||||
power_level_content: JsonDict = {
|
||||
"users": {creator_id: 9001},
|
||||
|
@ -1229,9 +1209,8 @@ class RoomCreationHandler:
|
|||
False,
|
||||
)
|
||||
current_state_group = pl_context._state_group
|
||||
await send(pl_event, pl_context, creator)
|
||||
events_to_send.append((pl_event, pl_context))
|
||||
|
||||
events_to_send = []
|
||||
if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
|
||||
room_alias_event, room_alias_context = await create_event(
|
||||
EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True
|
||||
|
@ -1509,7 +1488,12 @@ class TimestampLookupHandler:
|
|||
Raises:
|
||||
SynapseError if unable to find any event locally in the given direction
|
||||
"""
|
||||
|
||||
logger.debug(
|
||||
"get_event_for_timestamp(room_id=%s, timestamp=%s, direction=%s) Finding closest event...",
|
||||
room_id,
|
||||
timestamp,
|
||||
direction,
|
||||
)
|
||||
local_event_id = await self.store.get_event_id_for_timestamp(
|
||||
room_id, timestamp, direction
|
||||
)
|
||||
|
@ -1561,85 +1545,54 @@ class TimestampLookupHandler:
|
|||
)
|
||||
)
|
||||
|
||||
# Loop through each homeserver candidate until we get a succesful response
|
||||
for domain in likely_domains:
|
||||
# We don't want to ask our own server for information we don't have
|
||||
if domain == self.server_name:
|
||||
continue
|
||||
remote_response = await self.federation_client.timestamp_to_event(
|
||||
destinations=likely_domains,
|
||||
room_id=room_id,
|
||||
timestamp=timestamp,
|
||||
direction=direction,
|
||||
)
|
||||
if remote_response is not None:
|
||||
logger.debug(
|
||||
"get_event_for_timestamp: remote_response=%s",
|
||||
remote_response,
|
||||
)
|
||||
|
||||
try:
|
||||
remote_response = await self.federation_client.timestamp_to_event(
|
||||
domain, room_id, timestamp, direction
|
||||
)
|
||||
logger.debug(
|
||||
"get_event_for_timestamp: response from domain(%s)=%s",
|
||||
domain,
|
||||
remote_response,
|
||||
remote_event_id = remote_response.event_id
|
||||
remote_origin_server_ts = remote_response.origin_server_ts
|
||||
|
||||
# Backfill this event so we can get a pagination token for
|
||||
# it with `/context` and paginate `/messages` from this
|
||||
# point.
|
||||
pulled_pdu_info = await self.federation_event_handler.backfill_event_id(
|
||||
likely_domains, room_id, remote_event_id
|
||||
)
|
||||
remote_event = pulled_pdu_info.pdu
|
||||
|
||||
# XXX: When we see that the remote server is not trustworthy,
|
||||
# maybe we should not ask them first in the future.
|
||||
if remote_origin_server_ts != remote_event.origin_server_ts:
|
||||
logger.info(
|
||||
"get_event_for_timestamp: Remote server (%s) claimed that remote_event_id=%s occured at remote_origin_server_ts=%s but that isn't true (actually occured at %s). Their claims are dubious and we should consider not trusting them.",
|
||||
pulled_pdu_info.pull_origin,
|
||||
remote_event_id,
|
||||
remote_origin_server_ts,
|
||||
remote_event.origin_server_ts,
|
||||
)
|
||||
|
||||
remote_event_id = remote_response.event_id
|
||||
remote_origin_server_ts = remote_response.origin_server_ts
|
||||
|
||||
# Backfill this event so we can get a pagination token for
|
||||
# it with `/context` and paginate `/messages` from this
|
||||
# point.
|
||||
#
|
||||
# TODO: The requested timestamp may lie in a part of the
|
||||
# event graph that the remote server *also* didn't have,
|
||||
# in which case they will have returned another event
|
||||
# which may be nowhere near the requested timestamp. In
|
||||
# the future, we may need to reconcile that gap and ask
|
||||
# other homeservers, and/or extend `/timestamp_to_event`
|
||||
# to return events on *both* sides of the timestamp to
|
||||
# help reconcile the gap faster.
|
||||
remote_event = (
|
||||
await self.federation_event_handler.backfill_event_id(
|
||||
domain, room_id, remote_event_id
|
||||
)
|
||||
)
|
||||
|
||||
# XXX: When we see that the remote server is not trustworthy,
|
||||
# maybe we should not ask them first in the future.
|
||||
if remote_origin_server_ts != remote_event.origin_server_ts:
|
||||
logger.info(
|
||||
"get_event_for_timestamp: Remote server (%s) claimed that remote_event_id=%s occured at remote_origin_server_ts=%s but that isn't true (actually occured at %s). Their claims are dubious and we should consider not trusting them.",
|
||||
domain,
|
||||
remote_event_id,
|
||||
remote_origin_server_ts,
|
||||
remote_event.origin_server_ts,
|
||||
)
|
||||
|
||||
# Only return the remote event if it's closer than the local event
|
||||
if not local_event or (
|
||||
abs(remote_event.origin_server_ts - timestamp)
|
||||
< abs(local_event.origin_server_ts - timestamp)
|
||||
):
|
||||
logger.info(
|
||||
"get_event_for_timestamp: returning remote_event_id=%s (%s) since it's closer to timestamp=%s than local_event=%s (%s)",
|
||||
remote_event_id,
|
||||
remote_event.origin_server_ts,
|
||||
timestamp,
|
||||
local_event.event_id if local_event else None,
|
||||
local_event.origin_server_ts if local_event else None,
|
||||
)
|
||||
return remote_event_id, remote_origin_server_ts
|
||||
except (HttpResponseException, InvalidResponseError) as ex:
|
||||
# Let's not put a high priority on some other homeserver
|
||||
# failing to respond or giving a random response
|
||||
logger.debug(
|
||||
"get_event_for_timestamp: Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
|
||||
domain,
|
||||
type(ex).__name__,
|
||||
ex,
|
||||
ex.args,
|
||||
)
|
||||
except Exception:
|
||||
# But we do want to see some exceptions in our code
|
||||
logger.warning(
|
||||
"get_event_for_timestamp: Failed to fetch /timestamp_to_event from %s because of exception",
|
||||
domain,
|
||||
exc_info=True,
|
||||
# Only return the remote event if it's closer than the local event
|
||||
if not local_event or (
|
||||
abs(remote_event.origin_server_ts - timestamp)
|
||||
< abs(local_event.origin_server_ts - timestamp)
|
||||
):
|
||||
logger.info(
|
||||
"get_event_for_timestamp: returning remote_event_id=%s (%s) since it's closer to timestamp=%s than local_event=%s (%s)",
|
||||
remote_event_id,
|
||||
remote_event.origin_server_ts,
|
||||
timestamp,
|
||||
local_event.event_id if local_event else None,
|
||||
local_event.origin_server_ts if local_event else None,
|
||||
)
|
||||
return remote_event_id, remote_origin_server_ts
|
||||
|
||||
# To appease mypy, we have to add both of these conditions to check for
|
||||
# `None`. We only expect `local_event` to be `None` when
|
||||
|
|
|
@ -191,6 +191,7 @@ class SsoHandler:
|
|||
self._server_name = hs.hostname
|
||||
self._registration_handler = hs.get_registration_handler()
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._device_handler = hs.get_device_handler()
|
||||
self._error_template = hs.config.sso.sso_error_template
|
||||
self._bad_user_template = hs.config.sso.sso_auth_bad_user_template
|
||||
self._profile_handler = hs.get_profile_handler()
|
||||
|
@ -1026,6 +1027,76 @@ class SsoHandler:
|
|||
|
||||
return True
|
||||
|
||||
async def revoke_sessions_for_provider_session_id(
|
||||
self,
|
||||
auth_provider_id: str,
|
||||
auth_provider_session_id: str,
|
||||
expected_user_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Revoke any devices and in-flight logins tied to a provider session.
|
||||
|
||||
Args:
|
||||
auth_provider_id: A unique identifier for this SSO provider, e.g.
|
||||
"oidc" or "saml".
|
||||
auth_provider_session_id: The session ID from the provider to logout
|
||||
expected_user_id: The user we're expecting to logout. If set, it will ignore
|
||||
sessions belonging to other users and log an error.
|
||||
"""
|
||||
# Invalidate any running user-mapping sessions
|
||||
to_delete = []
|
||||
for session_id, session in self._username_mapping_sessions.items():
|
||||
if (
|
||||
session.auth_provider_id == auth_provider_id
|
||||
and session.auth_provider_session_id == auth_provider_session_id
|
||||
):
|
||||
to_delete.append(session_id)
|
||||
|
||||
for session_id in to_delete:
|
||||
logger.info("Revoking mapping session %s", session_id)
|
||||
del self._username_mapping_sessions[session_id]
|
||||
|
||||
# Invalidate any in-flight login tokens
|
||||
await self._store.invalidate_login_tokens_by_session_id(
|
||||
auth_provider_id=auth_provider_id,
|
||||
auth_provider_session_id=auth_provider_session_id,
|
||||
)
|
||||
|
||||
# Fetch any device(s) in the store associated with the session ID.
|
||||
devices = await self._store.get_devices_by_auth_provider_session_id(
|
||||
auth_provider_id=auth_provider_id,
|
||||
auth_provider_session_id=auth_provider_session_id,
|
||||
)
|
||||
|
||||
# We have no guarantee that all the devices of that session are for the same
|
||||
# `user_id`. Hence, we have to iterate over the list of devices and log them out
|
||||
# one by one.
|
||||
for device in devices:
|
||||
user_id = device["user_id"]
|
||||
device_id = device["device_id"]
|
||||
|
||||
# If the user_id associated with that device/session is not the one we got
|
||||
# out of the `sub` claim, skip that device and show log an error.
|
||||
if expected_user_id is not None and user_id != expected_user_id:
|
||||
logger.error(
|
||||
"Received a logout notification from SSO provider "
|
||||
f"{auth_provider_id!r} for the user {expected_user_id!r}, but with "
|
||||
f"a session ID ({auth_provider_session_id!r}) which belongs to "
|
||||
f"{user_id!r}. This may happen when the SSO provider user mapper "
|
||||
"uses something else than the standard attribute as mapping ID. "
|
||||
"For OIDC providers, set `backchannel_logout_ignore_sub` to `true` "
|
||||
"in the provider config if that is the case."
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
"Logging out %r (device %r) via SSO (%r) logout notification (session %r).",
|
||||
user_id,
|
||||
device_id,
|
||||
auth_provider_id,
|
||||
auth_provider_session_id,
|
||||
)
|
||||
await self._device_handler.delete_devices(user_id, [device_id])
|
||||
|
||||
|
||||
def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
|
||||
"""Extract the session ID from the cookie
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue