Merge remote-tracking branch 'upstream/release-v1.64'

This commit is contained in:
Tulir Asokan 2022-07-28 10:49:41 +03:00
commit b0f213fd3d
176 changed files with 4539 additions and 2643 deletions

View file

@ -104,14 +104,15 @@ class ApplicationServicesHandler:
with Measure(self.clock, "notify_interested_services"):
self.is_processing = True
try:
limit = 100
upper_bound = -1
while upper_bound < self.current_max:
last_token = await self.store.get_appservice_last_pos()
(
upper_bound,
events,
) = await self.store.get_new_events_for_appservice(
self.current_max, limit
event_to_received_ts,
) = await self.store.get_all_new_events_stream(
last_token, self.current_max, limit=100, get_prev_content=True
)
events_by_room: Dict[str, List[EventBase]] = {}
@ -150,7 +151,7 @@ class ApplicationServicesHandler:
)
now = self.clock.time_msec()
ts = await self.store.get_received_ts(event.event_id)
ts = event_to_received_ts[event.event_id]
assert ts is not None
synapse.metrics.event_processing_lag_by_event.labels(
@ -187,7 +188,7 @@ class ApplicationServicesHandler:
if events:
now = self.clock.time_msec()
ts = await self.store.get_received_ts(events[-1].event_id)
ts = event_to_received_ts[events[-1].event_id]
assert ts is not None
synapse.metrics.event_processing_lag.labels(

View file

@ -118,8 +118,8 @@ class DeviceWorkerHandler:
ips = await self.store.get_last_client_ip_by_device(user_id, device_id)
_update_device_from_client_ips(device, ips)
set_tag("device", device)
set_tag("ips", ips)
set_tag("device", str(device))
set_tag("ips", str(ips))
return device
@ -170,7 +170,7 @@ class DeviceWorkerHandler:
"""
set_tag("user_id", user_id)
set_tag("from_token", from_token)
set_tag("from_token", str(from_token))
now_room_key = self.store.get_room_max_token()
room_ids = await self.store.get_rooms_for_user(user_id)
@ -795,7 +795,7 @@ class DeviceListUpdater:
"""
set_tag("origin", origin)
set_tag("edu_content", edu_content)
set_tag("edu_content", str(edu_content))
user_id = edu_content.pop("user_id")
device_id = edu_content.pop("device_id")
stream_id = str(edu_content.pop("stream_id")) # They may come as ints

View file

@ -15,7 +15,7 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple
import attr
from canonicaljson import encode_canonical_json
@ -92,7 +92,11 @@ class E2eKeysHandler:
@trace
async def query_devices(
self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str
self,
query_body: JsonDict,
timeout: int,
from_user_id: str,
from_device_id: Optional[str],
) -> JsonDict:
"""Handle a device key query from a client
@ -120,9 +124,7 @@ class E2eKeysHandler:
the number of in-flight queries at a time.
"""
async with self._query_devices_linearizer.queue((from_user_id, from_device_id)):
device_keys_query: Dict[str, Iterable[str]] = query_body.get(
"device_keys", {}
)
device_keys_query: Dict[str, List[str]] = query_body.get("device_keys", {})
# separate users by domain.
# make a map from domain to user_id to device_ids
@ -136,8 +138,8 @@ class E2eKeysHandler:
else:
remote_queries[user_id] = device_ids
set_tag("local_key_query", local_query)
set_tag("remote_key_query", remote_queries)
set_tag("local_key_query", str(local_query))
set_tag("remote_key_query", str(remote_queries))
# First get local devices.
# A map of destination -> failure response.
@ -341,7 +343,7 @@ class E2eKeysHandler:
failure = _exception_to_failure(e)
failures[destination] = failure
set_tag("error", True)
set_tag("reason", failure)
set_tag("reason", str(failure))
return
@ -392,7 +394,7 @@ class E2eKeysHandler:
@trace
async def query_local_devices(
self, query: Dict[str, Optional[List[str]]]
self, query: Mapping[str, Optional[List[str]]]
) -> Dict[str, Dict[str, dict]]:
"""Get E2E device keys for local users
@ -403,7 +405,7 @@ class E2eKeysHandler:
Returns:
A map from user_id -> device_id -> device details
"""
set_tag("local_query", query)
set_tag("local_query", str(query))
local_query: List[Tuple[str, Optional[str]]] = []
result_dict: Dict[str, Dict[str, dict]] = {}
@ -461,7 +463,7 @@ class E2eKeysHandler:
@trace
async def claim_one_time_keys(
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
) -> JsonDict:
local_query: List[Tuple[str, str, str]] = []
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
@ -475,8 +477,8 @@ class E2eKeysHandler:
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = one_time_keys
set_tag("local_key_query", local_query)
set_tag("remote_key_query", remote_queries)
set_tag("local_key_query", str(local_query))
set_tag("remote_key_query", str(remote_queries))
results = await self.store.claim_e2e_one_time_keys(local_query)
@ -506,7 +508,7 @@ class E2eKeysHandler:
failure = _exception_to_failure(e)
failures[destination] = failure
set_tag("error", True)
set_tag("reason", failure)
set_tag("reason", str(failure))
await make_deferred_yieldable(
defer.gatherResults(
@ -609,7 +611,7 @@ class E2eKeysHandler:
result = await self.store.count_e2e_one_time_keys(user_id, device_id)
set_tag("one_time_key_counts", result)
set_tag("one_time_key_counts", str(result))
return {"one_time_key_counts": result}
async def _upload_one_time_keys_for_user(

View file

@ -14,7 +14,7 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Dict, Optional, cast
from typing_extensions import Literal
@ -97,7 +97,7 @@ class E2eRoomKeysHandler:
user_id, version, room_id, session_id
)
log_kv(results)
log_kv(cast(JsonDict, results))
return results
@trace

View file

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import itertools
import logging
from http import HTTPStatus
@ -347,7 +348,7 @@ class FederationEventHandler:
event.internal_metadata.send_on_behalf_of = origin
context = await self._state_handler.compute_event_context(event)
context = await self._check_event_auth(origin, event, context)
await self._check_event_auth(origin, event, context)
if context.rejected:
raise SynapseError(
403, f"{event.membership} event was rejected", Codes.FORBIDDEN
@ -485,7 +486,7 @@ class FederationEventHandler:
partial_state=partial_state,
)
context = await self._check_event_auth(origin, event, context)
await self._check_event_auth(origin, event, context)
if context.rejected:
raise SynapseError(400, "Join event was rejected")
@ -765,10 +766,24 @@ class FederationEventHandler:
"""
logger.info("Processing pulled event %s", event)
# these should not be outliers.
assert (
not event.internal_metadata.is_outlier()
), "pulled event unexpectedly flagged as outlier"
# This function should not be used to persist outliers (use something
# else) because this does a bunch of operations that aren't necessary
# (extra work; in particular, it makes sure we have all the prev_events
# and resolves the state across those prev events). If you happen to run
# into a situation where the event you're trying to process/backfill is
# marked as an `outlier`, then you should update that spot to return an
# `EventBase` copy that doesn't have `outlier` flag set.
#
# `EventBase` is used to represent both an event we have not yet
# persisted, and one that we have persisted and now keep in the cache.
# In an ideal world this method would only be called with the first type
# of event, but it turns out that's not actually the case and for
# example, you could get an event from cache that is marked as an
# `outlier` (fix up that spot though).
assert not event.internal_metadata.is_outlier(), (
"Outlier event passed to _process_pulled_event. "
"To persist an event as a non-outlier, make sure to pass in a copy without `event.internal_metadata.outlier = true`."
)
event_id = event.event_id
@ -778,7 +793,7 @@ class FederationEventHandler:
if existing:
if not existing.internal_metadata.is_outlier():
logger.info(
"Ignoring received event %s which we have already seen",
"_process_pulled_event: Ignoring received event %s which we have already seen",
event_id,
)
return
@ -1036,6 +1051,9 @@ class FederationEventHandler:
# XXX: this doesn't sound right? it means that we'll end up with incomplete
# state.
failed_to_fetch = desired_events - event_metadata.keys()
# `event_id` could be missing from `event_metadata` because it's not necessarily
# a state event. We've already checked that we've fetched it above.
failed_to_fetch.discard(event_id)
if failed_to_fetch:
logger.warning(
"Failed to fetch missing state events for %s %s",
@ -1116,11 +1134,7 @@ class FederationEventHandler:
state_ids_before_event=state_ids,
)
try:
context = await self._check_event_auth(
origin,
event,
context,
)
await self._check_event_auth(origin, event, context)
except AuthError as e:
# This happens only if we couldn't find the auth events. We'll already have
# logged a warning, so now we just convert to a FederationError.
@ -1315,6 +1329,53 @@ class FederationEventHandler:
marker_event,
)
async def backfill_event_id(
self, destination: str, room_id: str, event_id: str
) -> EventBase:
"""Backfill a single event and persist it as a non-outlier which means
we also pull in all of the state and auth events necessary for it.
Args:
destination: The homeserver to pull the given event_id from.
room_id: The room where the event is from.
event_id: The event ID to backfill.
Raises:
FederationError if we are unable to find the event from the destination
"""
logger.info(
"backfill_event_id: event_id=%s from destination=%s", event_id, destination
)
room_version = await self._store.get_room_version(room_id)
event_from_response = await self._federation_client.get_pdu(
[destination],
event_id,
room_version,
)
if not event_from_response:
raise FederationError(
"ERROR",
404,
"Unable to find event_id=%s from destination=%s to backfill."
% (event_id, destination),
affected=event_id,
)
# Persist the event we just fetched, including pulling all of the state
# and auth events to de-outlier it. This also sets up the necessary
# `state_groups` for the event.
await self._process_pulled_events(
destination,
[event_from_response],
# Prevent notifications going to clients
backfilled=True,
)
return event_from_response
async def _get_events_and_persist(
self, destination: str, room_id: str, event_ids: Collection[str]
) -> None:
@ -1495,11 +1556,8 @@ class FederationEventHandler:
)
async def _check_event_auth(
self,
origin: str,
event: EventBase,
context: EventContext,
) -> EventContext:
self, origin: str, event: EventBase, context: EventContext
) -> None:
"""
Checks whether an event should be rejected (for failing auth checks).
@ -1509,9 +1567,6 @@ class FederationEventHandler:
context:
The event context.
Returns:
The updated context object.
Raises:
AuthError if we were unable to find copies of the event's auth events.
(Most other failures just cause us to set `context.rejected`.)
@ -1526,7 +1581,7 @@ class FederationEventHandler:
logger.warning("While validating received event %r: %s", event, e)
# TODO: use a different rejected reason here?
context.rejected = RejectedReason.AUTH_ERROR
return context
return
# next, check that we have all of the event's auth events.
#
@ -1538,6 +1593,9 @@ class FederationEventHandler:
)
# ... and check that the event passes auth at those auth events.
# https://spec.matrix.org/v1.3/server-server-api/#checks-performed-on-receipt-of-a-pdu:
# 4. Passes authorization rules based on the events auth events,
# otherwise it is rejected.
try:
await check_state_independent_auth_rules(self._store, event)
check_state_dependent_auth_rules(event, claimed_auth_events)
@ -1546,55 +1604,90 @@ class FederationEventHandler:
"While checking auth of %r against auth_events: %s", event, e
)
context.rejected = RejectedReason.AUTH_ERROR
return context
return
# now check auth against what we think the auth events *should* be.
event_types = event_auth.auth_types_for_event(event.room_version, event)
prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types(event_types)
)
# now check the auth rules pass against the room state before the event
# https://spec.matrix.org/v1.3/server-server-api/#checks-performed-on-receipt-of-a-pdu:
# 5. Passes authorization rules based on the state before the event,
# otherwise it is rejected.
#
# ... however, if we only have partial state for the room, then there is a good
# chance that we'll be missing some of the state needed to auth the new event.
# So, we state-resolve the auth events that we are given against the state that
# we know about, which ensures things like bans are applied. (Note that we'll
# already have checked we have all the auth events, in
# _load_or_fetch_auth_events_for_event above)
if context.partial_state:
room_version = await self._store.get_room_version_id(event.room_id)
auth_events_ids = self._event_auth_handler.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events_x = await self._store.get_events(auth_events_ids)
calculated_auth_event_map = {
(e.type, e.state_key): e for e in auth_events_x.values()
}
local_state_id_map = await context.get_prev_state_ids()
claimed_auth_events_id_map = {
(ev.type, ev.state_key): ev.event_id for ev in claimed_auth_events
}
try:
updated_auth_events = await self._update_auth_events_for_auth(
event,
calculated_auth_event_map=calculated_auth_event_map,
state_for_auth_id_map = (
await self._state_resolution_handler.resolve_events_with_store(
event.room_id,
room_version,
[local_state_id_map, claimed_auth_events_id_map],
event_map=None,
state_res_store=StateResolutionStore(self._store),
)
)
except Exception:
# We don't really mind if the above fails, so lets not fail
# processing if it does. However, it really shouldn't fail so
# let's still log as an exception since we'll still want to fix
# any bugs.
logger.exception(
"Failed to double check auth events for %s with remote. "
"Ignoring failure and continuing processing of event.",
event.event_id,
)
updated_auth_events = None
if updated_auth_events:
context = await self._update_context_for_auth_events(
event, context, updated_auth_events
)
auth_events_for_auth = updated_auth_events
else:
auth_events_for_auth = calculated_auth_event_map
event_types = event_auth.auth_types_for_event(event.room_version, event)
state_for_auth_id_map = await context.get_prev_state_ids(
StateFilter.from_types(event_types)
)
calculated_auth_event_ids = self._event_auth_handler.compute_auth_events(
event, state_for_auth_id_map, for_verification=True
)
# if those are the same, we're done here.
if collections.Counter(event.auth_event_ids()) == collections.Counter(
calculated_auth_event_ids
):
return
# otherwise, re-run the auth checks based on what we calculated.
calculated_auth_events = await self._store.get_events_as_list(
calculated_auth_event_ids
)
# log the differences
claimed_auth_event_map = {(e.type, e.state_key): e for e in claimed_auth_events}
calculated_auth_event_map = {
(e.type, e.state_key): e for e in calculated_auth_events
}
logger.info(
"event's auth_events are different to our calculated auth_events. "
"Claimed but not calculated: %s. Calculated but not claimed: %s",
[
ev
for k, ev in claimed_auth_event_map.items()
if k not in calculated_auth_event_map
or calculated_auth_event_map[k].event_id != ev.event_id
],
[
ev
for k, ev in calculated_auth_event_map.items()
if k not in claimed_auth_event_map
or claimed_auth_event_map[k].event_id != ev.event_id
],
)
try:
check_state_dependent_auth_rules(event, auth_events_for_auth.values())
check_state_dependent_auth_rules(event, calculated_auth_events)
except AuthError as e:
logger.warning("Failed auth resolution for %r because %s", event, e)
logger.warning(
"While checking auth of %r against room state before the event: %s",
event,
e,
)
context.rejected = RejectedReason.AUTH_ERROR
return context
async def _maybe_kick_guest_users(self, event: EventBase) -> None:
if event.type != EventTypes.GuestAccess:
return
@ -1618,11 +1711,21 @@ class FederationEventHandler:
"""Checks if we should soft fail the event; if so, marks the event as
such.
Does nothing for events in rooms with partial state, since we may not have an
accurate membership event for the sender in the current state.
Args:
event
state_ids: The state at the event if we don't have all the event's prev events
origin: The host the event originates from.
"""
if await self._store.is_partial_state_room(event.room_id):
# We might not know the sender's membership in the current state, so don't
# soft fail anything. Even if we do have a membership for the sender in the
# current state, it may have been derived from state resolution between
# partial and full state and may not be accurate.
return
extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id)
extrem_ids = set(extrem_ids_list)
prev_event_ids = set(event.prev_event_ids())
@ -1704,93 +1807,6 @@ class FederationEventHandler:
soft_failed_event_counter.inc()
event.internal_metadata.soft_failed = True
async def _update_auth_events_for_auth(
self,
event: EventBase,
calculated_auth_event_map: StateMap[EventBase],
) -> Optional[StateMap[EventBase]]:
"""Helper for _check_event_auth. See there for docs.
Checks whether a given event has the expected auth events. If it
doesn't then we talk to the remote server to compare state to see if
we can come to a consensus (e.g. if one server missed some valid
state).
This attempts to resolve any potential divergence of state between
servers, but is not essential and so failures should not block further
processing of the event.
Args:
event:
calculated_auth_event_map:
Our calculated auth_events based on the state of the room
at the event's position in the DAG.
Returns:
updated auth event map, or None if no changes are needed.
"""
assert not event.internal_metadata.outlier
# check for events which are in the event's claimed auth_events, but not
# in our calculated event map.
event_auth_events = set(event.auth_event_ids())
different_auth = event_auth_events.difference(
e.event_id for e in calculated_auth_event_map.values()
)
if not different_auth:
return None
logger.info(
"auth_events refers to events which are not in our calculated auth "
"chain: %s",
different_auth,
)
# XXX: currently this checks for redactions but I'm not convinced that is
# necessary?
different_events = await self._store.get_events_as_list(different_auth)
# double-check they're all in the same room - we should already have checked
# this but it doesn't hurt to check again.
for d in different_events:
assert (
d.room_id == event.room_id
), f"Event {event.event_id} refers to auth_event {d.event_id} which is in a different room"
# now we state-resolve between our own idea of the auth events, and the remote's
# idea of them.
local_state = calculated_auth_event_map.values()
remote_auth_events = dict(calculated_auth_event_map)
remote_auth_events.update({(d.type, d.state_key): d for d in different_events})
remote_state = remote_auth_events.values()
room_version = await self._store.get_room_version_id(event.room_id)
new_state = await self._state_handler.resolve_events(
room_version, (local_state, remote_state), event
)
different_state = {
(d.type, d.state_key): d
for d in new_state.values()
if calculated_auth_event_map.get((d.type, d.state_key)) != d
}
if not different_state:
logger.info("State res returned no new state")
return None
logger.info(
"After state res: updating auth_events with new state %s",
different_state.values(),
)
# take a copy of calculated_auth_event_map before we modify it.
auth_events = dict(calculated_auth_event_map)
auth_events.update(different_state)
return auth_events
async def _load_or_fetch_auth_events_for_event(
self, destination: str, event: EventBase
) -> Collection[EventBase]:
@ -1888,61 +1904,6 @@ class FederationEventHandler:
await self._auth_and_persist_outliers(room_id, remote_auth_events)
async def _update_context_for_auth_events(
self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]
) -> EventContext:
"""Update the state_ids in an event context after auth event resolution,
storing the changes as a new state group.
Args:
event: The event we're handling the context for
context: initial event context
auth_events: Events to update in the event context.
Returns:
new event context
"""
# exclude the state key of the new event from the current_state in the context.
if event.is_state():
event_key: Optional[Tuple[str, str]] = (event.type, event.state_key)
else:
event_key = None
state_updates = {
k: a.event_id for k, a in auth_events.items() if k != event_key
}
current_state_ids = await context.get_current_state_ids()
current_state_ids = dict(current_state_ids) # type: ignore
current_state_ids.update(state_updates)
prev_state_ids = await context.get_prev_state_ids()
prev_state_ids = dict(prev_state_ids)
prev_state_ids.update({k: a.event_id for k, a in auth_events.items()})
# create a new state group as a delta from the existing one.
prev_group = context.state_group
state_group = await self._state_storage_controller.store_state_group(
event.event_id,
event.room_id,
prev_group=prev_group,
delta_ids=state_updates,
current_state_ids=current_state_ids,
)
return EventContext.with_state(
storage=self._storage_controllers,
state_group=state_group,
state_group_before_event=context.state_group_before_event,
state_delta_due_to_event=state_updates,
prev_group=prev_group,
delta_ids=state_updates,
partial_state=context.partial_state,
)
async def _run_push_actions_and_persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False
) -> None:
@ -2093,6 +2054,10 @@ class FederationEventHandler:
event, event_pos, max_stream_token, extra_users=extra_users
)
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
# TODO retrieve the previous state, and exclude join -> join transitions
self._notifier.notify_user_joined_room(event.event_id, event.room_id)
def _sanity_check_event(self, ev: EventBase) -> None:
"""
Do some early sanity checks of a received event

View file

@ -26,7 +26,6 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.http import RequestTimedOutError
from synapse.http.client import SimpleHttpClient
from synapse.http.site import SynapseRequest
@ -163,8 +162,7 @@ class IdentityHandler:
sid: str,
mxid: str,
id_server: str,
id_access_token: Optional[str] = None,
use_v2: bool = True,
id_access_token: str,
) -> JsonDict:
"""Bind a 3PID to an identity server
@ -174,8 +172,7 @@ class IdentityHandler:
mxid: The MXID to bind the 3PID to
id_server: The domain of the identity server to query
id_access_token: The access token to authenticate to the identity
server with, if necessary. Required if use_v2 is true
use_v2: Whether to use v2 Identity Service API endpoints. Defaults to True
server with
Raises:
SynapseError: On any of the following conditions
@ -187,24 +184,15 @@ class IdentityHandler:
"""
logger.debug("Proxying threepid bind request for %s to %s", mxid, id_server)
# If an id_access_token is not supplied, force usage of v1
if id_access_token is None:
use_v2 = False
if not valid_id_server_location(id_server):
raise SynapseError(
400,
"id_server must be a valid hostname with optional port and path components",
)
# Decide which API endpoint URLs to use
headers = {}
bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid}
if use_v2:
bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,)
headers["Authorization"] = create_id_access_token_header(id_access_token) # type: ignore
else:
bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server,)
bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,)
headers = {"Authorization": create_id_access_token_header(id_access_token)}
try:
# Use the blacklisting http client as this call is only to identity servers
@ -223,21 +211,14 @@ class IdentityHandler:
return data
except HttpResponseException as e:
if e.code != 404 or not use_v2:
logger.error("3PID bind failed with Matrix error: %r", e)
raise e.to_synapse_error()
logger.error("3PID bind failed with Matrix error: %r", e)
raise e.to_synapse_error()
except RequestTimedOutError:
raise SynapseError(500, "Timed out contacting identity server")
except CodeMessageException as e:
data = json_decoder.decode(e.msg) # XXX WAT?
return data
logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url)
res = await self.bind_threepid(
client_secret, sid, mxid, id_server, id_access_token, use_v2=False
)
return res
async def try_unbind_threepid(self, mxid: str, threepid: dict) -> bool:
"""Attempt to remove a 3PID from an identity server, or if one is not provided, all
identity servers we're aware the binding is present on
@ -300,8 +281,8 @@ class IdentityHandler:
"id_server must be a valid hostname with optional port and path components",
)
url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
url_bytes = b"/_matrix/identity/api/v1/3pid/unbind"
url = "https://%s/_matrix/identity/v2/3pid/unbind" % (id_server,)
url_bytes = b"/_matrix/identity/v2/3pid/unbind"
content = {
"mxid": mxid,
@ -434,48 +415,6 @@ class IdentityHandler:
return session_id
async def requestEmailToken(
self,
id_server: str,
email: str,
client_secret: str,
send_attempt: int,
next_link: Optional[str] = None,
) -> JsonDict:
"""
Request an external server send an email on our behalf for the purposes of threepid
validation.
Args:
id_server: The identity server to proxy to
email: The email to send the message to
client_secret: The unique client_secret sends by the user
send_attempt: Which attempt this is
next_link: A link to redirect the user to once they submit the token
Returns:
The json response body from the server
"""
params = {
"email": email,
"client_secret": client_secret,
"send_attempt": send_attempt,
}
if next_link:
params["next_link"] = next_link
try:
data = await self.http_client.post_json_get_json(
id_server + "/_matrix/identity/api/v1/validate/email/requestToken",
params,
)
return data
except HttpResponseException as e:
logger.info("Proxied requestToken failed: %r", e)
raise e.to_synapse_error()
except RequestTimedOutError:
raise SynapseError(500, "Timed out contacting identity server")
async def requestMsisdnToken(
self,
id_server: str,
@ -549,18 +488,7 @@ class IdentityHandler:
validation_session = None
# Try to validate as email
if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
# Remote emails will only be used if a valid identity server is provided.
assert (
self.hs.config.registration.account_threepid_delegate_email is not None
)
# Ask our delegated email identity server
validation_session = await self.threepid_from_creds(
self.hs.config.registration.account_threepid_delegate_email,
threepid_creds,
)
elif self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
if self.hs.config.email.can_verify_email:
# Get a validated session matching these details
validation_session = await self.store.get_threepid_validation_session(
"email", client_secret, sid=sid, validated=True

View file

@ -463,6 +463,7 @@ class EventCreationHandler:
)
self._events_shard_config = self.config.worker.events_shard_config
self._instance_name = hs.get_instance_name()
self._notifier = hs.get_notifier()
self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state
@ -1452,7 +1453,12 @@ class EventCreationHandler:
if state_entry.state_group in self._external_cache_joined_hosts_updates:
return
joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry)
state = await state_entry.get_state(
self._storage_controllers.state, StateFilter.all()
)
joined_hosts = await self.store.get_joined_hosts(
event.room_id, state, state_entry
)
# Note that the expiry times must be larger than the expiry time in
# _external_cache_joined_hosts_updates.
@ -1554,6 +1560,16 @@ class EventCreationHandler:
requester, is_admin_redaction=is_admin_redaction
)
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
(
current_membership,
_,
) = await self.store.get_local_current_membership_for_user_in_room(
event.state_key, event.room_id
)
if current_membership != Membership.JOIN:
self._notifier.notify_user_joined_room(event.event_id, event.room_id)
await self._maybe_kick_guest_users(event, context)
validation_override = event.sender in self.config.meow.validation_override
@ -1861,13 +1877,8 @@ class EventCreationHandler:
# For each room we need to find a joined member we can use to send
# the dummy event with.
latest_event_ids = await self.store.get_prev_events_for_room(room_id)
members = await self.state.get_current_users_in_room(
room_id, latest_event_ids=latest_event_ids
)
members = await self.store.get_local_users_in_room(room_id)
for user_id in members:
if not self.hs.is_mine_id(user_id):
continue
requester = create_requester(user_id, authenticated_entity=self.server_name)
try:
event, context = await self.create_event(
@ -1878,7 +1889,6 @@ class EventCreationHandler:
"room_id": room_id,
"sender": user_id,
},
prev_event_ids=latest_event_ids,
)
event.internal_metadata.proactively_send = False

