Merge pull request #5207 from matrix-org/erikj/reactions_redactions

Correctly update aggregation counts after redaction
This commit is contained in:
Erik Johnston 2019-05-20 12:36:06 +01:00 committed by GitHub
commit 9ad246e6d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 125 additions and 12 deletions

View File

@ -1325,6 +1325,9 @@ class EventsStore(
txn, event.room_id, event.redacts txn, event.room_id, event.redacts
) )
# Remove from relations table.
self._handle_redaction(txn, event.redacts)
# Update the event_forward_extremities, event_backward_extremities and # Update the event_forward_extremities, event_backward_extremities and
# event_edges tables. # event_edges tables.
self._handle_mult_prev_events( self._handle_mult_prev_events(

View File

@ -415,3 +415,20 @@ class RelationsStore(RelationsWorkerStore):
if rel_type == RelationTypes.REPLACES: if rel_type == RelationTypes.REPLACES:
txn.call_after(self.get_applicable_edit.invalidate, (parent_id,)) txn.call_after(self.get_applicable_edit.invalidate, (parent_id,))
def _handle_redaction(self, txn, redacted_event_id):
"""Handles receiving a redaction and checking whether we need to remove
any redacted relations from the database.
Args:
txn
redacted_event_id (str): The event that was redacted.
"""
self._simple_delete_txn(
txn,
table="event_relations",
keyvalues={
"event_id": redacted_event_id,
}
)

View File

@ -19,19 +19,22 @@ import json
import six import six
from synapse.api.constants import EventTypes, RelationTypes from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin
from synapse.rest.client.v1 import login, room from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import relations from synapse.rest.client.v2_alpha import register, relations
from tests import unittest from tests import unittest
class RelationsTestCase(unittest.HomeserverTestCase): class RelationsTestCase(unittest.HomeserverTestCase):
user_id = "@alice:test"
servlets = [ servlets = [
relations.register_servlets, relations.register_servlets,
room.register_servlets, room.register_servlets,
login.register_servlets, login.register_servlets,
register.register_servlets,
admin.register_servlets_for_client_rest_resource,
] ]
hijack_auth = False
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
# We need to enable msc1849 support for aggregations # We need to enable msc1849 support for aggregations
@ -40,8 +43,12 @@ class RelationsTestCase(unittest.HomeserverTestCase):
return self.setup_test_homeserver(config=config) return self.setup_test_homeserver(config=config)
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
self.room = self.helper.create_room_as(self.user_id) self.user_id, self.user_token = self._create_user("alice")
res = self.helper.send(self.room, body="Hi!") self.user2_id, self.user2_token = self._create_user("bob")
self.room = self.helper.create_room_as(self.user_id, tok=self.user_token)
self.helper.join(self.room, user=self.user2_id, tok=self.user2_token)
res = self.helper.send(self.room, body="Hi!", tok=self.user_token)
self.parent_id = res["event_id"] self.parent_id = res["event_id"]
def test_send_relation(self): def test_send_relation(self):
@ -55,7 +62,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
event_id = channel.json_body["event_id"] event_id = channel.json_body["event_id"]
request, channel = self.make_request( request, channel = self.make_request(
"GET", "/rooms/%s/event/%s" % (self.room, event_id) "GET",
"/rooms/%s/event/%s" % (self.room, event_id),
access_token=self.user_token,
) )
self.render(request) self.render(request)
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)
@ -95,6 +104,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"GET", "GET",
"/_matrix/client/unstable/rooms/%s/relations/%s?limit=1" "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1"
% (self.room, self.parent_id), % (self.room, self.parent_id),
access_token=self.user_token,
) )
self.render(request) self.render(request)
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)
@ -135,6 +145,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"GET", "GET",
"/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s" "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s"
% (self.room, self.parent_id, from_token), % (self.room, self.parent_id, from_token),
access_token=self.user_token,
) )
self.render(request) self.render(request)
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)
@ -156,15 +167,32 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"""Test that we can paginate annotation groups correctly. """Test that we can paginate annotation groups correctly.
""" """
# We need to create ten separate users to send each reaction.
access_tokens = [self.user_token, self.user2_token]
idx = 0
while len(access_tokens) < 10:
user_id, token = self._create_user("test" + str(idx))
idx += 1
self.helper.join(self.room, user=user_id, tok=token)
access_tokens.append(token)
idx = 0
sent_groups = {u"👍": 10, u"a": 7, u"b": 5, u"c": 3, u"d": 2, u"e": 1} sent_groups = {u"👍": 10, u"a": 7, u"b": 5, u"c": 3, u"d": 2, u"e": 1}
for key in itertools.chain.from_iterable( for key in itertools.chain.from_iterable(
itertools.repeat(key, num) for key, num in sent_groups.items() itertools.repeat(key, num) for key, num in sent_groups.items()
): ):
channel = self._send_relation( channel = self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", key=key RelationTypes.ANNOTATION,
"m.reaction",
key=key,
access_token=access_tokens[idx],
) )
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)
idx += 1
idx %= len(access_tokens)
prev_token = None prev_token = None
found_groups = {} found_groups = {}
for _ in range(20): for _ in range(20):
@ -176,6 +204,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"GET", "GET",
"/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s" "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s"
% (self.room, self.parent_id, from_token), % (self.room, self.parent_id, from_token),
access_token=self.user_token,
) )
self.render(request) self.render(request)
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)
@ -236,6 +265,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
encoded_key, encoded_key,
from_token, from_token,
), ),
access_token=self.user_token,
) )
self.render(request) self.render(request)
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)
@ -263,7 +293,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") channel = self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
)
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
@ -273,6 +305,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"GET", "GET",
"/_matrix/client/unstable/rooms/%s/aggregations/%s" "/_matrix/client/unstable/rooms/%s/aggregations/%s"
% (self.room, self.parent_id), % (self.room, self.parent_id),
access_token=self.user_token,
) )
self.render(request) self.render(request)
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)
@ -287,6 +320,43 @@ class RelationsTestCase(unittest.HomeserverTestCase):
}, },
) )
def test_aggregation_redactions(self):
"""Test that annotations get correctly aggregated after a redaction.
"""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEquals(200, channel.code, channel.json_body)
to_redact_event_id = channel.json_body["event_id"]
channel = self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
)
self.assertEquals(200, channel.code, channel.json_body)
# Now lets redact one of the 'a' reactions
request, channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/%s/redact/%s" % (self.room, to_redact_event_id),
access_token=self.user_token,
content={},
)
self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
request, channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/aggregations/%s"
% (self.room, self.parent_id),
access_token=self.user_token,
)
self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEquals(
channel.json_body,
{"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
)
def test_aggregation_must_be_annotation(self): def test_aggregation_must_be_annotation(self):
"""Test that aggregations must be annotations. """Test that aggregations must be annotations.
""" """
@ -295,6 +365,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"GET", "GET",
"/_matrix/client/unstable/rooms/%s/aggregations/%s/m.replace?limit=1" "/_matrix/client/unstable/rooms/%s/aggregations/%s/m.replace?limit=1"
% (self.room, self.parent_id), % (self.room, self.parent_id),
access_token=self.user_token,
) )
self.render(request) self.render(request)
self.assertEquals(400, channel.code, channel.json_body) self.assertEquals(400, channel.code, channel.json_body)
@ -307,7 +378,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") channel = self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
)
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
@ -322,7 +395,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
reply_2 = channel.json_body["event_id"] reply_2 = channel.json_body["event_id"]
request, channel = self.make_request( request, channel = self.make_request(
"GET", "/rooms/%s/event/%s" % (self.room, self.parent_id) "GET",
"/rooms/%s/event/%s" % (self.room, self.parent_id),
access_token=self.user_token,
) )
self.render(request) self.render(request)
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)
@ -357,7 +432,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
edit_event_id = channel.json_body["event_id"] edit_event_id = channel.json_body["event_id"]
request, channel = self.make_request( request, channel = self.make_request(
"GET", "/rooms/%s/event/%s" % (self.room, self.parent_id) "GET",
"/rooms/%s/event/%s" % (self.room, self.parent_id),
access_token=self.user_token,
) )
self.render(request) self.render(request)
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)
@ -407,7 +484,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)
request, channel = self.make_request( request, channel = self.make_request(
"GET", "/rooms/%s/event/%s" % (self.room, self.parent_id) "GET",
"/rooms/%s/event/%s" % (self.room, self.parent_id),
access_token=self.user_token,
) )
self.render(request) self.render(request)
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)
@ -419,7 +498,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
{RelationTypes.REPLACES: {"event_id": edit_event_id}}, {RelationTypes.REPLACES: {"event_id": edit_event_id}},
) )
def _send_relation(self, relation_type, event_type, key=None, content={}): def _send_relation(
self, relation_type, event_type, key=None, content={}, access_token=None
):
"""Helper function to send a relation pointing at `self.parent_id` """Helper function to send a relation pointing at `self.parent_id`
Args: Args:
@ -428,10 +509,15 @@ class RelationsTestCase(unittest.HomeserverTestCase):
key (str|None): The aggregation key used for m.annotation relation key (str|None): The aggregation key used for m.annotation relation
type. type.
content(dict|None): The content of the created event. content(dict|None): The content of the created event.
access_token (str|None): The access token used to send the relation,
defaults to `self.user_token`
Returns: Returns:
FakeChannel FakeChannel
""" """
if not access_token:
access_token = self.user_token
query = "" query = ""
if key: if key:
query = "?key=" + six.moves.urllib.parse.quote_plus(key.encode("utf-8")) query = "?key=" + six.moves.urllib.parse.quote_plus(key.encode("utf-8"))
@ -441,6 +527,13 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s" "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
% (self.room, self.parent_id, relation_type, event_type, query), % (self.room, self.parent_id, relation_type, event_type, query),
json.dumps(content).encode("utf-8"), json.dumps(content).encode("utf-8"),
access_token=access_token,
) )
self.render(request) self.render(request)
return channel return channel
def _create_user(self, localpart):
user_id = self.register_user(localpart, "abc123")
access_token = self.login(localpart, "abc123")
return user_id, access_token