Add tests for redactions

This commit is contained in:
Mark Haines 2016-04-07 16:26:52 +01:00
parent 8c82b06904
commit ceb599e789
4 changed files with 54 additions and 5 deletions

View File

@ -69,6 +69,7 @@ class SlavedEventStore(BaseSlavedStore):
"_get_current_state_for_key" "_get_current_state_for_key"
] ]
get_event = DataStore.get_event.__func__
get_current_state = DataStore.get_current_state.__func__ get_current_state = DataStore.get_current_state.__func__
get_current_state_for_key = DataStore.get_current_state_for_key.__func__ get_current_state_for_key = DataStore.get_current_state_for_key.__func__
get_rooms_for_user_where_membership_is = ( get_rooms_for_user_where_membership_is = (
@ -103,7 +104,7 @@ class SlavedEventStore(BaseSlavedStore):
def stream_positions(self): def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions() result = super(SlavedEventStore, self).stream_positions()
result["events"] = self._stream_id_gen.get_current_token() result["events"] = self._stream_id_gen.get_current_token()
result["backfilled"] = self._backfill_id_gen.get_current_token() result["backfill"] = self._backfill_id_gen.get_current_token()
return result return result
def process_replication(self, result): def process_replication(self, result):
@ -145,7 +146,6 @@ class SlavedEventStore(BaseSlavedStore):
position = row[0] position = row[0]
internal = json.loads(row[1]) internal = json.loads(row[1])
event_json = json.loads(row[2]) event_json = json.loads(row[2])
event = FrozenEvent(event_json, internal_metadata_dict=internal) event = FrozenEvent(event_json, internal_metadata_dict=internal)
self._invalidate_caches_for_event( self._invalidate_caches_for_event(
event, backfilled, reset_state=position in state_resets event, backfilled, reset_state=position in state_resets

View File

@ -112,7 +112,7 @@ class StreamIdGenerator(object):
self._current + self._step * (n + 1), self._current + self._step * (n + 1),
self._step self._step
) )
self._current += n self._current += n * self._step
for next_id in next_ids: for next_id in next_ids:
self._unfinished_ids.append(next_id) self._unfinished_ids.append(next_id)

View File

@ -51,7 +51,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
def check(self, method, args, expected_result=None): def check(self, method, args, expected_result=None):
master_result = yield getattr(self.master_store, method)(*args) master_result = yield getattr(self.master_store, method)(*args)
slaved_result = yield getattr(self.slaved_store, method)(*args) slaved_result = yield getattr(self.slaved_store, method)(*args)
self.assertEqual(master_result, slaved_result)
if expected_result is not None: if expected_result is not None:
self.assertEqual(master_result, expected_result) self.assertEqual(master_result, expected_result)
self.assertEqual(slaved_result, expected_result) self.assertEqual(slaved_result, expected_result)
self.assertEqual(master_result, slaved_result)

View File

@ -205,13 +205,59 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
[join3] [join3]
) )
@defer.inlineCallbacks
def test_redactions(self):
yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.persist(type="m.room.member", key=USER_ID, membership="join")
msg = yield self.persist(
type="m.room.message", msgtype="m.text", body="Hello"
)
yield self.replicate()
yield self.check("get_event", [msg.event_id], msg)
redaction = yield self.persist(
type="m.room.redaction", redacts=msg.event_id
)
yield self.replicate()
msg_dict = msg.get_dict()
msg_dict["content"] = {}
msg_dict["unsigned"]["redacted_by"] = redaction.event_id
msg_dict["unsigned"]["redacted_because"] = redaction
redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
yield self.check("get_event", [msg.event_id], redacted)
@defer.inlineCallbacks
def test_backfilled_redactions(self):
yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.persist(type="m.room.member", key=USER_ID, membership="join")
msg = yield self.persist(
type="m.room.message", msgtype="m.text", body="Hello"
)
yield self.replicate()
yield self.check("get_event", [msg.event_id], msg)
redaction = yield self.persist(
type="m.room.redaction", redacts=msg.event_id, backfill=True
)
yield self.replicate()
msg_dict = msg.get_dict()
msg_dict["content"] = {}
msg_dict["unsigned"]["redacted_by"] = redaction.event_id
msg_dict["unsigned"]["redacted_because"] = redaction
redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
yield self.check("get_event", [msg.event_id], redacted)
event_id = 0 event_id = 0
@defer.inlineCallbacks @defer.inlineCallbacks
def persist( def persist(
self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, internal={}, self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, internal={},
state=None, reset_state=False, backfill=False, state=None, reset_state=False, backfill=False,
depth=None, prev_events=[], auth_events=[], prev_state=[], depth=None, prev_events=[], auth_events=[], prev_state=[], redacts=None,
**content **content
): ):
""" """
@ -236,6 +282,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
event_dict["state_key"] = key event_dict["state_key"] = key
event_dict["prev_state"] = prev_state event_dict["prev_state"] = prev_state
if redacts is not None:
event_dict["redacts"] = redacts
event = FrozenEvent(event_dict, internal_metadata_dict=internal) event = FrozenEvent(event_dict, internal_metadata_dict=internal)
self.event_id += 1 self.event_id += 1