Merge pull request #680 from matrix-org/markjh/remove_is_new_state

Remove the is_new_state argument to persist event.
This commit is contained in:
Mark Haines 2016-03-31 15:14:48 +01:00
commit 03e406eefc
3 changed files with 56 additions and 55 deletions

View File

@ -33,6 +33,9 @@ class _EventInternalMetadata(object):
def is_outlier(self): def is_outlier(self):
return getattr(self, "outlier", False) return getattr(self, "outlier", False)
def is_invite_from_remote(self):
return getattr(self, "invite_from_remote", False)
def _event_dict_property(key): def _event_dict_property(key):
def getter(self): def getter(self):

View File

@ -102,8 +102,7 @@ class FederationHandler(BaseHandler):
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def on_receive_pdu(self, origin, pdu, state=None, def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None):
auth_chain=None):
""" Called by the ReplicationLayer when we have a new pdu. We need to """ Called by the ReplicationLayer when we have a new pdu. We need to
do auth checks and put it through the StateHandler. do auth checks and put it through the StateHandler.
""" """
@ -174,11 +173,7 @@ class FederationHandler(BaseHandler):
}) })
seen_ids.add(e.event_id) seen_ids.add(e.event_id)
yield self._handle_new_events( yield self._handle_new_events(origin, event_infos)
origin,
event_infos,
outliers=True
)
try: try:
context, event_stream_id, max_stream_id = yield self._handle_new_event( context, event_stream_id, max_stream_id = yield self._handle_new_event(
@ -761,6 +756,7 @@ class FederationHandler(BaseHandler):
event = pdu event = pdu
event.internal_metadata.outlier = True event.internal_metadata.outlier = True
event.internal_metadata.invite_from_remote = True
event.signatures.update( event.signatures.update(
compute_event_signature( compute_event_signature(
@ -1069,9 +1065,6 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def _handle_new_event(self, origin, event, state=None, auth_events=None): def _handle_new_event(self, origin, event, state=None, auth_events=None):
outlier = event.internal_metadata.is_outlier()
context = yield self._prep_event( context = yield self._prep_event(
origin, event, origin, event,
state=state, state=state,
@ -1087,14 +1080,12 @@ class FederationHandler(BaseHandler):
event_stream_id, max_stream_id = yield self.store.persist_event( event_stream_id, max_stream_id = yield self.store.persist_event(
event, event,
context=context, context=context,
is_new_state=not outlier,
) )
defer.returnValue((context, event_stream_id, max_stream_id)) defer.returnValue((context, event_stream_id, max_stream_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def _handle_new_events(self, origin, event_infos, backfilled=False, def _handle_new_events(self, origin, event_infos, backfilled=False):
outliers=False):
contexts = yield defer.gatherResults( contexts = yield defer.gatherResults(
[ [
self._prep_event( self._prep_event(
@ -1113,7 +1104,6 @@ class FederationHandler(BaseHandler):
for ev_info, context in itertools.izip(event_infos, contexts) for ev_info, context in itertools.izip(event_infos, contexts)
], ],
backfilled=backfilled, backfilled=backfilled,
is_new_state=(not outliers and not backfilled),
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -1176,7 +1166,6 @@ class FederationHandler(BaseHandler):
(e, events_to_context[e.event_id]) (e, events_to_context[e.event_id])
for e in itertools.chain(auth_events, state) for e in itertools.chain(auth_events, state)
], ],
is_new_state=False,
) )
new_event_context = yield self.state_handler.compute_event_context( new_event_context = yield self.state_handler.compute_event_context(
@ -1185,7 +1174,6 @@ class FederationHandler(BaseHandler):
event_stream_id, max_stream_id = yield self.store.persist_event( event_stream_id, max_stream_id = yield self.store.persist_event(
event, new_event_context, event, new_event_context,
is_new_state=True,
current_state=state, current_state=state,
) )

View File

@ -61,8 +61,7 @@ class EventsStore(SQLBaseStore):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def persist_events(self, events_and_contexts, backfilled=False, def persist_events(self, events_and_contexts, backfilled=False):
is_new_state=True):
if not events_and_contexts: if not events_and_contexts:
return return
@ -110,13 +109,11 @@ class EventsStore(SQLBaseStore):
self._persist_events_txn, self._persist_events_txn,
events_and_contexts=chunk, events_and_contexts=chunk,
backfilled=backfilled, backfilled=backfilled,
is_new_state=is_new_state,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def persist_event(self, event, context, def persist_event(self, event, context, current_state=None):
is_new_state=True, current_state=None):
try: try:
with self._stream_id_gen.get_next() as stream_ordering: with self._stream_id_gen.get_next() as stream_ordering:
@ -128,7 +125,6 @@ class EventsStore(SQLBaseStore):
self._persist_event_txn, self._persist_event_txn,
event=event, event=event,
context=context, context=context,
is_new_state=is_new_state,
current_state=current_state, current_state=current_state,
) )
except _RollbackButIsFineException: except _RollbackButIsFineException:
@ -194,8 +190,7 @@ class EventsStore(SQLBaseStore):
defer.returnValue({e.event_id: e for e in events}) defer.returnValue({e.event_id: e for e in events})
@log_function @log_function
def _persist_event_txn(self, txn, event, context, def _persist_event_txn(self, txn, event, context, current_state):
is_new_state, current_state):
# We purposefully do this first since if we include a `current_state` # We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table # key, we *want* to update the `current_state_events` table
if current_state: if current_state:
@ -236,12 +231,10 @@ class EventsStore(SQLBaseStore):
txn, txn,
[(event, context)], [(event, context)],
backfilled=False, backfilled=False,
is_new_state=is_new_state,
) )
@log_function @log_function
def _persist_events_txn(self, txn, events_and_contexts, backfilled, def _persist_events_txn(self, txn, events_and_contexts, backfilled):
is_new_state):
depth_updates = {} depth_updates = {}
for event, context in events_and_contexts: for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids # Remove the any existing cache entries for the event_ids
@ -452,10 +445,9 @@ class EventsStore(SQLBaseStore):
txn, [event for event, _ in events_and_contexts] txn, [event for event, _ in events_and_contexts]
) )
state_events_and_contexts = filter( state_events_and_contexts = [
lambda i: i[0].is_state(), ec for ec in events_and_contexts if ec[0].is_state()
events_and_contexts, ]
)
state_values = [] state_values = []
for event, context in state_events_and_contexts: for event, context in state_events_and_contexts:
@ -493,9 +485,27 @@ class EventsStore(SQLBaseStore):
], ],
) )
if is_new_state: if backfilled:
# Backfilled events come before the current state so we don't need
# to update the current state table
return
for event, _ in state_events_and_contexts: for event, _ in state_events_and_contexts:
if not context.rejected: if (not event.internal_metadata.is_invite_from_remote()
and event.internal_metadata.is_outlier()):
# Outlier events generally shouldn't clobber the current state.
# However invites from remote severs for rooms we aren't in
# are a bit special: they don't come with any associated
# state so are technically an outlier, however all the
# client-facing code assumes that they are in the current
# state table so we insert the event anyway.
continue
if context.rejected:
# If the event failed it's auth checks then it shouldn't
# clobbler the current state.
continue
txn.call_after( txn.call_after(
self._get_current_state_for_key.invalidate, self._get_current_state_for_key.invalidate,
(event.room_id, event.type, event.state_key,) (event.room_id, event.type, event.state_key,)