Return the previous stream token if a non-member event is a duplicate. (#8093)

This commit is contained in:
Patrick Cloke 2020-08-18 07:53:23 -04:00 committed by GitHub
parent 8b6c176aee
commit 25e55d2598
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 10 deletions

1
changelog.d/8093.bugfix Normal file
View File

@ -0,0 +1 @@
Return the previous stream token if a non-member event is a duplicate.

View File

@ -667,14 +667,14 @@ class EventCreationHandler(object):
assert self.hs.is_mine(user), "User must be our own: %s" % (user,) assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if event.is_state(): if event.is_state():
prev_state = await self.deduplicate_state_event(event, context) prev_event = await self.deduplicate_state_event(event, context)
if prev_state is not None: if prev_event is not None:
logger.info( logger.info(
"Not bothering to persist state event %s duplicated by %s", "Not bothering to persist state event %s duplicated by %s",
event.event_id, event.event_id,
prev_state.event_id, prev_event.event_id,
) )
return prev_state return await self.store.get_stream_token_for_event(prev_event.event_id)
return await self.handle_new_client_event( return await self.handle_new_client_event(
requester=requester, event=event, context=context, ratelimit=ratelimit requester=requester, event=event, context=context, ratelimit=ratelimit
@ -682,27 +682,32 @@ class EventCreationHandler(object):
async def deduplicate_state_event( async def deduplicate_state_event(
self, event: EventBase, context: EventContext self, event: EventBase, context: EventContext
) -> None: ) -> Optional[EventBase]:
""" """
Checks whether event is in the latest resolved state in context. Checks whether event is in the latest resolved state in context.
If so, returns the version of the event in context. Args:
Otherwise, returns None. event: The event to check for duplication.
context: The event context.
Returns:
The previous verion of the event is returned, if it is found in the
event context. Otherwise, None is returned.
""" """
prev_state_ids = await context.get_prev_state_ids() prev_state_ids = await context.get_prev_state_ids()
prev_event_id = prev_state_ids.get((event.type, event.state_key)) prev_event_id = prev_state_ids.get((event.type, event.state_key))
if not prev_event_id: if not prev_event_id:
return return None
prev_event = await self.store.get_event(prev_event_id, allow_none=True) prev_event = await self.store.get_event(prev_event_id, allow_none=True)
if not prev_event: if not prev_event:
return return None
if prev_event and event.user_id == prev_event.user_id: if prev_event and event.user_id == prev_event.user_id:
prev_content = encode_canonical_json(prev_event.content) prev_content = encode_canonical_json(prev_event.content)
next_content = encode_canonical_json(event.content) next_content = encode_canonical_json(event.content)
if prev_content == next_content: if prev_content == next_content:
return prev_event return prev_event
return return None
async def create_and_send_nonmember_event( async def create_and_send_nonmember_event(
self, self,