Add helpers for getting prev and auth events (#4139)

* Add helpers for getting prev and auth events

This is in preparation for allowing the event format to change between
room versions.
This commit is contained in:
Erik Johnston 2018-11-05 13:35:15 +00:00 committed by Amber Brown
parent 0467384d2f
commit bc80b3f454
10 changed files with 62 additions and 45 deletions

1
changelog.d/4139.misc Normal file
View File

@ -0,0 +1 @@
Add helpers functions for getting prev and auth events of an event

View File

@ -200,11 +200,11 @@ def _is_membership_change_allowed(event, auth_events):
membership = event.content["membership"] membership = event.content["membership"]
# Check if this is the room creator joining: # Check if this is the room creator joining:
if len(event.prev_events) == 1 and Membership.JOIN == membership: if len(event.prev_event_ids()) == 1 and Membership.JOIN == membership:
# Get room creation event: # Get room creation event:
key = (EventTypes.Create, "", ) key = (EventTypes.Create, "", )
create = auth_events.get(key) create = auth_events.get(key)
if create and event.prev_events[0][0] == create.event_id: if create and event.prev_event_ids()[0] == create.event_id:
if create.content["creator"] == event.state_key: if create.content["creator"] == event.state_key:
return return

View File

@ -159,6 +159,24 @@ class EventBase(object):
def keys(self): def keys(self):
return six.iterkeys(self._event_dict) return six.iterkeys(self._event_dict)
def prev_event_ids(self):
"""Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it.
Returns:
list[str]: The list of event IDs of this event's prev_events
"""
return [e for e, _ in self.prev_events]
def auth_event_ids(self):
"""Returns the list of auth event IDs. The order matches the order
specified in the event, though there is no meaning to it.
Returns:
list[str]: The list of event IDs of this event's auth_events
"""
return [e for e, _ in self.auth_events]
class FrozenEvent(EventBase): class FrozenEvent(EventBase):
def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None): def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):

View File

@ -183,9 +183,7 @@ class TransactionQueue(object):
# banned then it won't receive the event because it won't # banned then it won't receive the event because it won't
# be in the room after the ban. # be in the room after the ban.
destinations = yield self.state.get_current_hosts_in_room( destinations = yield self.state.get_current_hosts_in_room(
event.room_id, latest_event_ids=[ event.room_id, latest_event_ids=event.prev_event_ids(),
prev_id for prev_id, _ in event.prev_events
],
) )
except Exception: except Exception:
logger.exception( logger.exception(

View File

@ -239,7 +239,7 @@ class FederationHandler(BaseHandler):
room_id, event_id, min_depth, room_id, event_id, min_depth,
) )
prevs = {e_id for e_id, _ in pdu.prev_events} prevs = set(pdu.prev_event_ids())
seen = yield self.store.have_seen_events(prevs) seen = yield self.store.have_seen_events(prevs)
if min_depth and pdu.depth < min_depth: if min_depth and pdu.depth < min_depth:
@ -607,7 +607,7 @@ class FederationHandler(BaseHandler):
if e.event_id in seen_ids: if e.event_id in seen_ids:
continue continue
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
auth_ids = [e_id for e_id, _ in e.auth_events] auth_ids = e.auth_event_ids()
auth = { auth = {
(e.type, e.state_key): e for e in auth_chain (e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids or e.type == EventTypes.Create if e.event_id in auth_ids or e.type == EventTypes.Create
@ -726,7 +726,7 @@ class FederationHandler(BaseHandler):
edges = [ edges = [
ev.event_id ev.event_id
for ev in events for ev in events
if set(e_id for e_id, _ in ev.prev_events) - event_ids if set(ev.prev_event_ids()) - event_ids
] ]
logger.info( logger.info(
@ -753,7 +753,7 @@ class FederationHandler(BaseHandler):
required_auth = set( required_auth = set(
a_id a_id
for event in events + list(state_events.values()) + list(auth_events.values()) for event in events + list(state_events.values()) + list(auth_events.values())
for a_id, _ in event.auth_events for a_id in event.auth_event_ids()
) )
auth_events.update({ auth_events.update({
e_id: event_map[e_id] for e_id in required_auth if e_id in event_map e_id: event_map[e_id] for e_id in required_auth if e_id in event_map
@ -769,7 +769,7 @@ class FederationHandler(BaseHandler):
auth_events.update(ret_events) auth_events.update(ret_events)
required_auth.update( required_auth.update(
a_id for event in ret_events.values() for a_id, _ in event.auth_events a_id for event in ret_events.values() for a_id in event.auth_event_ids()
) )
missing_auth = required_auth - set(auth_events) missing_auth = required_auth - set(auth_events)
@ -796,7 +796,7 @@ class FederationHandler(BaseHandler):
required_auth.update( required_auth.update(
a_id a_id
for event in results if event for event in results if event
for a_id, _ in event.auth_events for a_id in event.auth_event_ids()
) )
missing_auth = required_auth - set(auth_events) missing_auth = required_auth - set(auth_events)
@ -816,7 +816,7 @@ class FederationHandler(BaseHandler):
"auth_events": { "auth_events": {
(auth_events[a_id].type, auth_events[a_id].state_key): (auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id] auth_events[a_id]
for a_id, _ in a.auth_events for a_id in a.auth_event_ids()
if a_id in auth_events if a_id in auth_events
} }
}) })
@ -828,7 +828,7 @@ class FederationHandler(BaseHandler):
"auth_events": { "auth_events": {
(auth_events[a_id].type, auth_events[a_id].state_key): (auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id] auth_events[a_id]
for a_id, _ in event_map[e_id].auth_events for a_id in event_map[e_id].auth_event_ids()
if a_id in auth_events if a_id in auth_events
} }
}) })
@ -1041,17 +1041,17 @@ class FederationHandler(BaseHandler):
Raises: Raises:
SynapseError if the event does not pass muster SynapseError if the event does not pass muster
""" """
if len(ev.prev_events) > 20: if len(ev.prev_event_ids()) > 20:
logger.warn("Rejecting event %s which has %i prev_events", logger.warn("Rejecting event %s which has %i prev_events",
ev.event_id, len(ev.prev_events)) ev.event_id, len(ev.prev_event_ids()))
raise SynapseError( raise SynapseError(
http_client.BAD_REQUEST, http_client.BAD_REQUEST,
"Too many prev_events", "Too many prev_events",
) )
if len(ev.auth_events) > 10: if len(ev.auth_event_ids()) > 10:
logger.warn("Rejecting event %s which has %i auth_events", logger.warn("Rejecting event %s which has %i auth_events",
ev.event_id, len(ev.auth_events)) ev.event_id, len(ev.auth_event_ids()))
raise SynapseError( raise SynapseError(
http_client.BAD_REQUEST, http_client.BAD_REQUEST,
"Too many auth_events", "Too many auth_events",
@ -1076,7 +1076,7 @@ class FederationHandler(BaseHandler):
def on_event_auth(self, event_id): def on_event_auth(self, event_id):
event = yield self.store.get_event(event_id) event = yield self.store.get_event(event_id)
auth = yield self.store.get_auth_chain( auth = yield self.store.get_auth_chain(
[auth_id for auth_id, _ in event.auth_events], [auth_id for auth_id in event.auth_event_ids()],
include_given=True include_given=True
) )
defer.returnValue([e for e in auth]) defer.returnValue([e for e in auth])
@ -1698,7 +1698,7 @@ class FederationHandler(BaseHandler):
missing_auth_events = set() missing_auth_events = set()
for e in itertools.chain(auth_events, state, [event]): for e in itertools.chain(auth_events, state, [event]):
for e_id, _ in e.auth_events: for e_id in e.auth_event_ids():
if e_id not in event_map: if e_id not in event_map:
missing_auth_events.add(e_id) missing_auth_events.add(e_id)
@ -1717,7 +1717,7 @@ class FederationHandler(BaseHandler):
for e in itertools.chain(auth_events, state, [event]): for e in itertools.chain(auth_events, state, [event]):
auth_for_e = { auth_for_e = {
(event_map[e_id].type, event_map[e_id].state_key): event_map[e_id] (event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
for e_id, _ in e.auth_events for e_id in e.auth_event_ids()
if e_id in event_map if e_id in event_map
} }
if create_event: if create_event:
@ -1785,10 +1785,10 @@ class FederationHandler(BaseHandler):
# This is a hack to fix some old rooms where the initial join event # This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events. # didn't reference the create event in its auth events.
if event.type == EventTypes.Member and not event.auth_events: if event.type == EventTypes.Member and not event.auth_event_ids():
if len(event.prev_events) == 1 and event.depth < 5: if len(event.prev_event_ids()) == 1 and event.depth < 5:
c = yield self.store.get_event( c = yield self.store.get_event(
event.prev_events[0][0], event.prev_event_ids()[0],
allow_none=True, allow_none=True,
) )
if c and c.type == EventTypes.Create: if c and c.type == EventTypes.Create:
@ -1835,7 +1835,7 @@ class FederationHandler(BaseHandler):
# Now get the current auth_chain for the event. # Now get the current auth_chain for the event.
local_auth_chain = yield self.store.get_auth_chain( local_auth_chain = yield self.store.get_auth_chain(
[auth_id for auth_id, _ in event.auth_events], [auth_id for auth_id in event.auth_event_ids()],
include_given=True include_given=True
) )
@ -1891,7 +1891,7 @@ class FederationHandler(BaseHandler):
""" """
# Check if we have all the auth events. # Check if we have all the auth events.
current_state = set(e.event_id for e in auth_events.values()) current_state = set(e.event_id for e in auth_events.values())
event_auth_events = set(e_id for e_id, _ in event.auth_events) event_auth_events = set(event.auth_event_ids())
if event.is_state(): if event.is_state():
event_key = (event.type, event.state_key) event_key = (event.type, event.state_key)
@ -1935,7 +1935,7 @@ class FederationHandler(BaseHandler):
continue continue
try: try:
auth_ids = [e_id for e_id, _ in e.auth_events] auth_ids = e.auth_event_ids()
auth = { auth = {
(e.type, e.state_key): e for e in remote_auth_chain (e.type, e.state_key): e for e in remote_auth_chain
if e.event_id in auth_ids or e.type == EventTypes.Create if e.event_id in auth_ids or e.type == EventTypes.Create
@ -1956,7 +1956,7 @@ class FederationHandler(BaseHandler):
pass pass
have_events = yield self.store.get_seen_events_with_rejections( have_events = yield self.store.get_seen_events_with_rejections(
[e_id for e_id, _ in event.auth_events] event.auth_event_ids()
) )
seen_events = set(have_events.keys()) seen_events = set(have_events.keys())
except Exception: except Exception:
@ -2058,7 +2058,7 @@ class FederationHandler(BaseHandler):
continue continue
try: try:
auth_ids = [e_id for e_id, _ in ev.auth_events] auth_ids = ev.auth_event_ids()
auth = { auth = {
(e.type, e.state_key): e (e.type, e.state_key): e
for e in result["auth_chain"] for e in result["auth_chain"]
@ -2250,7 +2250,7 @@ class FederationHandler(BaseHandler):
missing_remote_ids = [e.event_id for e in missing_remotes] missing_remote_ids = [e.event_id for e in missing_remotes]
base_remote_rejected = list(missing_remotes) base_remote_rejected = list(missing_remotes)
for e in missing_remotes: for e in missing_remotes:
for e_id, _ in e.auth_events: for e_id in e.auth_event_ids():
if e_id in missing_remote_ids: if e_id in missing_remote_ids:
try: try:
base_remote_rejected.remove(e) base_remote_rejected.remove(e)

View File

@ -261,7 +261,7 @@ class StateHandler(object):
logger.debug("calling resolve_state_groups from compute_event_context") logger.debug("calling resolve_state_groups from compute_event_context")
entry = yield self.resolve_state_groups_for_events( entry = yield self.resolve_state_groups_for_events(
event.room_id, [e for e, _ in event.prev_events], event.room_id, event.prev_event_ids(),
) )
prev_state_ids = entry.state prev_state_ids = entry.state

View File

@ -159,7 +159,7 @@ def _get_power_level_for_sender(event_id, event_map, state_res_store):
event = yield _get_event(event_id, event_map, state_res_store) event = yield _get_event(event_id, event_map, state_res_store)
pl = None pl = None
for aid, _ in event.auth_events: for aid in event.auth_event_ids():
aev = yield _get_event(aid, event_map, state_res_store) aev = yield _get_event(aid, event_map, state_res_store)
if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
pl = aev pl = aev
@ -167,7 +167,7 @@ def _get_power_level_for_sender(event_id, event_map, state_res_store):
if pl is None: if pl is None:
# Couldn't find power level. Check if they're the creator of the room # Couldn't find power level. Check if they're the creator of the room
for aid, _ in event.auth_events: for aid in event.auth_event_ids():
aev = yield _get_event(aid, event_map, state_res_store) aev = yield _get_event(aid, event_map, state_res_store)
if (aev.type, aev.state_key) == (EventTypes.Create, ""): if (aev.type, aev.state_key) == (EventTypes.Create, ""):
if aev.content.get("creator") == event.sender: if aev.content.get("creator") == event.sender:
@ -299,7 +299,7 @@ def _add_event_and_auth_chain_to_graph(graph, event_id, event_map,
graph.setdefault(eid, set()) graph.setdefault(eid, set())
event = yield _get_event(eid, event_map, state_res_store) event = yield _get_event(eid, event_map, state_res_store)
for aid, _ in event.auth_events: for aid in event.auth_event_ids():
if aid in auth_diff: if aid in auth_diff:
if aid not in graph: if aid not in graph:
state.append(aid) state.append(aid)
@ -369,7 +369,7 @@ def _iterative_auth_checks(event_ids, base_state, event_map, state_res_store):
event = event_map[event_id] event = event_map[event_id]
auth_events = {} auth_events = {}
for aid, _ in event.auth_events: for aid in event.auth_event_ids():
ev = yield _get_event(aid, event_map, state_res_store) ev = yield _get_event(aid, event_map, state_res_store)
if ev.rejected_reason is None: if ev.rejected_reason is None:
@ -417,9 +417,9 @@ def _mainline_sort(event_ids, resolved_power_event_id, event_map,
while pl: while pl:
mainline.append(pl) mainline.append(pl)
pl_ev = yield _get_event(pl, event_map, state_res_store) pl_ev = yield _get_event(pl, event_map, state_res_store)
auth_events = pl_ev.auth_events auth_events = pl_ev.auth_event_ids()
pl = None pl = None
for aid, _ in auth_events: for aid in auth_events:
ev = yield _get_event(aid, event_map, state_res_store) ev = yield _get_event(aid, event_map, state_res_store)
if (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""): if (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""):
pl = aid pl = aid
@ -464,10 +464,10 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
if depth is not None: if depth is not None:
defer.returnValue(depth) defer.returnValue(depth)
auth_events = event.auth_events auth_events = event.auth_event_ids()
event = None event = None
for aid, _ in auth_events: for aid in auth_events:
aev = yield _get_event(aid, event_map, state_res_store) aev = yield _get_event(aid, event_map, state_res_store)
if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
event = aev event = aev

View File

@ -477,7 +477,7 @@ class EventFederationStore(EventFederationWorkerStore):
"is_state": False, "is_state": False,
} }
for ev in events for ev in events
for e_id, _ in ev.prev_events for e_id in ev.prev_event_ids()
], ],
) )
@ -510,7 +510,7 @@ class EventFederationStore(EventFederationWorkerStore):
txn.executemany(query, [ txn.executemany(query, [
(e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False) (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
for ev in events for e_id, _ in ev.prev_events for ev in events for e_id in ev.prev_event_ids()
if not ev.internal_metadata.is_outlier() if not ev.internal_metadata.is_outlier()
]) ])

View File

@ -416,7 +416,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
) )
if len_1: if len_1:
all_single_prev_not_state = all( all_single_prev_not_state = all(
len(event.prev_events) == 1 len(event.prev_event_ids()) == 1
and not event.is_state() and not event.is_state()
for event, ctx in ev_ctx_rm for event, ctx in ev_ctx_rm
) )
@ -440,7 +440,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
# guess this by looking at the prev_events and checking # guess this by looking at the prev_events and checking
# if they match the current forward extremities. # if they match the current forward extremities.
for ev, _ in ev_ctx_rm: for ev, _ in ev_ctx_rm:
prev_event_ids = set(e for e, _ in ev.prev_events) prev_event_ids = set(ev.prev_event_ids())
if latest_event_ids == prev_event_ids: if latest_event_ids == prev_event_ids:
state_delta_reuse_delta_counter.inc() state_delta_reuse_delta_counter.inc()
break break
@ -551,7 +551,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
result.difference_update( result.difference_update(
e_id e_id
for event in new_events for event in new_events
for e_id, _ in event.prev_events for e_id in event.prev_event_ids()
) )
# Finally, remove any events which are prev_events of any existing events. # Finally, remove any events which are prev_events of any existing events.
@ -869,7 +869,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
"auth_id": auth_id, "auth_id": auth_id,
} }
for event, _ in events_and_contexts for event, _ in events_and_contexts
for auth_id, _ in event.auth_events for auth_id in event.auth_event_ids()
if event.is_state() if event.is_state()
], ],
) )

View File

@ -753,7 +753,7 @@ class TestStateResolutionStore(object):
result.add(event_id) result.add(event_id)
event = self.event_map[event_id] event = self.event_map[event_id]
for aid, _ in event.auth_events: for aid in event.auth_event_ids():
stack.append(aid) stack.append(aid)
return list(result) return list(result)