Rip out auth-event reconciliation code (#12943)

There is a corner in `_check_event_auth` (long known as "the weird corner") where, if we get an event with auth_events which don't match those we were expecting, we attempt to resolve the diffence between our state and the remote's with a state resolution.

This isn't specced, and there's general agreement we shouldn't be doing it.

However, it turns out that the faster-joins code was relying on it, so we need to introduce something similar (but rather simpler) for that.
This commit is contained in:
Richard van der Hoff 2022-07-14 22:52:26 +01:00 committed by GitHub
parent df55b377be
commit fe15a865a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 90 additions and 377 deletions

1
changelog.d/12943.misc Normal file
View File

@ -0,0 +1 @@
Remove code which incorrectly attempted to reconcile state with remote servers when processing incoming events.

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections
import itertools import itertools
import logging import logging
from http import HTTPStatus from http import HTTPStatus
@ -347,7 +348,7 @@ class FederationEventHandler:
event.internal_metadata.send_on_behalf_of = origin event.internal_metadata.send_on_behalf_of = origin
context = await self._state_handler.compute_event_context(event) context = await self._state_handler.compute_event_context(event)
context = await self._check_event_auth(origin, event, context) await self._check_event_auth(origin, event, context)
if context.rejected: if context.rejected:
raise SynapseError( raise SynapseError(
403, f"{event.membership} event was rejected", Codes.FORBIDDEN 403, f"{event.membership} event was rejected", Codes.FORBIDDEN
@ -485,7 +486,7 @@ class FederationEventHandler:
partial_state=partial_state, partial_state=partial_state,
) )
context = await self._check_event_auth(origin, event, context) await self._check_event_auth(origin, event, context)
if context.rejected: if context.rejected:
raise SynapseError(400, "Join event was rejected") raise SynapseError(400, "Join event was rejected")
@ -1116,11 +1117,7 @@ class FederationEventHandler:
state_ids_before_event=state_ids, state_ids_before_event=state_ids,
) )
try: try:
context = await self._check_event_auth( await self._check_event_auth(origin, event, context)
origin,
event,
context,
)
except AuthError as e: except AuthError as e:
# This happens only if we couldn't find the auth events. We'll already have # This happens only if we couldn't find the auth events. We'll already have
# logged a warning, so now we just convert to a FederationError. # logged a warning, so now we just convert to a FederationError.
@ -1495,11 +1492,8 @@ class FederationEventHandler:
) )
async def _check_event_auth( async def _check_event_auth(
self, self, origin: str, event: EventBase, context: EventContext
origin: str, ) -> None:
event: EventBase,
context: EventContext,
) -> EventContext:
""" """
Checks whether an event should be rejected (for failing auth checks). Checks whether an event should be rejected (for failing auth checks).
@ -1509,9 +1503,6 @@ class FederationEventHandler:
context: context:
The event context. The event context.
Returns:
The updated context object.
Raises: Raises:
AuthError if we were unable to find copies of the event's auth events. AuthError if we were unable to find copies of the event's auth events.
(Most other failures just cause us to set `context.rejected`.) (Most other failures just cause us to set `context.rejected`.)
@ -1526,7 +1517,7 @@ class FederationEventHandler:
logger.warning("While validating received event %r: %s", event, e) logger.warning("While validating received event %r: %s", event, e)
# TODO: use a different rejected reason here? # TODO: use a different rejected reason here?
context.rejected = RejectedReason.AUTH_ERROR context.rejected = RejectedReason.AUTH_ERROR
return context return
# next, check that we have all of the event's auth events. # next, check that we have all of the event's auth events.
# #
@ -1538,6 +1529,9 @@ class FederationEventHandler:
) )
# ... and check that the event passes auth at those auth events. # ... and check that the event passes auth at those auth events.
# https://spec.matrix.org/v1.3/server-server-api/#checks-performed-on-receipt-of-a-pdu:
# 4. Passes authorization rules based on the events auth events,
# otherwise it is rejected.
try: try:
await check_state_independent_auth_rules(self._store, event) await check_state_independent_auth_rules(self._store, event)
check_state_dependent_auth_rules(event, claimed_auth_events) check_state_dependent_auth_rules(event, claimed_auth_events)
@ -1546,55 +1540,90 @@ class FederationEventHandler:
"While checking auth of %r against auth_events: %s", event, e "While checking auth of %r against auth_events: %s", event, e
) )
context.rejected = RejectedReason.AUTH_ERROR context.rejected = RejectedReason.AUTH_ERROR
return context return
# now check auth against what we think the auth events *should* be. # now check the auth rules pass against the room state before the event
# https://spec.matrix.org/v1.3/server-server-api/#checks-performed-on-receipt-of-a-pdu:
# 5. Passes authorization rules based on the state before the event,
# otherwise it is rejected.
#
# ... however, if we only have partial state for the room, then there is a good
# chance that we'll be missing some of the state needed to auth the new event.
# So, we state-resolve the auth events that we are given against the state that
# we know about, which ensures things like bans are applied. (Note that we'll
# already have checked we have all the auth events, in
# _load_or_fetch_auth_events_for_event above)
if context.partial_state:
room_version = await self._store.get_room_version_id(event.room_id)
local_state_id_map = await context.get_prev_state_ids()
claimed_auth_events_id_map = {
(ev.type, ev.state_key): ev.event_id for ev in claimed_auth_events
}
state_for_auth_id_map = (
await self._state_resolution_handler.resolve_events_with_store(
event.room_id,
room_version,
[local_state_id_map, claimed_auth_events_id_map],
event_map=None,
state_res_store=StateResolutionStore(self._store),
)
)
else:
event_types = event_auth.auth_types_for_event(event.room_version, event) event_types = event_auth.auth_types_for_event(event.room_version, event)
prev_state_ids = await context.get_prev_state_ids( state_for_auth_id_map = await context.get_prev_state_ids(
StateFilter.from_types(event_types) StateFilter.from_types(event_types)
) )
auth_events_ids = self._event_auth_handler.compute_auth_events( calculated_auth_event_ids = self._event_auth_handler.compute_auth_events(
event, prev_state_ids, for_verification=True event, state_for_auth_id_map, for_verification=True
) )
auth_events_x = await self._store.get_events(auth_events_ids)
# if those are the same, we're done here.
if collections.Counter(event.auth_event_ids()) == collections.Counter(
calculated_auth_event_ids
):
return
# otherwise, re-run the auth checks based on what we calculated.
calculated_auth_events = await self._store.get_events_as_list(
calculated_auth_event_ids
)
# log the differences
claimed_auth_event_map = {(e.type, e.state_key): e for e in claimed_auth_events}
calculated_auth_event_map = { calculated_auth_event_map = {
(e.type, e.state_key): e for e in auth_events_x.values() (e.type, e.state_key): e for e in calculated_auth_events
} }
logger.info(
"event's auth_events are different to our calculated auth_events. "
"Claimed but not calculated: %s. Calculated but not claimed: %s",
[
ev
for k, ev in claimed_auth_event_map.items()
if k not in calculated_auth_event_map
or calculated_auth_event_map[k].event_id != ev.event_id
],
[
ev
for k, ev in calculated_auth_event_map.items()
if k not in claimed_auth_event_map
or claimed_auth_event_map[k].event_id != ev.event_id
],
)
try: try:
updated_auth_events = await self._update_auth_events_for_auth( check_state_dependent_auth_rules(event, calculated_auth_events)
event,
calculated_auth_event_map=calculated_auth_event_map,
)
except Exception:
# We don't really mind if the above fails, so lets not fail
# processing if it does. However, it really shouldn't fail so
# let's still log as an exception since we'll still want to fix
# any bugs.
logger.exception(
"Failed to double check auth events for %s with remote. "
"Ignoring failure and continuing processing of event.",
event.event_id,
)
updated_auth_events = None
if updated_auth_events:
context = await self._update_context_for_auth_events(
event, context, updated_auth_events
)
auth_events_for_auth = updated_auth_events
else:
auth_events_for_auth = calculated_auth_event_map
try:
check_state_dependent_auth_rules(event, auth_events_for_auth.values())
except AuthError as e: except AuthError as e:
logger.warning("Failed auth resolution for %r because %s", event, e) logger.warning(
"While checking auth of %r against room state before the event: %s",
event,
e,
)
context.rejected = RejectedReason.AUTH_ERROR context.rejected = RejectedReason.AUTH_ERROR
return context
async def _maybe_kick_guest_users(self, event: EventBase) -> None: async def _maybe_kick_guest_users(self, event: EventBase) -> None:
if event.type != EventTypes.GuestAccess: if event.type != EventTypes.GuestAccess:
return return
@ -1704,93 +1733,6 @@ class FederationEventHandler:
soft_failed_event_counter.inc() soft_failed_event_counter.inc()
event.internal_metadata.soft_failed = True event.internal_metadata.soft_failed = True
async def _update_auth_events_for_auth(
self,
event: EventBase,
calculated_auth_event_map: StateMap[EventBase],
) -> Optional[StateMap[EventBase]]:
"""Helper for _check_event_auth. See there for docs.
Checks whether a given event has the expected auth events. If it
doesn't then we talk to the remote server to compare state to see if
we can come to a consensus (e.g. if one server missed some valid
state).
This attempts to resolve any potential divergence of state between
servers, but is not essential and so failures should not block further
processing of the event.
Args:
event:
calculated_auth_event_map:
Our calculated auth_events based on the state of the room
at the event's position in the DAG.
Returns:
updated auth event map, or None if no changes are needed.
"""
assert not event.internal_metadata.outlier
# check for events which are in the event's claimed auth_events, but not
# in our calculated event map.
event_auth_events = set(event.auth_event_ids())
different_auth = event_auth_events.difference(
e.event_id for e in calculated_auth_event_map.values()
)
if not different_auth:
return None
logger.info(
"auth_events refers to events which are not in our calculated auth "
"chain: %s",
different_auth,
)
# XXX: currently this checks for redactions but I'm not convinced that is
# necessary?
different_events = await self._store.get_events_as_list(different_auth)
# double-check they're all in the same room - we should already have checked
# this but it doesn't hurt to check again.
for d in different_events:
assert (
d.room_id == event.room_id
), f"Event {event.event_id} refers to auth_event {d.event_id} which is in a different room"
# now we state-resolve between our own idea of the auth events, and the remote's
# idea of them.
local_state = calculated_auth_event_map.values()
remote_auth_events = dict(calculated_auth_event_map)
remote_auth_events.update({(d.type, d.state_key): d for d in different_events})
remote_state = remote_auth_events.values()
room_version = await self._store.get_room_version_id(event.room_id)
new_state = await self._state_handler.resolve_events(
room_version, (local_state, remote_state), event
)
different_state = {
(d.type, d.state_key): d
for d in new_state.values()
if calculated_auth_event_map.get((d.type, d.state_key)) != d
}
if not different_state:
logger.info("State res returned no new state")
return None
logger.info(
"After state res: updating auth_events with new state %s",
different_state.values(),
)
# take a copy of calculated_auth_event_map before we modify it.
auth_events = dict(calculated_auth_event_map)
auth_events.update(different_state)
return auth_events
async def _load_or_fetch_auth_events_for_event( async def _load_or_fetch_auth_events_for_event(
self, destination: str, event: EventBase self, destination: str, event: EventBase
) -> Collection[EventBase]: ) -> Collection[EventBase]:
@ -1888,61 +1830,6 @@ class FederationEventHandler:
await self._auth_and_persist_outliers(room_id, remote_auth_events) await self._auth_and_persist_outliers(room_id, remote_auth_events)
async def _update_context_for_auth_events(
self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]
) -> EventContext:
"""Update the state_ids in an event context after auth event resolution,
storing the changes as a new state group.
Args:
event: The event we're handling the context for
context: initial event context
auth_events: Events to update in the event context.
Returns:
new event context
"""
# exclude the state key of the new event from the current_state in the context.
if event.is_state():
event_key: Optional[Tuple[str, str]] = (event.type, event.state_key)
else:
event_key = None
state_updates = {
k: a.event_id for k, a in auth_events.items() if k != event_key
}
current_state_ids = await context.get_current_state_ids()
current_state_ids = dict(current_state_ids) # type: ignore
current_state_ids.update(state_updates)
prev_state_ids = await context.get_prev_state_ids()
prev_state_ids = dict(prev_state_ids)
prev_state_ids.update({k: a.event_id for k, a in auth_events.items()})
# create a new state group as a delta from the existing one.
prev_group = context.state_group
state_group = await self._state_storage_controller.store_state_group(
event.event_id,
event.room_id,
prev_group=prev_group,
delta_ids=state_updates,
current_state_ids=current_state_ids,
)
return EventContext.with_state(
storage=self._storage_controllers,
state_group=state_group,
state_group_before_event=context.state_group_before_event,
state_delta_due_to_event=state_updates,
prev_group=prev_group,
delta_ids=state_updates,
partial_state=context.partial_state,
)
async def _run_push_actions_and_persist_event( async def _run_push_actions_and_persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False self, event: EventBase, context: EventContext, backfilled: bool = False
) -> None: ) -> None:

