sanity-checking for events used in state res (#6531)

When we perform state resolution, check that all of the events involved are in
the right room.
This commit is contained in:
Richard van der Hoff 2019-12-13 12:55:32 +00:00 committed by Richard van der Hoff
parent 6577f2d887
commit 83895316d4
6 changed files with 128 additions and 43 deletions

1
changelog.d/6531.misc Normal file
View File

@ -0,0 +1 @@
Improve sanity-checking when receiving events over federation.

View File

@ -397,6 +397,7 @@ class FederationHandler(BaseHandler):
event_map[x.event_id] = x event_map[x.event_id] = x
state_map = yield resolve_events_with_store( state_map = yield resolve_events_with_store(
room_id,
room_version, room_version,
state_maps, state_maps,
event_map, event_map,

View File

@ -16,7 +16,7 @@
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import Iterable, Optional from typing import Dict, Iterable, List, Optional, Tuple
from six import iteritems, itervalues from six import iteritems, itervalues
@ -416,6 +416,7 @@ class StateHandler(object):
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_store( new_state = yield resolve_events_with_store(
event.room_id,
room_version, room_version,
state_set_ids, state_set_ids,
event_map=state_map, event_map=state_map,
@ -461,7 +462,7 @@ class StateResolutionHandler(object):
not be called for a single state group not be called for a single state group
Args: Args:
room_id (str): room we are resolving for (used for logging) room_id (str): room we are resolving for (used for logging and sanity checks)
room_version (str): version of the room room_version (str): version of the room
state_groups_ids (dict[int, dict[(str, str), str]]): state_groups_ids (dict[int, dict[(str, str), str]]):
map from state group id to the state in that state group map from state group id to the state in that state group
@ -517,6 +518,7 @@ class StateResolutionHandler(object):
logger.info("Resolving conflicted state for %r", room_id) logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_store( new_state = yield resolve_events_with_store(
room_id,
room_version, room_version,
list(itervalues(state_groups_ids)), list(itervalues(state_groups_ids)),
event_map=event_map, event_map=event_map,
@ -588,36 +590,44 @@ def _make_state_cache_entry(new_state, state_groups_ids):
) )
def resolve_events_with_store(room_version, state_sets, event_map, state_res_store): def resolve_events_with_store(
room_id: str,
room_version: str,
state_sets: List[Dict[Tuple[str, str], str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore",
):
""" """
Args: Args:
room_version(str): Version of the room room_id: the room we are working in
state_sets(list): List of dicts of (type, state_key) -> event_id, room_version: Version of the room
state_sets: List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve. which are the different state groups to resolve.
event_map(dict[str,FrozenEvent]|None): event_map:
a dict from event_id to event, for any events that we happen to a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be have in flight (eg, those currently being persisted). This will be
used as a starting point fof finding the state we need; any missing used as a starting point fof finding the state we need; any missing
events will be requested via state_map_factory. events will be requested via state_map_factory.
If None, all events will be fetched via state_map_factory. If None, all events will be fetched via state_res_store.
state_res_store (StateResolutionStore) state_res_store: a place to fetch events from
Returns Returns:
Deferred[dict[(str, str), str]]: Deferred[dict[(str, str), str]]:
a map from (type, state_key) to event_id. a map from (type, state_key) to event_id.
""" """
v = KNOWN_ROOM_VERSIONS[room_version] v = KNOWN_ROOM_VERSIONS[room_version]
if v.state_res == StateResolutionVersions.V1: if v.state_res == StateResolutionVersions.V1:
return v1.resolve_events_with_store( return v1.resolve_events_with_store(
state_sets, event_map, state_res_store.get_events room_id, state_sets, event_map, state_res_store.get_events
) )
else: else:
return v2.resolve_events_with_store( return v2.resolve_events_with_store(
room_version, state_sets, event_map, state_res_store room_id, room_version, state_sets, event_map, state_res_store
) )

View File

@ -15,6 +15,7 @@
import hashlib import hashlib
import logging import logging
from typing import Callable, Dict, List, Optional, Tuple
from six import iteritems, iterkeys, itervalues from six import iteritems, iterkeys, itervalues
@ -24,6 +25,7 @@ from synapse import event_auth
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -32,13 +34,20 @@ POWER_KEY = (EventTypes.PowerLevels, "")
@defer.inlineCallbacks @defer.inlineCallbacks
def resolve_events_with_store(state_sets, event_map, state_map_factory): def resolve_events_with_store(
room_id: str,
state_sets: List[Dict[Tuple[str, str], str]],
event_map: Optional[Dict[str, EventBase]],
state_map_factory: Callable,
):
""" """
Args: Args:
state_sets(list): List of dicts of (type, state_key) -> event_id, room_id: the room we are working in
state_sets: List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve. which are the different state groups to resolve.
event_map(dict[str,FrozenEvent]|None): event_map:
a dict from event_id to event, for any events that we happen to a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be have in flight (eg, those currently being persisted). This will be
used as a starting point fof finding the state we need; any missing used as a starting point fof finding the state we need; any missing
@ -46,11 +55,11 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory):
If None, all events will be fetched via state_map_factory. If None, all events will be fetched via state_map_factory.
state_map_factory(func): will be called state_map_factory: will be called
with a list of event_ids that are needed, and should return with with a list of event_ids that are needed, and should return with
a Deferred of dict of event_id to event. a Deferred of dict of event_id to event.
Returns Returns:
Deferred[dict[(str, str), str]]: Deferred[dict[(str, str), str]]:
a map from (type, state_key) to event_id. a map from (type, state_key) to event_id.
""" """
@ -76,6 +85,14 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory):
if event_map is not None: if event_map is not None:
state_map.update(event_map) state_map.update(event_map)
# everything in the state map should be in the right room
for event in state_map.values():
if event.room_id != room_id:
raise Exception(
"Attempting to state-resolve for room %s with event %s which is in %s"
% (room_id, event.event_id, event.room_id,)
)
# get the ids of the auth events which allow us to authenticate the # get the ids of the auth events which allow us to authenticate the
# conflicted state, picking only from the unconflicting state. # conflicted state, picking only from the unconflicting state.
# #
@ -95,6 +112,13 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory):
) )
state_map_new = yield state_map_factory(new_needed_events) state_map_new = yield state_map_factory(new_needed_events)
for event in state_map_new.values():
if event.room_id != room_id:
raise Exception(
"Attempting to state-resolve for room %s with event %s which is in %s"
% (room_id, event.event_id, event.room_id,)
)
state_map.update(state_map_new) state_map.update(state_map_new)
return _resolve_with_state( return _resolve_with_state(

View File

@ -16,29 +16,40 @@
import heapq import heapq
import itertools import itertools
import logging import logging
from typing import Dict, List, Optional, Tuple
from six import iteritems, itervalues from six import iteritems, itervalues
from twisted.internet import defer from twisted.internet import defer
import synapse.state
from synapse import event_auth from synapse import event_auth
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.events import EventBase
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@defer.inlineCallbacks @defer.inlineCallbacks
def resolve_events_with_store(room_version, state_sets, event_map, state_res_store): def resolve_events_with_store(
room_id: str,
room_version: str,
state_sets: List[Dict[Tuple[str, str], str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "synapse.state.StateResolutionStore",
):
"""Resolves the state using the v2 state resolution algorithm """Resolves the state using the v2 state resolution algorithm
Args: Args:
room_version (str): The room version room_id: the room we are working in
state_sets(list): List of dicts of (type, state_key) -> event_id, room_version: The room version
state_sets: List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve. which are the different state groups to resolve.
event_map(dict[str,FrozenEvent]|None): event_map:
a dict from event_id to event, for any events that we happen to a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be have in flight (eg, those currently being persisted). This will be
used as a starting point fof finding the state we need; any missing used as a starting point fof finding the state we need; any missing
@ -46,9 +57,9 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
If None, all events will be fetched via state_res_store. If None, all events will be fetched via state_res_store.
state_res_store (StateResolutionStore) state_res_store:
Returns Returns:
Deferred[dict[(str, str), str]]: Deferred[dict[(str, str), str]]:
a map from (type, state_key) to event_id. a map from (type, state_key) to event_id.
""" """
@ -84,6 +95,14 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
) )
event_map.update(events) event_map.update(events)
# everything in the event map should be in the right room
for event in event_map.values():
if event.room_id != room_id:
raise Exception(
"Attempting to state-resolve for room %s with event %s which is in %s"
% (room_id, event.event_id, event.room_id,)
)
full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map) full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map)
logger.debug("%d full_conflicted_set entries", len(full_conflicted_set)) logger.debug("%d full_conflicted_set entries", len(full_conflicted_set))
@ -94,13 +113,14 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
) )
sorted_power_events = yield _reverse_topological_power_sort( sorted_power_events = yield _reverse_topological_power_sort(
power_events, event_map, state_res_store, full_conflicted_set room_id, power_events, event_map, state_res_store, full_conflicted_set
) )
logger.debug("sorted %d power events", len(sorted_power_events)) logger.debug("sorted %d power events", len(sorted_power_events))
# Now sequentially auth each one # Now sequentially auth each one
resolved_state = yield _iterative_auth_checks( resolved_state = yield _iterative_auth_checks(
room_id,
room_version, room_version,
sorted_power_events, sorted_power_events,
unconflicted_state, unconflicted_state,
@ -121,13 +141,18 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
pl = resolved_state.get((EventTypes.PowerLevels, ""), None) pl = resolved_state.get((EventTypes.PowerLevels, ""), None)
leftover_events = yield _mainline_sort( leftover_events = yield _mainline_sort(
leftover_events, pl, event_map, state_res_store room_id, leftover_events, pl, event_map, state_res_store
) )
logger.debug("resolving remaining events") logger.debug("resolving remaining events")
resolved_state = yield _iterative_auth_checks( resolved_state = yield _iterative_auth_checks(
room_version, leftover_events, resolved_state, event_map, state_res_store room_id,
room_version,
leftover_events,
resolved_state,
event_map,
state_res_store,
) )
logger.debug("resolved") logger.debug("resolved")
@ -141,11 +166,12 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_power_level_for_sender(event_id, event_map, state_res_store): def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
"""Return the power level of the sender of the given event according to """Return the power level of the sender of the given event according to
their auth events. their auth events.
Args: Args:
room_id (str)
event_id (str) event_id (str)
event_map (dict[str,FrozenEvent]) event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore) state_res_store (StateResolutionStore)
@ -153,11 +179,11 @@ def _get_power_level_for_sender(event_id, event_map, state_res_store):
Returns: Returns:
Deferred[int] Deferred[int]
""" """
event = yield _get_event(event_id, event_map, state_res_store) event = yield _get_event(room_id, event_id, event_map, state_res_store)
pl = None pl = None
for aid in event.auth_event_ids(): for aid in event.auth_event_ids():
aev = yield _get_event(aid, event_map, state_res_store) aev = yield _get_event(room_id, aid, event_map, state_res_store)
if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
pl = aev pl = aev
break break
@ -165,7 +191,7 @@ def _get_power_level_for_sender(event_id, event_map, state_res_store):
if pl is None: if pl is None:
# Couldn't find power level. Check if they're the creator of the room # Couldn't find power level. Check if they're the creator of the room
for aid in event.auth_event_ids(): for aid in event.auth_event_ids():
aev = yield _get_event(aid, event_map, state_res_store) aev = yield _get_event(room_id, aid, event_map, state_res_store)
if (aev.type, aev.state_key) == (EventTypes.Create, ""): if (aev.type, aev.state_key) == (EventTypes.Create, ""):
if aev.content.get("creator") == event.sender: if aev.content.get("creator") == event.sender:
return 100 return 100
@ -279,7 +305,7 @@ def _is_power_event(event):
@defer.inlineCallbacks @defer.inlineCallbacks
def _add_event_and_auth_chain_to_graph( def _add_event_and_auth_chain_to_graph(
graph, event_id, event_map, state_res_store, auth_diff graph, room_id, event_id, event_map, state_res_store, auth_diff
): ):
"""Helper function for _reverse_topological_power_sort that add the event """Helper function for _reverse_topological_power_sort that add the event
and its auth chain (that is in the auth diff) to the graph and its auth chain (that is in the auth diff) to the graph
@ -287,6 +313,7 @@ def _add_event_and_auth_chain_to_graph(
Args: Args:
graph (dict[str, set[str]]): A map from event ID to the events auth graph (dict[str, set[str]]): A map from event ID to the events auth
event IDs event IDs
room_id (str): the room we are working in
event_id (str): Event to add to the graph event_id (str): Event to add to the graph
event_map (dict[str,FrozenEvent]) event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore) state_res_store (StateResolutionStore)
@ -298,7 +325,7 @@ def _add_event_and_auth_chain_to_graph(
eid = state.pop() eid = state.pop()
graph.setdefault(eid, set()) graph.setdefault(eid, set())
event = yield _get_event(eid, event_map, state_res_store) event = yield _get_event(room_id, eid, event_map, state_res_store)
for aid in event.auth_event_ids(): for aid in event.auth_event_ids():
if aid in auth_diff: if aid in auth_diff:
if aid not in graph: if aid not in graph:
@ -308,11 +335,14 @@ def _add_event_and_auth_chain_to_graph(
@defer.inlineCallbacks @defer.inlineCallbacks
def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_diff): def _reverse_topological_power_sort(
room_id, event_ids, event_map, state_res_store, auth_diff
):
"""Returns a list of the event_ids sorted by reverse topological ordering, """Returns a list of the event_ids sorted by reverse topological ordering,
and then by power level and origin_server_ts and then by power level and origin_server_ts
Args: Args:
room_id (str): the room we are working in
event_ids (list[str]): The events to sort event_ids (list[str]): The events to sort
event_map (dict[str,FrozenEvent]) event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore) state_res_store (StateResolutionStore)
@ -325,12 +355,14 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_
graph = {} graph = {}
for event_id in event_ids: for event_id in event_ids:
yield _add_event_and_auth_chain_to_graph( yield _add_event_and_auth_chain_to_graph(
graph, event_id, event_map, state_res_store, auth_diff graph, room_id, event_id, event_map, state_res_store, auth_diff
) )
event_to_pl = {} event_to_pl = {}
for event_id in graph: for event_id in graph:
pl = yield _get_power_level_for_sender(event_id, event_map, state_res_store) pl = yield _get_power_level_for_sender(
room_id, event_id, event_map, state_res_store
)
event_to_pl[event_id] = pl event_to_pl[event_id] = pl
def _get_power_order(event_id): def _get_power_order(event_id):
@ -348,12 +380,13 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_
@defer.inlineCallbacks @defer.inlineCallbacks
def _iterative_auth_checks( def _iterative_auth_checks(
room_version, event_ids, base_state, event_map, state_res_store room_id, room_version, event_ids, base_state, event_map, state_res_store
): ):
"""Sequentially apply auth checks to each event in given list, updating the """Sequentially apply auth checks to each event in given list, updating the
state as it goes along. state as it goes along.
Args: Args:
room_id (str)
room_version (str) room_version (str)
event_ids (list[str]): Ordered list of events to apply auth checks to event_ids (list[str]): Ordered list of events to apply auth checks to
base_state (dict[tuple[str, str], str]): The set of state to start with base_state (dict[tuple[str, str], str]): The set of state to start with
@ -370,7 +403,7 @@ def _iterative_auth_checks(
auth_events = {} auth_events = {}
for aid in event.auth_event_ids(): for aid in event.auth_event_ids():
ev = yield _get_event(aid, event_map, state_res_store) ev = yield _get_event(room_id, aid, event_map, state_res_store)
if ev.rejected_reason is None: if ev.rejected_reason is None:
auth_events[(ev.type, ev.state_key)] = ev auth_events[(ev.type, ev.state_key)] = ev
@ -378,7 +411,7 @@ def _iterative_auth_checks(
for key in event_auth.auth_types_for_event(event): for key in event_auth.auth_types_for_event(event):
if key in resolved_state: if key in resolved_state:
ev_id = resolved_state[key] ev_id = resolved_state[key]
ev = yield _get_event(ev_id, event_map, state_res_store) ev = yield _get_event(room_id, ev_id, event_map, state_res_store)
if ev.rejected_reason is None: if ev.rejected_reason is None:
auth_events[key] = event_map[ev_id] auth_events[key] = event_map[ev_id]
@ -400,11 +433,14 @@ def _iterative_auth_checks(
@defer.inlineCallbacks @defer.inlineCallbacks
def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_store): def _mainline_sort(
room_id, event_ids, resolved_power_event_id, event_map, state_res_store
):
"""Returns a sorted list of event_ids sorted by mainline ordering based on """Returns a sorted list of event_ids sorted by mainline ordering based on
the given event resolved_power_event_id the given event resolved_power_event_id
Args: Args:
room_id (str): room we're working in
event_ids (list[str]): Events to sort event_ids (list[str]): Events to sort
resolved_power_event_id (str): The final resolved power level event ID resolved_power_event_id (str): The final resolved power level event ID
event_map (dict[str,FrozenEvent]) event_map (dict[str,FrozenEvent])
@ -417,11 +453,11 @@ def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_stor
pl = resolved_power_event_id pl = resolved_power_event_id
while pl: while pl:
mainline.append(pl) mainline.append(pl)
pl_ev = yield _get_event(pl, event_map, state_res_store) pl_ev = yield _get_event(room_id, pl, event_map, state_res_store)
auth_events = pl_ev.auth_event_ids() auth_events = pl_ev.auth_event_ids()
pl = None pl = None
for aid in auth_events: for aid in auth_events:
ev = yield _get_event(aid, event_map, state_res_store) ev = yield _get_event(room_id, aid, event_map, state_res_store)
if (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""): if (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""):
pl = aid pl = aid
break break
@ -457,6 +493,8 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
Deferred[int] Deferred[int]
""" """
room_id = event.room_id
# We do an iterative search, replacing `event with the power level in its # We do an iterative search, replacing `event with the power level in its
# auth events (if any) # auth events (if any)
while event: while event:
@ -468,7 +506,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
event = None event = None
for aid in auth_events: for aid in auth_events:
aev = yield _get_event(aid, event_map, state_res_store) aev = yield _get_event(room_id, aid, event_map, state_res_store)
if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
event = aev event = aev
break break
@ -478,11 +516,12 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_event(event_id, event_map, state_res_store): def _get_event(room_id, event_id, event_map, state_res_store):
"""Helper function to look up event in event_map, falling back to looking """Helper function to look up event in event_map, falling back to looking
it up in the store it up in the store
Args: Args:
room_id (str)
event_id (str) event_id (str)
event_map (dict[str,FrozenEvent]) event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore) state_res_store (StateResolutionStore)
@ -493,7 +532,14 @@ def _get_event(event_id, event_map, state_res_store):
if event_id not in event_map: if event_id not in event_map:
events = yield state_res_store.get_events([event_id], allow_rejected=True) events = yield state_res_store.get_events([event_id], allow_rejected=True)
event_map.update(events) event_map.update(events)
return event_map[event_id] event = event_map[event_id]
assert event is not None
if event.room_id != room_id:
raise Exception(
"In state res for room %s, event %s is in %s"
% (room_id, event_id, event.room_id)
)
return event
def lexicographical_topological_sort(graph, key): def lexicographical_topological_sort(graph, key):

View File

@ -58,6 +58,7 @@ class FakeEvent(object):
self.type = type self.type = type
self.state_key = state_key self.state_key = state_key
self.content = content self.content = content
self.room_id = ROOM_ID
def to_event(self, auth_events, prev_events): def to_event(self, auth_events, prev_events):
"""Given the auth_events and prev_events, convert to a Frozen Event """Given the auth_events and prev_events, convert to a Frozen Event
@ -418,6 +419,7 @@ class StateTestCase(unittest.TestCase):
state_before = dict(state_at_event[prev_events[0]]) state_before = dict(state_at_event[prev_events[0]])
else: else:
state_d = resolve_events_with_store( state_d = resolve_events_with_store(
ROOM_ID,
RoomVersions.V2.identifier, RoomVersions.V2.identifier,
[state_at_event[n] for n in prev_events], [state_at_event[n] for n in prev_events],
event_map=event_map, event_map=event_map,
@ -565,6 +567,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
# Test that we correctly handle passing `None` as the event_map # Test that we correctly handle passing `None` as the event_map
state_d = resolve_events_with_store( state_d = resolve_events_with_store(
ROOM_ID,
RoomVersions.V2.identifier, RoomVersions.V2.identifier,
[self.state_at_bob, self.state_at_charlie], [self.state_at_bob, self.state_at_charlie],
event_map=None, event_map=None,