View file

@ -34,7 +34,6 @@ from typing import (
Callable,
Collection,
Dict,
FrozenSet,
Generator,
Iterable,
List,
@ -42,7 +41,6 @@ from typing import (
Set,
Tuple,
Type,
Union,
)
from prometheus_client import Counter
@ -68,7 +66,6 @@ from synapse.storage.databases.main import DataStore
from synapse.streams import EventSource
from synapse.types import JsonDict, StreamKeyType, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
@ -1656,15 +1653,18 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
# doesn't return. C.f. #5503.
return [], max_token
# Figure out which other users this user should receive updates for
users_interested_in = await self._get_interested_in(user, explicit_room_id)
# Figure out which other users this user should explicitly receive
# updates for
additional_users_interested_in = (
await self.get_presence_router().get_interested_users(user.to_string())
)
# We have a set of users that we're interested in the presence of. We want to
# cross-reference that with the users that have actually changed their presence.
# Check whether this user should see all user updates
if users_interested_in == PresenceRouter.ALL_USERS:
if additional_users_interested_in == PresenceRouter.ALL_USERS:
# Provide presence state for all users
presence_updates = await self._filter_all_presence_updates_for_user(
user_id, include_offline, from_key
@ -1673,34 +1673,47 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
return presence_updates, max_token
# Make mypy happy. users_interested_in should now be a set
assert not isinstance(users_interested_in, str)
assert not isinstance(additional_users_interested_in, str)
# We always care about our own presence.
additional_users_interested_in.add(user_id)
if explicit_room_id:
user_ids = await self.store.get_users_in_room(explicit_room_id)
additional_users_interested_in.update(user_ids)
# The set of users that we're interested in and that have had a presence update.
# We'll actually pull the presence updates for these users at the end.
interested_and_updated_users: Union[Set[str], FrozenSet[str]] = set()
interested_and_updated_users: Collection[str]
if from_key is not None:
# First get all users that have had a presence update
updated_users = stream_change_cache.get_all_entities_changed(from_key)
# Cross-reference users we're interested in with those that have had updates.
# Use a slightly-optimised method for processing smaller sets of updates.
if updated_users is not None and len(updated_users) < 500:
# For small deltas, it's quicker to get all changes and then
# cross-reference with the users we're interested in
if updated_users is not None:
# If we have the full list of changes for presence we can
# simply check which ones share a room with the user.
get_updates_counter.labels("stream").inc()
for other_user_id in updated_users:
if other_user_id in users_interested_in:
# mypy thinks this variable could be a FrozenSet as it's possibly set
# to one in the `get_entities_changed` call below, and `add()` is not
# method on a FrozenSet. That doesn't affect us here though, as
# `interested_and_updated_users` is clearly a set() above.
interested_and_updated_users.add(other_user_id) # type: ignore
sharing_users = await self.store.do_users_share_a_room(
user_id, updated_users
)
interested_and_updated_users = (
sharing_users.union(additional_users_interested_in)
).intersection(updated_users)
else:
# Too many possible updates. Find all users we can see and check
# if any of them have changed.
get_updates_counter.labels("full").inc()
users_interested_in = (
await self.store.get_users_who_share_room_with_user(user_id)
)
users_interested_in.update(additional_users_interested_in)
interested_and_updated_users = (
stream_change_cache.get_entities_changed(
users_interested_in, from_key
@ -1709,7 +1722,10 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
else:
# No from_key has been specified. Return the presence for all users
# this user is interested in
interested_and_updated_users = users_interested_in
interested_and_updated_users = (
await self.store.get_users_who_share_room_with_user(user_id)
)
interested_and_updated_users.update(additional_users_interested_in)
# Retrieve the current presence state for each user
users_to_state = await self.get_presence_handler().current_state_for_users(
@ -1804,62 +1820,6 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
def get_current_key(self) -> int:
return self.store.get_current_presence_token()
@cached(num_args=2, cache_context=True)
async def _get_interested_in(
self,
user: UserID,
explicit_room_id: Optional[str] = None,
cache_context: Optional[_CacheContext] = None,
) -> Union[Set[str], str]:
"""Returns the set of users that the given user should see presence
updates for.
Args:
user: The user to retrieve presence updates for.
explicit_room_id: The users that are in the room will be returned.
Returns:
A set of user IDs to return presence updates for, or "ALL" to return all
known updates.
"""
user_id = user.to_string()
users_interested_in = set()
users_interested_in.add(user_id) # So that we receive our own presence
# cache_context isn't likely to ever be None due to the @cached decorator,
# but we can't have a non-optional argument after the optional argument
# explicit_room_id either. Assert cache_context is not None so we can use it
# without mypy complaining.
assert cache_context
# Check with the presence router whether we should poll additional users for
# their presence information
additional_users = await self.get_presence_router().get_interested_users(
user.to_string()
)
if additional_users == PresenceRouter.ALL_USERS:
# If the module requested that this user see the presence updates of *all*
# users, then simply return that instead of calculating what rooms this
# user shares
return PresenceRouter.ALL_USERS
# Add the additional users from the router
users_interested_in.update(additional_users)
# Find the users who share a room with this user
users_who_share_room = await self.store.get_users_who_share_room_with_user(
user_id, on_invalidate=cache_context.invalidate
)
users_interested_in.update(users_who_share_room)
if explicit_room_id:
user_ids = await self.store.get_users_in_room(
explicit_room_id, on_invalidate=cache_context.invalidate
)
users_interested_in.update(user_ids)
return users_interested_in
def handle_timeouts(
user_states: List[UserPresenceState],

View file

@ -901,7 +901,11 @@ class RoomCreationHandler:
# override any attempt to set room versions via the creation_content
creation_content["room_version"] = room_version.identifier
last_stream_id = await self._send_events_for_new_room(
(
last_stream_id,
last_sent_event_id,
depth,
) = await self._send_events_for_new_room(
requester,
room_id,
preset_config=preset_config,
@ -917,7 +921,7 @@ class RoomCreationHandler:
if "name" in config:
name = config["name"]
(
_,
name_event,
last_stream_id,
) = await self.event_creation_handler.create_and_send_nonmember_event(
requester,
@ -929,12 +933,16 @@ class RoomCreationHandler:
"content": {"name": name},
},
ratelimit=False,
prev_event_ids=[last_sent_event_id],
depth=depth,
)
last_sent_event_id = name_event.event_id
depth += 1
if "topic" in config:
topic = config["topic"]
(
_,
topic_event,
last_stream_id,
) = await self.event_creation_handler.create_and_send_nonmember_event(
requester,
@ -946,7 +954,11 @@ class RoomCreationHandler:
"content": {"topic": topic},
},
ratelimit=False,
prev_event_ids=[last_sent_event_id],
depth=depth,
)
last_sent_event_id = topic_event.event_id
depth += 1
# we avoid dropping the lock between invites, as otherwise joins can
# start coming in and making the createRoom slow.
@ -961,7 +973,7 @@ class RoomCreationHandler:
for invitee in invite_list:
(
_,
member_event_id,
last_stream_id,
) = await self.room_member_handler.update_membership_locked(
requester,
@ -971,7 +983,11 @@ class RoomCreationHandler:
ratelimit=False,
content=content,
new_room=True,
prev_event_ids=[last_sent_event_id],
depth=depth,
)
last_sent_event_id = member_event_id
depth += 1
for invite_3pid in invite_3pid_list:
id_server = invite_3pid["id_server"]
@ -980,7 +996,10 @@ class RoomCreationHandler:
medium = invite_3pid["medium"]
# Note that do_3pid_invite can raise a ShadowBanError, but this was
# handled above by emptying invite_3pid_list.
last_stream_id = await self.hs.get_room_member_handler().do_3pid_invite(
(
member_event_id,
last_stream_id,
) = await self.hs.get_room_member_handler().do_3pid_invite(
room_id,
requester.user,
medium,
@ -989,7 +1008,11 @@ class RoomCreationHandler:
requester,
txn_id=None,
id_access_token=id_access_token,
prev_event_ids=[last_sent_event_id],
depth=depth,
)
last_sent_event_id = member_event_id
depth += 1
result = {"room_id": room_id}
@ -1017,20 +1040,22 @@ class RoomCreationHandler:
power_level_content_override: Optional[JsonDict] = None,
creator_join_profile: Optional[JsonDict] = None,
ratelimit: bool = True,
) -> int:
) -> Tuple[int, str, int]:
"""Sends the initial events into a new room.
`power_level_content_override` doesn't apply when initial state has
power level state event content.
Returns:
The stream_id of the last event persisted.
A tuple containing the stream ID, event ID and depth of the last
event sent to the room.
"""
creator_id = creator.user.to_string()
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
depth = 1
last_sent_event_id: Optional[str] = None
def create(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict:
@ -1043,6 +1068,7 @@ class RoomCreationHandler:
async def send(etype: str, content: JsonDict, **kwargs: Any) -> int:
nonlocal last_sent_event_id
nonlocal depth
event = create(etype, content, **kwargs)
logger.debug("Sending %s in new room", etype)
@ -1059,9 +1085,11 @@ class RoomCreationHandler:
# Note: we don't pass state_event_ids here because this triggers
# an additional query per event to look them up from the events table.
prev_event_ids=[last_sent_event_id] if last_sent_event_id else [],
depth=depth,
)
last_sent_event_id = sent_event.event_id
depth += 1
return last_stream_id
@ -1087,6 +1115,7 @@ class RoomCreationHandler:
content=creator_join_profile,
new_room=True,
prev_event_ids=[last_sent_event_id],
depth=depth,
)
last_sent_event_id = member_event_id
@ -1180,7 +1209,7 @@ class RoomCreationHandler:
content={"algorithm": RoomEncryptionAlgorithms.DEFAULT},
)
return last_sent_stream_id
return last_sent_stream_id, last_sent_event_id, depth
def _generate_room_id(self) -> str:
"""Generates a random room ID.
@ -1367,6 +1396,7 @@ class TimestampLookupHandler:
self.store = hs.get_datastores().main
self.state_handler = hs.get_state_handler()
self.federation_client = hs.get_federation_client()
self.federation_event_handler = hs.get_federation_event_handler()
self._storage_controllers = hs.get_storage_controllers()
async def get_event_for_timestamp(
@ -1462,38 +1492,68 @@ class TimestampLookupHandler:
remote_response,
)
# TODO: Do we want to persist this as an extremity?
# TODO: I think ideally, we would try to backfill from
# this event and run this whole
# `get_event_for_timestamp` function again to make sure
# they didn't give us an event from their gappy history.
remote_event_id = remote_response.event_id
origin_server_ts = remote_response.origin_server_ts
remote_origin_server_ts = remote_response.origin_server_ts
# Backfill this event so we can get a pagination token for
# it with `/context` and paginate `/messages` from this
# point.
#
# TODO: The requested timestamp may lie in a part of the
# event graph that the remote server *also* didn't have,
# in which case they will have returned another event
# which may be nowhere near the requested timestamp. In
# the future, we may need to reconcile that gap and ask
# other homeservers, and/or extend `/timestamp_to_event`
# to return events on *both* sides of the timestamp to
# help reconcile the gap faster.
remote_event = (
await self.federation_event_handler.backfill_event_id(
domain, room_id, remote_event_id
)
)
# XXX: When we see that the remote server is not trustworthy,
# maybe we should not ask them first in the future.
if remote_origin_server_ts != remote_event.origin_server_ts:
logger.info(
"get_event_for_timestamp: Remote server (%s) claimed that remote_event_id=%s occured at remote_origin_server_ts=%s but that isn't true (actually occured at %s). Their claims are dubious and we should consider not trusting them.",
domain,
remote_event_id,
remote_origin_server_ts,
remote_event.origin_server_ts,
)
# Only return the remote event if it's closer than the local event
if not local_event or (
abs(origin_server_ts - timestamp)
abs(remote_event.origin_server_ts - timestamp)
< abs(local_event.origin_server_ts - timestamp)
):
return remote_event_id, origin_server_ts
logger.info(
"get_event_for_timestamp: returning remote_event_id=%s (%s) since it's closer to timestamp=%s than local_event=%s (%s)",
remote_event_id,
remote_event.origin_server_ts,
timestamp,
local_event.event_id if local_event else None,
local_event.origin_server_ts if local_event else None,
)
return remote_event_id, remote_origin_server_ts
except (HttpResponseException, InvalidResponseError) as ex:
# Let's not put a high priority on some other homeserver
# failing to respond or giving a random response
logger.debug(
"Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
"get_event_for_timestamp: Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
domain,
type(ex).__name__,
ex,
ex.args,
)
except Exception as ex:
except Exception:
# But we do want to see some exceptions in our code
logger.warning(
"Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
"get_event_for_timestamp: Failed to fetch /timestamp_to_event from %s because of exception",
domain,
type(ex).__name__,
ex,
ex.args,
exc_info=True,
)
# To appease mypy, we have to add both of these conditions to check for

View file

@ -94,12 +94,29 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
burst_count=hs.config.ratelimiting.rc_joins_local.burst_count,
)
# Tracks joins from local users to rooms this server isn't a member of.
# I.e. joins this server makes by requesting /make_join /send_join from
# another server.
self._join_rate_limiter_remote = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second,
burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
)
# TODO: find a better place to keep this Ratelimiter.
# It needs to be
# - written to by event persistence code
# - written to by something which can snoop on replication streams
# - read by the RoomMemberHandler to rate limit joins from local users
# - read by the FederationServer to rate limit make_joins and send_joins from
# other homeservers
# I wonder if a homeserver-wide collection of rate limiters might be cleaner?
self._join_rate_per_room_limiter = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_joins_per_room.per_second,
burst_count=hs.config.ratelimiting.rc_joins_per_room.burst_count,
)
# Ratelimiter for invites, keyed by room (across all issuers, all
# recipients).
@ -136,6 +153,18 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
)
self.request_ratelimiter = hs.get_request_ratelimiter()
hs.get_notifier().add_new_join_in_room_callback(self._on_user_joined_room)
def _on_user_joined_room(self, event_id: str, room_id: str) -> None:
"""Notify the rate limiter that a room join has occurred.
Use this to inform the RoomMemberHandler about joins that have either
- taken place on another homeserver, or
- on another worker in this homeserver.
Joins actioned by this worker should use the usual `ratelimit` method, which
checks the limit and increments the counter in one go.
"""
self._join_rate_per_room_limiter.record_action(requester=None, key=room_id)
@abc.abstractmethod
async def _remote_join(
@ -285,6 +314,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None,
state_event_ids: Optional[List[str]] = None,
depth: Optional[int] = None,
txn_id: Optional[str] = None,
ratelimit: bool = True,
content: Optional[dict] = None,
@ -315,6 +345,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
prev_events are set so we need to set them ourself via this argument.
This should normally be left as None, which will cause the auth_event_ids
to be calculated based on the room state at the prev_events.
depth: Override the depth used to order the event in the DAG.
Should normally be set to None, which will cause the depth to be calculated
based on the prev_events.
txn_id:
ratelimit:
@ -370,6 +403,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
allow_no_prev_events=allow_no_prev_events,
prev_event_ids=prev_event_ids,
state_event_ids=state_event_ids,
depth=depth,
require_consent=require_consent,
outlier=outlier,
historical=historical,
@ -391,6 +425,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# up blocking profile updates.
if newly_joined and ratelimit:
await self._join_rate_limiter_local.ratelimit(requester)
await self._join_rate_per_room_limiter.ratelimit(
requester, key=room_id, update=False
)
result_event = await self.event_creation_handler.handle_new_client_event(
requester,
@ -466,6 +503,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None,
state_event_ids: Optional[List[str]] = None,
depth: Optional[int] = None,
) -> Tuple[str, int]:
"""Update a user's membership in a room.
@ -501,6 +539,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
prev_events are set so we need to set them ourself via this argument.
This should normally be left as None, which will cause the auth_event_ids
to be calculated based on the room state at the prev_events.
depth: Override the depth used to order the event in the DAG.
Should normally be set to None, which will cause the depth to be calculated
based on the prev_events.
Returns:
A tuple of the new event ID and stream ID.
@ -540,6 +581,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
allow_no_prev_events=allow_no_prev_events,
prev_event_ids=prev_event_ids,
state_event_ids=state_event_ids,
depth=depth,
)
return result
@ -562,6 +604,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None,
state_event_ids: Optional[List[str]] = None,
depth: Optional[int] = None,
) -> Tuple[str, int]:
"""Helper for update_membership.
@ -599,6 +642,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
prev_events are set so we need to set them ourself via this argument.
This should normally be left as None, which will cause the auth_event_ids
to be calculated based on the room state at the prev_events.
depth: Override the depth used to order the event in the DAG.
Should normally be set to None, which will cause the depth to be calculated
based on the prev_events.
Returns:
A tuple of the new event ID and stream ID.
@ -732,6 +778,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
allow_no_prev_events=allow_no_prev_events,
prev_event_ids=prev_event_ids,
state_event_ids=state_event_ids,
depth=depth,
content=content,
require_consent=require_consent,
outlier=outlier,
@ -740,14 +787,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
latest_event_ids = await self.store.get_prev_events_for_room(room_id)
current_state_ids = await self.state_handler.get_current_state_ids(
room_id, latest_event_ids=latest_event_ids
state_before_join = await self.state_handler.compute_state_after_events(
room_id, latest_event_ids
)
# TODO: Refactor into dictionary of explicitly allowed transitions
# between old and new state, with specific error messages for some
# transitions and generic otherwise
old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
old_state_id = state_before_join.get((EventTypes.Member, target.to_string()))
if old_state_id:
old_state = await self.store.get_event(old_state_id, allow_none=True)
old_membership = old_state.content.get("membership") if old_state else None
@ -798,11 +845,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if action == "kick":
raise AuthError(403, "The target user is not in the room")
is_host_in_room = await self._is_host_in_room(current_state_ids)
is_host_in_room = await self._is_host_in_room(state_before_join)
if effective_membership_state == Membership.JOIN:
if requester.is_guest:
guest_can_join = await self._can_guest_join(current_state_ids)
guest_can_join = await self._can_guest_join(state_before_join)
if not guest_can_join:
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
@ -840,13 +887,23 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Check if a remote join should be performed.
remote_join, remote_room_hosts = await self._should_perform_remote_join(
target.to_string(), room_id, remote_room_hosts, content, is_host_in_room
target.to_string(),
room_id,
remote_room_hosts,
content,
is_host_in_room,
state_before_join,
)
if remote_join:
if ratelimit:
await self._join_rate_limiter_remote.ratelimit(
requester,
)
await self._join_rate_per_room_limiter.ratelimit(
requester,
key=room_id,
update=False,
)
inviter = await self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
@ -967,6 +1024,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
ratelimit=ratelimit,
prev_event_ids=latest_event_ids,
state_event_ids=state_event_ids,
depth=depth,
content=content,
require_consent=require_consent,
outlier=outlier,
@ -979,6 +1037,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
remote_room_hosts: List[str],
content: JsonDict,
is_host_in_room: bool,
state_before_join: StateMap[str],
) -> Tuple[bool, List[str]]:
"""
Check whether the server should do a remote join (as opposed to a local
@ -998,6 +1057,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
content: The content to use as the event body of the join. This may
be modified.
is_host_in_room: True if the host is in the room.
state_before_join: The state before the join event (i.e. the resolution of
the states after its parent events).
Returns:
A tuple of:
@ -1014,20 +1075,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# If the host is in the room, but not one of the authorised hosts
# for restricted join rules, a remote join must be used.
room_version = await self.store.get_room_version(room_id)
current_state_ids = await self._storage_controllers.state.get_current_state_ids(
room_id
)
# If restricted join rules are not being used, a local join can always
# be used.
if not await self.event_auth_handler.has_restricted_join_rules(
current_state_ids, room_version
state_before_join, room_version
):
return False, []
# If the user is invited to the room or already joined, the join
# event can always be issued locally.
prev_member_event_id = current_state_ids.get((EventTypes.Member, user_id), None)
prev_member_event_id = state_before_join.get((EventTypes.Member, user_id), None)
prev_member_event = None
if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id)
@ -1042,10 +1100,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
#
# If not, generate a new list of remote hosts based on which
# can issue invites.
event_map = await self.store.get_events(current_state_ids.values())
event_map = await self.store.get_events(state_before_join.values())
current_state = {
state_key: event_map[event_id]
for state_key, event_id in current_state_ids.items()
for state_key, event_id in state_before_join.items()
}
allowed_servers = get_servers_from_users(
get_users_which_can_issue_invite(current_state)
@ -1059,7 +1117,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Ensure the member should be allowed access via membership in a room.
await self.event_auth_handler.check_restricted_join_rules(
current_state_ids, room_version, user_id, prev_member_event
state_before_join, room_version, user_id, prev_member_event
)
# If this is going to be a local join, additional information must
@ -1069,7 +1127,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
EventContentFields.AUTHORISING_USER
] = await self.event_auth_handler.get_user_which_could_invite(
room_id,
current_state_ids,
state_before_join,
)
return False, []
@ -1322,7 +1380,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
requester: Requester,
txn_id: Optional[str],
id_access_token: Optional[str] = None,
) -> int:
prev_event_ids: Optional[List[str]] = None,
depth: Optional[int] = None,
) -> Tuple[str, int]:
"""Invite a 3PID to a room.
Args:
@ -1335,9 +1395,13 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
txn_id: The transaction ID this is part of, or None if this is not
part of a transaction.
id_access_token: The optional identity server access token.
depth: Override the depth used to order the event in the DAG.
prev_event_ids: The event IDs to use as the prev events
Should normally be set to None, which will cause the depth to be calculated
based on the prev_events.
Returns:
The new stream ID.
Tuple of event ID and stream ordering position
Raises:
ShadowBanError if the requester has been shadow-banned.
@ -1383,7 +1447,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# We don't check the invite against the spamchecker(s) here (through
# user_may_invite) because we'll do it further down the line anyway (in
# update_membership_locked).
_, stream_id = await self.update_membership(
event_id, stream_id = await self.update_membership(
requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
)
else:
@ -1402,7 +1466,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
additional_fields=spam_check[1],
)
stream_id = await self._make_and_store_3pid_invite(
event, stream_id = await self._make_and_store_3pid_invite(
requester,
id_server,
medium,
@ -1411,9 +1475,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
inviter,
txn_id=txn_id,
id_access_token=id_access_token,
prev_event_ids=prev_event_ids,
depth=depth,
)
event_id = event.event_id
return stream_id
return event_id, stream_id
async def _make_and_store_3pid_invite(
self,
@ -1425,7 +1492,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
user: UserID,
txn_id: Optional[str],
id_access_token: Optional[str] = None,
) -> int:
prev_event_ids: Optional[List[str]] = None,
depth: Optional[int] = None,
) -> Tuple[EventBase, int]:
room_state = await self._storage_controllers.state.get_current_state(
room_id,
StateFilter.from_types(
@ -1518,8 +1587,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
},
ratelimit=False,
txn_id=txn_id,
prev_event_ids=prev_event_ids,
depth=depth,
)
return stream_id
return event, stream_id
async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool:
# Have we just created the room, and is this about to be the very

View file

@ -23,10 +23,12 @@ from pkg_resources import parse_version
import twisted
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IOpenSSLContextFactory, IReactorTCP
from twisted.internet.interfaces import IOpenSSLContextFactory
from twisted.internet.ssl import optionsForClientTLS
from twisted.mail.smtp import ESMTPSender, ESMTPSenderFactory
from synapse.logging.context import make_deferred_yieldable
from synapse.types import ISynapseReactor
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -48,7 +50,7 @@ class _NoTLSESMTPSender(ESMTPSender):
async def _sendmail(
reactor: IReactorTCP,
reactor: ISynapseReactor,
smtphost: str,
smtpport: int,
from_addr: str,
@ -59,6 +61,7 @@ async def _sendmail(
require_auth: bool = False,
require_tls: bool = False,
enable_tls: bool = True,
force_tls: bool = False,
) -> None:
"""A simple wrapper around ESMTPSenderFactory, to allow substitution in tests
@ -73,8 +76,9 @@ async def _sendmail(
password: password to give when authenticating
require_auth: if auth is not offered, fail the request
require_tls: if TLS is not offered, fail the reqest
enable_tls: True to enable TLS. If this is False and require_tls is True,
enable_tls: True to enable STARTTLS. If this is False and require_tls is True,
the request will fail.
force_tls: True to enable Implicit TLS.
"""
msg = BytesIO(msg_bytes)
d: "Deferred[object]" = Deferred()
@ -105,13 +109,23 @@ async def _sendmail(
# set to enable TLS.
factory = build_sender_factory(hostname=smtphost if enable_tls else None)
reactor.connectTCP(
smtphost,
smtpport,
factory,
timeout=30,
bindAddress=None,
)
if force_tls:
reactor.connectSSL(
smtphost,
smtpport,
factory,
optionsForClientTLS(smtphost),
timeout=30,
bindAddress=None,
)
else:
reactor.connectTCP(
smtphost,
smtpport,
factory,
timeout=30,
bindAddress=None,
)
await make_deferred_yieldable(d)
@ -132,6 +146,7 @@ class SendEmailHandler:
self._smtp_pass = passwd.encode("utf-8") if passwd is not None else None
self._require_transport_security = hs.config.email.require_transport_security
self._enable_tls = hs.config.email.enable_smtp_tls
self._force_tls = hs.config.email.force_tls
self._sendmail = _sendmail
@ -189,4 +204,5 @@ class SendEmailHandler:
require_auth=self._smtp_user is not None,
require_tls=self._require_transport_security,
enable_tls=self._enable_tls,
force_tls=self._force_tls,
)

View file

@ -19,7 +19,6 @@ from twisted.web.client import PartialDownloadError
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.util import json_decoder
if TYPE_CHECKING:
@ -153,7 +152,7 @@ class _BaseThreepidAuthChecker:
logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
# msisdns are currently always ThreepidBehaviour.REMOTE
# msisdns are currently always verified via the IS
if medium == "msisdn":
if not self.hs.config.registration.account_threepid_delegate_msisdn:
raise SynapseError(
@ -164,18 +163,7 @@ class _BaseThreepidAuthChecker:
threepid_creds,
)
elif medium == "email":
if (
self.hs.config.email.threepid_behaviour_email
== ThreepidBehaviour.REMOTE
):
assert self.hs.config.registration.account_threepid_delegate_email
threepid = await identity_handler.threepid_from_creds(
self.hs.config.registration.account_threepid_delegate_email,
threepid_creds,
)
elif (
self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL
):
if self.hs.config.email.can_verify_email:
threepid = None
row = await self.store.get_threepid_validation_session(
medium,
@ -227,10 +215,7 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec
_BaseThreepidAuthChecker.__init__(self, hs)
def is_enabled(self) -> bool:
return self.hs.config.email.threepid_behaviour_email in (
ThreepidBehaviour.REMOTE,
ThreepidBehaviour.LOCAL,
)
return self.hs.config.email.can_verify_email
async def check_auth(self, authdict: dict, clientip: str) -> Any:
return await self._check_threepid("email", authdict)