Initial stab at implementing a batched get_missing_pdus request

This commit is contained in:
Erik Johnston 2015-02-19 17:24:14 +00:00
parent 894a89d99b
commit 0ac2a79faa
3 changed files with 135 additions and 9 deletions

View File

@ -305,6 +305,78 @@ class FederationServer(FederationBase):
(200, send_content) (200, send_content)
) )
@defer.inlineCallbacks
def get_missing_events(self, origin, room_id, earliest_events,
latest_events, limit, min_depth):
limit = max(limit, 50)
min_depth = max(min_depth, 0)
missing_events = yield self.store.get_missing_events(
room_id=room_id,
earliest_events=earliest_events,
latest_events=latest_events,
limit=limit,
min_depth=min_depth,
)
known_ids = {e.event_id for e in missing_events} | {earliest_events}
back_edges = {
e for e in missing_events
if {i for i, h in e.prev_events.items()} <= known_ids
}
decoded_auth_events = set()
state = {}
auth_events = set()
auth_and_state = {}
for event in back_edges:
state_pdus = yield self.handler.get_state_for_pdu(
origin, room_id, event.event_id,
do_auth=False,
)
state[event.event_id] = [s.event_id for s in state_pdus]
auth_and_state.update({
s.event_id: s for s in state_pdus
})
state_ids = {pdu.event_id for pdu in state_pdus}
prev_ids = {i for i, h in event.prev_events.items()}
partial_auth_chain = yield self.store.get_auth_chain(
state_ids | prev_ids, have_ids=decoded_auth_events.keys()
)
for p in partial_auth_chain:
p.signatures.update(
compute_event_signature(
p,
self.hs.hostname,
self.hs.config.signing_key[0]
)
)
auth_events.update(
a.event_id for a in partial_auth_chain
)
auth_and_state.update({
a.event_id: a for a in partial_auth_chain
})
time_now = self._clock.time_msec()
defer.returnValue({
"events": [ev.get_pdu_json(time_now) for ev in missing_events],
"state_for_events": state,
"auth_events": auth_events,
"event_map": {
k: ev.get_pdu_json(time_now)
for k, ev in auth_and_state.items()
},
})
@log_function @log_function
def _get_persisted_pdu(self, origin, event_id, do_auth=True): def _get_persisted_pdu(self, origin, event_id, do_auth=True):
""" Get a PDU from the database with given origin and id. """ Get a PDU from the database with given origin and id.

View File

@ -581,12 +581,13 @@ class FederationHandler(BaseHandler):
defer.returnValue(event) defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_for_pdu(self, origin, room_id, event_id): def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True):
yield run_on_reactor() yield run_on_reactor()
in_room = yield self.auth.check_host_in_room(room_id, origin) if do_auth:
if not in_room: in_room = yield self.auth.check_host_in_room(room_id, origin)
raise AuthError(403, "Host not in room.") if not in_room:
raise AuthError(403, "Host not in room.")
state_groups = yield self.store.get_state_groups( state_groups = yield self.store.get_state_groups(
[event_id] [event_id]

View File

@ -32,15 +32,15 @@ class EventFederationStore(SQLBaseStore):
and backfilling from another server respectively. and backfilling from another server respectively.
""" """
def get_auth_chain(self, event_ids): def get_auth_chain(self, event_ids, have_ids=set()):
return self.runInteraction( return self.runInteraction(
"get_auth_chain", "get_auth_chain",
self._get_auth_chain_txn, self._get_auth_chain_txn,
event_ids event_ids, have_ids
) )
def _get_auth_chain_txn(self, txn, event_ids): def _get_auth_chain_txn(self, txn, event_ids, have_ids):
results = self._get_auth_chain_ids_txn(txn, event_ids) results = self._get_auth_chain_ids_txn(txn, event_ids, have_ids)
return self._get_events_txn(txn, results) return self._get_events_txn(txn, results)
@ -51,8 +51,9 @@ class EventFederationStore(SQLBaseStore):
event_ids event_ids
) )
def _get_auth_chain_ids_txn(self, txn, event_ids): def _get_auth_chain_ids_txn(self, txn, event_ids, have_ids):
results = set() results = set()
have_ids = set(have_ids)
base_sql = ( base_sql = (
"SELECT auth_id FROM event_auth WHERE event_id = ?" "SELECT auth_id FROM event_auth WHERE event_id = ?"
@ -64,6 +65,10 @@ class EventFederationStore(SQLBaseStore):
for f in front: for f in front:
txn.execute(base_sql, (f,)) txn.execute(base_sql, (f,))
new_front.update([r[0] for r in txn.fetchall()]) new_front.update([r[0] for r in txn.fetchall()])
new_front -= results
new_front -= have_ids
front = new_front front = new_front
results.update(front) results.update(front)
@ -378,3 +383,51 @@ class EventFederationStore(SQLBaseStore):
event_results += new_front event_results += new_front
return self._get_events_txn(txn, event_results) return self._get_events_txn(txn, event_results)
def get_missing_events(self, room_id, earliest_events, latest_events,
limit, min_depth):
return self.runInteraction(
"get_missing_events",
self._get_missing_events,
room_id, earliest_events, latest_events, limit, min_depth
)
def _get_missing_events(self, txn, room_id, earliest_events, latest_events,
limit, min_depth):
earliest_events = set(earliest_events)
front = set(latest_events) - earliest_events
event_results = set()
query = (
"SELECT prev_event_id FROM event_edges "
"WHERE room_id = ? AND event_id = ? AND is_state = 0 "
"LIMIT ?"
)
while front and len(event_results) < limit:
new_front = set()
for event_id in front:
txn.execute(
query,
(room_id, event_id, limit - len(event_results))
)
for e_id, in txn.fetchall():
new_front.add(e_id)
new_front -= earliest_events
new_front -= event_results
front = new_front
event_results |= new_front
events = self._get_events_txn(txn, event_results)
events = sorted(
[ev for ev in events if ev.depth >= min_depth],
key=lambda e: e.depth,
)
return events[:limit]