View File

@ -24,7 +24,6 @@ from typing import (
DefaultDict, DefaultDict,
Dict, Dict,
FrozenSet, FrozenSet,
Iterable,
List, List,
Mapping, Mapping,
Optional, Optional,
@ -422,31 +421,6 @@ class StateHandler:
) )
return result return result
async def resolve_events(
self,
room_version: str,
state_sets: Collection[Iterable[EventBase]],
event: EventBase,
) -> StateMap[EventBase]:
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
state_set_ids = [
{(ev.type, ev.state_key): ev.event_id for ev in st} for st in state_sets
]
state_map = {ev.event_id: ev for st in state_sets for ev in st}
new_state = await self._state_resolution_handler.resolve_events_with_store(
event.room_id,
room_version,
state_set_ids,
event_map=state_map,
state_res_store=StateResolutionStore(self.store),
)
return {key: state_map[ev_id] for key, ev_id in new_state.items()}
async def update_current_state(self, room_id: str) -> None: async def update_current_state(self, room_id: str) -> None:
"""Recalculates the current state for a room, and persists it. """Recalculates the current state for a room, and persists it.

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import List, cast from typing import cast
from unittest import TestCase from unittest import TestCase
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -50,8 +50,6 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
hs = self.setup_test_homeserver(federation_http_client=None) hs = self.setup_test_homeserver(federation_http_client=None)
self.handler = hs.get_federation_handler() self.handler = hs.get_federation_handler()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.state_storage_controller = hs.get_storage_controllers().state
self._event_auth_handler = hs.get_event_auth_handler()
return hs return hs
def test_exchange_revoked_invite(self) -> None: def test_exchange_revoked_invite(self) -> None:
@ -314,142 +312,6 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
) )
self.get_success(d) self.get_success(d)
def test_backfill_floating_outlier_membership_auth(self) -> None:
"""
As the local homeserver, check that we can properly process a federated
event from the OTHER_SERVER with auth_events that include a floating
membership event from the OTHER_SERVER.
Regression test, see #10439.
"""
OTHER_SERVER = "otherserver"
OTHER_USER = "@otheruser:" + OTHER_SERVER
# create the room
user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test")
room_id = self.helper.create_room_as(
room_creator=user_id,
is_public=True,
tok=tok,
extra_content={
"preset": "public_chat",
},
)
room_version = self.get_success(self.store.get_room_version(room_id))
prev_event_ids = self.get_success(self.store.get_prev_events_for_room(room_id))
(
most_recent_prev_event_id,
most_recent_prev_event_depth,
) = self.get_success(self.store.get_max_depth_of(prev_event_ids))
# mapping from (type, state_key) -> state_event_id
assert most_recent_prev_event_id is not None
prev_state_map = self.get_success(
self.state_storage_controller.get_state_ids_for_event(
most_recent_prev_event_id
)
)
# List of state event ID's
prev_state_ids = list(prev_state_map.values())
auth_event_ids = prev_state_ids
auth_events = list(
self.get_success(self.store.get_events(auth_event_ids)).values()
)
# build a floating outlier member state event
fake_prev_event_id = "$" + random_string(43)
member_event_dict = {
"type": EventTypes.Member,
"content": {
"membership": "join",
},
"state_key": OTHER_USER,
"room_id": room_id,
"sender": OTHER_USER,
"depth": most_recent_prev_event_depth,
"prev_events": [fake_prev_event_id],
"origin_server_ts": self.clock.time_msec(),
"signatures": {OTHER_SERVER: {"ed25519:key_version": "SomeSignatureHere"}},
}
builder = self.hs.get_event_builder_factory().for_room_version(
room_version, member_event_dict
)
member_event = self.get_success(
builder.build(
prev_event_ids=member_event_dict["prev_events"],
auth_event_ids=self._event_auth_handler.compute_auth_events(
builder,
prev_state_map,
for_verification=False,
),
depth=member_event_dict["depth"],
)
)
# Override the signature added from "test" homeserver that we created the event with
member_event.signatures = member_event_dict["signatures"]
# Add the new member_event to the StateMap
updated_state_map = dict(prev_state_map)
updated_state_map[
(member_event.type, member_event.state_key)
] = member_event.event_id
auth_events.append(member_event)
# build and send an event authed based on the member event
message_event_dict = {
"type": EventTypes.Message,
"content": {},
"room_id": room_id,
"sender": OTHER_USER,
"depth": most_recent_prev_event_depth,
"prev_events": prev_event_ids.copy(),
"origin_server_ts": self.clock.time_msec(),
"signatures": {OTHER_SERVER: {"ed25519:key_version": "SomeSignatureHere"}},
}
builder = self.hs.get_event_builder_factory().for_room_version(
room_version, message_event_dict
)
message_event = self.get_success(
builder.build(
prev_event_ids=message_event_dict["prev_events"],
auth_event_ids=self._event_auth_handler.compute_auth_events(
builder,
updated_state_map,
for_verification=False,
),
depth=message_event_dict["depth"],
)
)
# Override the signature added from "test" homeserver that we created the event with
message_event.signatures = message_event_dict["signatures"]
# Stub the /event_auth response from the OTHER_SERVER
async def get_event_auth(
destination: str, room_id: str, event_id: str
) -> List[EventBase]:
return [
event_from_pdu_json(ae.get_pdu_json(), room_version=room_version)
for ae in auth_events
]
self.handler.federation_client.get_event_auth = get_event_auth # type: ignore[assignment]
with LoggingContext("receive_pdu"):
# Fake the OTHER_SERVER federating the message event over to our local homeserver
d = run_in_background(
self.hs.get_federation_event_handler().on_receive_pdu,
OTHER_SERVER,
message_event,
)
self.get_success(d)
# Now try and get the events on our local homeserver
stored_event = self.get_success(
self.store.get_event(message_event.event_id, allow_none=True)
)
self.assertTrue(stored_event is not None)
@unittest.override_config( @unittest.override_config(
{"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}} {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
) )

View File

@ -21,7 +21,6 @@ from synapse.api.constants import EventTypes, LoginType, Membership
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.room_versions import RoomVersion from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import account, login, profile, room from synapse.rest.client import account, login, profile, room
@ -113,14 +112,8 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
# Have this homeserver skip event auth checks. This is necessary due to # Have this homeserver skip event auth checks. This is necessary due to
# event auth checks ensuring that events were signed by the sender's homeserver. # event auth checks ensuring that events were signed by the sender's homeserver.
async def _check_event_auth( async def _check_event_auth(origin: Any, event: Any, context: Any) -> None:
origin: str, pass
event: EventBase,
context: EventContext,
*args: Any,
**kwargs: Any,
) -> EventContext:
return context
hs.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment] hs.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment]

View File

@ -81,12 +81,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
self.handler = self.homeserver.get_federation_handler() self.handler = self.homeserver.get_federation_handler()
federation_event_handler = self.homeserver.get_federation_event_handler() federation_event_handler = self.homeserver.get_federation_event_handler()
async def _check_event_auth( async def _check_event_auth(origin, event, context):
origin, pass
event,
context,
):
return context
federation_event_handler._check_event_auth = _check_event_auth federation_event_handler._check_event_auth = _check_event_auth
self.client = self.homeserver.get_federation_client() self.client = self.homeserver.get_federation_client()