Merge remote-tracking branch 'upstream/release-v1.57'

This commit is contained in:
Tulir Asokan 2022-04-21 13:53:47 +03:00
commit b2fa6ec9f6
248 changed files with 14616 additions and 8934 deletions

View file

@ -12,8 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
from typing import TYPE_CHECKING, Collection, List, Optional, Tuple
from typing import TYPE_CHECKING, Awaitable, Callable, Collection, List, Optional, Tuple
from synapse.replication.http.account_data import (
ReplicationAddTagRestServlet,
@ -27,6 +28,12 @@ from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
ON_ACCOUNT_DATA_UPDATED_CALLBACK = Callable[
[str, Optional[str], str, JsonDict], Awaitable
]
class AccountDataHandler:
def __init__(self, hs: "HomeServer"):
@ -40,6 +47,44 @@ class AccountDataHandler:
self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs)
self._account_data_writers = hs.config.worker.writers.account_data
self._on_account_data_updated_callbacks: List[
ON_ACCOUNT_DATA_UPDATED_CALLBACK
] = []
def register_module_callbacks(
self, on_account_data_updated: Optional[ON_ACCOUNT_DATA_UPDATED_CALLBACK] = None
) -> None:
"""Register callbacks from modules."""
if on_account_data_updated is not None:
self._on_account_data_updated_callbacks.append(on_account_data_updated)
async def _notify_modules(
self,
user_id: str,
room_id: Optional[str],
account_data_type: str,
content: JsonDict,
) -> None:
"""Notifies modules about new account data changes.
A change can be either a new account data type being added, or the content
associated with a type being changed. Account data for a given type is removed by
changing the associated content to an empty dictionary.
Note that this is not called when the tags associated with a room change.
Args:
user_id: The user whose account data is changing.
room_id: The ID of the room the account data change concerns, if any.
account_data_type: The type of the account data.
content: The content that is now associated with this type.
"""
for callback in self._on_account_data_updated_callbacks:
try:
await callback(user_id, room_id, account_data_type, content)
except Exception as e:
logger.exception("Failed to run module callback %s: %s", callback, e)
async def add_account_data_to_room(
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
) -> int:
@ -63,6 +108,8 @@ class AccountDataHandler:
"account_data_key", max_stream_id, users=[user_id]
)
await self._notify_modules(user_id, room_id, account_data_type, content)
return max_stream_id
else:
response = await self._room_data_client(
@ -96,6 +143,9 @@ class AccountDataHandler:
self._notifier.on_new_event(
"account_data_key", max_stream_id, users=[user_id]
)
await self._notify_modules(user_id, None, account_data_type, content)
return max_stream_id
else:
response = await self._user_data_client(

View file

@ -180,9 +180,9 @@ class AccountValidityHandler:
expiring_users = await self.store.get_users_expiring_soon()
if expiring_users:
for user in expiring_users:
for user_id, expiration_ts_ms in expiring_users:
await self._send_renewal_email(
user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
user_id=user_id, expiration_ts=expiration_ts_ms
)
async def send_renewal_email_to_user(self, user_id: str) -> None:

View file

@ -33,7 +33,13 @@ from synapse.metrics.background_process_metrics import (
wrap_as_background_process,
)
from synapse.storage.databases.main.directory import RoomAliasMapping
from synapse.types import JsonDict, RoomAlias, RoomStreamToken, UserID
from synapse.types import (
DeviceListUpdates,
JsonDict,
RoomAlias,
RoomStreamToken,
UserID,
)
from synapse.util.async_helpers import Linearizer
from synapse.util.metrics import Measure
@ -58,6 +64,9 @@ class ApplicationServicesHandler:
self._msc2409_to_device_messages_enabled = (
hs.config.experimental.msc2409_to_device_messages_enabled
)
self._msc3202_transaction_extensions_enabled = (
hs.config.experimental.msc3202_transaction_extensions
)
self.current_max = 0
self.is_processing = False
@ -204,9 +213,9 @@ class ApplicationServicesHandler:
Args:
stream_key: The stream the event came from.
`stream_key` can be "typing_key", "receipt_key", "presence_key" or
"to_device_key". Any other value for `stream_key` will cause this function
to return early.
`stream_key` can be "typing_key", "receipt_key", "presence_key",
"to_device_key" or "device_list_key". Any other value for `stream_key`
will cause this function to return early.
Ephemeral events will only be pushed to appservices that have opted into
receiving them by setting `push_ephemeral` to true in their registration
@ -230,6 +239,7 @@ class ApplicationServicesHandler:
"receipt_key",
"presence_key",
"to_device_key",
"device_list_key",
):
return
@ -253,15 +263,37 @@ class ApplicationServicesHandler:
):
return
# Ignore device lists if the feature flag is not enabled
if (
stream_key == "device_list_key"
and not self._msc3202_transaction_extensions_enabled
):
return
# Check whether there are any appservices which have registered to receive
# ephemeral events.
#
# Note that whether these events are actually relevant to these appservices
# is decided later on.
services = self.store.get_app_services()
services = [
service
for service in self.store.get_app_services()
if service.supports_ephemeral
for service in services
# Different stream keys require different support booleans
if (
stream_key
in (
"typing_key",
"receipt_key",
"presence_key",
"to_device_key",
)
and service.supports_ephemeral
)
or (
stream_key == "device_list_key"
and service.msc3202_transaction_extensions
)
]
if not services:
# Bail out early if none of the target appservices have explicitly registered
@ -298,10 +330,8 @@ class ApplicationServicesHandler:
continue
# Since we read/update the stream position for this AS/stream
with (
await self._ephemeral_events_linearizer.queue(
(service.id, stream_key)
)
async with self._ephemeral_events_linearizer.queue(
(service.id, stream_key)
):
if stream_key == "receipt_key":
events = await self._handle_receipts(service, new_token)
@ -336,6 +366,20 @@ class ApplicationServicesHandler:
service, "to_device", new_token
)
elif stream_key == "device_list_key":
device_list_summary = await self._get_device_list_summary(
service, new_token
)
if device_list_summary:
self.scheduler.enqueue_for_appservice(
service, device_list_summary=device_list_summary
)
# Persist the latest handled stream token for this appservice
await self.store.set_appservice_stream_type_pos(
service, "device_list", new_token
)
async def _handle_typing(
self, service: ApplicationService, new_token: int
) -> List[JsonDict]:
@ -542,6 +586,96 @@ class ApplicationServicesHandler:
return message_payload
async def _get_device_list_summary(
self,
appservice: ApplicationService,
new_key: int,
) -> DeviceListUpdates:
"""
Retrieve a list of users who have changed their device lists.
Args:
appservice: The application service to retrieve device list changes for.
new_key: The stream key of the device list change that triggered this method call.
Returns:
A set of device list updates, comprised of users that the appservices needs to:
* resync the device list of, and
* stop tracking the device list of.
"""
# Fetch the last successfully processed device list update stream ID
# for this appservice.
from_key = await self.store.get_type_stream_id_for_appservice(
appservice, "device_list"
)
# Fetch the users who have modified their device list since then.
users_with_changed_device_lists = (
await self.store.get_users_whose_devices_changed(from_key, to_key=new_key)
)
# Filter out any users the application service is not interested in
#
# For each user who changed their device list, we want to check whether this
# appservice would be interested in the change.
filtered_users_with_changed_device_lists = {
user_id
for user_id in users_with_changed_device_lists
if await self._is_appservice_interested_in_device_lists_of_user(
appservice, user_id
)
}
# Create a summary of "changed" and "left" users.
# TODO: Calculate "left" users.
device_list_summary = DeviceListUpdates(
changed=filtered_users_with_changed_device_lists
)
return device_list_summary
async def _is_appservice_interested_in_device_lists_of_user(
self,
appservice: ApplicationService,
user_id: str,
) -> bool:
"""
Returns whether a given application service is interested in the device list
updates of a given user.
The application service is interested in the user's device list updates if any
of the following are true:
* The user is the appservice's sender localpart user.
* The user is in the appservice's user namespace.
* At least one member of one room that the user is a part of is in the
appservice's user namespace.
* The appservice is explicitly (via room ID or alias) interested in at
least one room that the user is in.
Args:
appservice: The application service to gauge interest of.
user_id: The ID of the user whose device list interest is in question.
Returns:
True if the application service is interested in the user's device lists, False
otherwise.
"""
# This method checks against both the sender localpart user as well as if the
# user is in the appservice's user namespace.
if appservice.is_interested_in_user(user_id):
return True
# Determine whether any of the rooms the user is in justifies sending this
# device list update to the application service.
room_ids = await self.store.get_rooms_for_user(user_id)
for room_id in room_ids:
# This method covers checking room members for appservice interest as well as
# room ID and alias checks.
if await appservice.is_interested_in_room(room_id, self.store):
return True
return False
async def query_user_exists(self, user_id: str) -> bool:
"""Check if any application service knows this user_id exists.

View file

@ -211,6 +211,7 @@ class AuthHandler:
self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.auth.password_enabled
self._password_localdb_enabled = hs.config.auth.password_localdb_enabled
self._third_party_rules = hs.get_third_party_event_rules()
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`.
@ -1505,6 +1506,8 @@ class AuthHandler:
user_id, medium, address, validated_at, self.hs.get_clock().time_msec()
)
await self._third_party_rules.on_threepid_bind(user_id, medium, address)
async def delete_threepid(
self, user_id: str, medium: str, address: str, id_server: Optional[str] = None
) -> bool:

View file

@ -37,7 +37,10 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
)
from synapse.types import (
JsonDict,
StreamToken,
@ -278,6 +281,22 @@ class DeviceHandler(DeviceWorkerHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room)
# Whether `_handle_new_device_update_async` is currently processing.
self._handle_new_device_update_is_processing = False
# If a new device update may have happened while the loop was
# processing.
self._handle_new_device_update_new_data = False
# On start up check if there are any updates pending.
hs.get_reactor().callWhenRunning(self._handle_new_device_update_async)
# Used to decide if we calculate outbound pokes up front or not. By
# default we do to allow safely downgrading Synapse.
self.use_new_device_lists_changes_in_room = (
hs.config.server.use_new_device_lists_changes_in_room
)
def _check_device_name_length(self, name: Optional[str]) -> None:
"""
Checks whether a device name is longer than the maximum allowed length.
@ -469,19 +488,26 @@ class DeviceHandler(DeviceWorkerHandler):
# No changes to notify about, so this is a no-op.
return
users_who_share_room = await self.store.get_users_who_share_room_with_user(
user_id
)
room_ids = await self.store.get_rooms_for_user(user_id)
hosts: Set[str] = set()
if self.hs.is_mine_id(user_id):
hosts.update(get_domain_from_id(u) for u in users_who_share_room)
hosts.discard(self.server_name)
hosts: Optional[Set[str]] = None
if not self.use_new_device_lists_changes_in_room:
hosts = set()
set_tag("target_hosts", hosts)
if self.hs.is_mine_id(user_id):
for room_id in room_ids:
joined_users = await self.store.get_users_in_room(room_id)
hosts.update(get_domain_from_id(u) for u in joined_users)
set_tag("target_hosts", hosts)
hosts.discard(self.server_name)
position = await self.store.add_device_change_to_streams(
user_id, device_ids, list(hosts)
user_id,
device_ids,
hosts=hosts,
room_ids=room_ids,
)
if not position:
@ -495,9 +521,12 @@ class DeviceHandler(DeviceWorkerHandler):
# specify the user ID too since the user should always get their own device list
# updates, even if they aren't in any rooms.
users_to_notify = users_who_share_room.union({user_id})
self.notifier.on_new_event(
"device_list_key", position, users={user_id}, rooms=room_ids
)
self.notifier.on_new_event("device_list_key", position, users=users_to_notify)
# We may need to do some processing asynchronously.
self._handle_new_device_update_async()
if hosts:
logger.info(
@ -614,6 +643,85 @@ class DeviceHandler(DeviceWorkerHandler):
return {"success": True}
@wrap_as_background_process("_handle_new_device_update_async")
async def _handle_new_device_update_async(self) -> None:
"""Called when we have a new local device list update that we need to
send out over federation.
This happens in the background so as not to block the original request
that generated the device update.
"""
if self._handle_new_device_update_is_processing:
self._handle_new_device_update_new_data = True
return
self._handle_new_device_update_is_processing = True
# The stream ID we processed previous iteration (if any), and the set of
# hosts we've already poked about for this update. This is so that we
# don't poke the same remote server about the same update repeatedly.
current_stream_id = None
hosts_already_sent_to: Set[str] = set()
try:
while True:
self._handle_new_device_update_new_data = False
rows = await self.store.get_uncoverted_outbound_room_pokes()
if not rows:
# If the DB returned nothing then there is nothing left to
# do, *unless* a new device list update happened during the
# DB query.
if self._handle_new_device_update_new_data:
continue
else:
return
for user_id, device_id, room_id, stream_id, opentracing_context in rows:
joined_user_ids = await self.store.get_users_in_room(room_id)
hosts = {get_domain_from_id(u) for u in joined_user_ids}
hosts.discard(self.server_name)
# Check if we've already sent this update to some hosts
if current_stream_id == stream_id:
hosts -= hosts_already_sent_to
await self.store.add_device_list_outbound_pokes(
user_id=user_id,
device_id=device_id,
room_id=room_id,
stream_id=stream_id,
hosts=hosts,
context=opentracing_context,
)
# Notify replication that we've updated the device list stream.
self.notifier.notify_replication()
if hosts:
logger.info(
"Sending device list update notif for %r to: %r",
user_id,
hosts,
)
for host in hosts:
self.federation_sender.send_device_messages(
host, immediate=False
)
log_kv(
{"message": "sent device update to host", "host": host}
)
if current_stream_id != stream_id:
# Clear the set of hosts we've already sent to as we're
# processing a new update.
hosts_already_sent_to.clear()
hosts_already_sent_to.update(hosts)
current_stream_id = stream_id
finally:
self._handle_new_device_update_is_processing = False
def _update_device_from_client_ips(
device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
@ -725,7 +833,7 @@ class DeviceListUpdater:
async def _handle_device_updates(self, user_id: str) -> None:
"Actually handle pending updates."
with (await self._remote_edu_linearizer.queue(user_id)):
async with self._remote_edu_linearizer.queue(user_id):
pending_updates = self._pending_updates.pop(user_id, [])
if not pending_updates:
# This can happen since we batch updates

View file

@ -118,7 +118,7 @@ class E2eKeysHandler:
from_device_id: the device making the query. This is used to limit
the number of in-flight queries at a time.
"""
with await self._query_devices_linearizer.queue((from_user_id, from_device_id)):
async with self._query_devices_linearizer.queue((from_user_id, from_device_id)):
device_keys_query: Dict[str, Iterable[str]] = query_body.get(
"device_keys", {}
)
@ -1386,7 +1386,7 @@ class SigningKeyEduUpdater:
device_handler = self.e2e_keys_handler.device_handler
device_list_updater = device_handler.device_list_updater
with (await self._remote_edu_linearizer.queue(user_id)):
async with self._remote_edu_linearizer.queue(user_id):
pending_updates = self._pending_updates.pop(user_id, [])
if not pending_updates:
# This can happen since we batch updates

View file

@ -83,7 +83,7 @@ class E2eRoomKeysHandler:
# we deliberately take the lock to get keys so that changing the version
# works atomically
with (await self._upload_linearizer.queue(user_id)):
async with self._upload_linearizer.queue(user_id):
# make sure the backup version exists
try:
await self.store.get_e2e_room_keys_version_info(user_id, version)
@ -126,7 +126,7 @@ class E2eRoomKeysHandler:
"""
# lock for consistency with uploading
with (await self._upload_linearizer.queue(user_id)):
async with self._upload_linearizer.queue(user_id):
# make sure the backup version exists
try:
version_info = await self.store.get_e2e_room_keys_version_info(
@ -187,7 +187,7 @@ class E2eRoomKeysHandler:
# TODO: Validate the JSON to make sure it has the right keys.
# XXX: perhaps we should use a finer grained lock here?
with (await self._upload_linearizer.queue(user_id)):
async with self._upload_linearizer.queue(user_id):
# Check that the version we're trying to upload is the current version
try:
@ -332,7 +332,7 @@ class E2eRoomKeysHandler:
# TODO: Validate the JSON to make sure it has the right keys.
# lock everyone out until we've switched version
with (await self._upload_linearizer.queue(user_id)):
async with self._upload_linearizer.queue(user_id):
new_version = await self.store.create_e2e_room_keys_version(
user_id, version_info
)
@ -359,7 +359,7 @@ class E2eRoomKeysHandler:
}
"""
with (await self._upload_linearizer.queue(user_id)):
async with self._upload_linearizer.queue(user_id):
try:
res = await self.store.get_e2e_room_keys_version_info(user_id, version)
except StoreError as e:
@ -383,7 +383,7 @@ class E2eRoomKeysHandler:
NotFoundError: if this backup version doesn't exist
"""
with (await self._upload_linearizer.queue(user_id)):
async with self._upload_linearizer.queue(user_id):
try:
await self.store.delete_e2e_room_keys_version(user_id, version)
except StoreError as e:
@ -413,7 +413,7 @@ class E2eRoomKeysHandler:
raise SynapseError(
400, "Version in body does not match", Codes.INVALID_PARAM
)
with (await self._upload_linearizer.queue(user_id)):
async with self._upload_linearizer.queue(user_id):
try:
old_info = await self.store.get_e2e_room_keys_version_info(
user_id, version

View file

@ -151,7 +151,7 @@ class FederationHandler:
return. This is used as part of the heuristic to decide if we
should back paginate.
"""
with (await self._room_backfill.queue(room_id)):
async with self._room_backfill.queue(room_id):
return await self._maybe_backfill_inner(room_id, current_depth, limit)
async def _maybe_backfill_inner(

View file

@ -224,7 +224,7 @@ class FederationEventHandler:
len(missing_prevs),
shortstr(missing_prevs),
)
with (await self._room_pdu_linearizer.queue(pdu.room_id)):
async with self._room_pdu_linearizer.queue(pdu.room_id):
logger.info(
"Acquired room lock to fetch %d missing prev_events",
len(missing_prevs),
@ -469,6 +469,12 @@ class FederationEventHandler:
if context.rejected:
raise SynapseError(400, "Join event was rejected")
# the remote server is responsible for sending our join event to the rest
# of the federation. Indeed, attempting to do so will result in problems
# when we try to look up the state before the join (to get the server list)
# and discover that we do not have it.
event.internal_metadata.proactively_send = False
return await self.persist_events_and_notify(room_id, [(event, context)])
async def backfill(
@ -891,10 +897,24 @@ class FederationEventHandler:
logger.debug("We are also missing %i auth events", len(missing_auth_events))
missing_events = missing_desired_events | missing_auth_events
logger.debug("Fetching %i events from remote", len(missing_events))
await self._get_events_and_persist(
destination=destination, room_id=room_id, event_ids=missing_events
)
# Making an individual request for each of 1000s of events has a lot of
# overhead. On the other hand, we don't really want to fetch all of the events
# if we already have most of them.
#
# As an arbitrary heuristic, if we are missing more than 10% of the events, then
# we fetch the whole state.
#
# TODO: might it be better to have an API which lets us do an aggregate event
# request
if (len(missing_events) * 10) >= len(auth_event_ids) + len(state_event_ids):
logger.debug("Requesting complete state from remote")
await self._get_state_and_persist(destination, room_id, event_id)
else:
logger.debug("Fetching %i events from remote", len(missing_events))
await self._get_events_and_persist(
destination=destination, room_id=room_id, event_ids=missing_events
)
# we need to make sure we re-load from the database to get the rejected
# state correct.
@ -953,6 +973,27 @@ class FederationEventHandler:
return remote_state
async def _get_state_and_persist(
self, destination: str, room_id: str, event_id: str
) -> None:
"""Get the complete room state at a given event, and persist any new events
as outliers"""
room_version = await self._store.get_room_version(room_id)
auth_events, state_events = await self._federation_client.get_room_state(
destination, room_id, event_id=event_id, room_version=room_version
)
logger.info("/state returned %i events", len(auth_events) + len(state_events))
await self._auth_and_persist_outliers(
room_id, itertools.chain(auth_events, state_events)
)
# we also need the event itself.
if not await self._store.have_seen_event(room_id, event_id):
await self._get_events_and_persist(
destination=destination, room_id=room_id, event_ids=(event_id,)
)
async def _process_received_pdu(
self,
origin: str,

View file

@ -858,8 +858,6 @@ class IdentityHandler:
if room_type is not None:
invite_config["room_type"] = room_type
# TODO The unstable field is deprecated and should be removed in the future.
invite_config["org.matrix.msc3288.room_type"] = room_type
# If a custom web client location is available, include it in the request.
if self._web_client_location:

View file

@ -853,7 +853,7 @@ class EventCreationHandler:
# a situation where event persistence can't keep up, causing
# extremities to pile up, which in turn leads to state resolution
# taking longer.
with (await self.limiter.queue(event_dict["room_id"])):
async with self.limiter.queue(event_dict["room_id"]):
if txn_id and requester.access_token_id:
existing_event_id = await self.store.get_event_id_from_transaction_id(
event_dict["room_id"],

View file

@ -441,7 +441,14 @@ class PaginationHandler:
if pagin_config.from_token:
from_token = pagin_config.from_token
else:
from_token = self.hs.get_event_sources().get_current_token_for_pagination()
from_token = (
await self.hs.get_event_sources().get_current_token_for_pagination(
room_id
)
)
# We expect `/messages` to use historic pagination tokens by default but
# `/messages` should still works with live tokens when manually provided.
assert from_token.room_key.topological
if pagin_config.limit is None:
# This shouldn't happen as we've set a default limit before this

View file

@ -1030,7 +1030,7 @@ class PresenceHandler(BasePresenceHandler):
is_syncing: Whether or not the user is now syncing
sync_time_msec: Time in ms when the user was last syncing
"""
with (await self.external_sync_linearizer.queue(process_id)):
async with self.external_sync_linearizer.queue(process_id):
prev_state = await self.current_state_for_user(user_id)
process_presence = self.external_process_to_current_syncs.setdefault(
@ -1071,7 +1071,7 @@ class PresenceHandler(BasePresenceHandler):
Used when the process has stopped/disappeared.
"""
with (await self.external_sync_linearizer.queue(process_id)):
async with self.external_sync_linearizer.queue(process_id):
process_presence = self.external_process_to_current_syncs.pop(
process_id, set()
)
@ -1625,7 +1625,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
# We'll actually pull the presence updates for these users at the end.
interested_and_updated_users: Union[Set[str], FrozenSet[str]] = set()
if from_key:
if from_key is not None:
# First get all users that have had a presence update
updated_users = stream_change_cache.get_all_entities_changed(from_key)

View file

@ -41,7 +41,7 @@ class ReadMarkerHandler:
the read marker has changed.
"""
with await self.read_marker_linearizer.queue((room_id, user_id)):
async with self.read_marker_linearizer.queue((room_id, user_id)):
existing_read_marker = await self.store.get_account_data_for_room_and_type(
user_id, room_id, "m.fully_read"
)

View file

@ -12,7 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, Iterable, Optional, cast
from typing import (
TYPE_CHECKING,
Collection,
Dict,
FrozenSet,
Iterable,
List,
Optional,
Tuple,
)
import attr
from frozendict import frozendict
@ -20,12 +29,12 @@ from frozendict import frozendict
from synapse.api.constants import RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.types import JsonDict, Requester, StreamToken
from synapse.storage.databases.main.relations import _RelatedEvent
from synapse.types import JsonDict, Requester, StreamToken, UserID
from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
@ -116,7 +125,10 @@ class RelationsHandler:
if event is None:
raise SynapseError(404, "Unknown parent event.")
pagination_chunk = await self._main_store.get_relations_for_event(
# Note that ignored users are not passed into get_relations_for_event
# below. Ignored users are handled in filter_events_for_client (and by
# not passing them in here we should get a better cache hit rate).
related_events, next_token = await self._main_store.get_relations_for_event(
event_id=event_id,
event=event,
room_id=room_id,
@ -130,7 +142,7 @@ class RelationsHandler:
)
events = await self._main_store.get_events_as_list(
[c["event_id"] for c in pagination_chunk.chunk]
[e.event_id for e in related_events]
)
events = await filter_events_for_client(
@ -152,14 +164,100 @@ class RelationsHandler:
events, now, bundle_aggregations=aggregations
)
return_value = await pagination_chunk.to_dict(self._main_store)
return_value["chunk"] = serialized_events
return_value["original_event"] = original_event
return_value = {
"chunk": serialized_events,
"original_event": original_event,
}
if next_token:
return_value["next_batch"] = await next_token.to_string(self._main_store)
if from_token:
return_value["prev_batch"] = await from_token.to_string(self._main_store)
return return_value
async def get_relations_for_event(
self,
event_id: str,
event: EventBase,
room_id: str,
relation_type: str,
ignored_users: FrozenSet[str] = frozenset(),
) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
"""Get a list of events which relate to an event, ordered by topological ordering.
Args:
event_id: Fetch events that relate to this event ID.
event: The matching EventBase to event_id.
room_id: The room the event belongs to.
relation_type: The type of relation.
ignored_users: The users ignored by the requesting user.
Returns:
List of event IDs that match relations requested. The rows are of
the form `{"event_id": "..."}`.
"""
# Call the underlying storage method, which is cached.
related_events, next_token = await self._main_store.get_relations_for_event(
event_id, event, room_id, relation_type, direction="f"
)
# Filter out ignored users and convert to the expected format.
related_events = [
event for event in related_events if event.sender not in ignored_users
]
return related_events, next_token
async def get_annotations_for_event(
self,
event_id: str,
room_id: str,
limit: int = 5,
ignored_users: FrozenSet[str] = frozenset(),
) -> List[JsonDict]:
"""Get a list of annotations on the event, grouped by event type and
aggregation key, sorted by count.
This is used e.g. to get the what and how many reactions have happend
on an event.
Args:
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
limit: Only fetch the `limit` groups.
ignored_users: The users ignored by the requesting user.
Returns:
List of groups of annotations that match. Each row is a dict with
`type`, `key` and `count` fields.
"""
# Get the base results for all users.
full_results = await self._main_store.get_aggregation_groups_for_event(
event_id, room_id, limit
)
# Then subtract off the results for any ignored users.
ignored_results = await self._main_store.get_aggregation_groups_for_users(
event_id, room_id, limit, ignored_users
)
filtered_results = []
for result in full_results:
key = (result["type"], result["key"])
if key in ignored_results:
result = result.copy()
result["count"] -= ignored_results[key]
if result["count"] <= 0:
continue
filtered_results.append(result)
return filtered_results
async def _get_bundled_aggregation_for_event(
self, event: EventBase, user_id: str
self, event: EventBase, ignored_users: FrozenSet[str]
) -> Optional[BundledAggregations]:
"""Generate bundled aggregations for an event.
@ -167,7 +265,7 @@ class RelationsHandler:
Args:
event: The event to calculate bundled aggregations for.
user_id: The user requesting the bundled aggregations.
ignored_users: The users ignored by the requesting user.
Returns:
The bundled aggregations for an event, if bundled aggregations are
@ -190,23 +288,125 @@ class RelationsHandler:
# while others need more processing during serialization.
aggregations = BundledAggregations()
annotations = await self._main_store.get_aggregation_groups_for_event(
event_id, room_id
annotations = await self.get_annotations_for_event(
event_id, room_id, ignored_users=ignored_users
)
if annotations.chunk:
aggregations.annotations = await annotations.to_dict(
cast("DataStore", self)
)
if annotations:
aggregations.annotations = {"chunk": annotations}
references = await self._main_store.get_relations_for_event(
event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
references, next_token = await self.get_relations_for_event(
event_id,
event,
room_id,
RelationTypes.REFERENCE,
ignored_users=ignored_users,
)
if references.chunk:
aggregations.references = await references.to_dict(cast("DataStore", self))
if references:
aggregations.references = {
"chunk": [{"event_id": event.event_id} for event in references]
}
if next_token:
aggregations.references["next_batch"] = await next_token.to_string(
self._main_store
)
# Store the bundled aggregations in the event metadata for later use.
return aggregations
async def get_threads_for_events(
self, event_ids: Collection[str], user_id: str, ignored_users: FrozenSet[str]
) -> Dict[str, _ThreadAggregation]:
"""Get the bundled aggregations for threads for the requested events.
Args:
event_ids: Events to get aggregations for threads.
user_id: The user requesting the bundled aggregations.
ignored_users: The users ignored by the requesting user.
Returns:
A dictionary mapping event ID to the thread information.
May not contain a value for all requested event IDs.
"""
user = UserID.from_string(user_id)
# Fetch thread summaries.
summaries = await self._main_store.get_thread_summaries(event_ids)
# Only fetch participated for a limited selection based on what had
# summaries.
thread_event_ids = [
event_id for event_id, summary in summaries.items() if summary
]
participated = await self._main_store.get_threads_participated(
thread_event_ids, user_id
)
# Then subtract off the results for any ignored users.
ignored_results = await self._main_store.get_threaded_messages_per_user(
thread_event_ids, ignored_users
)
# A map of event ID to the thread aggregation.
results = {}
for event_id, summary in summaries.items():
if summary:
thread_count, latest_thread_event, edit = summary
# Subtract off the count of any ignored users.
for ignored_user in ignored_users:
thread_count -= ignored_results.get((event_id, ignored_user), 0)
# This is gnarly, but if the latest event is from an ignored user,
# attempt to find one that isn't from an ignored user.
if latest_thread_event.sender in ignored_users:
room_id = latest_thread_event.room_id
# If the root event is not found, something went wrong, do
# not include a summary of the thread.
event = await self._event_handler.get_event(user, room_id, event_id)
if event is None:
continue
potential_events, _ = await self.get_relations_for_event(
event_id,
event,
room_id,
RelationTypes.THREAD,
ignored_users,
)
# If all found events are from ignored users, do not include
# a summary of the thread.
if not potential_events:
continue
# The *last* event returned is the one that is cared about.
event = await self._event_handler.get_event(
user, room_id, potential_events[-1].event_id
)
# It is unexpected that the event will not exist.
if event is None:
logger.warning(
"Unable to fetch latest event in a thread with event ID: %s",
potential_events[-1].event_id,
)
continue
latest_thread_event = event
results[event_id] = _ThreadAggregation(
latest_event=latest_thread_event,
latest_edit=edit,
count=thread_count,
# If there's a thread summary it must also exist in the
# participated dictionary.
current_user_participated=participated[event_id],
)
return results
async def get_bundled_aggregations(
self, events: Iterable[EventBase], user_id: str
) -> Dict[str, BundledAggregations]:
@ -230,13 +430,21 @@ class RelationsHandler:
# event ID -> bundled aggregation in non-serialized form.
results: Dict[str, BundledAggregations] = {}
# Fetch any ignored users of the requesting user.
ignored_users = await self._main_store.ignored_users(user_id)
# Fetch other relations per event.
for event in events_by_id.values():
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
event_result = await self._get_bundled_aggregation_for_event(
event, ignored_users
)
if event_result:
results[event.event_id] = event_result
# Fetch any edits (but not for redacted events).
#
# Note that there is no use in limiting edits by ignored users since the
# parent event should be ignored in the first place if the user is ignored.
edits = await self._main_store.get_applicable_edits(
[
event_id
@ -247,25 +455,10 @@ class RelationsHandler:
for event_id, edit in edits.items():
results.setdefault(event_id, BundledAggregations()).replace = edit
# Fetch thread summaries.
summaries = await self._main_store.get_thread_summaries(events_by_id.keys())
# Only fetch participated for a limited selection based on what had
# summaries.
participated = await self._main_store.get_threads_participated(
[event_id for event_id, summary in summaries.items() if summary], user_id
threads = await self.get_threads_for_events(
events_by_id.keys(), user_id, ignored_users
)
for event_id, summary in summaries.items():
if summary:
thread_count, latest_thread_event, edit = summary
results.setdefault(
event_id, BundledAggregations()
).thread = _ThreadAggregation(
latest_event=latest_thread_event,
latest_edit=edit,
count=thread_count,
# If there's a thread summary it must also exist in the
# participated dictionary.
current_user_participated=participated[event_id],
)
for event_id, thread in threads.items():
results.setdefault(event_id, BundledAggregations()).thread = thread
return results

View file

@ -771,7 +771,9 @@ class RoomCreationHandler:
% (user_id,),
)
visibility = config.get("visibility", None)
# The spec says rooms should default to private visibility if
# `visibility` is not specified.
visibility = config.get("visibility", "private")
is_public = visibility == "public"
if "room_id" in config:
@ -891,7 +893,7 @@ class RoomCreationHandler:
#
# we also don't need to check the requester's shadow-ban here, as we
# have already done so above (and potentially emptied invite_list).
with (await self.room_member_handler.member_linearizer.queue((room_id,))):
async with self.room_member_handler.member_linearizer.queue((room_id,)):
content = {}
is_direct = config.get("is_direct", None)
if is_direct:
@ -1452,8 +1454,8 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
def get_current_key(self) -> RoomStreamToken:
return self.store.get_room_max_token()
def get_current_key_for_room(self, room_id: str) -> Awaitable[str]:
return self.store.get_room_events_max_id(room_id)
def get_current_key_for_room(self, room_id: str) -> Awaitable[RoomStreamToken]:
return self.store.get_current_room_stream_token_for_room_id(room_id)
class ShutdownRoomResponse(TypedDict):

View file

@ -158,8 +158,8 @@ class RoomBatchHandler:
) -> List[str]:
"""Takes all `state_events_at_start` event dictionaries and creates/persists
them in a floating state event chain which don't resolve into the current room
state. They are floating because they reference no prev_events and are marked
as outliers which disconnects them from the normal DAG.
state. They are floating because they reference no prev_events which disconnects
them from the normal DAG.
Args:
state_events_at_start:
@ -215,31 +215,23 @@ class RoomBatchHandler:
room_id=room_id,
action=membership,
content=event_dict["content"],
# Mark as an outlier to disconnect it from the normal DAG
# and not show up between batches of history.
outlier=True,
historical=True,
# Only the first event in the state chain should be floating.
# The rest should hang off each other in a chain.
allow_no_prev_events=index == 0,
prev_event_ids=prev_event_ids_for_state_chain,
# Since each state event is marked as an outlier, the
# `EventContext.for_outlier()` won't have any `state_ids`
# set and therefore can't derive any state even though the
# prev_events are set. Also since the first event in the
# state chain is floating with no `prev_events`, it can't
# derive state from anywhere automatically. So we need to
# set some state explicitly.
# The first event in the state chain is floating with no
# `prev_events` which means it can't derive state from
# anywhere automatically. So we need to set some state
# explicitly.
#
# Make sure to use a copy of this list because we modify it
# later in the loop here. Otherwise it will be the same
# reference and also update in the event when we append later.
# reference and also update in the event when we append
# later.
state_event_ids=state_event_ids.copy(),
)
else:
# TODO: Add some complement tests that adds state that is not member joins
# and will use this code path. Maybe we only want to support join state events
# and can get rid of this `else`?
(
event,
_,
@ -248,21 +240,15 @@ class RoomBatchHandler:
state_event["sender"], app_service_requester.app_service
),
event_dict,
# Mark as an outlier to disconnect it from the normal DAG
# and not show up between batches of history.
outlier=True,
historical=True,
# Only the first event in the state chain should be floating.
# The rest should hang off each other in a chain.
allow_no_prev_events=index == 0,
prev_event_ids=prev_event_ids_for_state_chain,
# Since each state event is marked as an outlier, the
# `EventContext.for_outlier()` won't have any `state_ids`
# set and therefore can't derive any state even though the
# prev_events are set. Also since the first event in the
# state chain is floating with no `prev_events`, it can't
# derive state from anywhere automatically. So we need to
# set some state explicitly.
# The first event in the state chain is floating with no
# `prev_events` which means it can't derive state from
# anywhere automatically. So we need to set some state
# explicitly.
#
# Make sure to use a copy of this list because we modify it
# later in the loop here. Otherwise it will be the same

View file

@ -515,8 +515,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# We first linearise by the application service (to try to limit concurrent joins
# by application services), and then by room ID.
with (await self.member_as_limiter.queue(as_id)):
with (await self.member_linearizer.queue(key)):
async with self.member_as_limiter.queue(as_id):
async with self.member_linearizer.queue(key):
result = await self.update_membership_locked(
requester,
target,

View file

@ -59,8 +59,6 @@ class SearchHandler:
self.state_store = self.storage.state
self.auth = hs.get_auth()
self._msc3666_enabled = hs.config.experimental.msc3666_enabled
async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]:
"""Retrieves room IDs of old rooms in the history of an upgraded room.
@ -353,22 +351,20 @@ class SearchHandler:
state = await self.state_handler.get_current_state(room_id)
state_results[room_id] = list(state.values())
aggregations = None
if self._msc3666_enabled:
aggregations = await self._relations_handler.get_bundled_aggregations(
# Generate an iterable of EventBase for all the events that will be
# returned, including contextual events.
itertools.chain(
# The events_before and events_after for each context.
itertools.chain.from_iterable(
itertools.chain(context["events_before"], context["events_after"]) # type: ignore[arg-type]
for context in contexts.values()
),
# The returned events.
search_result.allowed_events,
aggregations = await self._relations_handler.get_bundled_aggregations(
# Generate an iterable of EventBase for all the events that will be
# returned, including contextual events.
itertools.chain(
# The events_before and events_after for each context.
itertools.chain.from_iterable(
itertools.chain(context["events_before"], context["events_after"]) # type: ignore[arg-type]
for context in contexts.values()
),
user.to_string(),
)
# The returned events.
search_result.allowed_events,
),
user.to_string(),
)
# We're now about to serialize the events. We should not make any
# blocking calls after this. Otherwise, the 'age' will be wrong.

View file

@ -430,7 +430,7 @@ class SsoHandler:
# grab a lock while we try to find a mapping for this user. This seems...
# optimistic, especially for implementations that end up redirecting to
# interstitial pages.
with await self._mapping_lock.queue(auth_provider_id):
async with self._mapping_lock.queue(auth_provider_id):
# first of all, check if we already have a mapping for this user
user_id = await self.get_sso_user_by_remote_user_id(
auth_provider_id,

View file

@ -13,17 +13,7 @@
# limitations under the License.
import itertools
import logging
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
FrozenSet,
List,
Optional,
Set,
Tuple,
)
from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple
import attr
from prometheus_client import Counter
@ -41,6 +31,7 @@ from synapse.storage.databases.main.event_push_actions import NotifCounts
from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
from synapse.types import (
DeviceListUpdates,
JsonDict,
MutableStateMap,
Requester,
@ -184,21 +175,6 @@ class GroupsSyncResult:
return bool(self.join or self.invite or self.leave)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DeviceLists:
"""
Attributes:
changed: List of user_ids whose devices may have changed
left: List of user_ids whose devices we no longer track
"""
changed: Collection[str]
left: Collection[str]
def __bool__(self) -> bool:
return bool(self.changed or self.left)
@attr.s(slots=True, auto_attribs=True)
class _RoomChanges:
"""The set of room entries to include in the sync, plus the set of joined
@ -240,7 +216,7 @@ class SyncResult:
knocked: List[KnockedSyncResult]
archived: List[ArchivedSyncResult]
to_device: List[JsonDict]
device_lists: DeviceLists
device_lists: DeviceListUpdates
device_one_time_keys_count: JsonDict
device_unused_fallback_key_types: List[str]
groups: Optional[GroupsSyncResult]
@ -298,6 +274,8 @@ class SyncHandler:
expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
)
self.rooms_to_exclude = hs.config.server.rooms_to_exclude_from_sync
async def wait_for_sync_for_user(
self,
requester: Requester,
@ -1176,8 +1154,9 @@ class SyncHandler:
await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
)
logger.debug("Fetching group data")
await self._generate_sync_entry_for_groups(sync_result_builder)
if self.hs_config.experimental.groups_enabled:
logger.debug("Fetching group data")
await self._generate_sync_entry_for_groups(sync_result_builder)
num_events = 0
@ -1261,8 +1240,8 @@ class SyncHandler:
newly_joined_or_invited_or_knocked_users: Set[str],
newly_left_rooms: Set[str],
newly_left_users: Set[str],
) -> DeviceLists:
"""Generate the DeviceLists section of sync
) -> DeviceListUpdates:
"""Generate the DeviceListUpdates section of sync
Args:
sync_result_builder
@ -1380,9 +1359,11 @@ class SyncHandler:
if any(e.room_id in joined_rooms for e in entries):
newly_left_users.discard(user_id)
return DeviceLists(changed=users_that_have_changed, left=newly_left_users)
return DeviceListUpdates(
changed=users_that_have_changed, left=newly_left_users
)
else:
return DeviceLists(changed=[], left=[])
return DeviceListUpdates()
async def _generate_sync_entry_for_to_device(
self, sync_result_builder: "SyncResultBuilder"
@ -1606,13 +1587,15 @@ class SyncHandler:
ignored_users = await self.store.ignored_users(user_id)
if since_token:
room_changes = await self._get_rooms_changed(
sync_result_builder, ignored_users
sync_result_builder, ignored_users, self.rooms_to_exclude
)
tags_by_room = await self.store.get_updated_tags(
user_id, since_token.account_data_key
)
else:
room_changes = await self._get_all_rooms(sync_result_builder, ignored_users)
room_changes = await self._get_all_rooms(
sync_result_builder, ignored_users, self.rooms_to_exclude
)
tags_by_room = await self.store.get_tags_for_user(user_id)
log_kv({"rooms_changed": len(room_changes.room_entries)})
@ -1688,7 +1671,10 @@ class SyncHandler:
return False
async def _get_rooms_changed(
self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
self,
sync_result_builder: "SyncResultBuilder",
ignored_users: FrozenSet[str],
excluded_rooms: List[str],
) -> _RoomChanges:
"""Determine the changes in rooms to report to the user.
@ -1720,7 +1706,7 @@ class SyncHandler:
# _have_rooms_changed. We could keep the results in memory to avoid a
# second query, at the cost of more complicated source code.
membership_change_events = await self.store.get_membership_changes_for_user(
user_id, since_token.room_key, now_token.room_key
user_id, since_token.room_key, now_token.room_key, excluded_rooms
)
mem_change_events_by_room_id: Dict[str, List[EventBase]] = {}
@ -1864,6 +1850,7 @@ class SyncHandler:
full_state=False,
since_token=since_token,
upto_token=leave_token,
out_of_band=leave_event.internal_metadata.is_out_of_band_membership(),
)
)
@ -1921,7 +1908,10 @@ class SyncHandler:
)
async def _get_all_rooms(
self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
self,
sync_result_builder: "SyncResultBuilder",
ignored_users: FrozenSet[str],
ignored_rooms: List[str],
) -> _RoomChanges:
"""Returns entries for all rooms for the user.
@ -1932,7 +1922,7 @@ class SyncHandler:
Args:
sync_result_builder
ignored_users: Set of users ignored by user.
ignored_rooms: List of rooms to ignore.
"""
user_id = sync_result_builder.sync_config.user.to_string()
@ -1943,6 +1933,7 @@ class SyncHandler:
room_list = await self.store.get_rooms_for_local_user_where_membership_is(
user_id=user_id,
membership_list=Membership.LIST,
excluded_rooms=ignored_rooms,
)
room_entries = []
@ -2126,33 +2117,41 @@ class SyncHandler:
):
return
state = await self.compute_state_delta(
room_id,
batch,
sync_config,
since_token,
now_token,
full_state=full_state,
)
if not room_builder.out_of_band:
state = await self.compute_state_delta(
room_id,
batch,
sync_config,
since_token,
now_token,
full_state=full_state,
)
else:
# An out of band room won't have any state changes.
state = {}
summary: Optional[JsonDict] = {}
# we include a summary in room responses when we're lazy loading
# members (as the client otherwise doesn't have enough info to form
# the name itself).
if sync_config.filter_collection.lazy_load_members() and (
# we recalculate the summary:
# if there are membership changes in the timeline, or
# if membership has changed during a gappy sync, or
# if this is an initial sync.
any(ev.type == EventTypes.Member for ev in batch.events)
or (
# XXX: this may include false positives in the form of LL
# members which have snuck into state
batch.limited
and any(t == EventTypes.Member for (t, k) in state)
if (
not room_builder.out_of_band
and sync_config.filter_collection.lazy_load_members()
and (
# we recalculate the summary:
# if there are membership changes in the timeline, or
# if membership has changed during a gappy sync, or
# if this is an initial sync.
any(ev.type == EventTypes.Member for ev in batch.events)
or (
# XXX: this may include false positives in the form of LL
# members which have snuck into state
batch.limited
and any(t == EventTypes.Member for (t, k) in state)
)
or since_token is None
)
or since_token is None
):
summary = await self.compute_summary(
room_id, sync_config, batch, state, now_token
@ -2396,6 +2395,8 @@ class RoomSyncResultBuilder:
full_state: Whether the full state should be sent in result
since_token: Earliest point to return events from, or None
upto_token: Latest point to return events from.
out_of_band: whether the events in the room are "out of band" events
and the server isn't in the room.
"""
room_id: str
@ -2405,3 +2406,5 @@ class RoomSyncResultBuilder:
full_state: bool
since_token: Optional[StreamToken]
upto_token: StreamToken
out_of_band: bool = False

View file

@ -107,6 +107,8 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
# TODO: get this from the homeserver rather than creating a new one for
# each request
try:
assert self._secret is not None
resp_body = await self._http_client.post_urlencoded_get_json(
self._url,
args={