Fixup bsaed on review comments

This commit is contained in:
Erik Johnston 2019-05-17 15:48:04 +01:00
parent d46aab3fa8
commit 5dbff34509
3 changed files with 19 additions and 19 deletions

View File

@ -368,9 +368,7 @@ class EventClientSerializer(object):
edit = None edit = None
if event.type == EventTypes.Message: if event.type == EventTypes.Message:
edit = yield self.store.get_applicable_edit( edit = yield self.store.get_applicable_edit(event_id)
event.event_id, event.type, event.sender,
)
if edit: if edit:
# If there is an edit replace the content, preserving existing # If there is an edit replace the content, preserving existing

View File

@ -143,4 +143,4 @@ class SlavedEventStore(EventFederationWorkerStore,
if relates_to: if relates_to:
self.get_relations_for_event.invalidate_many((relates_to,)) self.get_relations_for_event.invalidate_many((relates_to,))
self.get_aggregation_groups_for_event.invalidate_many((relates_to,)) self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
self.get_applicable_edit.invalidate_many((relates_to,)) self.get_applicable_edit.invalidate((relates_to,))

View File

@ -19,7 +19,7 @@ import attr
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, RelationTypes from synapse.api.constants import RelationTypes
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.stream import generate_pagination_where_clause from synapse.storage.stream import generate_pagination_where_clause
@ -314,8 +314,8 @@ class RelationsWorkerStore(SQLBaseStore):
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
) )
@cachedInlineCallbacks(tree=True) @cachedInlineCallbacks()
def get_applicable_edit(self, event_id, event_type, sender): def get_applicable_edit(self, event_id):
"""Get the most recent edit (if any) that has happened for the given """Get the most recent edit (if any) that has happened for the given
event. event.
@ -323,8 +323,6 @@ class RelationsWorkerStore(SQLBaseStore):
Args: Args:
event_id (str): The original event ID event_id (str): The original event ID
event_type (str): The original event type
sender (str): The original event sender
Returns: Returns:
Deferred[EventBase|None]: Returns the most recent edit, if any. Deferred[EventBase|None]: Returns the most recent edit, if any.
@ -332,26 +330,28 @@ class RelationsWorkerStore(SQLBaseStore):
# We only allow edits for `m.room.message` events that have the same sender # We only allow edits for `m.room.message` events that have the same sender
# and event type. We can't assert these things during regular event auth so # and event type. We can't assert these things during regular event auth so
# we have to do the post hoc. # we have to do the checks post hoc.
if event_type != EventTypes.Message:
return
# Fetches latest edit that has the same type and sender as the
# original, and is an `m.room.message`.
sql = """ sql = """
SELECT event_id, origin_server_ts FROM events SELECT edit.event_id FROM events AS edit
INNER JOIN event_relations USING (event_id) INNER JOIN event_relations USING (event_id)
INNER JOIN events AS original ON
original.event_id = relates_to_id
AND edit.type = original.type
AND edit.sender = original.sender
WHERE WHERE
relates_to_id = ? relates_to_id = ?
AND relation_type = ? AND relation_type = ?
AND type = ? AND edit.type = 'm.room.message'
AND sender = ? ORDER by edit.origin_server_ts DESC, edit.event_id DESC
ORDER by origin_server_ts DESC, event_id DESC
LIMIT 1 LIMIT 1
""" """
def _get_applicable_edit_txn(txn): def _get_applicable_edit_txn(txn):
txn.execute( txn.execute(
sql, (event_id, RelationTypes.REPLACES, event_type, sender) sql, (event_id, RelationTypes.REPLACES,)
) )
row = txn.fetchone() row = txn.fetchone()
if row: if row:
@ -412,4 +412,6 @@ class RelationsStore(RelationsWorkerStore):
txn.call_after( txn.call_after(
self.get_aggregation_groups_for_event.invalidate_many, (parent_id,) self.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
) )
txn.call_after(self.get_applicable_edit.invalidate_many, (parent_id,))
if rel_type == RelationTypes.REPLACES:
txn.call_after(self.get_applicable_edit.invalidate, (parent_id,))