Make federation return the old current state, so that we can use it to do auth

This commit is contained in:
Erik Johnston 2014-10-30 11:53:35 +00:00
parent ef9c4476a0
commit da511334d2
2 changed files with 32 additions and 10 deletions

View File

@ -112,7 +112,7 @@ class FederationHandler(BaseHandler):
is_new_state = yield self.state_handler.annotate_state_groups( is_new_state = yield self.state_handler.annotate_state_groups(
event, event,
state=state old_state=state
) )
logger.debug("Event: %s", event) logger.debug("Event: %s", event)
@ -240,7 +240,7 @@ class FederationHandler(BaseHandler):
is_new_state = yield self.state_handler.annotate_state_groups( is_new_state = yield self.state_handler.annotate_state_groups(
event, event,
state=state old_state=state
) )
logger.debug("do_invite_join event: %s", event) logger.debug("do_invite_join event: %s", event)
@ -279,7 +279,10 @@ class FederationHandler(BaseHandler):
del self.room_queues[room_id] del self.room_queues[room_id]
for p in room_queue: for p in room_queue:
yield self.on_receive_pdu(p, backfilled=False) try:
yield self.on_receive_pdu(p, backfilled=False)
except:
pass
defer.returnValue(True) defer.returnValue(True)
@ -355,15 +358,30 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_for_pdu(self, pdu_id, pdu_origin): def get_state_for_pdu(self, pdu_id, pdu_origin):
event_id = encode_event_id(pdu_id, pdu_origin)
state_groups = yield self.store.get_state_groups( state_groups = yield self.store.get_state_groups(
[encode_event_id(pdu_id, pdu_origin)] [event_id]
) )
if state_groups: if state_groups:
results = {
(e.type, e.state_key): e for e in state_groups[0].state
}
event = yield self.store.get_event(event_id)
if hasattr(event, "state_key"):
# Get previous state
if hasattr(event, "prev_state") and event.prev_state:
prev_event = yield self.store.get_event(event.prev_state)
results[(event.type, event.state_key)] = prev_event
else:
del results[(event.type, event.state_key)]
defer.returnValue( defer.returnValue(
[ [
self.pdu_codec.pdu_from_event(s) self.pdu_codec.pdu_from_event(s)
for s in state_groups[0].state for s in results.values()
] ]
) )
else: else:

View File

@ -128,11 +128,15 @@ class StateHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def annotate_state_groups(self, event, state=None): def annotate_state_groups(self, event, old_state=None):
if state: if old_state:
event.state_group = None event.state_group = None
event.old_state_events = None event.old_state_events = old_state
event.state_events = {(s.type, s.state_key): s for s in state} event.state_events = {(s.type, s.state_key): s for s in old_state}
if hasattr(event, "state_key"):
event.state_events[(event.type, event.state_key)] = event
defer.returnValue(False) defer.returnValue(False)
return return
@ -163,7 +167,7 @@ class StateHandler(object):
event_ids = [ event_ids = [
e_id e_id
for e_id, _ in events for e_id, _, _ in events
] ]
res = yield self.resolve_state_groups(event_ids) res = yield self.resolve_state_groups(event_ids)