mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-08-22 06:29:25 -04:00
Merge remote-tracking branch 'upstream/release-v1.32.0'
This commit is contained in:
commit
dbe12b3eb1
207 changed files with 4883 additions and 1550 deletions
|
@ -49,7 +49,7 @@ class BaseHandler:
|
|||
|
||||
# The rate_hz and burst_count are overridden on a per-user basis
|
||||
self.request_ratelimiter = Ratelimiter(
|
||||
clock=self.clock, rate_hz=0, burst_count=0
|
||||
store=self.store, clock=self.clock, rate_hz=0, burst_count=0
|
||||
)
|
||||
self._rc_message = self.hs.config.rc_message
|
||||
|
||||
|
@ -57,6 +57,7 @@ class BaseHandler:
|
|||
# by the presence of rate limits in the config
|
||||
if self.hs.config.rc_admin_redaction:
|
||||
self.admin_redaction_ratelimiter = Ratelimiter(
|
||||
store=self.store,
|
||||
clock=self.clock,
|
||||
rate_hz=self.hs.config.rc_admin_redaction.per_second,
|
||||
burst_count=self.hs.config.rc_admin_redaction.burst_count,
|
||||
|
@ -91,11 +92,6 @@ class BaseHandler:
|
|||
if app_service is not None:
|
||||
return # do not ratelimit app service senders
|
||||
|
||||
# Disable rate limiting of users belonging to any AS that is configured
|
||||
# not to be rate limited in its registration file (rate_limited: true|false).
|
||||
if requester.app_service and not requester.app_service.is_rate_limited():
|
||||
return
|
||||
|
||||
messages_per_second = self._rc_message.per_second
|
||||
burst_count = self._rc_message.burst_count
|
||||
|
||||
|
@ -113,11 +109,11 @@ class BaseHandler:
|
|||
if is_admin_redaction and self.admin_redaction_ratelimiter:
|
||||
# If we have separate config for admin redactions, use a separate
|
||||
# ratelimiter as to not have user_ids clash
|
||||
self.admin_redaction_ratelimiter.ratelimit(user_id, update=update)
|
||||
await self.admin_redaction_ratelimiter.ratelimit(requester, update=update)
|
||||
else:
|
||||
# Override rate and burst count per-user
|
||||
self.request_ratelimiter.ratelimit(
|
||||
user_id,
|
||||
await self.request_ratelimiter.ratelimit(
|
||||
requester,
|
||||
rate_hz=messages_per_second,
|
||||
burst_count=burst_count,
|
||||
update=update,
|
||||
|
|
|
@ -18,7 +18,7 @@ import email.utils
|
|||
import logging
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import TYPE_CHECKING, List
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from synapse.api.errors import StoreError, SynapseError
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
|
@ -241,7 +241,10 @@ class AccountValidityHandler:
|
|||
return True
|
||||
|
||||
async def renew_account_for_user(
|
||||
self, user_id: str, expiration_ts: int = None, email_sent: bool = False
|
||||
self,
|
||||
user_id: str,
|
||||
expiration_ts: Optional[int] = None,
|
||||
email_sent: bool = False,
|
||||
) -> int:
|
||||
"""Renews the account attached to a given user by pushing back the
|
||||
expiration date by the current validity period in the server's
|
||||
|
|
|
@ -182,7 +182,7 @@ class ApplicationServicesHandler:
|
|||
self,
|
||||
stream_key: str,
|
||||
new_token: Optional[int],
|
||||
users: Collection[Union[str, UserID]] = [],
|
||||
users: Optional[Collection[Union[str, UserID]]] = None,
|
||||
):
|
||||
"""This is called by the notifier in the background
|
||||
when a ephemeral event handled by the homeserver.
|
||||
|
@ -215,7 +215,7 @@ class ApplicationServicesHandler:
|
|||
# We only start a new background process if necessary rather than
|
||||
# optimistically (to cut down on overhead).
|
||||
self._notify_interested_services_ephemeral(
|
||||
services, stream_key, new_token, users
|
||||
services, stream_key, new_token, users or []
|
||||
)
|
||||
|
||||
@wrap_as_background_process("notify_interested_services_ephemeral")
|
||||
|
|
|
@ -238,6 +238,7 @@ class AuthHandler(BaseHandler):
|
|||
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
|
||||
# as per `rc_login.failed_attempts`.
|
||||
self._failed_uia_attempts_ratelimiter = Ratelimiter(
|
||||
store=self.store,
|
||||
clock=self.clock,
|
||||
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
|
||||
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
|
||||
|
@ -248,6 +249,7 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
# Ratelimitier for failed /login attempts
|
||||
self._failed_login_attempts_ratelimiter = Ratelimiter(
|
||||
store=self.store,
|
||||
clock=hs.get_clock(),
|
||||
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
|
||||
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
|
||||
|
@ -352,7 +354,7 @@ class AuthHandler(BaseHandler):
|
|||
requester_user_id = requester.user.to_string()
|
||||
|
||||
# Check if we should be ratelimited due to too many previous failed attempts
|
||||
self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False)
|
||||
await self._failed_uia_attempts_ratelimiter.ratelimit(requester, update=False)
|
||||
|
||||
# build a list of supported flows
|
||||
supported_ui_auth_types = await self._get_available_ui_auth_types(
|
||||
|
@ -373,7 +375,9 @@ class AuthHandler(BaseHandler):
|
|||
)
|
||||
except LoginError:
|
||||
# Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
|
||||
self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id)
|
||||
await self._failed_uia_attempts_ratelimiter.can_do_action(
|
||||
requester,
|
||||
)
|
||||
raise
|
||||
|
||||
# find the completed login type
|
||||
|
@ -982,8 +986,8 @@ class AuthHandler(BaseHandler):
|
|||
# We also apply account rate limiting using the 3PID as a key, as
|
||||
# otherwise using 3PID bypasses the ratelimiting based on user ID.
|
||||
if ratelimit:
|
||||
self._failed_login_attempts_ratelimiter.ratelimit(
|
||||
(medium, address), update=False
|
||||
await self._failed_login_attempts_ratelimiter.ratelimit(
|
||||
None, (medium, address), update=False
|
||||
)
|
||||
|
||||
# Check for login providers that support 3pid login types
|
||||
|
@ -1016,8 +1020,8 @@ class AuthHandler(BaseHandler):
|
|||
# this code path, which is fine as then the per-user ratelimit
|
||||
# will kick in below.
|
||||
if ratelimit:
|
||||
self._failed_login_attempts_ratelimiter.can_do_action(
|
||||
(medium, address)
|
||||
await self._failed_login_attempts_ratelimiter.can_do_action(
|
||||
None, (medium, address)
|
||||
)
|
||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||
|
||||
|
@ -1039,8 +1043,8 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
# Check if we've hit the failed ratelimit (but don't update it)
|
||||
if ratelimit:
|
||||
self._failed_login_attempts_ratelimiter.ratelimit(
|
||||
qualified_user_id.lower(), update=False
|
||||
await self._failed_login_attempts_ratelimiter.ratelimit(
|
||||
None, qualified_user_id.lower(), update=False
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -1051,8 +1055,8 @@ class AuthHandler(BaseHandler):
|
|||
# exception and masking the LoginError. The actual ratelimiting
|
||||
# should have happened above.
|
||||
if ratelimit:
|
||||
self._failed_login_attempts_ratelimiter.can_do_action(
|
||||
qualified_user_id.lower()
|
||||
await self._failed_login_attempts_ratelimiter.can_do_action(
|
||||
None, qualified_user_id.lower()
|
||||
)
|
||||
raise
|
||||
|
||||
|
|
|
@ -631,7 +631,7 @@ class DeviceListUpdater:
|
|||
max_len=10000,
|
||||
expiry_ms=30 * 60 * 1000,
|
||||
iterable=True,
|
||||
)
|
||||
) # type: ExpiringCache[str, Set[str]]
|
||||
|
||||
# Attempt to resync out of sync device lists every 30s.
|
||||
self._resync_retry_in_progress = False
|
||||
|
@ -760,7 +760,7 @@ class DeviceListUpdater:
|
|||
"""Given a list of updates for a user figure out if we need to do a full
|
||||
resync, or whether we have enough data that we can just apply the delta.
|
||||
"""
|
||||
seen_updates = self._seen_updates.get(user_id, set())
|
||||
seen_updates = self._seen_updates.get(user_id, set()) # type: Set[str]
|
||||
|
||||
extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)
|
||||
|
||||
|
|
|
@ -21,10 +21,10 @@ from synapse.api.errors import SynapseError
|
|||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.logging.context import run_in_background
|
||||
from synapse.logging.opentracing import (
|
||||
SynapseTags,
|
||||
get_active_span_text_map,
|
||||
log_kv,
|
||||
set_tag,
|
||||
start_active_span,
|
||||
)
|
||||
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
|
||||
from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
|
||||
|
@ -81,6 +81,7 @@ class DeviceMessageHandler:
|
|||
)
|
||||
|
||||
self._ratelimiter = Ratelimiter(
|
||||
store=self.store,
|
||||
clock=hs.get_clock(),
|
||||
rate_hz=hs.config.rc_key_requests.per_second,
|
||||
burst_count=hs.config.rc_key_requests.burst_count,
|
||||
|
@ -182,7 +183,10 @@ class DeviceMessageHandler:
|
|||
) -> None:
|
||||
sender_user_id = requester.user.to_string()
|
||||
|
||||
set_tag("number_of_messages", len(messages))
|
||||
message_id = random_string(16)
|
||||
set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id)
|
||||
|
||||
log_kv({"number_of_to_device_messages": len(messages)})
|
||||
set_tag("sender", sender_user_id)
|
||||
local_messages = {}
|
||||
remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
|
||||
|
@ -191,8 +195,8 @@ class DeviceMessageHandler:
|
|||
if (
|
||||
message_type == EduTypes.RoomKeyRequest
|
||||
and user_id != sender_user_id
|
||||
and self._ratelimiter.can_do_action(
|
||||
(sender_user_id, requester.device_id)
|
||||
and await self._ratelimiter.can_do_action(
|
||||
requester, (sender_user_id, requester.device_id)
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
@ -204,32 +208,35 @@ class DeviceMessageHandler:
|
|||
"content": message_content,
|
||||
"type": message_type,
|
||||
"sender": sender_user_id,
|
||||
"message_id": message_id,
|
||||
}
|
||||
for device_id, message_content in by_device.items()
|
||||
}
|
||||
if messages_by_device:
|
||||
local_messages[user_id] = messages_by_device
|
||||
log_kv(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"device_id": list(messages_by_device),
|
||||
}
|
||||
)
|
||||
else:
|
||||
destination = get_domain_from_id(user_id)
|
||||
remote_messages.setdefault(destination, {})[user_id] = by_device
|
||||
|
||||
message_id = random_string(16)
|
||||
|
||||
context = get_active_span_text_map()
|
||||
|
||||
remote_edu_contents = {}
|
||||
for destination, messages in remote_messages.items():
|
||||
with start_active_span("to_device_for_user"):
|
||||
set_tag("destination", destination)
|
||||
remote_edu_contents[destination] = {
|
||||
"messages": messages,
|
||||
"sender": sender_user_id,
|
||||
"type": message_type,
|
||||
"message_id": message_id,
|
||||
"org.matrix.opentracing_context": json_encoder.encode(context),
|
||||
}
|
||||
log_kv({"destination": destination})
|
||||
remote_edu_contents[destination] = {
|
||||
"messages": messages,
|
||||
"sender": sender_user_id,
|
||||
"type": message_type,
|
||||
"message_id": message_id,
|
||||
"org.matrix.opentracing_context": json_encoder.encode(context),
|
||||
}
|
||||
|
||||
log_kv({"local_messages": local_messages})
|
||||
stream_id = await self.store.add_messages_to_device_inbox(
|
||||
local_messages, remote_edu_contents
|
||||
)
|
||||
|
@ -238,7 +245,6 @@ class DeviceMessageHandler:
|
|||
"to_device_key", stream_id, users=local_messages.keys()
|
||||
)
|
||||
|
||||
log_kv({"remote_messages": remote_messages})
|
||||
if self.federation_sender:
|
||||
for destination in remote_messages.keys():
|
||||
# Enqueue a new federation transaction to send the new
|
||||
|
|
|
@ -38,7 +38,6 @@ from synapse.types import (
|
|||
)
|
||||
from synapse.util import json_decoder, unwrapFirstError
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -1008,7 +1007,7 @@ class E2eKeysHandler:
|
|||
return signature_list, failures
|
||||
|
||||
async def _get_e2e_cross_signing_verify_key(
|
||||
self, user_id: str, key_type: str, from_user_id: str = None
|
||||
self, user_id: str, key_type: str, from_user_id: Optional[str] = None
|
||||
) -> Tuple[JsonDict, str, VerifyKey]:
|
||||
"""Fetch locally or remotely query for a cross-signing public key.
|
||||
|
||||
|
@ -1292,17 +1291,6 @@ class SigningKeyEduUpdater:
|
|||
# user_id -> list of updates waiting to be handled.
|
||||
self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
|
||||
|
||||
# Recently seen stream ids. We don't bother keeping these in the DB,
|
||||
# but they're useful to have them about to reduce the number of spurious
|
||||
# resyncs.
|
||||
self._seen_updates = ExpiringCache(
|
||||
cache_name="signing_key_update_edu",
|
||||
clock=self.clock,
|
||||
max_len=10000,
|
||||
expiry_ms=30 * 60 * 1000,
|
||||
iterable=True,
|
||||
)
|
||||
|
||||
async def incoming_signing_key_update(
|
||||
self, origin: str, edu_content: JsonDict
|
||||
) -> None:
|
||||
|
|
|
@ -21,7 +21,17 @@ import itertools
|
|||
import logging
|
||||
from collections.abc import Container
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import attr
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
|
@ -171,15 +181,17 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
|
||||
|
||||
async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
|
||||
async def on_receive_pdu(
|
||||
self, origin: str, pdu: EventBase, sent_to_us_directly: bool = False
|
||||
) -> None:
|
||||
"""Process a PDU received via a federation /send/ transaction, or
|
||||
via backfill of missing prev_events
|
||||
|
||||
Args:
|
||||
origin (str): server which initiated the /send/ transaction. Will
|
||||
origin: server which initiated the /send/ transaction. Will
|
||||
be used to fetch missing events or state.
|
||||
pdu (FrozenEvent): received PDU
|
||||
sent_to_us_directly (bool): True if this event was pushed to us; False if
|
||||
pdu: received PDU
|
||||
sent_to_us_directly: True if this event was pushed to us; False if
|
||||
we pulled it as the result of a missing prev_event.
|
||||
"""
|
||||
|
||||
|
@ -411,13 +423,15 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
await self._process_received_pdu(origin, pdu, state=state)
|
||||
|
||||
async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
|
||||
async def _get_missing_events_for_pdu(
|
||||
self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
origin (str): Origin of the pdu. Will be called to get the missing events
|
||||
origin: Origin of the pdu. Will be called to get the missing events
|
||||
pdu: received pdu
|
||||
prevs (set(str)): List of event ids which we are missing
|
||||
min_depth (int): Minimum depth of events to return.
|
||||
prevs: List of event ids which we are missing
|
||||
min_depth: Minimum depth of events to return.
|
||||
"""
|
||||
|
||||
room_id = pdu.room_id
|
||||
|
@ -778,7 +792,7 @@ class FederationHandler(BaseHandler):
|
|||
origin: str,
|
||||
event: EventBase,
|
||||
state: Optional[Iterable[EventBase]],
|
||||
):
|
||||
) -> None:
|
||||
"""Called when we have a new pdu. We need to do auth checks and put it
|
||||
through the StateHandler.
|
||||
|
||||
|
@ -887,7 +901,9 @@ class FederationHandler(BaseHandler):
|
|||
logger.exception("Failed to resync device for %s", sender)
|
||||
|
||||
@log_function
|
||||
async def backfill(self, dest, room_id, limit, extremities):
|
||||
async def backfill(
|
||||
self, dest: str, room_id: str, limit: int, extremities: List[str]
|
||||
) -> List[EventBase]:
|
||||
"""Trigger a backfill request to `dest` for the given `room_id`
|
||||
|
||||
This will attempt to get more events from the remote. If the other side
|
||||
|
@ -1142,16 +1158,15 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
curr_state = await self.state_handler.get_current_state(room_id)
|
||||
|
||||
def get_domains_from_state(state):
|
||||
def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
|
||||
"""Get joined domains from state
|
||||
|
||||
Args:
|
||||
state (dict[tuple, FrozenEvent]): State map from type/state
|
||||
key to event.
|
||||
state: State map from type/state key to event.
|
||||
|
||||
Returns:
|
||||
list[tuple[str, int]]: Returns a list of servers with the
|
||||
lowest depth of their joins. Sorted by lowest depth first.
|
||||
Returns a list of servers with the lowest depth of their joins.
|
||||
Sorted by lowest depth first.
|
||||
"""
|
||||
joined_users = [
|
||||
(state_key, int(event.depth))
|
||||
|
@ -1179,7 +1194,7 @@ class FederationHandler(BaseHandler):
|
|||
domain for domain, depth in curr_domains if domain != self.server_name
|
||||
]
|
||||
|
||||
async def try_backfill(domains):
|
||||
async def try_backfill(domains: List[str]) -> bool:
|
||||
# TODO: Should we try multiple of these at a time?
|
||||
for dom in domains:
|
||||
try:
|
||||
|
@ -1258,21 +1273,25 @@ class FederationHandler(BaseHandler):
|
|||
}
|
||||
|
||||
for e_id, _ in sorted_extremeties_tuple:
|
||||
likely_domains = get_domains_from_state(states[e_id])
|
||||
likely_extremeties_domains = get_domains_from_state(states[e_id])
|
||||
|
||||
success = await try_backfill(
|
||||
[dom for dom, _ in likely_domains if dom not in tried_domains]
|
||||
[
|
||||
dom
|
||||
for dom, _ in likely_extremeties_domains
|
||||
if dom not in tried_domains
|
||||
]
|
||||
)
|
||||
if success:
|
||||
return True
|
||||
|
||||
tried_domains.update(dom for dom, _ in likely_domains)
|
||||
tried_domains.update(dom for dom, _ in likely_extremeties_domains)
|
||||
|
||||
return False
|
||||
|
||||
async def _get_events_and_persist(
|
||||
self, destination: str, room_id: str, events: Iterable[str]
|
||||
):
|
||||
) -> None:
|
||||
"""Fetch the given events from a server, and persist them as outliers.
|
||||
|
||||
This function *does not* recursively get missing auth events of the
|
||||
|
@ -1348,7 +1367,7 @@ class FederationHandler(BaseHandler):
|
|||
event_infos,
|
||||
)
|
||||
|
||||
def _sanity_check_event(self, ev):
|
||||
def _sanity_check_event(self, ev: EventBase) -> None:
|
||||
"""
|
||||
Do some early sanity checks of a received event
|
||||
|
||||
|
@ -1357,9 +1376,7 @@ class FederationHandler(BaseHandler):
|
|||
or cascade of event fetches.
|
||||
|
||||
Args:
|
||||
ev (synapse.events.EventBase): event to be checked
|
||||
|
||||
Returns: None
|
||||
ev: event to be checked
|
||||
|
||||
Raises:
|
||||
SynapseError if the event does not pass muster
|
||||
|
@ -1380,7 +1397,7 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
|
||||
|
||||
async def send_invite(self, target_host, event):
|
||||
async def send_invite(self, target_host: str, event: EventBase) -> EventBase:
|
||||
"""Sends the invite to the remote server for signing.
|
||||
|
||||
Invites must be signed by the invitee's server before distribution.
|
||||
|
@ -1528,12 +1545,13 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
run_in_background(self._handle_queued_pdus, room_queue)
|
||||
|
||||
async def _handle_queued_pdus(self, room_queue):
|
||||
async def _handle_queued_pdus(
|
||||
self, room_queue: List[Tuple[EventBase, str]]
|
||||
) -> None:
|
||||
"""Process PDUs which got queued up while we were busy send_joining.
|
||||
|
||||
Args:
|
||||
room_queue (list[FrozenEvent, str]): list of PDUs to be processed
|
||||
and the servers that sent them
|
||||
room_queue: list of PDUs to be processed and the servers that sent them
|
||||
"""
|
||||
for p, origin in room_queue:
|
||||
try:
|
||||
|
@ -1612,7 +1630,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
return event
|
||||
|
||||
async def on_send_join_request(self, origin, pdu):
|
||||
async def on_send_join_request(self, origin: str, pdu: EventBase) -> JsonDict:
|
||||
"""We have received a join event for a room. Fully process it and
|
||||
respond with the current state and auth chains.
|
||||
"""
|
||||
|
@ -1668,7 +1686,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
async def on_invite_request(
|
||||
self, origin: str, event: EventBase, room_version: RoomVersion
|
||||
):
|
||||
) -> EventBase:
|
||||
"""We've got an invite event. Process and persist it. Sign it.
|
||||
|
||||
Respond with the now signed event.
|
||||
|
@ -1711,7 +1729,7 @@ class FederationHandler(BaseHandler):
|
|||
member_handler = self.hs.get_room_member_handler()
|
||||
# We don't rate limit based on room ID, as that should be done by
|
||||
# sending server.
|
||||
member_handler.ratelimit_invite(None, event.state_key)
|
||||
await member_handler.ratelimit_invite(None, None, event.state_key)
|
||||
|
||||
# keep a record of the room version, if we don't yet know it.
|
||||
# (this may get overwritten if we later get a different room version in a
|
||||
|
@ -1772,7 +1790,7 @@ class FederationHandler(BaseHandler):
|
|||
room_id: str,
|
||||
user_id: str,
|
||||
membership: str,
|
||||
content: JsonDict = {},
|
||||
content: JsonDict,
|
||||
params: Optional[Dict[str, Union[str, Iterable[str]]]] = None,
|
||||
) -> Tuple[str, EventBase, RoomVersion]:
|
||||
(
|
||||
|
@ -1841,7 +1859,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
return event
|
||||
|
||||
async def on_send_leave_request(self, origin, pdu):
|
||||
async def on_send_leave_request(self, origin: str, pdu: EventBase) -> None:
|
||||
""" We have received a leave event for a room. Fully process it."""
|
||||
event = pdu
|
||||
|
||||
|
@ -1969,12 +1987,17 @@ class FederationHandler(BaseHandler):
|
|||
else:
|
||||
return None
|
||||
|
||||
async def get_min_depth_for_context(self, context):
|
||||
async def get_min_depth_for_context(self, context: str) -> int:
|
||||
return await self.store.get_min_depth(context)
|
||||
|
||||
async def _handle_new_event(
|
||||
self, origin, event, state=None, auth_events=None, backfilled=False
|
||||
):
|
||||
self,
|
||||
origin: str,
|
||||
event: EventBase,
|
||||
state: Optional[Iterable[EventBase]] = None,
|
||||
auth_events: Optional[MutableStateMap[EventBase]] = None,
|
||||
backfilled: bool = False,
|
||||
) -> EventContext:
|
||||
context = await self._prep_event(
|
||||
origin, event, state=state, auth_events=auth_events, backfilled=backfilled
|
||||
)
|
||||
|
@ -2280,40 +2303,14 @@ class FederationHandler(BaseHandler):
|
|||
logger.warning("Soft-failing %r because %s", event, e)
|
||||
event.internal_metadata.soft_failed = True
|
||||
|
||||
async def on_query_auth(
|
||||
self, origin, event_id, room_id, remote_auth_chain, rejects, missing
|
||||
):
|
||||
in_room = await self.auth.check_host_in_room(room_id, origin)
|
||||
if not in_room:
|
||||
raise AuthError(403, "Host not in room.")
|
||||
|
||||
event = await self.store.get_event(event_id, check_room_id=room_id)
|
||||
|
||||
# Just go through and process each event in `remote_auth_chain`. We
|
||||
# don't want to fall into the trap of `missing` being wrong.
|
||||
for e in remote_auth_chain:
|
||||
try:
|
||||
await self._handle_new_event(origin, e)
|
||||
except AuthError:
|
||||
pass
|
||||
|
||||
# Now get the current auth_chain for the event.
|
||||
local_auth_chain = await self.store.get_auth_chain(
|
||||
room_id, list(event.auth_event_ids()), include_given=True
|
||||
)
|
||||
|
||||
# TODO: Check if we would now reject event_id. If so we need to tell
|
||||
# everyone.
|
||||
|
||||
ret = await self.construct_auth_difference(local_auth_chain, remote_auth_chain)
|
||||
|
||||
logger.debug("on_query_auth returning: %s", ret)
|
||||
|
||||
return ret
|
||||
|
||||
async def on_get_missing_events(
|
||||
self, origin, room_id, earliest_events, latest_events, limit
|
||||
):
|
||||
self,
|
||||
origin: str,
|
||||
room_id: str,
|
||||
earliest_events: List[str],
|
||||
latest_events: List[str],
|
||||
limit: int,
|
||||
) -> List[EventBase]:
|
||||
in_room = await self.auth.check_host_in_room(room_id, origin)
|
||||
if not in_room:
|
||||
raise AuthError(403, "Host not in room.")
|
||||
|
@ -2617,8 +2614,8 @@ class FederationHandler(BaseHandler):
|
|||
assumes that we have already processed all events in remote_auth
|
||||
|
||||
Params:
|
||||
local_auth (list)
|
||||
remote_auth (list)
|
||||
local_auth
|
||||
remote_auth
|
||||
|
||||
Returns:
|
||||
dict
|
||||
|
@ -2742,8 +2739,8 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
@log_function
|
||||
async def exchange_third_party_invite(
|
||||
self, sender_user_id, target_user_id, room_id, signed
|
||||
):
|
||||
self, sender_user_id: str, target_user_id: str, room_id: str, signed: JsonDict
|
||||
) -> None:
|
||||
third_party_invite = {"signed": signed}
|
||||
|
||||
event_dict = {
|
||||
|
@ -2835,8 +2832,12 @@ class FederationHandler(BaseHandler):
|
|||
await member_handler.send_membership_event(None, event, context)
|
||||
|
||||
async def add_display_name_to_third_party_invite(
|
||||
self, room_version, event_dict, event, context
|
||||
):
|
||||
self,
|
||||
room_version: str,
|
||||
event_dict: JsonDict,
|
||||
event: EventBase,
|
||||
context: EventContext,
|
||||
) -> Tuple[EventBase, EventContext]:
|
||||
key = (
|
||||
EventTypes.ThirdPartyInvite,
|
||||
event.content["third_party_invite"]["signed"]["token"],
|
||||
|
@ -2872,13 +2873,13 @@ class FederationHandler(BaseHandler):
|
|||
EventValidator().validate_new(event, self.config)
|
||||
return (event, context)
|
||||
|
||||
async def _check_signature(self, event, context):
|
||||
async def _check_signature(self, event: EventBase, context: EventContext) -> None:
|
||||
"""
|
||||
Checks that the signature in the event is consistent with its invite.
|
||||
|
||||
Args:
|
||||
event (Event): The m.room.member event to check
|
||||
context (EventContext):
|
||||
event: The m.room.member event to check
|
||||
context:
|
||||
|
||||
Raises:
|
||||
AuthError: if signature didn't match any keys, or key has been
|
||||
|
@ -2964,13 +2965,13 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
raise last_exception
|
||||
|
||||
async def _check_key_revocation(self, public_key, url):
|
||||
async def _check_key_revocation(self, public_key: str, url: str) -> None:
|
||||
"""
|
||||
Checks whether public_key has been revoked.
|
||||
|
||||
Args:
|
||||
public_key (str): base-64 encoded public key.
|
||||
url (str): Key revocation URL.
|
||||
public_key: base-64 encoded public key.
|
||||
url: Key revocation URL.
|
||||
|
||||
Raises:
|
||||
AuthError: if they key has been revoked.
|
||||
|
|
|
@ -61,17 +61,19 @@ class IdentityHandler(BaseHandler):
|
|||
|
||||
# Ratelimiters for `/requestToken` endpoints.
|
||||
self._3pid_validation_ratelimiter_ip = Ratelimiter(
|
||||
store=self.store,
|
||||
clock=hs.get_clock(),
|
||||
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
|
||||
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
|
||||
)
|
||||
self._3pid_validation_ratelimiter_address = Ratelimiter(
|
||||
store=self.store,
|
||||
clock=hs.get_clock(),
|
||||
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
|
||||
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
|
||||
)
|
||||
|
||||
def ratelimit_request_token_requests(
|
||||
async def ratelimit_request_token_requests(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
medium: str,
|
||||
|
@ -85,8 +87,12 @@ class IdentityHandler(BaseHandler):
|
|||
address: The actual threepid ID, e.g. the phone number or email address
|
||||
"""
|
||||
|
||||
self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP()))
|
||||
self._3pid_validation_ratelimiter_address.ratelimit((medium, address))
|
||||
await self._3pid_validation_ratelimiter_ip.ratelimit(
|
||||
None, (medium, request.getClientIP())
|
||||
)
|
||||
await self._3pid_validation_ratelimiter_address.ratelimit(
|
||||
None, (medium, address)
|
||||
)
|
||||
|
||||
async def threepid_from_creds(
|
||||
self, id_server: str, creds: Dict[str, str]
|
||||
|
|
|
@ -137,7 +137,7 @@ class MessageHandler:
|
|||
self,
|
||||
user_id: str,
|
||||
room_id: str,
|
||||
state_filter: StateFilter = StateFilter.all(),
|
||||
state_filter: Optional[StateFilter] = None,
|
||||
at_token: Optional[StreamToken] = None,
|
||||
is_guest: bool = False,
|
||||
) -> List[dict]:
|
||||
|
@ -164,6 +164,8 @@ class MessageHandler:
|
|||
AuthError (403) if the user doesn't have permission to view
|
||||
members of this room.
|
||||
"""
|
||||
state_filter = state_filter or StateFilter.all()
|
||||
|
||||
if at_token:
|
||||
# FIXME this claims to get the state at a stream position, but
|
||||
# get_recent_events_for_room operates by topo ordering. This therefore
|
||||
|
@ -385,7 +387,7 @@ class EventCreationHandler:
|
|||
self._events_shard_config = self.config.worker.events_shard_config
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
||||
self.room_invite_state_types = self.hs.config.room_invite_state_types
|
||||
self.room_invite_state_types = self.hs.config.api.room_prejoin_state
|
||||
|
||||
self.membership_types_to_include_profile_data_in = (
|
||||
{Membership.JOIN, Membership.INVITE}
|
||||
|
@ -876,7 +878,7 @@ class EventCreationHandler:
|
|||
event: EventBase,
|
||||
context: EventContext,
|
||||
ratelimit: bool = True,
|
||||
extra_users: List[UserID] = [],
|
||||
extra_users: Optional[List[UserID]] = None,
|
||||
ignore_shadow_ban: bool = False,
|
||||
) -> EventBase:
|
||||
"""Processes a new event.
|
||||
|
@ -904,6 +906,7 @@ class EventCreationHandler:
|
|||
Raises:
|
||||
ShadowBanError if the requester has been shadow-banned.
|
||||
"""
|
||||
extra_users = extra_users or []
|
||||
|
||||
# we don't apply shadow-banning to membership events here. Invites are blocked
|
||||
# higher up the stack, and we allow shadow-banned users to send join and leave
|
||||
|
@ -1073,7 +1076,7 @@ class EventCreationHandler:
|
|||
event: EventBase,
|
||||
context: EventContext,
|
||||
ratelimit: bool = True,
|
||||
extra_users: List[UserID] = [],
|
||||
extra_users: Optional[List[UserID]] = None,
|
||||
) -> EventBase:
|
||||
"""Called when we have fully built the event, have already
|
||||
calculated the push actions for the event, and checked auth.
|
||||
|
@ -1085,6 +1088,8 @@ class EventCreationHandler:
|
|||
it was de-duplicated (e.g. because we had already persisted an
|
||||
event with the same transaction ID.)
|
||||
"""
|
||||
extra_users = extra_users or []
|
||||
|
||||
assert self.storage.persistence is not None
|
||||
assert self._events_shard_config.should_handle(
|
||||
self._instance_name, event.room_id
|
||||
|
|
|
@ -25,7 +25,17 @@ The methods that define policy are:
|
|||
import abc
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Dict,
|
||||
FrozenSet,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from prometheus_client import Counter
|
||||
from typing_extensions import ContextManager
|
||||
|
@ -34,6 +44,7 @@ import synapse.metrics
|
|||
from synapse.api.constants import EventTypes, Membership, PresenceState
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.presence import UserPresenceState
|
||||
from synapse.events.presence_router import PresenceRouter
|
||||
from synapse.logging.context import run_in_background
|
||||
from synapse.logging.utils import log_function
|
||||
from synapse.metrics import LaterGauge
|
||||
|
@ -42,7 +53,7 @@ from synapse.state import StateHandler
|
|||
from synapse.storage.databases.main import DataStore
|
||||
from synapse.types import Collection, JsonDict, UserID, get_domain_from_id
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.caches.descriptors import _CacheContext, cached
|
||||
from synapse.util.metrics import Measure
|
||||
from synapse.util.wheel_timer import WheelTimer
|
||||
|
||||
|
@ -209,6 +220,7 @@ class PresenceHandler(BasePresenceHandler):
|
|||
self.notifier = hs.get_notifier()
|
||||
self.federation = hs.get_federation_sender()
|
||||
self.state = hs.get_state_handler()
|
||||
self.presence_router = hs.get_presence_router()
|
||||
self._presence_enabled = hs.config.use_presence
|
||||
|
||||
federation_registry = hs.get_federation_registry()
|
||||
|
@ -653,7 +665,7 @@ class PresenceHandler(BasePresenceHandler):
|
|||
"""
|
||||
stream_id, max_token = await self.store.update_presence(states)
|
||||
|
||||
parties = await get_interested_parties(self.store, states)
|
||||
parties = await get_interested_parties(self.store, self.presence_router, states)
|
||||
room_ids_to_states, users_to_states = parties
|
||||
|
||||
self.notifier.on_new_event(
|
||||
|
@ -1041,7 +1053,12 @@ class PresenceEventSource:
|
|||
#
|
||||
# Presence -> Notifier -> PresenceEventSource -> Presence
|
||||
#
|
||||
# Same with get_module_api, get_presence_router
|
||||
#
|
||||
# AuthHandler -> Notifier -> PresenceEventSource -> ModuleApi -> AuthHandler
|
||||
self.get_presence_handler = hs.get_presence_handler
|
||||
self.get_module_api = hs.get_module_api
|
||||
self.get_presence_router = hs.get_presence_router
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
self.state = hs.get_state_handler()
|
||||
|
@ -1054,8 +1071,8 @@ class PresenceEventSource:
|
|||
room_ids=None,
|
||||
include_offline=True,
|
||||
explicit_room_id=None,
|
||||
**kwargs
|
||||
):
|
||||
**kwargs,
|
||||
) -> Tuple[List[UserPresenceState], int]:
|
||||
# The process for getting presence events are:
|
||||
# 1. Get the rooms the user is in.
|
||||
# 2. Get the list of user in the rooms.
|
||||
|
@ -1068,7 +1085,17 @@ class PresenceEventSource:
|
|||
# We don't try and limit the presence updates by the current token, as
|
||||
# sending down the rare duplicate is not a concern.
|
||||
|
||||
user_id = user.to_string()
|
||||
stream_change_cache = self.store.presence_stream_cache
|
||||
|
||||
with Measure(self.clock, "presence.get_new_events"):
|
||||
if user_id in self.get_module_api()._send_full_presence_to_local_users:
|
||||
# This user has been specified by a module to receive all current, online
|
||||
# user presence. Removing from_key and setting include_offline to false
|
||||
# will do effectively this.
|
||||
from_key = None
|
||||
include_offline = False
|
||||
|
||||
if from_key is not None:
|
||||
from_key = int(from_key)
|
||||
|
||||
|
@ -1091,59 +1118,209 @@ class PresenceEventSource:
|
|||
# doesn't return. C.f. #5503.
|
||||
return [], max_token
|
||||
|
||||
presence = self.get_presence_handler()
|
||||
stream_change_cache = self.store.presence_stream_cache
|
||||
|
||||
# Figure out which other users this user should receive updates for
|
||||
users_interested_in = await self._get_interested_in(user, explicit_room_id)
|
||||
|
||||
user_ids_changed = set() # type: Collection[str]
|
||||
changed = None
|
||||
if from_key:
|
||||
changed = stream_change_cache.get_all_entities_changed(from_key)
|
||||
# We have a set of users that we're interested in the presence of. We want to
|
||||
# cross-reference that with the users that have actually changed their presence.
|
||||
|
||||
if changed is not None and len(changed) < 500:
|
||||
assert isinstance(user_ids_changed, set)
|
||||
# Check whether this user should see all user updates
|
||||
|
||||
# For small deltas, its quicker to get all changes and then
|
||||
# work out if we share a room or they're in our presence list
|
||||
get_updates_counter.labels("stream").inc()
|
||||
for other_user_id in changed:
|
||||
if other_user_id in users_interested_in:
|
||||
user_ids_changed.add(other_user_id)
|
||||
else:
|
||||
# Too many possible updates. Find all users we can see and check
|
||||
# if any of them have changed.
|
||||
get_updates_counter.labels("full").inc()
|
||||
if users_interested_in == PresenceRouter.ALL_USERS:
|
||||
# Provide presence state for all users
|
||||
presence_updates = await self._filter_all_presence_updates_for_user(
|
||||
user_id, include_offline, from_key
|
||||
)
|
||||
|
||||
if from_key:
|
||||
user_ids_changed = stream_change_cache.get_entities_changed(
|
||||
users_interested_in, from_key
|
||||
# Remove the user from the list of users to receive all presence
|
||||
if user_id in self.get_module_api()._send_full_presence_to_local_users:
|
||||
self.get_module_api()._send_full_presence_to_local_users.remove(
|
||||
user_id
|
||||
)
|
||||
|
||||
return presence_updates, max_token
|
||||
|
||||
# Make mypy happy. users_interested_in should now be a set
|
||||
assert not isinstance(users_interested_in, str)
|
||||
|
||||
# The set of users that we're interested in and that have had a presence update.
|
||||
# We'll actually pull the presence updates for these users at the end.
|
||||
interested_and_updated_users = (
|
||||
set()
|
||||
) # type: Union[Set[str], FrozenSet[str]]
|
||||
|
||||
if from_key:
|
||||
# First get all users that have had a presence update
|
||||
updated_users = stream_change_cache.get_all_entities_changed(from_key)
|
||||
|
||||
# Cross-reference users we're interested in with those that have had updates.
|
||||
# Use a slightly-optimised method for processing smaller sets of updates.
|
||||
if updated_users is not None and len(updated_users) < 500:
|
||||
# For small deltas, it's quicker to get all changes and then
|
||||
# cross-reference with the users we're interested in
|
||||
get_updates_counter.labels("stream").inc()
|
||||
for other_user_id in updated_users:
|
||||
if other_user_id in users_interested_in:
|
||||
# mypy thinks this variable could be a FrozenSet as it's possibly set
|
||||
# to one in the `get_entities_changed` call below, and `add()` is not
|
||||
# method on a FrozenSet. That doesn't affect us here though, as
|
||||
# `interested_and_updated_users` is clearly a set() above.
|
||||
interested_and_updated_users.add(other_user_id) # type: ignore
|
||||
else:
|
||||
user_ids_changed = users_interested_in
|
||||
# Too many possible updates. Find all users we can see and check
|
||||
# if any of them have changed.
|
||||
get_updates_counter.labels("full").inc()
|
||||
|
||||
updates = await presence.current_state_for_users(user_ids_changed)
|
||||
interested_and_updated_users = (
|
||||
stream_change_cache.get_entities_changed(
|
||||
users_interested_in, from_key
|
||||
)
|
||||
)
|
||||
else:
|
||||
# No from_key has been specified. Return the presence for all users
|
||||
# this user is interested in
|
||||
interested_and_updated_users = users_interested_in
|
||||
|
||||
if include_offline:
|
||||
return (list(updates.values()), max_token)
|
||||
else:
|
||||
return (
|
||||
[s for s in updates.values() if s.state != PresenceState.OFFLINE],
|
||||
max_token,
|
||||
# Retrieve the current presence state for each user
|
||||
users_to_state = await self.get_presence_handler().current_state_for_users(
|
||||
interested_and_updated_users
|
||||
)
|
||||
presence_updates = list(users_to_state.values())
|
||||
|
||||
# Remove the user from the list of users to receive all presence
|
||||
if user_id in self.get_module_api()._send_full_presence_to_local_users:
|
||||
self.get_module_api()._send_full_presence_to_local_users.remove(user_id)
|
||||
|
||||
if not include_offline:
|
||||
# Filter out offline presence states
|
||||
presence_updates = self._filter_offline_presence_state(presence_updates)
|
||||
|
||||
return presence_updates, max_token
|
||||
|
||||
async def _filter_all_presence_updates_for_user(
|
||||
self,
|
||||
user_id: str,
|
||||
include_offline: bool,
|
||||
from_key: Optional[int] = None,
|
||||
) -> List[UserPresenceState]:
|
||||
"""
|
||||
Computes the presence updates a user should receive.
|
||||
|
||||
First pulls presence updates from the database. Then consults PresenceRouter
|
||||
for whether any updates should be excluded by user ID.
|
||||
|
||||
Args:
|
||||
user_id: The User ID of the user to compute presence updates for.
|
||||
include_offline: Whether to include offline presence states from the results.
|
||||
from_key: The minimum stream ID of updates to pull from the database
|
||||
before filtering.
|
||||
|
||||
Returns:
|
||||
A list of presence states for the given user to receive.
|
||||
"""
|
||||
if from_key:
|
||||
# Only return updates since the last sync
|
||||
updated_users = self.store.presence_stream_cache.get_all_entities_changed(
|
||||
from_key
|
||||
)
|
||||
if not updated_users:
|
||||
updated_users = []
|
||||
|
||||
# Get the actual presence update for each change
|
||||
users_to_state = await self.get_presence_handler().current_state_for_users(
|
||||
updated_users
|
||||
)
|
||||
presence_updates = list(users_to_state.values())
|
||||
|
||||
if not include_offline:
|
||||
# Filter out offline states
|
||||
presence_updates = self._filter_offline_presence_state(presence_updates)
|
||||
else:
|
||||
users_to_state = await self.store.get_presence_for_all_users(
|
||||
include_offline=include_offline
|
||||
)
|
||||
|
||||
presence_updates = list(users_to_state.values())
|
||||
|
||||
# TODO: This feels wildly inefficient, and it's unfortunate we need to ask the
|
||||
# module for information on a number of users when we then only take the info
|
||||
# for a single user
|
||||
|
||||
# Filter through the presence router
|
||||
users_to_state_set = await self.get_presence_router().get_users_for_states(
|
||||
presence_updates
|
||||
)
|
||||
|
||||
# We only want the mapping for the syncing user
|
||||
presence_updates = list(users_to_state_set[user_id])
|
||||
|
||||
# Return presence information for all users
|
||||
return presence_updates
|
||||
|
||||
def _filter_offline_presence_state(
|
||||
self, presence_updates: Iterable[UserPresenceState]
|
||||
) -> List[UserPresenceState]:
|
||||
"""Given an iterable containing user presence updates, return a list with any offline
|
||||
presence states removed.
|
||||
|
||||
Args:
|
||||
presence_updates: Presence states to filter
|
||||
|
||||
Returns:
|
||||
A new list with any offline presence states removed.
|
||||
"""
|
||||
return [
|
||||
update
|
||||
for update in presence_updates
|
||||
if update.state != PresenceState.OFFLINE
|
||||
]
|
||||
|
||||
def get_current_key(self):
|
||||
return self.store.get_current_presence_token()
|
||||
|
||||
@cached(num_args=2, cache_context=True)
|
||||
async def _get_interested_in(self, user, explicit_room_id, cache_context):
|
||||
async def _get_interested_in(
|
||||
self,
|
||||
user: UserID,
|
||||
explicit_room_id: Optional[str] = None,
|
||||
cache_context: Optional[_CacheContext] = None,
|
||||
) -> Union[Set[str], str]:
|
||||
"""Returns the set of users that the given user should see presence
|
||||
updates for
|
||||
updates for.
|
||||
|
||||
Args:
|
||||
user: The user to retrieve presence updates for.
|
||||
explicit_room_id: The users that are in the room will be returned.
|
||||
|
||||
Returns:
|
||||
A set of user IDs to return presence updates for, or "ALL" to return all
|
||||
known updates.
|
||||
"""
|
||||
user_id = user.to_string()
|
||||
users_interested_in = set()
|
||||
users_interested_in.add(user_id) # So that we receive our own presence
|
||||
|
||||
# cache_context isn't likely to ever be None due to the @cached decorator,
|
||||
# but we can't have a non-optional argument after the optional argument
|
||||
# explicit_room_id either. Assert cache_context is not None so we can use it
|
||||
# without mypy complaining.
|
||||
assert cache_context
|
||||
|
||||
# Check with the presence router whether we should poll additional users for
|
||||
# their presence information
|
||||
additional_users = await self.get_presence_router().get_interested_users(
|
||||
user.to_string()
|
||||
)
|
||||
if additional_users == PresenceRouter.ALL_USERS:
|
||||
# If the module requested that this user see the presence updates of *all*
|
||||
# users, then simply return that instead of calculating what rooms this
|
||||
# user shares
|
||||
return PresenceRouter.ALL_USERS
|
||||
|
||||
# Add the additional users from the router
|
||||
users_interested_in.update(additional_users)
|
||||
|
||||
# Find the users who share a room with this user
|
||||
users_who_share_room = await self.store.get_users_who_share_room_with_user(
|
||||
user_id, on_invalidate=cache_context.invalidate
|
||||
)
|
||||
|
@ -1314,14 +1491,15 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now):
|
|||
|
||||
|
||||
async def get_interested_parties(
|
||||
store: DataStore, states: List[UserPresenceState]
|
||||
store: DataStore, presence_router: PresenceRouter, states: List[UserPresenceState]
|
||||
) -> Tuple[Dict[str, List[UserPresenceState]], Dict[str, List[UserPresenceState]]]:
|
||||
"""Given a list of states return which entities (rooms, users)
|
||||
are interested in the given states.
|
||||
|
||||
Args:
|
||||
store
|
||||
states
|
||||
store: The homeserver's data store.
|
||||
presence_router: A module for augmenting the destinations for presence updates.
|
||||
states: A list of incoming user presence updates.
|
||||
|
||||
Returns:
|
||||
A 2-tuple of `(room_ids_to_states, users_to_states)`,
|
||||
|
@ -1337,11 +1515,22 @@ async def get_interested_parties(
|
|||
# Always notify self
|
||||
users_to_states.setdefault(state.user_id, []).append(state)
|
||||
|
||||
# Ask a presence routing module for any additional parties if one
|
||||
# is loaded.
|
||||
router_users_to_states = await presence_router.get_users_for_states(states)
|
||||
|
||||
# Update the dictionaries with additional destinations and state to send
|
||||
for user_id, user_states in router_users_to_states.items():
|
||||
users_to_states.setdefault(user_id, []).extend(user_states)
|
||||
|
||||
return room_ids_to_states, users_to_states
|
||||
|
||||
|
||||
async def get_interested_remotes(
|
||||
store: DataStore, states: List[UserPresenceState], state_handler: StateHandler
|
||||
store: DataStore,
|
||||
presence_router: PresenceRouter,
|
||||
states: List[UserPresenceState],
|
||||
state_handler: StateHandler,
|
||||
) -> List[Tuple[Collection[str], List[UserPresenceState]]]:
|
||||
"""Given a list of presence states figure out which remote servers
|
||||
should be sent which.
|
||||
|
@ -1349,9 +1538,10 @@ async def get_interested_remotes(
|
|||
All the presence states should be for local users only.
|
||||
|
||||
Args:
|
||||
store
|
||||
states
|
||||
state_handler
|
||||
store: The homeserver's data store.
|
||||
presence_router: A module for augmenting the destinations for presence updates.
|
||||
states: A list of incoming user presence updates.
|
||||
state_handler:
|
||||
|
||||
Returns:
|
||||
A list of 2-tuples of destinations and states, where for
|
||||
|
@ -1363,7 +1553,9 @@ async def get_interested_remotes(
|
|||
# First we look up the rooms each user is in (as well as any explicit
|
||||
# subscriptions), then for each distinct room we look up the remote
|
||||
# hosts in those rooms.
|
||||
room_ids_to_states, users_to_states = await get_interested_parties(store, states)
|
||||
room_ids_to_states, users_to_states = await get_interested_parties(
|
||||
store, presence_router, states
|
||||
)
|
||||
|
||||
for room_id, states in room_ids_to_states.items():
|
||||
hosts = await state_handler.get_current_hosts_in_room(room_id)
|
||||
|
|
|
@ -174,7 +174,7 @@ class RegistrationHandler(BaseHandler):
|
|||
user_type: Optional[str] = None,
|
||||
default_display_name: Optional[str] = None,
|
||||
address: Optional[str] = None,
|
||||
bind_emails: Iterable[str] = [],
|
||||
bind_emails: Optional[Iterable[str]] = None,
|
||||
by_admin: bool = False,
|
||||
user_agent_ips: Optional[List[Tuple[str, str]]] = None,
|
||||
auth_provider_id: Optional[str] = None,
|
||||
|
@ -209,7 +209,9 @@ class RegistrationHandler(BaseHandler):
|
|||
Raises:
|
||||
SynapseError if there was a problem registering.
|
||||
"""
|
||||
self.check_registration_ratelimit(address)
|
||||
bind_emails = bind_emails or []
|
||||
|
||||
await self.check_registration_ratelimit(address)
|
||||
|
||||
result = await self.spam_checker.check_registration_for_spam(
|
||||
threepid,
|
||||
|
@ -590,7 +592,7 @@ class RegistrationHandler(BaseHandler):
|
|||
errcode=Codes.EXCLUSIVE,
|
||||
)
|
||||
|
||||
def check_registration_ratelimit(self, address: Optional[str]) -> None:
|
||||
async def check_registration_ratelimit(self, address: Optional[str]) -> None:
|
||||
"""A simple helper method to check whether the registration rate limit has been hit
|
||||
for a given IP address
|
||||
|
||||
|
@ -604,7 +606,7 @@ class RegistrationHandler(BaseHandler):
|
|||
if not address:
|
||||
return
|
||||
|
||||
self.ratelimiter.ratelimit(address)
|
||||
await self.ratelimiter.ratelimit(None, address)
|
||||
|
||||
async def register_with_store(
|
||||
self,
|
||||
|
|
|
@ -20,7 +20,7 @@ from http import HTTPStatus
|
|||
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
|
||||
|
||||
from synapse import types
|
||||
from synapse.api.constants import AccountDataTypes, EventTypes, Membership
|
||||
from synapse.api.constants import AccountDataTypes, EventTypes, JoinRules, Membership
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
Codes,
|
||||
|
@ -29,6 +29,7 @@ from synapse.api.errors import (
|
|||
SynapseError,
|
||||
)
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.api.room_versions import RoomVersion
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
|
||||
|
@ -75,22 +76,26 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
self.allow_per_room_profiles = self.config.allow_per_room_profiles
|
||||
|
||||
self._join_rate_limiter_local = Ratelimiter(
|
||||
store=self.store,
|
||||
clock=self.clock,
|
||||
rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
|
||||
burst_count=hs.config.ratelimiting.rc_joins_local.burst_count,
|
||||
)
|
||||
self._join_rate_limiter_remote = Ratelimiter(
|
||||
store=self.store,
|
||||
clock=self.clock,
|
||||
rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second,
|
||||
burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
|
||||
)
|
||||
|
||||
self._invites_per_room_limiter = Ratelimiter(
|
||||
store=self.store,
|
||||
clock=self.clock,
|
||||
rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second,
|
||||
burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count,
|
||||
)
|
||||
self._invites_per_user_limiter = Ratelimiter(
|
||||
store=self.store,
|
||||
clock=self.clock,
|
||||
rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second,
|
||||
burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
|
||||
|
@ -159,15 +164,76 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
async def forget(self, user: UserID, room_id: str) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def ratelimit_invite(self, room_id: Optional[str], invitee_user_id: str):
|
||||
async def ratelimit_invite(
|
||||
self,
|
||||
requester: Optional[Requester],
|
||||
room_id: Optional[str],
|
||||
invitee_user_id: str,
|
||||
):
|
||||
"""Ratelimit invites by room and by target user.
|
||||
|
||||
If room ID is missing then we just rate limit by target user.
|
||||
"""
|
||||
if room_id:
|
||||
self._invites_per_room_limiter.ratelimit(room_id)
|
||||
await self._invites_per_room_limiter.ratelimit(requester, room_id)
|
||||
|
||||
self._invites_per_user_limiter.ratelimit(invitee_user_id)
|
||||
await self._invites_per_user_limiter.ratelimit(requester, invitee_user_id)
|
||||
|
||||
async def _can_join_without_invite(
|
||||
self, state_ids: StateMap[str], room_version: RoomVersion, user_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
Check whether a user can join a room without an invite.
|
||||
|
||||
When joining a room with restricted joined rules (as defined in MSC3083),
|
||||
the membership of spaces must be checked during join.
|
||||
|
||||
Args:
|
||||
state_ids: The state of the room as it currently is.
|
||||
room_version: The room version of the room being joined.
|
||||
user_id: The user joining the room.
|
||||
|
||||
Returns:
|
||||
True if the user can join the room, false otherwise.
|
||||
"""
|
||||
# This only applies to room versions which support the new join rule.
|
||||
if not room_version.msc3083_join_rules:
|
||||
return True
|
||||
|
||||
# If there's no join rule, then it defaults to public (so this doesn't apply).
|
||||
join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None)
|
||||
if not join_rules_event_id:
|
||||
return True
|
||||
|
||||
# If the join rule is not restricted, this doesn't apply.
|
||||
join_rules_event = await self.store.get_event(join_rules_event_id)
|
||||
if join_rules_event.content.get("join_rule") != JoinRules.MSC3083_RESTRICTED:
|
||||
return True
|
||||
|
||||
# If allowed is of the wrong form, then only allow invited users.
|
||||
allowed_spaces = join_rules_event.content.get("allow", [])
|
||||
if not isinstance(allowed_spaces, list):
|
||||
return False
|
||||
|
||||
# Get the list of joined rooms and see if there's an overlap.
|
||||
joined_rooms = await self.store.get_rooms_for_user(user_id)
|
||||
|
||||
# Pull out the other room IDs, invalid data gets filtered.
|
||||
for space in allowed_spaces:
|
||||
if not isinstance(space, dict):
|
||||
continue
|
||||
|
||||
space_id = space.get("space")
|
||||
if not isinstance(space_id, str):
|
||||
continue
|
||||
|
||||
# The user was joined to one of the spaces specified, they can join
|
||||
# this room!
|
||||
if space_id in joined_rooms:
|
||||
return True
|
||||
|
||||
# The user was not in any of the required spaces.
|
||||
return False
|
||||
|
||||
async def _local_membership_update(
|
||||
self,
|
||||
|
@ -226,9 +292,25 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
|
||||
if event.membership == Membership.JOIN:
|
||||
newly_joined = True
|
||||
user_is_invited = False
|
||||
if prev_member_event_id:
|
||||
prev_member_event = await self.store.get_event(prev_member_event_id)
|
||||
newly_joined = prev_member_event.membership != Membership.JOIN
|
||||
user_is_invited = prev_member_event.membership == Membership.INVITE
|
||||
|
||||
# If the member is not already in the room and is not accepting an invite,
|
||||
# check if they should be allowed access via membership in a space.
|
||||
if (
|
||||
newly_joined
|
||||
and not user_is_invited
|
||||
and not await self._can_join_without_invite(
|
||||
prev_state_ids, event.room_version, user_id
|
||||
)
|
||||
):
|
||||
raise AuthError(
|
||||
403,
|
||||
"You do not belong to any of the required spaces to join this room.",
|
||||
)
|
||||
|
||||
# Only rate-limit if the user actually joined the room, otherwise we'll end
|
||||
# up blocking profile updates.
|
||||
|
@ -237,7 +319,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
(
|
||||
allowed,
|
||||
time_allowed,
|
||||
) = self._join_rate_limiter_local.can_requester_do_action(requester)
|
||||
) = await self._join_rate_limiter_local.can_do_action(requester)
|
||||
|
||||
if not allowed:
|
||||
raise LimitExceededError(
|
||||
|
@ -421,9 +503,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
if effective_membership_state == Membership.INVITE:
|
||||
target_id = target.to_string()
|
||||
if ratelimit:
|
||||
# Don't ratelimit application services.
|
||||
if not requester.app_service or requester.app_service.is_rate_limited():
|
||||
self.ratelimit_invite(room_id, target_id)
|
||||
await self.ratelimit_invite(requester, room_id, target_id)
|
||||
|
||||
# block any attempts to invite the server notices mxid
|
||||
if target_id == self._server_notices_mxid:
|
||||
|
@ -534,7 +614,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
(
|
||||
allowed,
|
||||
time_allowed,
|
||||
) = self._join_rate_limiter_remote.can_requester_do_action(
|
||||
) = await self._join_rate_limiter_remote.can_do_action(
|
||||
requester,
|
||||
)
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ from synapse.api.constants import AccountDataTypes, EventTypes, Membership
|
|||
from synapse.api.filtering import FilterCollection
|
||||
from synapse.events import EventBase
|
||||
from synapse.logging.context import current_context
|
||||
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
|
||||
from synapse.push.clientformat import format_push_rules_for_user
|
||||
from synapse.storage.roommember import MemberSummary
|
||||
from synapse.storage.state import StateFilter
|
||||
|
@ -251,13 +252,13 @@ class SyncHandler:
|
|||
self.storage = hs.get_storage()
|
||||
self.state_store = self.storage.state
|
||||
|
||||
# ExpiringCache((User, Device)) -> LruCache(state_key => event_id)
|
||||
# ExpiringCache((User, Device)) -> LruCache(user_id => event_id)
|
||||
self.lazy_loaded_members_cache = ExpiringCache(
|
||||
"lazy_loaded_members_cache",
|
||||
self.clock,
|
||||
max_len=0,
|
||||
expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
|
||||
)
|
||||
) # type: ExpiringCache[Tuple[str, Optional[str]], LruCache[str, str]]
|
||||
|
||||
async def wait_for_sync_for_user(
|
||||
self,
|
||||
|
@ -340,7 +341,14 @@ class SyncHandler:
|
|||
full_state: bool = False,
|
||||
) -> SyncResult:
|
||||
"""Get the sync for client needed to match what the server has now."""
|
||||
return await self.generate_sync_result(sync_config, since_token, full_state)
|
||||
with start_active_span("current_sync_for_user"):
|
||||
log_kv({"since_token": since_token})
|
||||
sync_result = await self.generate_sync_result(
|
||||
sync_config, since_token, full_state
|
||||
)
|
||||
|
||||
set_tag(SynapseTags.SYNC_RESULT, bool(sync_result))
|
||||
return sync_result
|
||||
|
||||
async def push_rules_for_user(self, user: UserID) -> JsonDict:
|
||||
user_id = user.to_string()
|
||||
|
@ -540,7 +548,7 @@ class SyncHandler:
|
|||
)
|
||||
|
||||
async def get_state_after_event(
|
||||
self, event: EventBase, state_filter: StateFilter = StateFilter.all()
|
||||
self, event: EventBase, state_filter: Optional[StateFilter] = None
|
||||
) -> StateMap[str]:
|
||||
"""
|
||||
Get the room state after the given event
|
||||
|
@ -550,7 +558,7 @@ class SyncHandler:
|
|||
state_filter: The state filter used to fetch state from the database.
|
||||
"""
|
||||
state_ids = await self.state_store.get_state_ids_for_event(
|
||||
event.event_id, state_filter=state_filter
|
||||
event.event_id, state_filter=state_filter or StateFilter.all()
|
||||
)
|
||||
if event.is_state():
|
||||
state_ids = dict(state_ids)
|
||||
|
@ -561,7 +569,7 @@ class SyncHandler:
|
|||
self,
|
||||
room_id: str,
|
||||
stream_position: StreamToken,
|
||||
state_filter: StateFilter = StateFilter.all(),
|
||||
state_filter: Optional[StateFilter] = None,
|
||||
) -> StateMap[str]:
|
||||
"""Get the room state at a particular stream position
|
||||
|
||||
|
@ -581,7 +589,7 @@ class SyncHandler:
|
|||
if last_events:
|
||||
last_event = last_events[-1]
|
||||
state = await self.get_state_after_event(
|
||||
last_event, state_filter=state_filter
|
||||
last_event, state_filter=state_filter or StateFilter.all()
|
||||
)
|
||||
|
||||
else:
|
||||
|
@ -725,8 +733,10 @@ class SyncHandler:
|
|||
|
||||
def get_lazy_loaded_members_cache(
|
||||
self, cache_key: Tuple[str, Optional[str]]
|
||||
) -> LruCache:
|
||||
cache = self.lazy_loaded_members_cache.get(cache_key)
|
||||
) -> LruCache[str, str]:
|
||||
cache = self.lazy_loaded_members_cache.get(
|
||||
cache_key
|
||||
) # type: Optional[LruCache[str, str]]
|
||||
if cache is None:
|
||||
logger.debug("creating LruCache for %r", cache_key)
|
||||
cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE)
|
||||
|
@ -963,6 +973,7 @@ class SyncHandler:
|
|||
# to query up to a given point.
|
||||
# Always use the `now_token` in `SyncResultBuilder`
|
||||
now_token = self.event_sources.get_current_token()
|
||||
log_kv({"now_token": now_token})
|
||||
|
||||
logger.debug(
|
||||
"Calculating sync response for %r between %s and %s",
|
||||
|
@ -1224,6 +1235,13 @@ class SyncHandler:
|
|||
user_id, device_id, since_stream_id, now_token.to_device_key
|
||||
)
|
||||
|
||||
for message in messages:
|
||||
# We pop here as we shouldn't be sending the message ID down
|
||||
# `/sync`
|
||||
message_id = message.pop("message_id", None)
|
||||
if message_id:
|
||||
set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id)
|
||||
|
||||
logger.debug(
|
||||
"Returning %d to-device messages between %d and %d (current token: %d)",
|
||||
len(messages),
|
||||
|
|
|
@ -19,7 +19,10 @@ from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
|
|||
|
||||
from synapse.api.errors import AuthError, ShadowBanError, SynapseError
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.metrics.background_process_metrics import (
|
||||
run_as_background_process,
|
||||
wrap_as_background_process,
|
||||
)
|
||||
from synapse.replication.tcp.streams import TypingStream
|
||||
from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
@ -86,6 +89,7 @@ class FollowerTypingHandler:
|
|||
self._member_last_federation_poke = {}
|
||||
self.wheel_timer = WheelTimer(bucket_size=5000)
|
||||
|
||||
@wrap_as_background_process("typing._handle_timeouts")
|
||||
def _handle_timeouts(self) -> None:
|
||||
logger.debug("Checking for typing timeouts")
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue