Implement MSC3706: partial state in /send_join response (#11967)

* Make `get_auth_chain_ids` return a Set

It has a set internally, and a set is often useful where it gets used, so let's
avoid converting to an intermediate list.

* Minor refactors in `on_send_join_request`

A little bit of non-functional groundwork

* Implement MSC3706: partial state in /send_join response
This commit is contained in:
Richard van der Hoff 2022-02-12 10:44:16 +00:00 committed by GitHub
parent b2b971f28a
commit 63c46349c4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 262 additions and 21 deletions

View file

@ -61,3 +61,6 @@ class ExperimentalConfig(Config):
self.msc2409_to_device_messages_enabled: bool = experimental.get(
"msc2409_to_device_messages_enabled", False
)
# MSC3706 (server-side support for partial state in /send_join responses)
self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False)

View file

@ -20,6 +20,7 @@ from typing import (
Any,
Awaitable,
Callable,
Collection,
Dict,
Iterable,
List,
@ -64,7 +65,7 @@ from synapse.replication.http.federation import (
ReplicationGetQueryRestServlet,
)
from synapse.storage.databases.main.lock import Lock
from synapse.types import JsonDict, get_domain_from_id
from synapse.types import JsonDict, StateMap, get_domain_from_id
from synapse.util import json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
from synapse.util.caches.response_cache import ResponseCache
@ -571,7 +572,7 @@ class FederationServer(FederationBase):
) -> JsonDict:
state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids)
return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
return {"pdu_ids": state_ids, "auth_chain_ids": list(auth_chain_ids)}
async def _on_context_state_request_compute(
self, room_id: str, event_id: Optional[str]
@ -645,27 +646,61 @@ class FederationServer(FederationBase):
return {"event": ret_pdu.get_pdu_json(time_now)}
async def on_send_join_request(
self, origin: str, content: JsonDict, room_id: str
self,
origin: str,
content: JsonDict,
room_id: str,
caller_supports_partial_state: bool = False,
) -> Dict[str, Any]:
event, context = await self._on_send_membership_event(
origin, content, Membership.JOIN, room_id
)
prev_state_ids = await context.get_prev_state_ids()
state_ids = list(prev_state_ids.values())
auth_chain = await self.store.get_auth_chain(room_id, state_ids)
state = await self.store.get_events(state_ids)
state_event_ids: Collection[str]
servers_in_room: Optional[Collection[str]]
if caller_supports_partial_state:
state_event_ids = _get_event_ids_for_partial_state_join(
event, prev_state_ids
)
servers_in_room = await self.state.get_hosts_in_room_at_events(
room_id, event_ids=event.prev_event_ids()
)
else:
state_event_ids = prev_state_ids.values()
servers_in_room = None
auth_chain_event_ids = await self.store.get_auth_chain_ids(
room_id, state_event_ids
)
# if the caller has opted in, we can omit any auth_chain events which are
# already in state_event_ids
if caller_supports_partial_state:
auth_chain_event_ids.difference_update(state_event_ids)
auth_chain_events = await self.store.get_events_as_list(auth_chain_event_ids)
state_events = await self.store.get_events_as_list(state_event_ids)
# we try to do all the async stuff before this point, so that time_now is as
# accurate as possible.
time_now = self._clock.time_msec()
event_json = event.get_pdu_json()
return {
event_json = event.get_pdu_json(time_now)
resp = {
# TODO Remove the unstable prefix when servers have updated.
"org.matrix.msc3083.v2.event": event_json,
"event": event_json,
"state": [p.get_pdu_json(time_now) for p in state.values()],
"auth_chain": [p.get_pdu_json(time_now) for p in auth_chain],
"state": [p.get_pdu_json(time_now) for p in state_events],
"auth_chain": [p.get_pdu_json(time_now) for p in auth_chain_events],
"org.matrix.msc3706.partial_state": caller_supports_partial_state,
}
if servers_in_room is not None:
resp["org.matrix.msc3706.servers_in_room"] = list(servers_in_room)
return resp
async def on_make_leave_request(
self, origin: str, room_id: str, user_id: str
) -> Dict[str, Any]:
@ -1339,3 +1374,39 @@ class FederationHandlerRegistry:
# error.
logger.warning("No handler registered for query type %s", query_type)
raise NotFoundError("No handler for Query type '%s'" % (query_type,))
def _get_event_ids_for_partial_state_join(
join_event: EventBase,
prev_state_ids: StateMap[str],
) -> Collection[str]:
"""Calculate state to be retuned in a partial_state send_join
Args:
join_event: the join event being send_joined
prev_state_ids: the event ids of the state before the join
Returns:
the event ids to be returned
"""
# return all non-member events
state_event_ids = {
event_id
for (event_type, state_key), event_id in prev_state_ids.items()
if event_type != EventTypes.Member
}
# we also need the current state of the current user (it's going to
# be an auth event for the new join, so we may as well return it)
current_membership_event_id = prev_state_ids.get(
(EventTypes.Member, join_event.state_key)
)
if current_membership_event_id is not None:
state_event_ids.add(current_membership_event_id)
# TODO: return a few more members:
# - those with invites
# - those that are kicked? / banned
return state_event_ids

View file

@ -412,6 +412,16 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet):
PREFIX = FEDERATION_V2_PREFIX
def __init__(
self,
hs: "HomeServer",
authenticator: Authenticator,
ratelimiter: FederationRateLimiter,
server_name: str,
):
super().__init__(hs, authenticator, ratelimiter, server_name)
self._msc3706_enabled = hs.config.experimental.msc3706_enabled
async def on_PUT(
self,
origin: str,
@ -422,7 +432,15 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet):
) -> Tuple[int, JsonDict]:
# TODO(paul): assert that event_id parsed from path actually
# match those given in content
result = await self.handler.on_send_join_request(origin, content, room_id)
partial_state = False
if self._msc3706_enabled:
partial_state = parse_boolean_from_args(
query, "org.matrix.msc3706.partial_state", default=False
)
result = await self.handler.on_send_join_request(
origin, content, room_id, caller_supports_partial_state=partial_state
)
return 200, result

View file

@ -121,7 +121,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
room_id: str,
event_ids: Collection[str],
include_given: bool = False,
) -> List[str]:
) -> Set[str]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
@ -130,7 +130,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
include_given: include the given events in result
Returns:
list of event_ids
set of event_ids
"""
# Check if we have indexed the room so we can use the chain cover
@ -159,7 +159,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
def _get_auth_chain_ids_using_cover_index_txn(
self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
) -> List[str]:
) -> Set[str]:
"""Calculates the auth chain IDs using the chain index."""
# First we look up the chain ID/sequence numbers for the given events.
@ -272,11 +272,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
txn.execute(sql, (chain_id, max_no))
results.update(r for r, in txn)
return list(results)
return results
def _get_auth_chain_ids_txn(
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
) -> List[str]:
) -> Set[str]:
"""Calculates the auth chain IDs.
This is used when we don't have a cover index for the room.
@ -331,7 +331,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
front = new_front
results.update(front)
return list(results)
return results
async def get_auth_chain_difference(
self, room_id: str, state_sets: List[Set[str]]