Convert storage test cases to HomeserverTestCase. (#9736)

This commit is contained in:
Patrick Cloke 2021-04-06 07:21:02 -04:00 committed by GitHub
parent e2b8a90897
commit e7b769aea1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 265 additions and 499 deletions

1
changelog.d/9736.misc Normal file
View File

@ -0,0 +1 @@
Convert various testcases to `HomeserverTestCase`.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,32 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
import synapse.api.errors
import tests.unittest
import tests.utils
from tests.unittest import HomeserverTestCase
class DeviceStoreTestCase(tests.unittest.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.store = None # type: synapse.storage.DataStore
@defer.inlineCallbacks
def setUp(self):
hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
class DeviceStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
@defer.inlineCallbacks
def test_store_new_device(self):
yield defer.ensureDeferred(
self.get_success(
self.store.store_device("user_id", "device_id", "display_name")
)
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertDictContainsSubset(
{
"user_id": "user_id",
@ -48,19 +37,18 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
res,
)
@defer.inlineCallbacks
def test_get_devices_by_user(self):
yield defer.ensureDeferred(
self.get_success(
self.store.store_device("user_id", "device1", "display_name 1")
)
yield defer.ensureDeferred(
self.get_success(
self.store.store_device("user_id", "device2", "display_name 2")
)
yield defer.ensureDeferred(
self.get_success(
self.store.store_device("user_id2", "device3", "display_name 3")
)
res = yield defer.ensureDeferred(self.store.get_devices_by_user("user_id"))
res = self.get_success(self.store.get_devices_by_user("user_id"))
self.assertEqual(2, len(res.keys()))
self.assertDictContainsSubset(
{
@ -79,43 +67,41 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
res["device2"],
)
@defer.inlineCallbacks
def test_count_devices_by_users(self):
yield defer.ensureDeferred(
self.get_success(
self.store.store_device("user_id", "device1", "display_name 1")
)
yield defer.ensureDeferred(
self.get_success(
self.store.store_device("user_id", "device2", "display_name 2")
)
yield defer.ensureDeferred(
self.get_success(
self.store.store_device("user_id2", "device3", "display_name 3")
)
res = yield defer.ensureDeferred(self.store.count_devices_by_users())
res = self.get_success(self.store.count_devices_by_users())
self.assertEqual(0, res)
res = yield defer.ensureDeferred(self.store.count_devices_by_users(["unknown"]))
res = self.get_success(self.store.count_devices_by_users(["unknown"]))
self.assertEqual(0, res)
res = yield defer.ensureDeferred(self.store.count_devices_by_users(["user_id"]))
res = self.get_success(self.store.count_devices_by_users(["user_id"]))
self.assertEqual(2, res)
res = yield defer.ensureDeferred(
res = self.get_success(
self.store.count_devices_by_users(["user_id", "user_id2"])
)
self.assertEqual(3, res)
@defer.inlineCallbacks
def test_get_device_updates_by_remote(self):
device_ids = ["device_id1", "device_id2"]
# Add two device updates with a single stream_id
yield defer.ensureDeferred(
self.get_success(
self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
)
# Get all device updates ever meant for this remote
now_stream_id, device_updates = yield defer.ensureDeferred(
now_stream_id, device_updates = self.get_success(
self.store.get_device_updates_by_remote("somehost", -1, limit=100)
)
@ -131,37 +117,35 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
}
self.assertEqual(received_device_ids, set(expected_device_ids))
@defer.inlineCallbacks
def test_update_device(self):
yield defer.ensureDeferred(
self.get_success(
self.store.store_device("user_id", "device_id", "display_name 1")
)
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"])
# do a no-op first
yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.get_success(self.store.update_device("user_id", "device_id"))
res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"])
# do the update
yield defer.ensureDeferred(
self.get_success(
self.store.update_device(
"user_id", "device_id", new_display_name="display_name 2"
)
)
# check it worked
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 2", res["display_name"])
@defer.inlineCallbacks
def test_update_unknown_device(self):
with self.assertRaises(synapse.api.errors.StoreError) as cm:
yield defer.ensureDeferred(
exc = self.get_failure(
self.store.update_device(
"user_id", "unknown_device_id", new_display_name="display_name 2"
),
synapse.api.errors.StoreError,
)
)
self.assertEqual(404, cm.exception.code)
self.assertEqual(404, exc.value.code)

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,28 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.types import RoomAlias, RoomID
from tests import unittest
from tests.utils import setup_test_homeserver
from tests.unittest import HomeserverTestCase
class DirectoryStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup)
class DirectoryStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.room = RoomID.from_string("!abcde:test")
self.alias = RoomAlias.from_string("#my-room:test")
@defer.inlineCallbacks
def test_room_to_alias(self):
yield defer.ensureDeferred(
self.get_success(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
)
@ -42,16 +34,11 @@ class DirectoryStoreTestCase(unittest.TestCase):
self.assertEquals(
["#my-room:test"],
(
yield defer.ensureDeferred(
self.store.get_aliases_for_room(self.room.to_string())
)
),
(self.get_success(self.store.get_aliases_for_room(self.room.to_string()))),
)
@defer.inlineCallbacks
def test_alias_to_room(self):
yield defer.ensureDeferred(
self.get_success(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
)
@ -59,28 +46,19 @@ class DirectoryStoreTestCase(unittest.TestCase):
self.assertObjectHasAttributes(
{"room_id": self.room.to_string(), "servers": ["test"]},
(
yield defer.ensureDeferred(
self.store.get_association_from_room_alias(self.alias)
)
),
(self.get_success(self.store.get_association_from_room_alias(self.alias))),
)
@defer.inlineCallbacks
def test_delete_alias(self):
yield defer.ensureDeferred(
self.get_success(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
)
)
room_id = yield defer.ensureDeferred(self.store.delete_room_alias(self.alias))
room_id = self.get_success(self.store.delete_room_alias(self.alias))
self.assertEqual(self.room.to_string(), room_id)
self.assertIsNone(
(
yield defer.ensureDeferred(
self.store.get_association_from_room_alias(self.alias)
)
)
(self.get_success(self.store.get_association_from_room_alias(self.alias)))
)

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,30 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
import tests.unittest
import tests.utils
from tests.unittest import HomeserverTestCase
class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
class EndToEndKeyStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
@defer.inlineCallbacks
def test_key_without_device_name(self):
now = 1470174257070
json = {"key": "value"}
yield defer.ensureDeferred(self.store.store_device("user", "device", None))
self.get_success(self.store.store_device("user", "device", None))
yield defer.ensureDeferred(
self.store.set_e2e_device_keys("user", "device", now, json)
)
self.get_success(self.store.set_e2e_device_keys("user", "device", now, json))
res = yield defer.ensureDeferred(
res = self.get_success(
self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
)
self.assertIn("user", res)
@ -44,38 +36,32 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
dev = res["user"]["device"]
self.assertDictContainsSubset(json, dev)
@defer.inlineCallbacks
def test_reupload_key(self):
now = 1470174257070
json = {"key": "value"}
yield defer.ensureDeferred(self.store.store_device("user", "device", None))
self.get_success(self.store.store_device("user", "device", None))
changed = yield defer.ensureDeferred(
changed = self.get_success(
self.store.set_e2e_device_keys("user", "device", now, json)
)
self.assertTrue(changed)
# If we try to upload the same key then we should be told nothing
# changed
changed = yield defer.ensureDeferred(
changed = self.get_success(
self.store.set_e2e_device_keys("user", "device", now, json)
)
self.assertFalse(changed)
@defer.inlineCallbacks
def test_get_key_with_device_name(self):
now = 1470174257070
json = {"key": "value"}
yield defer.ensureDeferred(
self.store.set_e2e_device_keys("user", "device", now, json)
)
yield defer.ensureDeferred(
self.store.store_device("user", "device", "display_name")
)
self.get_success(self.store.set_e2e_device_keys("user", "device", now, json))
self.get_success(self.store.store_device("user", "device", "display_name"))
res = yield defer.ensureDeferred(
res = self.get_success(
self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
)
self.assertIn("user", res)
@ -85,29 +71,28 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
{"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev
)
@defer.inlineCallbacks
def test_multiple_devices(self):
now = 1470174257070
yield defer.ensureDeferred(self.store.store_device("user1", "device1", None))
yield defer.ensureDeferred(self.store.store_device("user1", "device2", None))
yield defer.ensureDeferred(self.store.store_device("user2", "device1", None))
yield defer.ensureDeferred(self.store.store_device("user2", "device2", None))
self.get_success(self.store.store_device("user1", "device1", None))
self.get_success(self.store.store_device("user1", "device2", None))
self.get_success(self.store.store_device("user2", "device1", None))
self.get_success(self.store.store_device("user2", "device2", None))
yield defer.ensureDeferred(
self.get_success(
self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
)
yield defer.ensureDeferred(
self.get_success(
self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
)
yield defer.ensureDeferred(
self.get_success(
self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
)
yield defer.ensureDeferred(
self.get_success(
self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
)
res = yield defer.ensureDeferred(
res = self.get_success(
self.store.get_e2e_device_keys_for_cs_api(
(("user1", "device1"), ("user2", "device2"))
)

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -15,10 +15,7 @@
from mock import Mock
from twisted.internet import defer
import tests.unittest
import tests.utils
from tests.unittest import HomeserverTestCase
USER_ID = "@user:example.com"
@ -30,37 +27,31 @@ HIGHLIGHT = [
]
class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
class EventPushActionsStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.persist_events_store = hs.get_datastores().persist_events
@defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_http(self):
yield defer.ensureDeferred(
self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_http(
USER_ID, 0, 1000, 20
)
)
@defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_email(self):
yield defer.ensureDeferred(
self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_email(
USER_ID, 0, 1000, 20
)
)
@defer.inlineCallbacks
def test_count_aggregation(self):
room_id = "!foo:example.com"
user_id = "@user1235:example.com"
@defer.inlineCallbacks
def _assert_counts(noitf_count, highlight_count):
counts = yield defer.ensureDeferred(
counts = self.get_success(
self.store.db_pool.runInteraction(
"", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
)
@ -74,7 +65,6 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
},
)
@defer.inlineCallbacks
def _inject_actions(stream, action):
event = Mock()
event.room_id = room_id
@ -82,14 +72,14 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
event.internal_metadata.stream_ordering = stream
event.depth = stream
yield defer.ensureDeferred(
self.get_success(
self.store.add_push_actions_to_staging(
event.event_id,
{user_id: action},
False,
)
)
yield defer.ensureDeferred(
self.get_success(
self.store.db_pool.runInteraction(
"",
self.persist_events_store._set_push_actions_for_event_and_users_txn,
@ -99,14 +89,14 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
)
def _rotate(stream):
return defer.ensureDeferred(
self.get_success(
self.store.db_pool.runInteraction(
"", self.store._rotate_notifs_before_txn, stream
)
)
def _mark_read(stream, depth):
return defer.ensureDeferred(
self.get_success(
self.store.db_pool.runInteraction(
"",
self.store._remove_old_push_actions_before_txn,
@ -116,49 +106,48 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
)
)
yield _assert_counts(0, 0)
yield _inject_actions(1, PlAIN_NOTIF)
yield _assert_counts(1, 0)
yield _rotate(2)
yield _assert_counts(1, 0)
_assert_counts(0, 0)
_inject_actions(1, PlAIN_NOTIF)
_assert_counts(1, 0)
_rotate(2)
_assert_counts(1, 0)
yield _inject_actions(3, PlAIN_NOTIF)
yield _assert_counts(2, 0)
yield _rotate(4)
yield _assert_counts(2, 0)
_inject_actions(3, PlAIN_NOTIF)
_assert_counts(2, 0)
_rotate(4)
_assert_counts(2, 0)
yield _inject_actions(5, PlAIN_NOTIF)
yield _mark_read(3, 3)
yield _assert_counts(1, 0)
_inject_actions(5, PlAIN_NOTIF)
_mark_read(3, 3)
_assert_counts(1, 0)
yield _mark_read(5, 5)
yield _assert_counts(0, 0)
_mark_read(5, 5)
_assert_counts(0, 0)
yield _inject_actions(6, PlAIN_NOTIF)
yield _rotate(7)
_inject_actions(6, PlAIN_NOTIF)
_rotate(7)
yield defer.ensureDeferred(
self.get_success(
self.store.db_pool.simple_delete(
table="event_push_actions", keyvalues={"1": 1}, desc=""
)
)
yield _assert_counts(1, 0)
_assert_counts(1, 0)
yield _mark_read(7, 7)
yield _assert_counts(0, 0)
_mark_read(7, 7)
_assert_counts(0, 0)
yield _inject_actions(8, HIGHLIGHT)
yield _assert_counts(1, 1)
yield _rotate(9)
yield _assert_counts(1, 1)
yield _rotate(10)
yield _assert_counts(1, 1)
_inject_actions(8, HIGHLIGHT)
_assert_counts(1, 1)
_rotate(9)
_assert_counts(1, 1)
_rotate(10)
_assert_counts(1, 1)
@defer.inlineCallbacks
def test_find_first_stream_ordering_after_ts(self):
def add_event(so, ts):
return defer.ensureDeferred(
self.get_success(
self.store.db_pool.simple_insert(
"events",
{
@ -177,24 +166,16 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
)
# start with the base case where there are no events in the table
r = yield defer.ensureDeferred(
self.store.find_first_stream_ordering_after_ts(11)
)
r = self.get_success(self.store.find_first_stream_ordering_after_ts(11))
self.assertEqual(r, 0)
# now with one event
yield add_event(2, 10)
r = yield defer.ensureDeferred(
self.store.find_first_stream_ordering_after_ts(9)
)
add_event(2, 10)
r = self.get_success(self.store.find_first_stream_ordering_after_ts(9))
self.assertEqual(r, 2)
r = yield defer.ensureDeferred(
self.store.find_first_stream_ordering_after_ts(10)
)
r = self.get_success(self.store.find_first_stream_ordering_after_ts(10))
self.assertEqual(r, 2)
r = yield defer.ensureDeferred(
self.store.find_first_stream_ordering_after_ts(11)
)
r = self.get_success(self.store.find_first_stream_ordering_after_ts(11))
self.assertEqual(r, 3)
# add a bunch of dummy events to the events table
@ -205,39 +186,27 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
(10, 130),
(20, 140),
):
yield add_event(stream_ordering, ts)
add_event(stream_ordering, ts)
r = yield defer.ensureDeferred(
self.store.find_first_stream_ordering_after_ts(110)
)
r = self.get_success(self.store.find_first_stream_ordering_after_ts(110))
self.assertEqual(r, 3, "First event after 110ms should be 3, was %i" % r)
# 4 and 5 are both after 120: we want 4 rather than 5
r = yield defer.ensureDeferred(
self.store.find_first_stream_ordering_after_ts(120)
)
r = self.get_success(self.store.find_first_stream_ordering_after_ts(120))
self.assertEqual(r, 4, "First event after 120ms should be 4, was %i" % r)
r = yield defer.ensureDeferred(
self.store.find_first_stream_ordering_after_ts(129)
)
r = self.get_success(self.store.find_first_stream_ordering_after_ts(129))
self.assertEqual(r, 10, "First event after 129ms should be 10, was %i" % r)
# check we can get the last event
r = yield defer.ensureDeferred(
self.store.find_first_stream_ordering_after_ts(140)
)
r = self.get_success(self.store.find_first_stream_ordering_after_ts(140))
self.assertEqual(r, 20, "First event after 14ms should be 20, was %i" % r)
# off the end
r = yield defer.ensureDeferred(
self.store.find_first_stream_ordering_after_ts(160)
)
r = self.get_success(self.store.find_first_stream_ordering_after_ts(160))
self.assertEqual(r, 21)
# check we can find an event at ordering zero
yield add_event(0, 5)
r = yield defer.ensureDeferred(
self.store.find_first_stream_ordering_after_ts(1)
)
add_event(0, 5)
r = self.get_success(self.store.find_first_stream_ordering_after_ts(1))
self.assertEqual(r, 0)

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,59 +13,50 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.types import UserID
from tests import unittest
from tests.utils import setup_test_homeserver
class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup)
class ProfileStoreTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.u_frank = UserID.from_string("@frank:test")
@defer.inlineCallbacks
def test_displayname(self):
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
self.get_success(self.store.create_profile(self.u_frank.localpart))
yield defer.ensureDeferred(
self.get_success(
self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
)
self.assertEquals(
"Frank",
(
yield defer.ensureDeferred(
self.get_success(
self.store.get_profile_displayname(self.u_frank.localpart)
)
),
)
# test set to None
yield defer.ensureDeferred(
self.get_success(
self.store.set_profile_displayname(self.u_frank.localpart, None)
)
self.assertIsNone(
(
yield defer.ensureDeferred(
self.get_success(
self.store.get_profile_displayname(self.u_frank.localpart)
)
)
)
@defer.inlineCallbacks
def test_avatar_url(self):
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
self.get_success(self.store.create_profile(self.u_frank.localpart))
yield defer.ensureDeferred(
self.get_success(
self.store.set_profile_avatar_url(
self.u_frank.localpart, "http://my.site/here"
)
@ -74,20 +65,20 @@ class ProfileStoreTestCase(unittest.TestCase):
self.assertEquals(
"http://my.site/here",
(
yield defer.ensureDeferred(
self.get_success(
self.store.get_profile_avatar_url(self.u_frank.localpart)
)
),
)
# test set to None
yield defer.ensureDeferred(
self.get_success(
self.store.set_profile_avatar_url(self.u_frank.localpart, None)
)
self.assertIsNone(
(
yield defer.ensureDeferred(
self.get_success(
self.store.get_profile_avatar_url(self.u_frank.localpart)
)
)

View File

@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -16,8 +15,6 @@
from canonicaljson import json
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.types import RoomID, UserID
@ -230,10 +227,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self._base_builder = base_builder
self._event_id = event_id
@defer.inlineCallbacks
def build(self, prev_event_ids, auth_event_ids):
built_event = yield defer.ensureDeferred(
self._base_builder.build(prev_event_ids, auth_event_ids)
async def build(self, prev_event_ids, auth_event_ids):
built_event = await self._base_builder.build(
prev_event_ids, auth_event_ids
)
built_event._event_id = self._event_id

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,21 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.constants import UserTypes
from synapse.api.errors import ThreepidValidationError
from tests import unittest
from tests.utils import setup_test_homeserver
from tests.unittest import HomeserverTestCase
class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup)
class RegistrationStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.user_id = "@my-user:test"
@ -35,9 +28,8 @@ class RegistrationStoreTestCase(unittest.TestCase):
self.pwhash = "{xx1}123456789"
self.device_id = "akgjhdjklgshg"
@defer.inlineCallbacks
def test_register(self):
yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
self.get_success(self.store.register_user(self.user_id, self.pwhash))
self.assertEquals(
{
@ -49,93 +41,81 @@ class RegistrationStoreTestCase(unittest.TestCase):
"consent_version": None,
"consent_server_notice_sent": None,
"appservice_id": None,
"creation_ts": 1000,
"creation_ts": 0,
"user_type": None,
"deactivated": 0,
"shadow_banned": 0,
},
(yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))),
(self.get_success(self.store.get_user_by_id(self.user_id))),
)
@defer.inlineCallbacks
def test_add_tokens(self):
yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
yield defer.ensureDeferred(
self.get_success(self.store.register_user(self.user_id, self.pwhash))
self.get_success(
self.store.add_access_token_to_user(
self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
)
)
result = yield defer.ensureDeferred(
self.store.get_user_by_access_token(self.tokens[1])
)
result = self.get_success(self.store.get_user_by_access_token(self.tokens[1]))
self.assertEqual(result.user_id, self.user_id)
self.assertEqual(result.device_id, self.device_id)
self.assertIsNotNone(result.token_id)
@defer.inlineCallbacks
def test_user_delete_access_tokens(self):
# add some tokens
yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
yield defer.ensureDeferred(
self.get_success(self.store.register_user(self.user_id, self.pwhash))
self.get_success(
self.store.add_access_token_to_user(
self.user_id, self.tokens[0], device_id=None, valid_until_ms=None
)
)
yield defer.ensureDeferred(
self.get_success(
self.store.add_access_token_to_user(
self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
)
)
# now delete some
yield defer.ensureDeferred(
self.get_success(
self.store.user_delete_access_tokens(self.user_id, device_id=self.device_id)
)
# check they were deleted
user = yield defer.ensureDeferred(
self.store.get_user_by_access_token(self.tokens[1])
)
user = self.get_success(self.store.get_user_by_access_token(self.tokens[1]))
self.assertIsNone(user, "access token was not deleted by device_id")
# check the one not associated with the device was not deleted
user = yield defer.ensureDeferred(
self.store.get_user_by_access_token(self.tokens[0])
)
user = self.get_success(self.store.get_user_by_access_token(self.tokens[0]))
self.assertEqual(self.user_id, user.user_id)
# now delete the rest
yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id))
self.get_success(self.store.user_delete_access_tokens(self.user_id))
user = yield defer.ensureDeferred(
self.store.get_user_by_access_token(self.tokens[0])
)
user = self.get_success(self.store.get_user_by_access_token(self.tokens[0]))
self.assertIsNone(user, "access token was not deleted without device_id")
@defer.inlineCallbacks
def test_is_support_user(self):
TEST_USER = "@test:test"
SUPPORT_USER = "@support:test"
res = yield defer.ensureDeferred(self.store.is_support_user(None))
res = self.get_success(self.store.is_support_user(None))
self.assertFalse(res)
yield defer.ensureDeferred(
self.get_success(
self.store.register_user(user_id=TEST_USER, password_hash=None)
)
res = yield defer.ensureDeferred(self.store.is_support_user(TEST_USER))
res = self.get_success(self.store.is_support_user(TEST_USER))
self.assertFalse(res)
yield defer.ensureDeferred(
self.get_success(
self.store.register_user(
user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT
)
)
res = yield defer.ensureDeferred(self.store.is_support_user(SUPPORT_USER))
res = self.get_success(self.store.is_support_user(SUPPORT_USER))
self.assertTrue(res)
@defer.inlineCallbacks
def test_3pid_inhibit_invalid_validation_session_error(self):
"""Tests that enabling the configuration option to inhibit 3PID errors on
/requestToken also inhibits validation errors caused by an unknown session ID.
@ -143,30 +123,28 @@ class RegistrationStoreTestCase(unittest.TestCase):
# Check that, with the config setting set to false (the default value), a
# validation error is caused by the unknown session ID.
try:
yield defer.ensureDeferred(
e = self.get_failure(
self.store.validate_threepid_session(
"fake_sid",
"fake_client_secret",
"fake_token",
0,
),
ThreepidValidationError,
)
)
except ThreepidValidationError as e:
self.assertEquals(e.msg, "Unknown session_id", e)
self.assertEquals(e.value.msg, "Unknown session_id", e)
# Set the config setting to true.
self.store._ignore_unknown_session_error = True
# Check that now the validation error is caused by the token not matching.
try:
yield defer.ensureDeferred(
e = self.get_failure(
self.store.validate_threepid_session(
"fake_sid",
"fake_client_secret",
"fake_token",
0,
),
ThreepidValidationError,
)
)
except ThreepidValidationError as e:
self.assertEquals(e.msg, "Validation token not found or has expired", e)
self.assertEquals(e.value.msg, "Validation token not found or has expired", e)

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,22 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.room_versions import RoomVersions
from synapse.types import RoomAlias, RoomID, UserID
from tests import unittest
from tests.utils import setup_test_homeserver
from tests.unittest import HomeserverTestCase
class RoomStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup)
class RoomStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
# We can't test RoomStore on its own without the DirectoryStore, for
# management of the 'room_aliases' table
self.store = hs.get_datastore()
@ -37,7 +30,7 @@ class RoomStoreTestCase(unittest.TestCase):
self.alias = RoomAlias.from_string("#a-room-name:test")
self.u_creator = UserID.from_string("@creator:test")
yield defer.ensureDeferred(
self.get_success(
self.store.store_room(
self.room.to_string(),
room_creator_user_id=self.u_creator.to_string(),
@ -46,7 +39,6 @@ class RoomStoreTestCase(unittest.TestCase):
)
)
@defer.inlineCallbacks
def test_get_room(self):
self.assertDictContainsSubset(
{
@ -54,16 +46,12 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(),
"is_public": True,
},
(yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))),
(self.get_success(self.store.get_room(self.room.to_string()))),
)
@defer.inlineCallbacks
def test_get_room_unknown_room(self):
self.assertIsNone(
(yield defer.ensureDeferred(self.store.get_room("!uknown:test")))
)
self.assertIsNone((self.get_success(self.store.get_room("!uknown:test"))))
@defer.inlineCallbacks
def test_get_room_with_stats(self):
self.assertDictContainsSubset(
{
@ -71,29 +59,17 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(),
"public": True,
},
(
yield defer.ensureDeferred(
self.store.get_room_with_stats(self.room.to_string())
)
),
(self.get_success(self.store.get_room_with_stats(self.room.to_string()))),
)
@defer.inlineCallbacks
def test_get_room_with_stats_unknown_room(self):
self.assertIsNone(
(
yield defer.ensureDeferred(
self.store.get_room_with_stats("!uknown:test")
)
),
(self.get_success(self.store.get_room_with_stats("!uknown:test"))),
)
class RoomEventsStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = setup_test_homeserver(self.addCleanup)
class RoomEventsStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
# Room events need the full datastore, for persist_event() and
# get_room_state()
self.store = hs.get_datastore()
@ -102,7 +78,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
self.room = RoomID.from_string("!abcde:test")
yield defer.ensureDeferred(
self.get_success(
self.store.store_room(
self.room.to_string(),
room_creator_user_id="@creator:text",
@ -111,23 +87,21 @@ class RoomEventsStoreTestCase(unittest.TestCase):
)
)
@defer.inlineCallbacks
def inject_room_event(self, **kwargs):
yield defer.ensureDeferred(
self.get_success(
self.storage.persistence.persist_event(
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
)
)
@defer.inlineCallbacks
def STALE_test_room_name(self):
name = "A-Room-Name"
yield self.inject_room_event(
self.inject_room_event(
etype=EventTypes.Name, name=name, content={"name": name}, depth=1
)
state = yield defer.ensureDeferred(
state = self.get_success(
self.store.get_current_state(room_id=self.room.to_string())
)
@ -137,15 +111,14 @@ class RoomEventsStoreTestCase(unittest.TestCase):
state[0],
)
@defer.inlineCallbacks
def STALE_test_room_topic(self):
topic = "A place for things"
yield self.inject_room_event(
self.inject_room_event(
etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1
)
state = yield defer.ensureDeferred(
state = self.get_success(
self.store.get_current_state(room_id=self.room.to_string())
)

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -15,24 +15,18 @@
import logging
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.storage.state import StateFilter
from synapse.types import RoomID, UserID
import tests.unittest
import tests.utils
from tests.unittest import HomeserverTestCase
logger = logging.getLogger(__name__)
class StateStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
class StateStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.state_datastore = self.storage.state.stores.state
@ -44,7 +38,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room = RoomID.from_string("!abc123:test")
yield defer.ensureDeferred(
self.get_success(
self.store.store_room(
self.room.to_string(),
room_creator_user_id="@creator:text",
@ -53,7 +47,6 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
)
@defer.inlineCallbacks
def inject_state_event(self, room, sender, typ, state_key, content):
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
@ -66,13 +59,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
},
)
event, context = yield defer.ensureDeferred(
event, context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
yield defer.ensureDeferred(
self.storage.persistence.persist_event(event, context)
)
self.get_success(self.storage.persistence.persist_event(event, context))
return event
@ -82,16 +73,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(s1[t].event_id, s2[t].event_id)
self.assertEqual(len(s1), len(s2))
@defer.inlineCallbacks
def test_get_state_groups_ids(self):
e1 = yield self.inject_state_event(
self.room, self.u_alice, EventTypes.Create, "", {}
)
e2 = yield self.inject_state_event(
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
state_group_map = yield defer.ensureDeferred(
state_group_map = self.get_success(
self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
@ -101,16 +89,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
{(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
)
@defer.inlineCallbacks
def test_get_state_groups(self):
e1 = yield self.inject_state_event(
self.room, self.u_alice, EventTypes.Create, "", {}
)
e2 = yield self.inject_state_event(
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
state_group_map = yield defer.ensureDeferred(
state_group_map = self.get_success(
self.storage.state.get_state_groups(self.room, [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
@ -118,32 +103,29 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
@defer.inlineCallbacks
def test_get_state_for_event(self):
# this defaults to a linear DAG as each new injection defaults to whatever
# forward extremities are currently in the DB for this room.
e1 = yield self.inject_state_event(
self.room, self.u_alice, EventTypes.Create, "", {}
)
e2 = yield self.inject_state_event(
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
e3 = yield self.inject_state_event(
e3 = self.inject_state_event(
self.room,
self.u_alice,
EventTypes.Member,
self.u_alice.to_string(),
{"membership": Membership.JOIN},
)
e4 = yield self.inject_state_event(
e4 = self.inject_state_event(
self.room,
self.u_bob,
EventTypes.Member,
self.u_bob.to_string(),
{"membership": Membership.JOIN},
)
e5 = yield self.inject_state_event(
e5 = self.inject_state_event(
self.room,
self.u_bob,
EventTypes.Member,
@ -152,9 +134,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we get the full state as of the final event
state = yield defer.ensureDeferred(
self.storage.state.get_state_for_event(e5.event_id)
)
state = self.get_success(self.storage.state.get_state_for_event(e5.event_id))
self.assertIsNotNone(e4)
@ -170,7 +150,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we can filter to the m.room.name event (with a '' state key)
state = yield defer.ensureDeferred(
state = self.get_success(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
)
@ -179,7 +159,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can filter to the m.room.name event (with a wildcard None state key)
state = yield defer.ensureDeferred(
state = self.get_success(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
)
@ -188,7 +168,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can grab the m.room.member events (with a wildcard None state key)
state = yield defer.ensureDeferred(
state = self.get_success(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
)
@ -200,7 +180,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# check we can grab a specific room member without filtering out the
# other event types
state = yield defer.ensureDeferred(
state = self.get_success(
self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
@ -220,7 +200,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check that we can grab everything except members
state = yield defer.ensureDeferred(
state = self.get_success(
self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
@ -238,17 +218,14 @@ class StateStoreTestCase(tests.unittest.TestCase):
#######################################################
room_id = self.room.to_string()
group_ids = yield defer.ensureDeferred(
group_ids = self.get_success(
self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
)
group = list(group_ids.keys())[0]
# test _get_state_for_group_using_cache correctly filters out members
# with types=[]
(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
(state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@ -265,10 +242,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
(state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@ -281,10 +255,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with wildcard types
(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
(state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@ -301,10 +272,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
(state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@ -324,10 +292,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
(state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@ -344,10 +309,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
(state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@ -360,10 +322,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
(state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@ -413,10 +372,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters out members
# with types=[]
room_id = self.room.to_string()
(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
(state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@ -428,10 +384,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
room_id = self.room.to_string()
(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
(state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@ -444,10 +397,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# wildcard types
(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
(state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@ -458,10 +408,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
(state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@ -480,10 +427,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
(state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@ -494,10 +438,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
(state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@ -510,10 +451,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
(state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@ -524,10 +462,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({}, state_dict)
(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
(state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,10 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from tests import unittest
from tests.utils import setup_test_homeserver
from tests.unittest import HomeserverTestCase, override_config
ALICE = "@alice:a"
BOB = "@bob:b"
@ -25,46 +22,31 @@ BOBBY = "@bobby:a"
BELA = "@somenickname:a"
class UserDirectoryStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hs = yield setup_test_homeserver(self.addCleanup)
self.store = self.hs.get_datastore()
class UserDirectoryStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
# alice and bob are both in !room_id. bobby is not but shares
# a homeserver with alice.
yield defer.ensureDeferred(
self.store.update_profile_in_user_dir(ALICE, "alice", None)
)
yield defer.ensureDeferred(
self.store.update_profile_in_user_dir(BOB, "bob", None)
)
yield defer.ensureDeferred(
self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
)
yield defer.ensureDeferred(
self.store.update_profile_in_user_dir(BELA, "Bela", None)
)
yield defer.ensureDeferred(
self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
)
self.get_success(self.store.update_profile_in_user_dir(ALICE, "alice", None))
self.get_success(self.store.update_profile_in_user_dir(BOB, "bob", None))
self.get_success(self.store.update_profile_in_user_dir(BOBBY, "bobby", None))
self.get_success(self.store.update_profile_in_user_dir(BELA, "Bela", None))
self.get_success(self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)))
@defer.inlineCallbacks
def test_search_user_dir(self):
# normally when alice searches the directory she should just find
# bob because bobby doesn't share a room with her.
r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10))
self.assertFalse(r["limited"])
self.assertEqual(1, len(r["results"]))
self.assertDictEqual(
r["results"][0], {"user_id": BOB, "display_name": "bob", "avatar_url": None}
)
@defer.inlineCallbacks
@override_config({"user_directory": {"search_all_users": True}})
def test_search_user_dir_all_users(self):
self.hs.config.user_directory_search_all_users = True
try:
r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10))
self.assertFalse(r["limited"])
self.assertEqual(2, len(r["results"]))
self.assertDictEqual(
@ -75,23 +57,17 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
r["results"][1],
{"user_id": BOBBY, "display_name": "bobby", "avatar_url": None},
)
finally:
self.hs.config.user_directory_search_all_users = False
@defer.inlineCallbacks
@override_config({"user_directory": {"search_all_users": True}})
def test_search_user_dir_stop_words(self):
"""Tests that a user can look up another user by searching for the start if its
display name even if that name happens to be a common English word that would
usually be ignored in full text searches.
"""
self.hs.config.user_directory_search_all_users = True
try:
r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "be", 10))
r = self.get_success(self.store.search_user_dir(ALICE, "be", 10))
self.assertFalse(r["limited"])
self.assertEqual(1, len(r["results"]))
self.assertDictEqual(
r["results"][0],
{"user_id": BELA, "display_name": "Bela", "avatar_url": None},
)
finally:
self.hs.config.user_directory_search_all_users = False