mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-08-17 13:20:14 -04:00
Merge remote-tracking branch 'upstream/release-v1.73'
This commit is contained in:
commit
bb26f5f0a9
167 changed files with 3234 additions and 1676 deletions
|
@ -16,6 +16,7 @@ import logging
|
|||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.types import Codes, Requester, UserID, create_requester
|
||||
|
||||
|
@ -76,6 +77,9 @@ class DeactivateAccountHandler:
|
|||
True if identity server supports removing threepids, otherwise False.
|
||||
"""
|
||||
|
||||
# This can only be called on the main process.
|
||||
assert isinstance(self._device_handler, DeviceHandler)
|
||||
|
||||
# Check if this user can be deactivated
|
||||
if not await self._third_party_rules.check_can_deactivate_user(
|
||||
user_id, by_admin
|
||||
|
|
|
@ -65,6 +65,8 @@ DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000
|
|||
|
||||
|
||||
class DeviceWorkerHandler:
|
||||
device_list_updater: "DeviceListWorkerUpdater"
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.clock = hs.get_clock()
|
||||
self.hs = hs
|
||||
|
@ -76,6 +78,8 @@ class DeviceWorkerHandler:
|
|||
self.server_name = hs.hostname
|
||||
self._msc3852_enabled = hs.config.experimental.msc3852_enabled
|
||||
|
||||
self.device_list_updater = DeviceListWorkerUpdater(hs)
|
||||
|
||||
@trace
|
||||
async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
|
||||
"""
|
||||
|
@ -99,6 +103,19 @@ class DeviceWorkerHandler:
|
|||
log_kv(device_map)
|
||||
return devices
|
||||
|
||||
async def get_dehydrated_device(
|
||||
self, user_id: str
|
||||
) -> Optional[Tuple[str, JsonDict]]:
|
||||
"""Retrieve the information for a dehydrated device.
|
||||
|
||||
Args:
|
||||
user_id: the user whose dehydrated device we are looking for
|
||||
Returns:
|
||||
a tuple whose first item is the device ID, and the second item is
|
||||
the dehydrated device information
|
||||
"""
|
||||
return await self.store.get_dehydrated_device(user_id)
|
||||
|
||||
@trace
|
||||
async def get_device(self, user_id: str, device_id: str) -> JsonDict:
|
||||
"""Retrieve the given device
|
||||
|
@ -127,7 +144,7 @@ class DeviceWorkerHandler:
|
|||
@cancellable
|
||||
async def get_device_changes_in_shared_rooms(
|
||||
self, user_id: str, room_ids: Collection[str], from_token: StreamToken
|
||||
) -> Collection[str]:
|
||||
) -> Set[str]:
|
||||
"""Get the set of users whose devices have changed who share a room with
|
||||
the given user.
|
||||
"""
|
||||
|
@ -320,6 +337,8 @@ class DeviceWorkerHandler:
|
|||
|
||||
|
||||
class DeviceHandler(DeviceWorkerHandler):
|
||||
device_list_updater: "DeviceListUpdater"
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
|
||||
|
@ -402,6 +421,9 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
|
||||
self._check_device_name_length(initial_device_display_name)
|
||||
|
||||
# Prune the user's device list if they already have a lot of devices.
|
||||
await self._prune_too_many_devices(user_id)
|
||||
|
||||
if device_id is not None:
|
||||
new_device = await self.store.store_device(
|
||||
user_id=user_id,
|
||||
|
@ -433,6 +455,14 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
|
||||
raise errors.StoreError(500, "Couldn't generate a device ID.")
|
||||
|
||||
async def _prune_too_many_devices(self, user_id: str) -> None:
|
||||
"""Delete any excess old devices this user may have."""
|
||||
device_ids = await self.store.check_too_many_devices_for_user(user_id)
|
||||
if not device_ids:
|
||||
return
|
||||
|
||||
await self.delete_devices(user_id, device_ids)
|
||||
|
||||
async def _delete_stale_devices(self) -> None:
|
||||
"""Background task that deletes devices which haven't been accessed for more than
|
||||
a configured time period.
|
||||
|
@ -462,7 +492,7 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
device_ids = [d for d in device_ids if d != except_device_id]
|
||||
await self.delete_devices(user_id, device_ids)
|
||||
|
||||
async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
|
||||
async def delete_devices(self, user_id: str, device_ids: Collection[str]) -> None:
|
||||
"""Delete several devices
|
||||
|
||||
Args:
|
||||
|
@ -606,19 +636,6 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
await self.delete_devices(user_id, [old_device_id])
|
||||
return device_id
|
||||
|
||||
async def get_dehydrated_device(
|
||||
self, user_id: str
|
||||
) -> Optional[Tuple[str, JsonDict]]:
|
||||
"""Retrieve the information for a dehydrated device.
|
||||
|
||||
Args:
|
||||
user_id: the user whose dehydrated device we are looking for
|
||||
Returns:
|
||||
a tuple whose first item is the device ID, and the second item is
|
||||
the dehydrated device information
|
||||
"""
|
||||
return await self.store.get_dehydrated_device(user_id)
|
||||
|
||||
async def rehydrate_device(
|
||||
self, user_id: str, access_token: str, device_id: str
|
||||
) -> dict:
|
||||
|
@ -682,13 +699,33 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
hosts_already_sent_to: Set[str] = set()
|
||||
|
||||
try:
|
||||
stream_id, room_id = await self.store.get_device_change_last_converted_pos()
|
||||
|
||||
while True:
|
||||
self._handle_new_device_update_new_data = False
|
||||
rows = await self.store.get_uncoverted_outbound_room_pokes()
|
||||
max_stream_id = self.store.get_device_stream_token()
|
||||
rows = await self.store.get_uncoverted_outbound_room_pokes(
|
||||
stream_id, room_id
|
||||
)
|
||||
if not rows:
|
||||
# If the DB returned nothing then there is nothing left to
|
||||
# do, *unless* a new device list update happened during the
|
||||
# DB query.
|
||||
|
||||
# Advance `(stream_id, room_id)`.
|
||||
# `max_stream_id` comes from *before* the query for unconverted
|
||||
# rows, which means that any unconverted rows must have a larger
|
||||
# stream ID.
|
||||
if max_stream_id > stream_id:
|
||||
stream_id, room_id = max_stream_id, ""
|
||||
await self.store.set_device_change_last_converted_pos(
|
||||
stream_id, room_id
|
||||
)
|
||||
else:
|
||||
assert max_stream_id == stream_id
|
||||
# Avoid moving `room_id` backwards.
|
||||
pass
|
||||
|
||||
if self._handle_new_device_update_new_data:
|
||||
continue
|
||||
else:
|
||||
|
@ -718,7 +755,6 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
room_id=room_id,
|
||||
stream_id=stream_id,
|
||||
hosts=hosts,
|
||||
context=opentracing_context,
|
||||
)
|
||||
|
@ -752,6 +788,12 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
hosts_already_sent_to.update(hosts)
|
||||
current_stream_id = stream_id
|
||||
|
||||
# Advance `(stream_id, room_id)`.
|
||||
_, _, room_id, stream_id, _ = rows[-1]
|
||||
await self.store.set_device_change_last_converted_pos(
|
||||
stream_id, room_id
|
||||
)
|
||||
|
||||
finally:
|
||||
self._handle_new_device_update_is_processing = False
|
||||
|
||||
|
@ -834,7 +876,6 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
room_id=room_id,
|
||||
stream_id=None,
|
||||
hosts=potentially_changed_hosts,
|
||||
context=None,
|
||||
)
|
||||
|
@ -858,7 +899,36 @@ def _update_device_from_client_ips(
|
|||
)
|
||||
|
||||
|
||||
class DeviceListUpdater:
|
||||
class DeviceListWorkerUpdater:
|
||||
"Handles incoming device list updates from federation and contacts the main process over replication"
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
from synapse.replication.http.devices import (
|
||||
ReplicationUserDevicesResyncRestServlet,
|
||||
)
|
||||
|
||||
self._user_device_resync_client = (
|
||||
ReplicationUserDevicesResyncRestServlet.make_client(hs)
|
||||
)
|
||||
|
||||
async def user_device_resync(
|
||||
self, user_id: str, mark_failed_as_stale: bool = True
|
||||
) -> Optional[JsonDict]:
|
||||
"""Fetches all devices for a user and updates the device cache with them.
|
||||
|
||||
Args:
|
||||
user_id: The user's id whose device_list will be updated.
|
||||
mark_failed_as_stale: Whether to mark the user's device list as stale
|
||||
if the attempt to resync failed.
|
||||
Returns:
|
||||
A dict with device info as under the "devices" in the result of this
|
||||
request:
|
||||
https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
|
||||
"""
|
||||
return await self._user_device_resync_client(user_id=user_id)
|
||||
|
||||
|
||||
class DeviceListUpdater(DeviceListWorkerUpdater):
|
||||
"Handles incoming device list updates from federation and updates the DB"
|
||||
|
||||
def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):
|
||||
|
|
|
@ -27,9 +27,9 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.api.constants import EduTypes
|
||||
from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
|
||||
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
UserID,
|
||||
|
@ -56,27 +56,23 @@ class E2eKeysHandler:
|
|||
self.is_mine = hs.is_mine
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
self._edu_updater = SigningKeyEduUpdater(hs, self)
|
||||
|
||||
federation_registry = hs.get_federation_registry()
|
||||
|
||||
self._is_master = hs.config.worker.worker_app is None
|
||||
if not self._is_master:
|
||||
self._user_device_resync_client = (
|
||||
ReplicationUserDevicesResyncRestServlet.make_client(hs)
|
||||
)
|
||||
else:
|
||||
is_master = hs.config.worker.worker_app is None
|
||||
if is_master:
|
||||
edu_updater = SigningKeyEduUpdater(hs)
|
||||
|
||||
# Only register this edu handler on master as it requires writing
|
||||
# device updates to the db
|
||||
federation_registry.register_edu_handler(
|
||||
EduTypes.SIGNING_KEY_UPDATE,
|
||||
self._edu_updater.incoming_signing_key_update,
|
||||
edu_updater.incoming_signing_key_update,
|
||||
)
|
||||
# also handle the unstable version
|
||||
# FIXME: remove this when enough servers have upgraded
|
||||
federation_registry.register_edu_handler(
|
||||
EduTypes.UNSTABLE_SIGNING_KEY_UPDATE,
|
||||
self._edu_updater.incoming_signing_key_update,
|
||||
edu_updater.incoming_signing_key_update,
|
||||
)
|
||||
|
||||
# doesn't really work as part of the generic query API, because the
|
||||
|
@ -319,14 +315,13 @@ class E2eKeysHandler:
|
|||
# probably be tracking their device lists. However, we haven't
|
||||
# done an initial sync on the device list so we do it now.
|
||||
try:
|
||||
if self._is_master:
|
||||
resync_results = await self.device_handler.device_list_updater.user_device_resync(
|
||||
resync_results = (
|
||||
await self.device_handler.device_list_updater.user_device_resync(
|
||||
user_id
|
||||
)
|
||||
else:
|
||||
resync_results = await self._user_device_resync_client(
|
||||
user_id=user_id
|
||||
)
|
||||
)
|
||||
if resync_results is None:
|
||||
raise ValueError("Device resync failed")
|
||||
|
||||
# Add the device keys to the results.
|
||||
user_devices = resync_results["devices"]
|
||||
|
@ -605,6 +600,8 @@ class E2eKeysHandler:
|
|||
async def upload_keys_for_user(
|
||||
self, user_id: str, device_id: str, keys: JsonDict
|
||||
) -> JsonDict:
|
||||
# This can only be called from the main process.
|
||||
assert isinstance(self.device_handler, DeviceHandler)
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
|
@ -732,6 +729,8 @@ class E2eKeysHandler:
|
|||
user_id: the user uploading the keys
|
||||
keys: the signing keys
|
||||
"""
|
||||
# This can only be called from the main process.
|
||||
assert isinstance(self.device_handler, DeviceHandler)
|
||||
|
||||
# if a master key is uploaded, then check it. Otherwise, load the
|
||||
# stored master key, to check signatures on other keys
|
||||
|
@ -823,6 +822,9 @@ class E2eKeysHandler:
|
|||
Raises:
|
||||
SynapseError: if the signatures dict is not valid.
|
||||
"""
|
||||
# This can only be called from the main process.
|
||||
assert isinstance(self.device_handler, DeviceHandler)
|
||||
|
||||
failures = {}
|
||||
|
||||
# signatures to be stored. Each item will be a SignatureListItem
|
||||
|
@ -870,7 +872,7 @@ class E2eKeysHandler:
|
|||
- signatures of the user's master key by the user's devices.
|
||||
|
||||
Args:
|
||||
user_id (string): the user uploading the keys
|
||||
user_id: the user uploading the keys
|
||||
signatures (dict[string, dict]): map of devices to signed keys
|
||||
|
||||
Returns:
|
||||
|
@ -1200,6 +1202,9 @@ class E2eKeysHandler:
|
|||
A tuple of the retrieved key content, the key's ID and the matching VerifyKey.
|
||||
If the key cannot be retrieved, all values in the tuple will instead be None.
|
||||
"""
|
||||
# This can only be called from the main process.
|
||||
assert isinstance(self.device_handler, DeviceHandler)
|
||||
|
||||
try:
|
||||
remote_result = await self.federation.query_user_devices(
|
||||
user.domain, user.to_string()
|
||||
|
@ -1396,11 +1401,14 @@ class SignatureListItem:
|
|||
class SigningKeyEduUpdater:
|
||||
"""Handles incoming signing key updates from federation and updates the DB"""
|
||||
|
||||
def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastores().main
|
||||
self.federation = hs.get_federation_client()
|
||||
self.clock = hs.get_clock()
|
||||
self.e2e_keys_handler = e2e_keys_handler
|
||||
|
||||
device_handler = hs.get_device_handler()
|
||||
assert isinstance(device_handler, DeviceHandler)
|
||||
self._device_handler = device_handler
|
||||
|
||||
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
|
||||
|
||||
|
@ -1445,9 +1453,6 @@ class SigningKeyEduUpdater:
|
|||
user_id: the user whose updates we are processing
|
||||
"""
|
||||
|
||||
device_handler = self.e2e_keys_handler.device_handler
|
||||
device_list_updater = device_handler.device_list_updater
|
||||
|
||||
async with self._remote_edu_linearizer.queue(user_id):
|
||||
pending_updates = self._pending_updates.pop(user_id, [])
|
||||
if not pending_updates:
|
||||
|
@ -1459,13 +1464,11 @@ class SigningKeyEduUpdater:
|
|||
logger.info("pending updates: %r", pending_updates)
|
||||
|
||||
for master_key, self_signing_key in pending_updates:
|
||||
new_device_ids = (
|
||||
await device_list_updater.process_cross_signing_key_update(
|
||||
user_id,
|
||||
master_key,
|
||||
self_signing_key,
|
||||
)
|
||||
new_device_ids = await self._device_handler.device_list_updater.process_cross_signing_key_update(
|
||||
user_id,
|
||||
master_key,
|
||||
self_signing_key,
|
||||
)
|
||||
device_ids = device_ids + new_device_ids
|
||||
|
||||
await device_handler.notify_device_update(user_id, device_ids)
|
||||
await self._device_handler.notify_device_update(user_id, device_ids)
|
||||
|
|
|
@ -377,8 +377,9 @@ class E2eRoomKeysHandler:
|
|||
"""Deletes a given version of the user's e2e_room_keys backup
|
||||
|
||||
Args:
|
||||
user_id(str): the user whose current backup version we're deleting
|
||||
version(str): the version id of the backup being deleted
|
||||
user_id: the user whose current backup version we're deleting
|
||||
version: Optional. the version ID of the backup version we're deleting
|
||||
If missing, we delete the current backup version info.
|
||||
Raises:
|
||||
NotFoundError: if this backup version doesn't exist
|
||||
"""
|
||||
|
|
|
@ -45,6 +45,7 @@ class EventAuthHandler:
|
|||
def __init__(self, hs: "HomeServer"):
|
||||
self._clock = hs.get_clock()
|
||||
self._store = hs.get_datastores().main
|
||||
self._state_storage_controller = hs.get_storage_controllers().state
|
||||
self._server_name = hs.hostname
|
||||
|
||||
async def check_auth_rules_from_context(
|
||||
|
@ -179,17 +180,22 @@ class EventAuthHandler:
|
|||
this function may return an incorrect result as we are not able to fully
|
||||
track server membership in a room without full state.
|
||||
"""
|
||||
if not allow_partial_state_rooms and await self._store.is_partial_state_room(
|
||||
room_id
|
||||
):
|
||||
raise AuthError(
|
||||
403,
|
||||
"Unable to authorise you right now; room is partial-stated here.",
|
||||
errcode=Codes.UNABLE_DUE_TO_PARTIAL_STATE,
|
||||
)
|
||||
|
||||
if not await self.is_host_in_room(room_id, host):
|
||||
raise AuthError(403, "Host not in room.")
|
||||
if await self._store.is_partial_state_room(room_id):
|
||||
if allow_partial_state_rooms:
|
||||
current_hosts = await self._state_storage_controller.get_current_hosts_in_room_or_partial_state_approximation(
|
||||
room_id
|
||||
)
|
||||
if host not in current_hosts:
|
||||
raise AuthError(403, "Host not in room (partial-state approx).")
|
||||
else:
|
||||
raise AuthError(
|
||||
403,
|
||||
"Unable to authorise you right now; room is partial-stated here.",
|
||||
errcode=Codes.UNABLE_DUE_TO_PARTIAL_STATE,
|
||||
)
|
||||
else:
|
||||
if not await self.is_host_in_room(room_id, host):
|
||||
raise AuthError(403, "Host not in room.")
|
||||
|
||||
async def check_restricted_join_rules(
|
||||
self,
|
||||
|
|
|
@ -379,6 +379,7 @@ class FederationHandler:
|
|||
filtered_extremities = await filter_events_for_server(
|
||||
self._storage_controllers,
|
||||
self.server_name,
|
||||
self.server_name,
|
||||
events_to_check,
|
||||
redact=False,
|
||||
check_history_visibility_only=True,
|
||||
|
@ -1231,7 +1232,9 @@ class FederationHandler:
|
|||
async def on_backfill_request(
|
||||
self, origin: str, room_id: str, pdu_list: List[str], limit: int
|
||||
) -> List[EventBase]:
|
||||
await self._event_auth_handler.assert_host_in_room(room_id, origin)
|
||||
# We allow partially joined rooms since in this case we are filtering out
|
||||
# non-local events in `filter_events_for_server`.
|
||||
await self._event_auth_handler.assert_host_in_room(room_id, origin, True)
|
||||
|
||||
# Synapse asks for 100 events per backfill request. Do not allow more.
|
||||
limit = min(limit, 100)
|
||||
|
@ -1252,7 +1255,7 @@ class FederationHandler:
|
|||
)
|
||||
|
||||
events = await filter_events_for_server(
|
||||
self._storage_controllers, origin, events
|
||||
self._storage_controllers, origin, self.server_name, events
|
||||
)
|
||||
|
||||
return events
|
||||
|
@ -1283,7 +1286,7 @@ class FederationHandler:
|
|||
await self._event_auth_handler.assert_host_in_room(event.room_id, origin)
|
||||
|
||||
events = await filter_events_for_server(
|
||||
self._storage_controllers, origin, [event]
|
||||
self._storage_controllers, origin, self.server_name, [event]
|
||||
)
|
||||
event = events[0]
|
||||
return event
|
||||
|
@ -1296,7 +1299,9 @@ class FederationHandler:
|
|||
latest_events: List[str],
|
||||
limit: int,
|
||||
) -> List[EventBase]:
|
||||
await self._event_auth_handler.assert_host_in_room(room_id, origin)
|
||||
# We allow partially joined rooms since in this case we are filtering out
|
||||
# non-local events in `filter_events_for_server`.
|
||||
await self._event_auth_handler.assert_host_in_room(room_id, origin, True)
|
||||
|
||||
# Only allow up to 20 events to be retrieved per request.
|
||||
limit = min(limit, 20)
|
||||
|
@ -1309,7 +1314,7 @@ class FederationHandler:
|
|||
)
|
||||
|
||||
missing_events = await filter_events_for_server(
|
||||
self._storage_controllers, origin, missing_events
|
||||
self._storage_controllers, origin, self.server_name, missing_events
|
||||
)
|
||||
|
||||
return missing_events
|
||||
|
@ -1596,8 +1601,8 @@ class FederationHandler:
|
|||
Fetch the complexity of a remote room over federation.
|
||||
|
||||
Args:
|
||||
remote_room_hosts (list[str]): The remote servers to ask.
|
||||
room_id (str): The room ID to ask about.
|
||||
remote_room_hosts: The remote servers to ask.
|
||||
room_id: The room ID to ask about.
|
||||
|
||||
Returns:
|
||||
Dict contains the complexity
|
||||
|
|
|
@ -711,7 +711,7 @@ class IdentityHandler:
|
|||
inviter_display_name: The current display name of the
|
||||
inviter.
|
||||
inviter_avatar_url: The URL of the inviter's avatar.
|
||||
id_access_token (str): The access token to authenticate to the identity
|
||||
id_access_token: The access token to authenticate to the identity
|
||||
server with
|
||||
|
||||
Returns:
|
||||
|
|
|
@ -1137,11 +1137,13 @@ class EventCreationHandler:
|
|||
)
|
||||
state_events = await self.store.get_events_as_list(state_event_ids)
|
||||
# Create a StateMap[str]
|
||||
state_map = {(e.type, e.state_key): e.event_id for e in state_events}
|
||||
current_state_ids = {
|
||||
(e.type, e.state_key): e.event_id for e in state_events
|
||||
}
|
||||
# Actually strip down and only use the necessary auth events
|
||||
auth_event_ids = self._event_auth_handler.compute_auth_events(
|
||||
event=temp_event,
|
||||
current_state_ids=state_map,
|
||||
current_state_ids=current_state_ids,
|
||||
for_verification=False,
|
||||
)
|
||||
|
||||
|
|
|
@ -787,7 +787,7 @@ class OidcProvider:
|
|||
Must include an ``access_token`` field.
|
||||
|
||||
Returns:
|
||||
UserInfo: an object representing the user.
|
||||
an object representing the user.
|
||||
"""
|
||||
logger.debug("Using the OAuth2 access_token to request userinfo")
|
||||
metadata = await self.load_metadata()
|
||||
|
@ -1435,6 +1435,7 @@ class UserAttributeDict(TypedDict):
|
|||
localpart: Optional[str]
|
||||
confirm_localpart: bool
|
||||
display_name: Optional[str]
|
||||
picture: Optional[str] # may be omitted by older `OidcMappingProviders`
|
||||
emails: List[str]
|
||||
|
||||
|
||||
|
@ -1520,6 +1521,7 @@ env.filters.update(
|
|||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class JinjaOidcMappingConfig:
|
||||
subject_claim: str
|
||||
picture_claim: str
|
||||
localpart_template: Optional[Template]
|
||||
display_name_template: Optional[Template]
|
||||
email_template: Optional[Template]
|
||||
|
@ -1539,6 +1541,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
|
|||
@staticmethod
|
||||
def parse_config(config: dict) -> JinjaOidcMappingConfig:
|
||||
subject_claim = config.get("subject_claim", "sub")
|
||||
picture_claim = config.get("picture_claim", "picture")
|
||||
|
||||
def parse_template_config(option_name: str) -> Optional[Template]:
|
||||
if option_name not in config:
|
||||
|
@ -1572,6 +1575,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
|
|||
|
||||
return JinjaOidcMappingConfig(
|
||||
subject_claim=subject_claim,
|
||||
picture_claim=picture_claim,
|
||||
localpart_template=localpart_template,
|
||||
display_name_template=display_name_template,
|
||||
email_template=email_template,
|
||||
|
@ -1611,10 +1615,13 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
|
|||
if email:
|
||||
emails.append(email)
|
||||
|
||||
picture = userinfo.get("picture")
|
||||
|
||||
return UserAttributeDict(
|
||||
localpart=localpart,
|
||||
display_name=display_name,
|
||||
emails=emails,
|
||||
picture=picture,
|
||||
confirm_localpart=self._config.confirm_localpart,
|
||||
)
|
||||
|
||||
|
|
|
@ -448,6 +448,12 @@ class PaginationHandler:
|
|||
|
||||
if pagin_config.from_token:
|
||||
from_token = pagin_config.from_token
|
||||
elif pagin_config.direction == "f":
|
||||
from_token = (
|
||||
await self.hs.get_event_sources().get_start_token_for_pagination(
|
||||
room_id
|
||||
)
|
||||
)
|
||||
else:
|
||||
from_token = (
|
||||
await self.hs.get_event_sources().get_current_token_for_pagination(
|
||||
|
|
|
@ -201,7 +201,7 @@ class BasePresenceHandler(abc.ABC):
|
|||
"""Get the current presence state for multiple users.
|
||||
|
||||
Returns:
|
||||
dict: `user_id` -> `UserPresenceState`
|
||||
A mapping of `user_id` -> `UserPresenceState`
|
||||
"""
|
||||
states = {}
|
||||
missing = []
|
||||
|
@ -478,7 +478,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
|
|||
return _NullContextManager()
|
||||
|
||||
prev_state = await self.current_state_for_user(user_id)
|
||||
if prev_state != PresenceState.BUSY:
|
||||
if prev_state.state != PresenceState.BUSY:
|
||||
# We set state here but pass ignore_status_msg = True as we don't want to
|
||||
# cause the status message to be cleared.
|
||||
# Note that this causes last_active_ts to be incremented which is not
|
||||
|
|
|
@ -92,7 +92,6 @@ class ReceiptsHandler:
|
|||
continue
|
||||
|
||||
# Check if these receipts apply to a thread.
|
||||
thread_id = None
|
||||
data = user_values.get("data", {})
|
||||
thread_id = data.get("thread_id")
|
||||
# If the thread ID is invalid, consider it missing.
|
||||
|
|
|
@ -38,6 +38,7 @@ from synapse.api.errors import (
|
|||
)
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.config.server import is_threepid_reserved
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.http.servlet import assert_params_in_dict
|
||||
from synapse.replication.http.login import RegisterDeviceReplicationServlet
|
||||
from synapse.replication.http.register import (
|
||||
|
@ -848,6 +849,9 @@ class RegistrationHandler:
|
|||
refresh_token = None
|
||||
refresh_token_id = None
|
||||
|
||||
# This can only run on the main process.
|
||||
assert isinstance(self.device_handler, DeviceHandler)
|
||||
|
||||
registered_device_id = await self.device_handler.check_device_registered(
|
||||
user_id,
|
||||
device_id,
|
||||
|
|
|
@ -13,17 +13,19 @@
|
|||
# limitations under the License.
|
||||
import enum
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, Optional
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.constants import EventTypes, RelationTypes
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.events import EventBase, relation_from_event
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.logging.opentracing import trace
|
||||
from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.types import JsonDict, Requester, StreamToken, UserID
|
||||
from synapse.types import JsonDict, Requester, UserID
|
||||
from synapse.util.async_helpers import gather_results
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -172,40 +174,6 @@ class RelationsHandler:
|
|||
|
||||
return return_value
|
||||
|
||||
async def get_relations_for_event(
|
||||
self,
|
||||
event_id: str,
|
||||
event: EventBase,
|
||||
room_id: str,
|
||||
relation_type: str,
|
||||
ignored_users: FrozenSet[str] = frozenset(),
|
||||
) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
|
||||
"""Get a list of events which relate to an event, ordered by topological ordering.
|
||||
|
||||
Args:
|
||||
event_id: Fetch events that relate to this event ID.
|
||||
event: The matching EventBase to event_id.
|
||||
room_id: The room the event belongs to.
|
||||
relation_type: The type of relation.
|
||||
ignored_users: The users ignored by the requesting user.
|
||||
|
||||
Returns:
|
||||
List of event IDs that match relations requested. The rows are of
|
||||
the form `{"event_id": "..."}`.
|
||||
"""
|
||||
|
||||
# Call the underlying storage method, which is cached.
|
||||
related_events, next_token = await self._main_store.get_relations_for_event(
|
||||
event_id, event, room_id, relation_type, direction="f"
|
||||
)
|
||||
|
||||
# Filter out ignored users and convert to the expected format.
|
||||
related_events = [
|
||||
event for event in related_events if event.sender not in ignored_users
|
||||
]
|
||||
|
||||
return related_events, next_token
|
||||
|
||||
async def redact_events_related_to(
|
||||
self,
|
||||
requester: Requester,
|
||||
|
@ -259,51 +227,107 @@ class RelationsHandler:
|
|||
e.msg,
|
||||
)
|
||||
|
||||
async def get_annotations_for_event(
|
||||
self,
|
||||
event_id: str,
|
||||
room_id: str,
|
||||
limit: int = 5,
|
||||
ignored_users: FrozenSet[str] = frozenset(),
|
||||
) -> List[JsonDict]:
|
||||
"""Get a list of annotations on the event, grouped by event type and
|
||||
async def get_annotations_for_events(
|
||||
self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset()
|
||||
) -> Dict[str, List[JsonDict]]:
|
||||
"""Get a list of annotations to the given events, grouped by event type and
|
||||
aggregation key, sorted by count.
|
||||
|
||||
This is used e.g. to get the what and how many reactions have happend
|
||||
This is used e.g. to get the what and how many reactions have happened
|
||||
on an event.
|
||||
|
||||
Args:
|
||||
event_id: Fetch events that relate to this event ID.
|
||||
room_id: The room the event belongs to.
|
||||
limit: Only fetch the `limit` groups.
|
||||
event_ids: Fetch events that relate to these event IDs.
|
||||
ignored_users: The users ignored by the requesting user.
|
||||
|
||||
Returns:
|
||||
List of groups of annotations that match. Each row is a dict with
|
||||
`type`, `key` and `count` fields.
|
||||
A map of event IDs to a list of groups of annotations that match.
|
||||
Each entry is a dict with `type`, `key` and `count` fields.
|
||||
"""
|
||||
# Get the base results for all users.
|
||||
full_results = await self._main_store.get_aggregation_groups_for_event(
|
||||
event_id, room_id, limit
|
||||
full_results = await self._main_store.get_aggregation_groups_for_events(
|
||||
event_ids
|
||||
)
|
||||
|
||||
# Avoid additional logic if there are no ignored users.
|
||||
if not ignored_users:
|
||||
return {
|
||||
event_id: results
|
||||
for event_id, results in full_results.items()
|
||||
if results
|
||||
}
|
||||
|
||||
# Then subtract off the results for any ignored users.
|
||||
ignored_results = await self._main_store.get_aggregation_groups_for_users(
|
||||
event_id, room_id, limit, ignored_users
|
||||
[event_id for event_id, results in full_results.items() if results],
|
||||
ignored_users,
|
||||
)
|
||||
|
||||
filtered_results = []
|
||||
for result in full_results:
|
||||
key = (result["type"], result["key"])
|
||||
if key in ignored_results:
|
||||
result = result.copy()
|
||||
result["count"] -= ignored_results[key]
|
||||
if result["count"] <= 0:
|
||||
continue
|
||||
filtered_results.append(result)
|
||||
filtered_results = {}
|
||||
for event_id, results in full_results.items():
|
||||
# If no annotations, skip.
|
||||
if not results:
|
||||
continue
|
||||
|
||||
# If there are not ignored results for this event, copy verbatim.
|
||||
if event_id not in ignored_results:
|
||||
filtered_results[event_id] = results
|
||||
continue
|
||||
|
||||
# Otherwise, subtract out the ignored results.
|
||||
event_ignored_results = ignored_results[event_id]
|
||||
for result in results:
|
||||
key = (result["type"], result["key"])
|
||||
if key in event_ignored_results:
|
||||
# Ensure to not modify the cache.
|
||||
result = result.copy()
|
||||
result["count"] -= event_ignored_results[key]
|
||||
if result["count"] <= 0:
|
||||
continue
|
||||
filtered_results.setdefault(event_id, []).append(result)
|
||||
|
||||
return filtered_results
|
||||
|
||||
async def get_references_for_events(
|
||||
self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset()
|
||||
) -> Dict[str, List[_RelatedEvent]]:
|
||||
"""Get a list of references to the given events.
|
||||
|
||||
Args:
|
||||
event_ids: Fetch events that relate to this event ID.
|
||||
ignored_users: The users ignored by the requesting user.
|
||||
|
||||
Returns:
|
||||
A map of event IDs to a list related events.
|
||||
"""
|
||||
|
||||
related_events = await self._main_store.get_references_for_events(event_ids)
|
||||
|
||||
# Avoid additional logic if there are no ignored users.
|
||||
if not ignored_users:
|
||||
return {
|
||||
event_id: results
|
||||
for event_id, results in related_events.items()
|
||||
if results
|
||||
}
|
||||
|
||||
# Filter out ignored users.
|
||||
results = {}
|
||||
for event_id, events in related_events.items():
|
||||
# If no references, skip.
|
||||
if not events:
|
||||
continue
|
||||
|
||||
# Filter ignored users out.
|
||||
events = [event for event in events if event.sender not in ignored_users]
|
||||
# If there are no events left, skip this event.
|
||||
if not events:
|
||||
continue
|
||||
|
||||
results[event_id] = events
|
||||
|
||||
return results
|
||||
|
||||
async def _get_threads_for_events(
|
||||
self,
|
||||
events_by_id: Dict[str, EventBase],
|
||||
|
@ -366,59 +390,66 @@ class RelationsHandler:
|
|||
results = {}
|
||||
|
||||
for event_id, summary in summaries.items():
|
||||
if summary:
|
||||
thread_count, latest_thread_event = summary
|
||||
# If no thread, skip.
|
||||
if not summary:
|
||||
continue
|
||||
|
||||
# Subtract off the count of any ignored users.
|
||||
for ignored_user in ignored_users:
|
||||
thread_count -= ignored_results.get((event_id, ignored_user), 0)
|
||||
thread_count, latest_thread_event = summary
|
||||
|
||||
# This is gnarly, but if the latest event is from an ignored user,
|
||||
# attempt to find one that isn't from an ignored user.
|
||||
if latest_thread_event.sender in ignored_users:
|
||||
room_id = latest_thread_event.room_id
|
||||
# Subtract off the count of any ignored users.
|
||||
for ignored_user in ignored_users:
|
||||
thread_count -= ignored_results.get((event_id, ignored_user), 0)
|
||||
|
||||
# If the root event is not found, something went wrong, do
|
||||
# not include a summary of the thread.
|
||||
event = await self._event_handler.get_event(user, room_id, event_id)
|
||||
if event is None:
|
||||
continue
|
||||
# This is gnarly, but if the latest event is from an ignored user,
|
||||
# attempt to find one that isn't from an ignored user.
|
||||
if latest_thread_event.sender in ignored_users:
|
||||
room_id = latest_thread_event.room_id
|
||||
|
||||
potential_events, _ = await self.get_relations_for_event(
|
||||
event_id,
|
||||
event,
|
||||
room_id,
|
||||
RelationTypes.THREAD,
|
||||
ignored_users,
|
||||
)
|
||||
# If the root event is not found, something went wrong, do
|
||||
# not include a summary of the thread.
|
||||
event = await self._event_handler.get_event(user, room_id, event_id)
|
||||
if event is None:
|
||||
continue
|
||||
|
||||
# If all found events are from ignored users, do not include
|
||||
# a summary of the thread.
|
||||
if not potential_events:
|
||||
continue
|
||||
|
||||
# The *last* event returned is the one that is cared about.
|
||||
event = await self._event_handler.get_event(
|
||||
user, room_id, potential_events[-1].event_id
|
||||
)
|
||||
# It is unexpected that the event will not exist.
|
||||
if event is None:
|
||||
logger.warning(
|
||||
"Unable to fetch latest event in a thread with event ID: %s",
|
||||
potential_events[-1].event_id,
|
||||
)
|
||||
continue
|
||||
latest_thread_event = event
|
||||
|
||||
results[event_id] = _ThreadAggregation(
|
||||
latest_event=latest_thread_event,
|
||||
count=thread_count,
|
||||
# If there's a thread summary it must also exist in the
|
||||
# participated dictionary.
|
||||
current_user_participated=events_by_id[event_id].sender == user_id
|
||||
or participated[event_id],
|
||||
# Attempt to find another event to use as the latest event.
|
||||
potential_events, _ = await self._main_store.get_relations_for_event(
|
||||
event_id, event, room_id, RelationTypes.THREAD, direction="f"
|
||||
)
|
||||
|
||||
# Filter out ignored users.
|
||||
potential_events = [
|
||||
event
|
||||
for event in potential_events
|
||||
if event.sender not in ignored_users
|
||||
]
|
||||
|
||||
# If all found events are from ignored users, do not include
|
||||
# a summary of the thread.
|
||||
if not potential_events:
|
||||
continue
|
||||
|
||||
# The *last* event returned is the one that is cared about.
|
||||
event = await self._event_handler.get_event(
|
||||
user, room_id, potential_events[-1].event_id
|
||||
)
|
||||
# It is unexpected that the event will not exist.
|
||||
if event is None:
|
||||
logger.warning(
|
||||
"Unable to fetch latest event in a thread with event ID: %s",
|
||||
potential_events[-1].event_id,
|
||||
)
|
||||
continue
|
||||
latest_thread_event = event
|
||||
|
||||
results[event_id] = _ThreadAggregation(
|
||||
latest_event=latest_thread_event,
|
||||
count=thread_count,
|
||||
# If there's a thread summary it must also exist in the
|
||||
# participated dictionary.
|
||||
current_user_participated=events_by_id[event_id].sender == user_id
|
||||
or participated[event_id],
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@trace
|
||||
|
@ -496,49 +527,56 @@ class RelationsHandler:
|
|||
# (as that is what makes it part of the thread).
|
||||
relations_by_id[latest_thread_event.event_id] = RelationTypes.THREAD
|
||||
|
||||
# Fetch other relations per event.
|
||||
for event in events_by_id.values():
|
||||
# Fetch any annotations (ie, reactions) to bundle with this event.
|
||||
annotations = await self.get_annotations_for_event(
|
||||
event.event_id, event.room_id, ignored_users=ignored_users
|
||||
async def _fetch_annotations() -> None:
|
||||
"""Fetch any annotations (ie, reactions) to bundle with this event."""
|
||||
annotations_by_event_id = await self.get_annotations_for_events(
|
||||
events_by_id.keys(), ignored_users=ignored_users
|
||||
)
|
||||
if annotations:
|
||||
results.setdefault(
|
||||
event.event_id, BundledAggregations()
|
||||
).annotations = {"chunk": annotations}
|
||||
for event_id, annotations in annotations_by_event_id.items():
|
||||
if annotations:
|
||||
results.setdefault(event_id, BundledAggregations()).annotations = {
|
||||
"chunk": annotations
|
||||
}
|
||||
|
||||
# Fetch any references to bundle with this event.
|
||||
references, next_token = await self.get_relations_for_event(
|
||||
event.event_id,
|
||||
event,
|
||||
event.room_id,
|
||||
RelationTypes.REFERENCE,
|
||||
ignored_users=ignored_users,
|
||||
async def _fetch_references() -> None:
|
||||
"""Fetch any references to bundle with this event."""
|
||||
references_by_event_id = await self.get_references_for_events(
|
||||
events_by_id.keys(), ignored_users=ignored_users
|
||||
)
|
||||
if references:
|
||||
aggregations = results.setdefault(event.event_id, BundledAggregations())
|
||||
aggregations.references = {
|
||||
"chunk": [{"event_id": ev.event_id} for ev in references]
|
||||
}
|
||||
for event_id, references in references_by_event_id.items():
|
||||
if references:
|
||||
results.setdefault(event_id, BundledAggregations()).references = {
|
||||
"chunk": [{"event_id": ev.event_id} for ev in references]
|
||||
}
|
||||
|
||||
if next_token:
|
||||
aggregations.references["next_batch"] = await next_token.to_string(
|
||||
self._main_store
|
||||
)
|
||||
async def _fetch_edits() -> None:
|
||||
"""
|
||||
Fetch any edits (but not for redacted events).
|
||||
|
||||
# Fetch any edits (but not for redacted events).
|
||||
#
|
||||
# Note that there is no use in limiting edits by ignored users since the
|
||||
# parent event should be ignored in the first place if the user is ignored.
|
||||
edits = await self._main_store.get_applicable_edits(
|
||||
[
|
||||
event_id
|
||||
for event_id, event in events_by_id.items()
|
||||
if not event.internal_metadata.is_redacted()
|
||||
]
|
||||
Note that there is no use in limiting edits by ignored users since the
|
||||
parent event should be ignored in the first place if the user is ignored.
|
||||
"""
|
||||
edits = await self._main_store.get_applicable_edits(
|
||||
[
|
||||
event_id
|
||||
for event_id, event in events_by_id.items()
|
||||
if not event.internal_metadata.is_redacted()
|
||||
]
|
||||
)
|
||||
for event_id, edit in edits.items():
|
||||
results.setdefault(event_id, BundledAggregations()).replace = edit
|
||||
|
||||
# Parallelize the calls for annotations, references, and edits since they
|
||||
# are unrelated.
|
||||
await make_deferred_yieldable(
|
||||
gather_results(
|
||||
(
|
||||
run_in_background(_fetch_annotations),
|
||||
run_in_background(_fetch_references),
|
||||
run_in_background(_fetch_edits),
|
||||
)
|
||||
)
|
||||
)
|
||||
for event_id, edit in edits.items():
|
||||
results.setdefault(event_id, BundledAggregations()).replace = edit
|
||||
|
||||
return results
|
||||
|
||||
|
@ -571,7 +609,7 @@ class RelationsHandler:
|
|||
room_id, requester, allow_departed_users=True
|
||||
)
|
||||
|
||||
# Note that ignored users are not passed into get_relations_for_event
|
||||
# Note that ignored users are not passed into get_threads
|
||||
# below. Ignored users are handled in filter_events_for_client (and by
|
||||
# not passing them in here we should get a better cache hit rate).
|
||||
thread_roots, next_batch = await self._main_store.get_threads(
|
||||
|
|
|
@ -441,7 +441,7 @@ class DefaultSamlMappingProvider:
|
|||
client_redirect_url: where the client wants to redirect to
|
||||
|
||||
Returns:
|
||||
dict: A dict containing new user attributes. Possible keys:
|
||||
A dict containing new user attributes. Possible keys:
|
||||
* mxid_localpart (str): Required. The localpart of the user's mxid
|
||||
* displayname (str): The displayname of the user
|
||||
* emails (list[str]): Any emails for the user
|
||||
|
@ -483,7 +483,7 @@ class DefaultSamlMappingProvider:
|
|||
Args:
|
||||
config: A dictionary containing configuration options for this provider
|
||||
Returns:
|
||||
SamlConfig: A custom config object for this module
|
||||
A custom config object for this module
|
||||
"""
|
||||
# Parse config options and use defaults where necessary
|
||||
mxid_source_attribute = config.get("mxid_source_attribute", "uid")
|
||||
|
|
|
@ -15,6 +15,7 @@ import logging
|
|||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from synapse.api.errors import Codes, StoreError, SynapseError
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.types import Requester
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -29,7 +30,10 @@ class SetPasswordHandler:
|
|||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastores().main
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._device_handler = hs.get_device_handler()
|
||||
# This can only be instantiated on the main process.
|
||||
device_handler = hs.get_device_handler()
|
||||
assert isinstance(device_handler, DeviceHandler)
|
||||
self._device_handler = device_handler
|
||||
|
||||
async def set_password(
|
||||
self,
|
||||
|
|
|
@ -12,6 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
import hashlib
|
||||
import io
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
|
@ -37,6 +39,7 @@ from twisted.web.server import Request
|
|||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
|
||||
from synapse.config.sso import SsoAttributeRequirement
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.handlers.register import init_counters_for_auth_provider
|
||||
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
|
||||
from synapse.http import get_request_user_agent
|
||||
|
@ -137,6 +140,7 @@ class UserAttributes:
|
|||
localpart: Optional[str]
|
||||
confirm_localpart: bool = False
|
||||
display_name: Optional[str] = None
|
||||
picture: Optional[str] = None
|
||||
emails: Collection[str] = attr.Factory(list)
|
||||
|
||||
|
||||
|
@ -195,6 +199,10 @@ class SsoHandler:
|
|||
self._error_template = hs.config.sso.sso_error_template
|
||||
self._bad_user_template = hs.config.sso.sso_auth_bad_user_template
|
||||
self._profile_handler = hs.get_profile_handler()
|
||||
self._media_repo = (
|
||||
hs.get_media_repository() if hs.config.media.can_load_media_repo else None
|
||||
)
|
||||
self._http_client = hs.get_proxied_blacklisted_http_client()
|
||||
|
||||
# The following template is shown after a successful user interactive
|
||||
# authentication session. It tells the user they can close the window.
|
||||
|
@ -494,6 +502,8 @@ class SsoHandler:
|
|||
await self._profile_handler.set_displayname(
|
||||
user_id_obj, requester, attributes.display_name, True
|
||||
)
|
||||
if attributes.picture:
|
||||
await self.set_avatar(user_id, attributes.picture)
|
||||
|
||||
await self._auth_handler.complete_sso_login(
|
||||
user_id,
|
||||
|
@ -702,8 +712,110 @@ class SsoHandler:
|
|||
await self._store.record_user_external_id(
|
||||
auth_provider_id, remote_user_id, registered_user_id
|
||||
)
|
||||
|
||||
# Set avatar, if available
|
||||
if attributes.picture:
|
||||
await self.set_avatar(registered_user_id, attributes.picture)
|
||||
|
||||
return registered_user_id
|
||||
|
||||
async def set_avatar(self, user_id: str, picture_https_url: str) -> bool:
|
||||
"""Set avatar of the user.
|
||||
|
||||
This downloads the image file from the URL provided, stores that in
|
||||
the media repository and then sets the avatar on the user's profile.
|
||||
|
||||
It can detect if the same image is being saved again and bails early by storing
|
||||
the hash of the file in the `upload_name` of the avatar image.
|
||||
|
||||
Currently, it only supports server configurations which run the media repository
|
||||
within the same process.
|
||||
|
||||
It silently fails and logs a warning by raising an exception and catching it
|
||||
internally if:
|
||||
* it is unable to fetch the image itself (non 200 status code) or
|
||||
* the image supplied is bigger than max allowed size or
|
||||
* the image type is not one of the allowed image types.
|
||||
|
||||
Args:
|
||||
user_id: matrix user ID in the form @localpart:domain as a string.
|
||||
|
||||
picture_https_url: HTTPS url for the picture image file.
|
||||
|
||||
Returns: `True` if the user's avatar has been successfully set to the image at
|
||||
`picture_https_url`.
|
||||
"""
|
||||
if self._media_repo is None:
|
||||
logger.info(
|
||||
"failed to set user avatar because out-of-process media repositories "
|
||||
"are not supported yet "
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
uid = UserID.from_string(user_id)
|
||||
|
||||
def is_allowed_mime_type(content_type: str) -> bool:
|
||||
if (
|
||||
self._profile_handler.allowed_avatar_mimetypes
|
||||
and content_type
|
||||
not in self._profile_handler.allowed_avatar_mimetypes
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
# download picture, enforcing size limit & mime type check
|
||||
picture = io.BytesIO()
|
||||
|
||||
content_length, headers, uri, code = await self._http_client.get_file(
|
||||
url=picture_https_url,
|
||||
output_stream=picture,
|
||||
max_size=self._profile_handler.max_avatar_size,
|
||||
is_allowed_content_type=is_allowed_mime_type,
|
||||
)
|
||||
|
||||
if code != 200:
|
||||
raise Exception(
|
||||
"GET request to download sso avatar image returned {}".format(code)
|
||||
)
|
||||
|
||||
# upload name includes hash of the image file's content so that we can
|
||||
# easily check if it requires an update or not, the next time user logs in
|
||||
upload_name = "sso_avatar_" + hashlib.sha256(picture.read()).hexdigest()
|
||||
|
||||
# bail if user already has the same avatar
|
||||
profile = await self._profile_handler.get_profile(user_id)
|
||||
if profile["avatar_url"] is not None:
|
||||
server_name = profile["avatar_url"].split("/")[-2]
|
||||
media_id = profile["avatar_url"].split("/")[-1]
|
||||
if server_name == self._server_name:
|
||||
media = await self._media_repo.store.get_local_media(media_id)
|
||||
if media is not None and upload_name == media["upload_name"]:
|
||||
logger.info("skipping saving the user avatar")
|
||||
return True
|
||||
|
||||
# store it in media repository
|
||||
avatar_mxc_url = await self._media_repo.create_content(
|
||||
media_type=headers[b"Content-Type"][0].decode("utf-8"),
|
||||
upload_name=upload_name,
|
||||
content=picture,
|
||||
content_length=content_length,
|
||||
auth_user=uid,
|
||||
)
|
||||
|
||||
# save it as user avatar
|
||||
await self._profile_handler.set_avatar_url(
|
||||
uid,
|
||||
create_requester(uid),
|
||||
str(avatar_mxc_url),
|
||||
)
|
||||
|
||||
logger.info("successfully saved the user avatar")
|
||||
return True
|
||||
except Exception:
|
||||
logger.warning("failed to save the user avatar")
|
||||
return False
|
||||
|
||||
async def complete_sso_ui_auth_request(
|
||||
self,
|
||||
auth_provider_id: str,
|
||||
|
@ -1035,6 +1147,8 @@ class SsoHandler:
|
|||
) -> None:
|
||||
"""Revoke any devices and in-flight logins tied to a provider session.
|
||||
|
||||
Can only be called from the main process.
|
||||
|
||||
Args:
|
||||
auth_provider_id: A unique identifier for this SSO provider, e.g.
|
||||
"oidc" or "saml".
|
||||
|
@ -1042,6 +1156,12 @@ class SsoHandler:
|
|||
expected_user_id: The user we're expecting to logout. If set, it will ignore
|
||||
sessions belonging to other users and log an error.
|
||||
"""
|
||||
|
||||
# It is expected that this is the main process.
|
||||
assert isinstance(
|
||||
self._device_handler, DeviceHandler
|
||||
), "revoking SSO sessions can only be called on the main process"
|
||||
|
||||
# Invalidate any running user-mapping sessions
|
||||
to_delete = []
|
||||
for session_id, session in self._username_mapping_sessions.items():
|
||||
|
|
|
@ -1425,14 +1425,14 @@ class SyncHandler:
|
|||
|
||||
logger.debug("Fetching OTK data")
|
||||
device_id = sync_config.device_id
|
||||
one_time_key_counts: JsonDict = {}
|
||||
one_time_keys_count: JsonDict = {}
|
||||
unused_fallback_key_types: List[str] = []
|
||||
if device_id:
|
||||
# TODO: We should have a way to let clients differentiate between the states of:
|
||||
# * no change in OTK count since the provided since token
|
||||
# * the server has zero OTKs left for this device
|
||||
# Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
|
||||
one_time_key_counts = await self.store.count_e2e_one_time_keys(
|
||||
one_time_keys_count = await self.store.count_e2e_one_time_keys(
|
||||
user_id, device_id
|
||||
)
|
||||
unused_fallback_key_types = (
|
||||
|
@ -1462,7 +1462,7 @@ class SyncHandler:
|
|||
archived=sync_result_builder.archived,
|
||||
to_device=sync_result_builder.to_device,
|
||||
device_lists=device_lists,
|
||||
device_one_time_keys_count=one_time_key_counts,
|
||||
device_one_time_keys_count=one_time_keys_count,
|
||||
device_unused_fallback_key_types=unused_fallback_key_types,
|
||||
next_batch=sync_result_builder.now_token,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue