Convert storage test cases to HomeserverTestCase. (#9736)

This commit is contained in:
Patrick Cloke 2021-04-06 07:21:02 -04:00 committed by GitHub
parent e2b8a90897
commit e7b769aea1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 265 additions and 499 deletions

1
changelog.d/9736.misc Normal file
View File

@ -0,0 +1 @@
Convert various testcases to `HomeserverTestCase`.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd # Copyright 2016-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,32 +13,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
import synapse.api.errors import synapse.api.errors
import tests.unittest from tests.unittest import HomeserverTestCase
import tests.utils
class DeviceStoreTestCase(tests.unittest.TestCase): class DeviceStoreTestCase(HomeserverTestCase):
def __init__(self, *args, **kwargs): def prepare(self, reactor, clock, hs):
super().__init__(*args, **kwargs)
self.store = None # type: synapse.storage.DataStore
@defer.inlineCallbacks
def setUp(self):
hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks
def test_store_new_device(self): def test_store_new_device(self):
yield defer.ensureDeferred( self.get_success(
self.store.store_device("user_id", "device_id", "display_name") self.store.store_device("user_id", "device_id", "display_name")
) )
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertDictContainsSubset( self.assertDictContainsSubset(
{ {
"user_id": "user_id", "user_id": "user_id",
@ -48,19 +37,18 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
res, res,
) )
@defer.inlineCallbacks
def test_get_devices_by_user(self): def test_get_devices_by_user(self):
yield defer.ensureDeferred( self.get_success(
self.store.store_device("user_id", "device1", "display_name 1") self.store.store_device("user_id", "device1", "display_name 1")
) )
yield defer.ensureDeferred( self.get_success(
self.store.store_device("user_id", "device2", "display_name 2") self.store.store_device("user_id", "device2", "display_name 2")
) )
yield defer.ensureDeferred( self.get_success(
self.store.store_device("user_id2", "device3", "display_name 3") self.store.store_device("user_id2", "device3", "display_name 3")
) )
res = yield defer.ensureDeferred(self.store.get_devices_by_user("user_id")) res = self.get_success(self.store.get_devices_by_user("user_id"))
self.assertEqual(2, len(res.keys())) self.assertEqual(2, len(res.keys()))
self.assertDictContainsSubset( self.assertDictContainsSubset(
{ {
@ -79,43 +67,41 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
res["device2"], res["device2"],
) )
@defer.inlineCallbacks
def test_count_devices_by_users(self): def test_count_devices_by_users(self):
yield defer.ensureDeferred( self.get_success(
self.store.store_device("user_id", "device1", "display_name 1") self.store.store_device("user_id", "device1", "display_name 1")
) )
yield defer.ensureDeferred( self.get_success(
self.store.store_device("user_id", "device2", "display_name 2") self.store.store_device("user_id", "device2", "display_name 2")
) )
yield defer.ensureDeferred( self.get_success(
self.store.store_device("user_id2", "device3", "display_name 3") self.store.store_device("user_id2", "device3", "display_name 3")
) )
res = yield defer.ensureDeferred(self.store.count_devices_by_users()) res = self.get_success(self.store.count_devices_by_users())
self.assertEqual(0, res) self.assertEqual(0, res)
res = yield defer.ensureDeferred(self.store.count_devices_by_users(["unknown"])) res = self.get_success(self.store.count_devices_by_users(["unknown"]))
self.assertEqual(0, res) self.assertEqual(0, res)
res = yield defer.ensureDeferred(self.store.count_devices_by_users(["user_id"])) res = self.get_success(self.store.count_devices_by_users(["user_id"]))
self.assertEqual(2, res) self.assertEqual(2, res)
res = yield defer.ensureDeferred( res = self.get_success(
self.store.count_devices_by_users(["user_id", "user_id2"]) self.store.count_devices_by_users(["user_id", "user_id2"])
) )
self.assertEqual(3, res) self.assertEqual(3, res)
@defer.inlineCallbacks
def test_get_device_updates_by_remote(self): def test_get_device_updates_by_remote(self):
device_ids = ["device_id1", "device_id2"] device_ids = ["device_id1", "device_id2"]
# Add two device updates with a single stream_id # Add two device updates with a single stream_id
yield defer.ensureDeferred( self.get_success(
self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
) )
# Get all device updates ever meant for this remote # Get all device updates ever meant for this remote
now_stream_id, device_updates = yield defer.ensureDeferred( now_stream_id, device_updates = self.get_success(
self.store.get_device_updates_by_remote("somehost", -1, limit=100) self.store.get_device_updates_by_remote("somehost", -1, limit=100)
) )
@ -131,37 +117,35 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
} }
self.assertEqual(received_device_ids, set(expected_device_ids)) self.assertEqual(received_device_ids, set(expected_device_ids))
@defer.inlineCallbacks
def test_update_device(self): def test_update_device(self):
yield defer.ensureDeferred( self.get_success(
self.store.store_device("user_id", "device_id", "display_name 1") self.store.store_device("user_id", "device_id", "display_name 1")
) )
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"]) self.assertEqual("display_name 1", res["display_name"])
# do a no-op first # do a no-op first
yield defer.ensureDeferred(self.store.update_device("user_id", "device_id")) self.get_success(self.store.update_device("user_id", "device_id"))
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"]) self.assertEqual("display_name 1", res["display_name"])
# do the update # do the update
yield defer.ensureDeferred( self.get_success(
self.store.update_device( self.store.update_device(
"user_id", "device_id", new_display_name="display_name 2" "user_id", "device_id", new_display_name="display_name 2"
) )
) )
# check it worked # check it worked
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 2", res["display_name"]) self.assertEqual("display_name 2", res["display_name"])
@defer.inlineCallbacks
def test_update_unknown_device(self): def test_update_unknown_device(self):
with self.assertRaises(synapse.api.errors.StoreError) as cm: exc = self.get_failure(
yield defer.ensureDeferred(
self.store.update_device( self.store.update_device(
"user_id", "unknown_device_id", new_display_name="display_name 2" "user_id", "unknown_device_id", new_display_name="display_name 2"
),
synapse.api.errors.StoreError,
) )
) self.assertEqual(404, exc.value.code)
self.assertEqual(404, cm.exception.code)

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,28 +13,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.types import RoomAlias, RoomID from synapse.types import RoomAlias, RoomID
from tests import unittest from tests.unittest import HomeserverTestCase
from tests.utils import setup_test_homeserver
class DirectoryStoreTestCase(unittest.TestCase): class DirectoryStoreTestCase(HomeserverTestCase):
@defer.inlineCallbacks def prepare(self, reactor, clock, hs):
def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.room = RoomID.from_string("!abcde:test") self.room = RoomID.from_string("!abcde:test")
self.alias = RoomAlias.from_string("#my-room:test") self.alias = RoomAlias.from_string("#my-room:test")
@defer.inlineCallbacks
def test_room_to_alias(self): def test_room_to_alias(self):
yield defer.ensureDeferred( self.get_success(
self.store.create_room_alias_association( self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
) )
@ -42,16 +34,11 @@ class DirectoryStoreTestCase(unittest.TestCase):
self.assertEquals( self.assertEquals(
["#my-room:test"], ["#my-room:test"],
( (self.get_success(self.store.get_aliases_for_room(self.room.to_string()))),
yield defer.ensureDeferred(
self.store.get_aliases_for_room(self.room.to_string())
)
),
) )
@defer.inlineCallbacks
def test_alias_to_room(self): def test_alias_to_room(self):
yield defer.ensureDeferred( self.get_success(
self.store.create_room_alias_association( self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
) )
@ -59,28 +46,19 @@ class DirectoryStoreTestCase(unittest.TestCase):
self.assertObjectHasAttributes( self.assertObjectHasAttributes(
{"room_id": self.room.to_string(), "servers": ["test"]}, {"room_id": self.room.to_string(), "servers": ["test"]},
( (self.get_success(self.store.get_association_from_room_alias(self.alias))),
yield defer.ensureDeferred(
self.store.get_association_from_room_alias(self.alias)
)
),
) )
@defer.inlineCallbacks
def test_delete_alias(self): def test_delete_alias(self):
yield defer.ensureDeferred( self.get_success(
self.store.create_room_alias_association( self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
) )
) )
room_id = yield defer.ensureDeferred(self.store.delete_room_alias(self.alias)) room_id = self.get_success(self.store.delete_room_alias(self.alias))
self.assertEqual(self.room.to_string(), room_id) self.assertEqual(self.room.to_string(), room_id)
self.assertIsNone( self.assertIsNone(
( (self.get_success(self.store.get_association_from_room_alias(self.alias)))
yield defer.ensureDeferred(
self.store.get_association_from_room_alias(self.alias)
)
)
) )

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd # Copyright 2016-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,30 +13,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer from tests.unittest import HomeserverTestCase
import tests.unittest
import tests.utils
class EndToEndKeyStoreTestCase(tests.unittest.TestCase): class EndToEndKeyStoreTestCase(HomeserverTestCase):
@defer.inlineCallbacks def prepare(self, reactor, clock, hs):
def setUp(self):
hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks
def test_key_without_device_name(self): def test_key_without_device_name(self):
now = 1470174257070 now = 1470174257070
json = {"key": "value"} json = {"key": "value"}
yield defer.ensureDeferred(self.store.store_device("user", "device", None)) self.get_success(self.store.store_device("user", "device", None))
yield defer.ensureDeferred( self.get_success(self.store.set_e2e_device_keys("user", "device", now, json))
self.store.set_e2e_device_keys("user", "device", now, json)
)
res = yield defer.ensureDeferred( res = self.get_success(
self.store.get_e2e_device_keys_for_cs_api((("user", "device"),)) self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
) )
self.assertIn("user", res) self.assertIn("user", res)
@ -44,38 +36,32 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
dev = res["user"]["device"] dev = res["user"]["device"]
self.assertDictContainsSubset(json, dev) self.assertDictContainsSubset(json, dev)
@defer.inlineCallbacks
def test_reupload_key(self): def test_reupload_key(self):
now = 1470174257070 now = 1470174257070
json = {"key": "value"} json = {"key": "value"}
yield defer.ensureDeferred(self.store.store_device("user", "device", None)) self.get_success(self.store.store_device("user", "device", None))
changed = yield defer.ensureDeferred( changed = self.get_success(
self.store.set_e2e_device_keys("user", "device", now, json) self.store.set_e2e_device_keys("user", "device", now, json)
) )
self.assertTrue(changed) self.assertTrue(changed)
# If we try to upload the same key then we should be told nothing # If we try to upload the same key then we should be told nothing
# changed # changed
changed = yield defer.ensureDeferred( changed = self.get_success(
self.store.set_e2e_device_keys("user", "device", now, json) self.store.set_e2e_device_keys("user", "device", now, json)
) )
self.assertFalse(changed) self.assertFalse(changed)
@defer.inlineCallbacks
def test_get_key_with_device_name(self): def test_get_key_with_device_name(self):
now = 1470174257070 now = 1470174257070
json = {"key": "value"} json = {"key": "value"}
yield defer.ensureDeferred( self.get_success(self.store.set_e2e_device_keys("user", "device", now, json))
self.store.set_e2e_device_keys("user", "device", now, json) self.get_success(self.store.store_device("user", "device", "display_name"))
)
yield defer.ensureDeferred(
self.store.store_device("user", "device", "display_name")
)
res = yield defer.ensureDeferred( res = self.get_success(
self.store.get_e2e_device_keys_for_cs_api((("user", "device"),)) self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
) )
self.assertIn("user", res) self.assertIn("user", res)
@ -85,29 +71,28 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
{"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev {"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev
) )
@defer.inlineCallbacks
def test_multiple_devices(self): def test_multiple_devices(self):
now = 1470174257070 now = 1470174257070
yield defer.ensureDeferred(self.store.store_device("user1", "device1", None)) self.get_success(self.store.store_device("user1", "device1", None))
yield defer.ensureDeferred(self.store.store_device("user1", "device2", None)) self.get_success(self.store.store_device("user1", "device2", None))
yield defer.ensureDeferred(self.store.store_device("user2", "device1", None)) self.get_success(self.store.store_device("user2", "device1", None))
yield defer.ensureDeferred(self.store.store_device("user2", "device2", None)) self.get_success(self.store.store_device("user2", "device2", None))
yield defer.ensureDeferred( self.get_success(
self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"}) self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
) )
yield defer.ensureDeferred( self.get_success(
self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"}) self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
) )
yield defer.ensureDeferred( self.get_success(
self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"}) self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
) )
yield defer.ensureDeferred( self.get_success(
self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"}) self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
) )
res = yield defer.ensureDeferred( res = self.get_success(
self.store.get_e2e_device_keys_for_cs_api( self.store.get_e2e_device_keys_for_cs_api(
(("user1", "device1"), ("user2", "device2")) (("user1", "device1"), ("user2", "device2"))
) )

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd # Copyright 2016-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -15,10 +15,7 @@
from mock import Mock from mock import Mock
from twisted.internet import defer from tests.unittest import HomeserverTestCase
import tests.unittest
import tests.utils
USER_ID = "@user:example.com" USER_ID = "@user:example.com"
@ -30,37 +27,31 @@ HIGHLIGHT = [
] ]
class EventPushActionsStoreTestCase(tests.unittest.TestCase): class EventPushActionsStoreTestCase(HomeserverTestCase):
@defer.inlineCallbacks def prepare(self, reactor, clock, hs):
def setUp(self):
hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.persist_events_store = hs.get_datastores().persist_events self.persist_events_store = hs.get_datastores().persist_events
@defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_http(self): def test_get_unread_push_actions_for_user_in_range_for_http(self):
yield defer.ensureDeferred( self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_http( self.store.get_unread_push_actions_for_user_in_range_for_http(
USER_ID, 0, 1000, 20 USER_ID, 0, 1000, 20
) )
) )
@defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_email(self): def test_get_unread_push_actions_for_user_in_range_for_email(self):
yield defer.ensureDeferred( self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_email( self.store.get_unread_push_actions_for_user_in_range_for_email(
USER_ID, 0, 1000, 20 USER_ID, 0, 1000, 20
) )
) )
@defer.inlineCallbacks
def test_count_aggregation(self): def test_count_aggregation(self):
room_id = "!foo:example.com" room_id = "!foo:example.com"
user_id = "@user1235:example.com" user_id = "@user1235:example.com"
@defer.inlineCallbacks
def _assert_counts(noitf_count, highlight_count): def _assert_counts(noitf_count, highlight_count):
counts = yield defer.ensureDeferred( counts = self.get_success(
self.store.db_pool.runInteraction( self.store.db_pool.runInteraction(
"", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
) )
@ -74,7 +65,6 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
}, },
) )
@defer.inlineCallbacks
def _inject_actions(stream, action): def _inject_actions(stream, action):
event = Mock() event = Mock()
event.room_id = room_id event.room_id = room_id
@ -82,14 +72,14 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
event.internal_metadata.stream_ordering = stream event.internal_metadata.stream_ordering = stream
event.depth = stream event.depth = stream
yield defer.ensureDeferred( self.get_success(
self.store.add_push_actions_to_staging( self.store.add_push_actions_to_staging(
event.event_id, event.event_id,
{user_id: action}, {user_id: action},
False, False,
) )
) )
yield defer.ensureDeferred( self.get_success(
self.store.db_pool.runInteraction( self.store.db_pool.runInteraction(
"", "",
self.persist_events_store._set_push_actions_for_event_and_users_txn, self.persist_events_store._set_push_actions_for_event_and_users_txn,
@ -99,14 +89,14 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
) )
def _rotate(stream): def _rotate(stream):
return defer.ensureDeferred( self.get_success(
self.store.db_pool.runInteraction( self.store.db_pool.runInteraction(
"", self.store._rotate_notifs_before_txn, stream "", self.store._rotate_notifs_before_txn, stream
) )
) )
def _mark_read(stream, depth): def _mark_read(stream, depth):
return defer.ensureDeferred( self.get_success(
self.store.db_pool.runInteraction( self.store.db_pool.runInteraction(
"", "",
self.store._remove_old_push_actions_before_txn, self.store._remove_old_push_actions_before_txn,
@ -116,49 +106,48 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
) )
) )
yield _assert_counts(0, 0) _assert_counts(0, 0)
yield _inject_actions(1, PlAIN_NOTIF) _inject_actions(1, PlAIN_NOTIF)
yield _assert_counts(1, 0) _assert_counts(1, 0)
yield _rotate(2) _rotate(2)
yield _assert_counts(1, 0) _assert_counts(1, 0)
yield _inject_actions(3, PlAIN_NOTIF) _inject_actions(3, PlAIN_NOTIF)
yield _assert_counts(2, 0) _assert_counts(2, 0)
yield _rotate(4) _rotate(4)
yield _assert_counts(2, 0) _assert_counts(2, 0)
yield _inject_actions(5, PlAIN_NOTIF) _inject_actions(5, PlAIN_NOTIF)
yield _mark_read(3, 3) _mark_read(3, 3)
yield _assert_counts(1, 0) _assert_counts(1, 0)
yield _mark_read(5, 5) _mark_read(5, 5)
yield _assert_counts(0, 0) _assert_counts(0, 0)
yield _inject_actions(6, PlAIN_NOTIF) _inject_actions(6, PlAIN_NOTIF)
yield _rotate(7) _rotate(7)
yield defer.ensureDeferred( self.get_success(
self.store.db_pool.simple_delete( self.store.db_pool.simple_delete(
table="event_push_actions", keyvalues={"1": 1}, desc="" table="event_push_actions", keyvalues={"1": 1}, desc=""
) )
) )
yield _assert_counts(1, 0) _assert_counts(1, 0)
yield _mark_read(7, 7) _mark_read(7, 7)
yield _assert_counts(0, 0) _assert_counts(0, 0)
yield _inject_actions(8, HIGHLIGHT) _inject_actions(8, HIGHLIGHT)
yield _assert_counts(1, 1) _assert_counts(1, 1)
yield _rotate(9) _rotate(9)
yield _assert_counts(1, 1) _assert_counts(1, 1)
yield _rotate(10) _rotate(10)
yield _assert_counts(1, 1) _assert_counts(1, 1)
@defer.inlineCallbacks
def test_find_first_stream_ordering_after_ts(self): def test_find_first_stream_ordering_after_ts(self):
def add_event(so, ts): def add_event(so, ts):
return defer.ensureDeferred( self.get_success(
self.store.db_pool.simple_insert( self.store.db_pool.simple_insert(
"events", "events",
{ {
@ -177,24 +166,16 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
) )
# start with the base case where there are no events in the table # start with the base case where there are no events in the table
r = yield defer.ensureDeferred( r = self.get_success(self.store.find_first_stream_ordering_after_ts(11))
self.store.find_first_stream_ordering_after_ts(11)
)
self.assertEqual(r, 0) self.assertEqual(r, 0)
# now with one event # now with one event
yield add_event(2, 10) add_event(2, 10)
r = yield defer.ensureDeferred( r = self.get_success(self.store.find_first_stream_ordering_after_ts(9))
self.store.find_first_stream_ordering_after_ts(9)
)
self.assertEqual(r, 2) self.assertEqual(r, 2)
r = yield defer.ensureDeferred( r = self.get_success(self.store.find_first_stream_ordering_after_ts(10))
self.store.find_first_stream_ordering_after_ts(10)
)
self.assertEqual(r, 2) self.assertEqual(r, 2)
r = yield defer.ensureDeferred( r = self.get_success(self.store.find_first_stream_ordering_after_ts(11))
self.store.find_first_stream_ordering_after_ts(11)
)
self.assertEqual(r, 3) self.assertEqual(r, 3)
# add a bunch of dummy events to the events table # add a bunch of dummy events to the events table
@ -205,39 +186,27 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
(10, 130), (10, 130),
(20, 140), (20, 140),
): ):
yield add_event(stream_ordering, ts) add_event(stream_ordering, ts)
r = yield defer.ensureDeferred( r = self.get_success(self.store.find_first_stream_ordering_after_ts(110))
self.store.find_first_stream_ordering_after_ts(110)
)
self.assertEqual(r, 3, "First event after 110ms should be 3, was %i" % r) self.assertEqual(r, 3, "First event after 110ms should be 3, was %i" % r)
# 4 and 5 are both after 120: we want 4 rather than 5 # 4 and 5 are both after 120: we want 4 rather than 5
r = yield defer.ensureDeferred( r = self.get_success(self.store.find_first_stream_ordering_after_ts(120))
self.store.find_first_stream_ordering_after_ts(120)
)
self.assertEqual(r, 4, "First event after 120ms should be 4, was %i" % r) self.assertEqual(r, 4, "First event after 120ms should be 4, was %i" % r)
r = yield defer.ensureDeferred( r = self.get_success(self.store.find_first_stream_ordering_after_ts(129))
self.store.find_first_stream_ordering_after_ts(129)
)
self.assertEqual(r, 10, "First event after 129ms should be 10, was %i" % r) self.assertEqual(r, 10, "First event after 129ms should be 10, was %i" % r)
# check we can get the last event # check we can get the last event
r = yield defer.ensureDeferred( r = self.get_success(self.store.find_first_stream_ordering_after_ts(140))
self.store.find_first_stream_ordering_after_ts(140)
)
self.assertEqual(r, 20, "First event after 14ms should be 20, was %i" % r) self.assertEqual(r, 20, "First event after 14ms should be 20, was %i" % r)
# off the end # off the end
r = yield defer.ensureDeferred( r = self.get_success(self.store.find_first_stream_ordering_after_ts(160))
self.store.find_first_stream_ordering_after_ts(160)
)
self.assertEqual(r, 21) self.assertEqual(r, 21)
# check we can find an event at ordering zero # check we can find an event at ordering zero
yield add_event(0, 5) add_event(0, 5)
r = yield defer.ensureDeferred( r = self.get_success(self.store.find_first_stream_ordering_after_ts(1))
self.store.find_first_stream_ordering_after_ts(1)
)
self.assertEqual(r, 0) self.assertEqual(r, 0)

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,59 +13,50 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.types import UserID from synapse.types import UserID
from tests import unittest from tests import unittest
from tests.utils import setup_test_homeserver
class ProfileStoreTestCase(unittest.TestCase): class ProfileStoreTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks def prepare(self, reactor, clock, hs):
def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.u_frank = UserID.from_string("@frank:test") self.u_frank = UserID.from_string("@frank:test")
@defer.inlineCallbacks
def test_displayname(self): def test_displayname(self):
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) self.get_success(self.store.create_profile(self.u_frank.localpart))
yield defer.ensureDeferred( self.get_success(
self.store.set_profile_displayname(self.u_frank.localpart, "Frank") self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
) )
self.assertEquals( self.assertEquals(
"Frank", "Frank",
( (
yield defer.ensureDeferred( self.get_success(
self.store.get_profile_displayname(self.u_frank.localpart) self.store.get_profile_displayname(self.u_frank.localpart)
) )
), ),
) )
# test set to None # test set to None
yield defer.ensureDeferred( self.get_success(
self.store.set_profile_displayname(self.u_frank.localpart, None) self.store.set_profile_displayname(self.u_frank.localpart, None)
) )
self.assertIsNone( self.assertIsNone(
( (
yield defer.ensureDeferred( self.get_success(
self.store.get_profile_displayname(self.u_frank.localpart) self.store.get_profile_displayname(self.u_frank.localpart)
) )
) )
) )
@defer.inlineCallbacks
def test_avatar_url(self): def test_avatar_url(self):
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) self.get_success(self.store.create_profile(self.u_frank.localpart))
yield defer.ensureDeferred( self.get_success(
self.store.set_profile_avatar_url( self.store.set_profile_avatar_url(
self.u_frank.localpart, "http://my.site/here" self.u_frank.localpart, "http://my.site/here"
) )
@ -74,20 +65,20 @@ class ProfileStoreTestCase(unittest.TestCase):
self.assertEquals( self.assertEquals(
"http://my.site/here", "http://my.site/here",
( (
yield defer.ensureDeferred( self.get_success(
self.store.get_profile_avatar_url(self.u_frank.localpart) self.store.get_profile_avatar_url(self.u_frank.localpart)
) )
), ),
) )
# test set to None # test set to None
yield defer.ensureDeferred( self.get_success(
self.store.set_profile_avatar_url(self.u_frank.localpart, None) self.store.set_profile_avatar_url(self.u_frank.localpart, None)
) )
self.assertIsNone( self.assertIsNone(
( (
yield defer.ensureDeferred( self.get_success(
self.store.get_profile_avatar_url(self.u_frank.localpart) self.store.get_profile_avatar_url(self.u_frank.localpart)
) )
) )

View File

@ -1,6 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2021 The Matrix.org Foundation C.I.C.
# Copyright 2019 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -16,8 +15,6 @@
from canonicaljson import json from canonicaljson import json
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.types import RoomID, UserID from synapse.types import RoomID, UserID
@ -230,10 +227,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self._base_builder = base_builder self._base_builder = base_builder
self._event_id = event_id self._event_id = event_id
@defer.inlineCallbacks async def build(self, prev_event_ids, auth_event_ids):
def build(self, prev_event_ids, auth_event_ids): built_event = await self._base_builder.build(
built_event = yield defer.ensureDeferred( prev_event_ids, auth_event_ids
self._base_builder.build(prev_event_ids, auth_event_ids)
) )
built_event._event_id = self._event_id built_event._event_id = self._event_id

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,21 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import ThreepidValidationError from synapse.api.errors import ThreepidValidationError
from tests import unittest from tests.unittest import HomeserverTestCase
from tests.utils import setup_test_homeserver
class RegistrationStoreTestCase(unittest.TestCase): class RegistrationStoreTestCase(HomeserverTestCase):
@defer.inlineCallbacks def prepare(self, reactor, clock, hs):
def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.user_id = "@my-user:test" self.user_id = "@my-user:test"
@ -35,9 +28,8 @@ class RegistrationStoreTestCase(unittest.TestCase):
self.pwhash = "{xx1}123456789" self.pwhash = "{xx1}123456789"
self.device_id = "akgjhdjklgshg" self.device_id = "akgjhdjklgshg"
@defer.inlineCallbacks
def test_register(self): def test_register(self):
yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash)) self.get_success(self.store.register_user(self.user_id, self.pwhash))
self.assertEquals( self.assertEquals(
{ {
@ -49,93 +41,81 @@ class RegistrationStoreTestCase(unittest.TestCase):
"consent_version": None, "consent_version": None,
"consent_server_notice_sent": None, "consent_server_notice_sent": None,
"appservice_id": None, "appservice_id": None,
"creation_ts": 1000, "creation_ts": 0,
"user_type": None, "user_type": None,
"deactivated": 0, "deactivated": 0,
"shadow_banned": 0, "shadow_banned": 0,
}, },
(yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))), (self.get_success(self.store.get_user_by_id(self.user_id))),
) )
@defer.inlineCallbacks
def test_add_tokens(self): def test_add_tokens(self):
yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash)) self.get_success(self.store.register_user(self.user_id, self.pwhash))
yield defer.ensureDeferred( self.get_success(
self.store.add_access_token_to_user( self.store.add_access_token_to_user(
self.user_id, self.tokens[1], self.device_id, valid_until_ms=None self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
) )
) )
result = yield defer.ensureDeferred( result = self.get_success(self.store.get_user_by_access_token(self.tokens[1]))
self.store.get_user_by_access_token(self.tokens[1])
)
self.assertEqual(result.user_id, self.user_id) self.assertEqual(result.user_id, self.user_id)
self.assertEqual(result.device_id, self.device_id) self.assertEqual(result.device_id, self.device_id)
self.assertIsNotNone(result.token_id) self.assertIsNotNone(result.token_id)
@defer.inlineCallbacks
def test_user_delete_access_tokens(self): def test_user_delete_access_tokens(self):
# add some tokens # add some tokens
yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash)) self.get_success(self.store.register_user(self.user_id, self.pwhash))
yield defer.ensureDeferred( self.get_success(
self.store.add_access_token_to_user( self.store.add_access_token_to_user(
self.user_id, self.tokens[0], device_id=None, valid_until_ms=None self.user_id, self.tokens[0], device_id=None, valid_until_ms=None
) )
) )
yield defer.ensureDeferred( self.get_success(
self.store.add_access_token_to_user( self.store.add_access_token_to_user(
self.user_id, self.tokens[1], self.device_id, valid_until_ms=None self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
) )
) )
# now delete some # now delete some
yield defer.ensureDeferred( self.get_success(
self.store.user_delete_access_tokens(self.user_id, device_id=self.device_id) self.store.user_delete_access_tokens(self.user_id, device_id=self.device_id)
) )
# check they were deleted # check they were deleted
user = yield defer.ensureDeferred( user = self.get_success(self.store.get_user_by_access_token(self.tokens[1]))
self.store.get_user_by_access_token(self.tokens[1])
)
self.assertIsNone(user, "access token was not deleted by device_id") self.assertIsNone(user, "access token was not deleted by device_id")
# check the one not associated with the device was not deleted # check the one not associated with the device was not deleted
user = yield defer.ensureDeferred( user = self.get_success(self.store.get_user_by_access_token(self.tokens[0]))
self.store.get_user_by_access_token(self.tokens[0])
)
self.assertEqual(self.user_id, user.user_id) self.assertEqual(self.user_id, user.user_id)
# now delete the rest # now delete the rest
yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id)) self.get_success(self.store.user_delete_access_tokens(self.user_id))
user = yield defer.ensureDeferred( user = self.get_success(self.store.get_user_by_access_token(self.tokens[0]))
self.store.get_user_by_access_token(self.tokens[0])
)
self.assertIsNone(user, "access token was not deleted without device_id") self.assertIsNone(user, "access token was not deleted without device_id")
@defer.inlineCallbacks
def test_is_support_user(self): def test_is_support_user(self):
TEST_USER = "@test:test" TEST_USER = "@test:test"
SUPPORT_USER = "@support:test" SUPPORT_USER = "@support:test"
res = yield defer.ensureDeferred(self.store.is_support_user(None)) res = self.get_success(self.store.is_support_user(None))
self.assertFalse(res) self.assertFalse(res)
yield defer.ensureDeferred( self.get_success(
self.store.register_user(user_id=TEST_USER, password_hash=None) self.store.register_user(user_id=TEST_USER, password_hash=None)
) )
res = yield defer.ensureDeferred(self.store.is_support_user(TEST_USER)) res = self.get_success(self.store.is_support_user(TEST_USER))
self.assertFalse(res) self.assertFalse(res)
yield defer.ensureDeferred( self.get_success(
self.store.register_user( self.store.register_user(
user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT
) )
) )
res = yield defer.ensureDeferred(self.store.is_support_user(SUPPORT_USER)) res = self.get_success(self.store.is_support_user(SUPPORT_USER))
self.assertTrue(res) self.assertTrue(res)
@defer.inlineCallbacks
def test_3pid_inhibit_invalid_validation_session_error(self): def test_3pid_inhibit_invalid_validation_session_error(self):
"""Tests that enabling the configuration option to inhibit 3PID errors on """Tests that enabling the configuration option to inhibit 3PID errors on
/requestToken also inhibits validation errors caused by an unknown session ID. /requestToken also inhibits validation errors caused by an unknown session ID.
@ -143,30 +123,28 @@ class RegistrationStoreTestCase(unittest.TestCase):
# Check that, with the config setting set to false (the default value), a # Check that, with the config setting set to false (the default value), a
# validation error is caused by the unknown session ID. # validation error is caused by the unknown session ID.
try: e = self.get_failure(
yield defer.ensureDeferred(
self.store.validate_threepid_session( self.store.validate_threepid_session(
"fake_sid", "fake_sid",
"fake_client_secret", "fake_client_secret",
"fake_token", "fake_token",
0, 0,
),
ThreepidValidationError,
) )
) self.assertEquals(e.value.msg, "Unknown session_id", e)
except ThreepidValidationError as e:
self.assertEquals(e.msg, "Unknown session_id", e)
# Set the config setting to true. # Set the config setting to true.
self.store._ignore_unknown_session_error = True self.store._ignore_unknown_session_error = True
# Check that now the validation error is caused by the token not matching. # Check that now the validation error is caused by the token not matching.
try: e = self.get_failure(
yield defer.ensureDeferred(
self.store.validate_threepid_session( self.store.validate_threepid_session(
"fake_sid", "fake_sid",
"fake_client_secret", "fake_client_secret",
"fake_token", "fake_token",
0, 0,
),
ThreepidValidationError,
) )
) self.assertEquals(e.value.msg, "Validation token not found or has expired", e)
except ThreepidValidationError as e:
self.assertEquals(e.msg, "Validation token not found or has expired", e)

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,22 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.types import RoomAlias, RoomID, UserID from synapse.types import RoomAlias, RoomID, UserID
from tests import unittest from tests.unittest import HomeserverTestCase
from tests.utils import setup_test_homeserver
class RoomStoreTestCase(unittest.TestCase): class RoomStoreTestCase(HomeserverTestCase):
@defer.inlineCallbacks def prepare(self, reactor, clock, hs):
def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup)
# We can't test RoomStore on its own without the DirectoryStore, for # We can't test RoomStore on its own without the DirectoryStore, for
# management of the 'room_aliases' table # management of the 'room_aliases' table
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -37,7 +30,7 @@ class RoomStoreTestCase(unittest.TestCase):
self.alias = RoomAlias.from_string("#a-room-name:test") self.alias = RoomAlias.from_string("#a-room-name:test")
self.u_creator = UserID.from_string("@creator:test") self.u_creator = UserID.from_string("@creator:test")
yield defer.ensureDeferred( self.get_success(
self.store.store_room( self.store.store_room(
self.room.to_string(), self.room.to_string(),
room_creator_user_id=self.u_creator.to_string(), room_creator_user_id=self.u_creator.to_string(),
@ -46,7 +39,6 @@ class RoomStoreTestCase(unittest.TestCase):
) )
) )
@defer.inlineCallbacks
def test_get_room(self): def test_get_room(self):
self.assertDictContainsSubset( self.assertDictContainsSubset(
{ {
@ -54,16 +46,12 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(), "creator": self.u_creator.to_string(),
"is_public": True, "is_public": True,
}, },
(yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))), (self.get_success(self.store.get_room(self.room.to_string()))),
) )
@defer.inlineCallbacks
def test_get_room_unknown_room(self): def test_get_room_unknown_room(self):
self.assertIsNone( self.assertIsNone((self.get_success(self.store.get_room("!uknown:test"))))
(yield defer.ensureDeferred(self.store.get_room("!uknown:test")))
)
@defer.inlineCallbacks
def test_get_room_with_stats(self): def test_get_room_with_stats(self):
self.assertDictContainsSubset( self.assertDictContainsSubset(
{ {
@ -71,29 +59,17 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(), "creator": self.u_creator.to_string(),
"public": True, "public": True,
}, },
( (self.get_success(self.store.get_room_with_stats(self.room.to_string()))),
yield defer.ensureDeferred(
self.store.get_room_with_stats(self.room.to_string())
)
),
) )
@defer.inlineCallbacks
def test_get_room_with_stats_unknown_room(self): def test_get_room_with_stats_unknown_room(self):
self.assertIsNone( self.assertIsNone(
( (self.get_success(self.store.get_room_with_stats("!uknown:test"))),
yield defer.ensureDeferred(
self.store.get_room_with_stats("!uknown:test")
)
),
) )
class RoomEventsStoreTestCase(unittest.TestCase): class RoomEventsStoreTestCase(HomeserverTestCase):
@defer.inlineCallbacks def prepare(self, reactor, clock, hs):
def setUp(self):
hs = setup_test_homeserver(self.addCleanup)
# Room events need the full datastore, for persist_event() and # Room events need the full datastore, for persist_event() and
# get_room_state() # get_room_state()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -102,7 +78,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
self.room = RoomID.from_string("!abcde:test") self.room = RoomID.from_string("!abcde:test")
yield defer.ensureDeferred( self.get_success(
self.store.store_room( self.store.store_room(
self.room.to_string(), self.room.to_string(),
room_creator_user_id="@creator:text", room_creator_user_id="@creator:text",
@ -111,23 +87,21 @@ class RoomEventsStoreTestCase(unittest.TestCase):
) )
) )
@defer.inlineCallbacks
def inject_room_event(self, **kwargs): def inject_room_event(self, **kwargs):
yield defer.ensureDeferred( self.get_success(
self.storage.persistence.persist_event( self.storage.persistence.persist_event(
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
) )
) )
@defer.inlineCallbacks
def STALE_test_room_name(self): def STALE_test_room_name(self):
name = "A-Room-Name" name = "A-Room-Name"
yield self.inject_room_event( self.inject_room_event(
etype=EventTypes.Name, name=name, content={"name": name}, depth=1 etype=EventTypes.Name, name=name, content={"name": name}, depth=1
) )
state = yield defer.ensureDeferred( state = self.get_success(
self.store.get_current_state(room_id=self.room.to_string()) self.store.get_current_state(room_id=self.room.to_string())
) )
@ -137,15 +111,14 @@ class RoomEventsStoreTestCase(unittest.TestCase):
state[0], state[0],
) )
@defer.inlineCallbacks
def STALE_test_room_topic(self): def STALE_test_room_topic(self):
topic = "A place for things" topic = "A place for things"
yield self.inject_room_event( self.inject_room_event(
etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1 etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1
) )
state = yield defer.ensureDeferred( state = self.get_success(
self.store.get_current_state(room_id=self.room.to_string()) self.store.get_current_state(room_id=self.room.to_string())
) )

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd # Copyright 2018-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -15,24 +15,18 @@
import logging import logging
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import RoomID, UserID from synapse.types import RoomID, UserID
import tests.unittest from tests.unittest import HomeserverTestCase
import tests.utils
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class StateStoreTestCase(tests.unittest.TestCase): class StateStoreTestCase(HomeserverTestCase):
@defer.inlineCallbacks def prepare(self, reactor, clock, hs):
def setUp(self):
hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.storage = hs.get_storage() self.storage = hs.get_storage()
self.state_datastore = self.storage.state.stores.state self.state_datastore = self.storage.state.stores.state
@ -44,7 +38,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room = RoomID.from_string("!abc123:test") self.room = RoomID.from_string("!abc123:test")
yield defer.ensureDeferred( self.get_success(
self.store.store_room( self.store.store_room(
self.room.to_string(), self.room.to_string(),
room_creator_user_id="@creator:text", room_creator_user_id="@creator:text",
@ -53,7 +47,6 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
) )
@defer.inlineCallbacks
def inject_state_event(self, room, sender, typ, state_key, content): def inject_state_event(self, room, sender, typ, state_key, content):
builder = self.event_builder_factory.for_room_version( builder = self.event_builder_factory.for_room_version(
RoomVersions.V1, RoomVersions.V1,
@ -66,13 +59,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
}, },
) )
event, context = yield defer.ensureDeferred( event, context = self.get_success(
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )
yield defer.ensureDeferred( self.get_success(self.storage.persistence.persist_event(event, context))
self.storage.persistence.persist_event(event, context)
)
return event return event
@ -82,16 +73,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(s1[t].event_id, s2[t].event_id) self.assertEqual(s1[t].event_id, s2[t].event_id)
self.assertEqual(len(s1), len(s2)) self.assertEqual(len(s1), len(s2))
@defer.inlineCallbacks
def test_get_state_groups_ids(self): def test_get_state_groups_ids(self):
e1 = yield self.inject_state_event( e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
self.room, self.u_alice, EventTypes.Create, "", {} e2 = self.inject_state_event(
)
e2 = yield self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
) )
state_group_map = yield defer.ensureDeferred( state_group_map = self.get_success(
self.storage.state.get_state_groups_ids(self.room, [e2.event_id]) self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
) )
self.assertEqual(len(state_group_map), 1) self.assertEqual(len(state_group_map), 1)
@ -101,16 +89,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
{(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id}, {(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
) )
@defer.inlineCallbacks
def test_get_state_groups(self): def test_get_state_groups(self):
e1 = yield self.inject_state_event( e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
self.room, self.u_alice, EventTypes.Create, "", {} e2 = self.inject_state_event(
)
e2 = yield self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
) )
state_group_map = yield defer.ensureDeferred( state_group_map = self.get_success(
self.storage.state.get_state_groups(self.room, [e2.event_id]) self.storage.state.get_state_groups(self.room, [e2.event_id])
) )
self.assertEqual(len(state_group_map), 1) self.assertEqual(len(state_group_map), 1)
@ -118,32 +103,29 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id}) self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
@defer.inlineCallbacks
def test_get_state_for_event(self): def test_get_state_for_event(self):
# this defaults to a linear DAG as each new injection defaults to whatever # this defaults to a linear DAG as each new injection defaults to whatever
# forward extremities are currently in the DB for this room. # forward extremities are currently in the DB for this room.
e1 = yield self.inject_state_event( e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
self.room, self.u_alice, EventTypes.Create, "", {} e2 = self.inject_state_event(
)
e2 = yield self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
) )
e3 = yield self.inject_state_event( e3 = self.inject_state_event(
self.room, self.room,
self.u_alice, self.u_alice,
EventTypes.Member, EventTypes.Member,
self.u_alice.to_string(), self.u_alice.to_string(),
{"membership": Membership.JOIN}, {"membership": Membership.JOIN},
) )
e4 = yield self.inject_state_event( e4 = self.inject_state_event(
self.room, self.room,
self.u_bob, self.u_bob,
EventTypes.Member, EventTypes.Member,
self.u_bob.to_string(), self.u_bob.to_string(),
{"membership": Membership.JOIN}, {"membership": Membership.JOIN},
) )
e5 = yield self.inject_state_event( e5 = self.inject_state_event(
self.room, self.room,
self.u_bob, self.u_bob,
EventTypes.Member, EventTypes.Member,
@ -152,9 +134,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
# check we get the full state as of the final event # check we get the full state as of the final event
state = yield defer.ensureDeferred( state = self.get_success(self.storage.state.get_state_for_event(e5.event_id))
self.storage.state.get_state_for_event(e5.event_id)
)
self.assertIsNotNone(e4) self.assertIsNotNone(e4)
@ -170,7 +150,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
# check we can filter to the m.room.name event (with a '' state key) # check we can filter to the m.room.name event (with a '' state key)
state = yield defer.ensureDeferred( state = self.get_success(
self.storage.state.get_state_for_event( self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
) )
@ -179,7 +159,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can filter to the m.room.name event (with a wildcard None state key) # check we can filter to the m.room.name event (with a wildcard None state key)
state = yield defer.ensureDeferred( state = self.get_success(
self.storage.state.get_state_for_event( self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
) )
@ -188,7 +168,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can grab the m.room.member events (with a wildcard None state key) # check we can grab the m.room.member events (with a wildcard None state key)
state = yield defer.ensureDeferred( state = self.get_success(
self.storage.state.get_state_for_event( self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
) )
@ -200,7 +180,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# check we can grab a specific room member without filtering out the # check we can grab a specific room member without filtering out the
# other event types # other event types
state = yield defer.ensureDeferred( state = self.get_success(
self.storage.state.get_state_for_event( self.storage.state.get_state_for_event(
e5.event_id, e5.event_id,
state_filter=StateFilter( state_filter=StateFilter(
@ -220,7 +200,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
# check that we can grab everything except members # check that we can grab everything except members
state = yield defer.ensureDeferred( state = self.get_success(
self.storage.state.get_state_for_event( self.storage.state.get_state_for_event(
e5.event_id, e5.event_id,
state_filter=StateFilter( state_filter=StateFilter(
@ -238,17 +218,14 @@ class StateStoreTestCase(tests.unittest.TestCase):
####################################################### #######################################################
room_id = self.room.to_string() room_id = self.room.to_string()
group_ids = yield defer.ensureDeferred( group_ids = self.get_success(
self.storage.state.get_state_groups_ids(room_id, [e5.event_id]) self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
) )
group = list(group_ids.keys())[0] group = list(group_ids.keys())[0]
# test _get_state_for_group_using_cache correctly filters out members # test _get_state_for_group_using_cache correctly filters out members
# with types=[] # with types=[]
( (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
@ -265,10 +242,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict, state_dict,
) )
( (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
@ -281,10 +255,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members # test _get_state_for_group_using_cache correctly filters in members
# with wildcard types # with wildcard types
( (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
@ -301,10 +272,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict, state_dict,
) )
( (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
@ -324,10 +292,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members # test _get_state_for_group_using_cache correctly filters in members
# with specific types # with specific types
( (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
@ -344,10 +309,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict, state_dict,
) )
( (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
@ -360,10 +322,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members # test _get_state_for_group_using_cache correctly filters in members
# with specific types # with specific types
( (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
@ -413,10 +372,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters out members # test _get_state_for_group_using_cache correctly filters out members
# with types=[] # with types=[]
room_id = self.room.to_string() room_id = self.room.to_string()
( (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
@ -428,10 +384,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
room_id = self.room.to_string() room_id = self.room.to_string()
( (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
@ -444,10 +397,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members # test _get_state_for_group_using_cache correctly filters in members
# wildcard types # wildcard types
( (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
@ -458,10 +408,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
( (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
@ -480,10 +427,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members # test _get_state_for_group_using_cache correctly filters in members
# with specific types # with specific types
( (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
@ -494,10 +438,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
( (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
@ -510,10 +451,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members # test _get_state_for_group_using_cache correctly filters in members
# with specific types # with specific types
( (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
@ -524,10 +462,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
self.assertDictEqual({}, state_dict) self.assertDictEqual({}, state_dict)
( (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd # Copyright 2018-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,10 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer from tests.unittest import HomeserverTestCase, override_config
from tests import unittest
from tests.utils import setup_test_homeserver
ALICE = "@alice:a" ALICE = "@alice:a"
BOB = "@bob:b" BOB = "@bob:b"
@ -25,46 +22,31 @@ BOBBY = "@bobby:a"
BELA = "@somenickname:a" BELA = "@somenickname:a"
class UserDirectoryStoreTestCase(unittest.TestCase): class UserDirectoryStoreTestCase(HomeserverTestCase):
@defer.inlineCallbacks def prepare(self, reactor, clock, hs):
def setUp(self): self.store = hs.get_datastore()
self.hs = yield setup_test_homeserver(self.addCleanup)
self.store = self.hs.get_datastore()
# alice and bob are both in !room_id. bobby is not but shares # alice and bob are both in !room_id. bobby is not but shares
# a homeserver with alice. # a homeserver with alice.
yield defer.ensureDeferred( self.get_success(self.store.update_profile_in_user_dir(ALICE, "alice", None))
self.store.update_profile_in_user_dir(ALICE, "alice", None) self.get_success(self.store.update_profile_in_user_dir(BOB, "bob", None))
) self.get_success(self.store.update_profile_in_user_dir(BOBBY, "bobby", None))
yield defer.ensureDeferred( self.get_success(self.store.update_profile_in_user_dir(BELA, "Bela", None))
self.store.update_profile_in_user_dir(BOB, "bob", None) self.get_success(self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)))
)
yield defer.ensureDeferred(
self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
)
yield defer.ensureDeferred(
self.store.update_profile_in_user_dir(BELA, "Bela", None)
)
yield defer.ensureDeferred(
self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
)
@defer.inlineCallbacks
def test_search_user_dir(self): def test_search_user_dir(self):
# normally when alice searches the directory she should just find # normally when alice searches the directory she should just find
# bob because bobby doesn't share a room with her. # bob because bobby doesn't share a room with her.
r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10)) r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10))
self.assertFalse(r["limited"]) self.assertFalse(r["limited"])
self.assertEqual(1, len(r["results"])) self.assertEqual(1, len(r["results"]))
self.assertDictEqual( self.assertDictEqual(
r["results"][0], {"user_id": BOB, "display_name": "bob", "avatar_url": None} r["results"][0], {"user_id": BOB, "display_name": "bob", "avatar_url": None}
) )
@defer.inlineCallbacks @override_config({"user_directory": {"search_all_users": True}})
def test_search_user_dir_all_users(self): def test_search_user_dir_all_users(self):
self.hs.config.user_directory_search_all_users = True r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10))
try:
r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
self.assertFalse(r["limited"]) self.assertFalse(r["limited"])
self.assertEqual(2, len(r["results"])) self.assertEqual(2, len(r["results"]))
self.assertDictEqual( self.assertDictEqual(
@ -75,23 +57,17 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
r["results"][1], r["results"][1],
{"user_id": BOBBY, "display_name": "bobby", "avatar_url": None}, {"user_id": BOBBY, "display_name": "bobby", "avatar_url": None},
) )
finally:
self.hs.config.user_directory_search_all_users = False
@defer.inlineCallbacks @override_config({"user_directory": {"search_all_users": True}})
def test_search_user_dir_stop_words(self): def test_search_user_dir_stop_words(self):
"""Tests that a user can look up another user by searching for the start if its """Tests that a user can look up another user by searching for the start if its
display name even if that name happens to be a common English word that would display name even if that name happens to be a common English word that would
usually be ignored in full text searches. usually be ignored in full text searches.
""" """
self.hs.config.user_directory_search_all_users = True r = self.get_success(self.store.search_user_dir(ALICE, "be", 10))
try:
r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "be", 10))
self.assertFalse(r["limited"]) self.assertFalse(r["limited"])
self.assertEqual(1, len(r["results"])) self.assertEqual(1, len(r["results"]))
self.assertDictEqual( self.assertDictEqual(
r["results"][0], r["results"][0],
{"user_id": BELA, "display_name": "Bela", "avatar_url": None}, {"user_id": BELA, "display_name": "Bela", "avatar_url": None},
) )
finally:
self.hs.config.user_directory_search_all_users = False