mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Cache empty responses from /user/devices
(#11587)
If we've never made a request to a remote homeserver, we should cache the response---even if the response is "this user has no devices".
This commit is contained in:
parent
0fb3dd0830
commit
88a78c6577
1
changelog.d/11587.bugfix
Normal file
1
changelog.d/11587.bugfix
Normal file
@ -0,0 +1 @@
|
||||
Fix a long-standing bug where Synapse wouldn't cache a response indicating that a remote user has no devices.
|
@ -948,8 +948,16 @@ class DeviceListUpdater:
|
||||
devices = []
|
||||
ignore_devices = True
|
||||
else:
|
||||
prev_stream_id = await self.store.get_device_list_last_stream_id_for_remote(
|
||||
user_id
|
||||
)
|
||||
cached_devices = await self.store.get_cached_devices_for_user(user_id)
|
||||
if cached_devices == {d["device_id"]: d for d in devices}:
|
||||
|
||||
# To ensure that a user with no devices is cached, we skip the resync only
|
||||
# if we have a stream_id from previously writing a cache entry.
|
||||
if prev_stream_id is not None and cached_devices == {
|
||||
d["device_id"]: d for d in devices
|
||||
}:
|
||||
logging.info(
|
||||
"Skipping device list resync for %s, as our cache matches already",
|
||||
user_id,
|
||||
|
@ -713,7 +713,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
@cached(max_entries=10000)
|
||||
async def get_device_list_last_stream_id_for_remote(
|
||||
self, user_id: str
|
||||
) -> Optional[Any]:
|
||||
) -> Optional[str]:
|
||||
"""Get the last stream_id we got for a user. May be None if we haven't
|
||||
got any information for them.
|
||||
"""
|
||||
@ -729,7 +729,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
cached_method_name="get_device_list_last_stream_id_for_remote",
|
||||
list_name="user_ids",
|
||||
)
|
||||
async def get_device_list_last_stream_id_for_remotes(self, user_ids: Iterable[str]):
|
||||
async def get_device_list_last_stream_id_for_remotes(
|
||||
self, user_ids: Iterable[str]
|
||||
) -> Dict[str, Optional[str]]:
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="device_lists_remote_extremeties",
|
||||
column="user_id",
|
||||
@ -1316,6 +1318,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||
content: JsonDict,
|
||||
stream_id: str,
|
||||
) -> None:
|
||||
"""Delete, update or insert a cache entry for this (user, device) pair."""
|
||||
if content.get("deleted"):
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
@ -1375,6 +1378,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||
def _update_remote_device_list_cache_txn(
|
||||
self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
|
||||
) -> None:
|
||||
"""Replace the list of cached devices for this user with the given list."""
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
|
||||
)
|
||||
|
@ -13,8 +13,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Iterable
|
||||
from unittest import mock
|
||||
|
||||
from parameterized import parameterized
|
||||
from signedjson import key as key, sign as sign
|
||||
|
||||
from twisted.internet import defer
|
||||
@ -23,6 +25,7 @@ from synapse.api.constants import RoomEncryptionAlgorithms
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
|
||||
from tests import unittest
|
||||
from tests.test_utils import make_awaitable
|
||||
|
||||
|
||||
class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
@ -765,6 +768,8 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
remote_user_id = "@test:other"
|
||||
local_user_id = "@test:test"
|
||||
|
||||
# Pretend we're sharing a room with the user we're querying. If not,
|
||||
# `_query_devices_for_destination` will return early.
|
||||
self.store.get_rooms_for_user = mock.Mock(
|
||||
return_value=defer.succeed({"some_room_id"})
|
||||
)
|
||||
@ -831,3 +836,94 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# The remote homeserver's response indicates that this user has 0/1/2 devices.
|
||||
([],),
|
||||
(["device_1"],),
|
||||
(["device_1", "device_2"],),
|
||||
]
|
||||
)
|
||||
def test_query_all_devices_caches_result(self, device_ids: Iterable[str]):
|
||||
"""Test that requests for all of a remote user's devices are cached.
|
||||
|
||||
We do this by asserting that only one call over federation was made, and that
|
||||
the two queries to the local homeserver produce the same response.
|
||||
"""
|
||||
local_user_id = "@test:test"
|
||||
remote_user_id = "@test:other"
|
||||
request_body = {"device_keys": {remote_user_id: []}}
|
||||
|
||||
response_devices = [
|
||||
{
|
||||
"device_id": device_id,
|
||||
"keys": {
|
||||
"algorithms": ["dummy"],
|
||||
"device_id": device_id,
|
||||
"keys": {f"dummy:{device_id}": "dummy"},
|
||||
"signatures": {device_id: {f"dummy:{device_id}": "dummy"}},
|
||||
"unsigned": {},
|
||||
"user_id": "@test:other",
|
||||
},
|
||||
}
|
||||
for device_id in device_ids
|
||||
]
|
||||
|
||||
response_body = {
|
||||
"devices": response_devices,
|
||||
"user_id": remote_user_id,
|
||||
"stream_id": 12345, # an integer, according to the spec
|
||||
}
|
||||
|
||||
e2e_handler = self.hs.get_e2e_keys_handler()
|
||||
|
||||
# Pretend we're sharing a room with the user we're querying. If not,
|
||||
# `_query_devices_for_destination` will return early.
|
||||
mock_get_rooms = mock.patch.object(
|
||||
self.store,
|
||||
"get_rooms_for_user",
|
||||
new_callable=mock.MagicMock,
|
||||
return_value=make_awaitable(["some_room_id"]),
|
||||
)
|
||||
mock_request = mock.patch.object(
|
||||
self.hs.get_federation_client(),
|
||||
"query_user_devices",
|
||||
new_callable=mock.MagicMock,
|
||||
return_value=make_awaitable(response_body),
|
||||
)
|
||||
|
||||
with mock_get_rooms, mock_request as mocked_federation_request:
|
||||
# Make the first query and sanity check it succeeds.
|
||||
response_1 = self.get_success(
|
||||
e2e_handler.query_devices(
|
||||
request_body,
|
||||
timeout=10,
|
||||
from_user_id=local_user_id,
|
||||
from_device_id="some_device_id",
|
||||
)
|
||||
)
|
||||
self.assertEqual(response_1["failures"], {})
|
||||
|
||||
# We should have made a federation request to do so.
|
||||
mocked_federation_request.assert_called_once()
|
||||
|
||||
# Reset the mock so we can prove we don't make a second federation request.
|
||||
mocked_federation_request.reset_mock()
|
||||
|
||||
# Repeat the query.
|
||||
response_2 = self.get_success(
|
||||
e2e_handler.query_devices(
|
||||
request_body,
|
||||
timeout=10,
|
||||
from_user_id=local_user_id,
|
||||
from_device_id="some_device_id",
|
||||
)
|
||||
)
|
||||
self.assertEqual(response_2["failures"], {})
|
||||
|
||||
# We should not have made a second federation request.
|
||||
mocked_federation_request.assert_not_called()
|
||||
|
||||
# The two requests to the local homeserver should be identical.
|
||||
self.assertEqual(response_1, response_2)
|
||||
|
@ -19,7 +19,7 @@ import sys
|
||||
import warnings
|
||||
from asyncio import Future
|
||||
from binascii import unhexlify
|
||||
from typing import Any, Awaitable, Callable, TypeVar
|
||||
from typing import Awaitable, Callable, TypeVar
|
||||
from unittest.mock import Mock
|
||||
|
||||
import attr
|
||||
@ -46,7 +46,7 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
|
||||
raise Exception("awaitable has not yet completed")
|
||||
|
||||
|
||||
def make_awaitable(result: Any) -> Awaitable[Any]:
|
||||
def make_awaitable(result: TV) -> Awaitable[TV]:
|
||||
"""
|
||||
Makes an awaitable, suitable for mocking an `async` function.
|
||||
This uses Futures as they can be awaited multiple times so can be returned
|
||||
|
Loading…
Reference in New Issue
Block a user