Cache empty responses from /user/devices ()

If we've never made a request to a remote homeserver, we should cache the response---even if the response is "this user has no devices".
This commit is contained in:
David Robertson 2022-01-05 13:33:28 +00:00 committed by GitHub
parent 0fb3dd0830
commit 88a78c6577
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 114 additions and 5 deletions
changelog.d
synapse
handlers
storage/databases/main
tests
handlers
test_utils

1
changelog.d/11587.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a long-standing bug where Synapse wouldn't cache a response indicating that a remote user has no devices.

View File

@ -948,8 +948,16 @@ class DeviceListUpdater:
devices = [] devices = []
ignore_devices = True ignore_devices = True
else: else:
prev_stream_id = await self.store.get_device_list_last_stream_id_for_remote(
user_id
)
cached_devices = await self.store.get_cached_devices_for_user(user_id) cached_devices = await self.store.get_cached_devices_for_user(user_id)
if cached_devices == {d["device_id"]: d for d in devices}:
# To ensure that a user with no devices is cached, we skip the resync only
# if we have a stream_id from previously writing a cache entry.
if prev_stream_id is not None and cached_devices == {
d["device_id"]: d for d in devices
}:
logging.info( logging.info(
"Skipping device list resync for %s, as our cache matches already", "Skipping device list resync for %s, as our cache matches already",
user_id, user_id,

View File

@ -713,7 +713,7 @@ class DeviceWorkerStore(SQLBaseStore):
@cached(max_entries=10000) @cached(max_entries=10000)
async def get_device_list_last_stream_id_for_remote( async def get_device_list_last_stream_id_for_remote(
self, user_id: str self, user_id: str
) -> Optional[Any]: ) -> Optional[str]:
"""Get the last stream_id we got for a user. May be None if we haven't """Get the last stream_id we got for a user. May be None if we haven't
got any information for them. got any information for them.
""" """
@ -729,7 +729,9 @@ class DeviceWorkerStore(SQLBaseStore):
cached_method_name="get_device_list_last_stream_id_for_remote", cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids", list_name="user_ids",
) )
async def get_device_list_last_stream_id_for_remotes(self, user_ids: Iterable[str]): async def get_device_list_last_stream_id_for_remotes(
self, user_ids: Iterable[str]
) -> Dict[str, Optional[str]]:
rows = await self.db_pool.simple_select_many_batch( rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties", table="device_lists_remote_extremeties",
column="user_id", column="user_id",
@ -1316,6 +1318,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
content: JsonDict, content: JsonDict,
stream_id: str, stream_id: str,
) -> None: ) -> None:
"""Delete, update or insert a cache entry for this (user, device) pair."""
if content.get("deleted"): if content.get("deleted"):
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
@ -1375,6 +1378,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def _update_remote_device_list_cache_txn( def _update_remote_device_list_cache_txn(
self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
) -> None: ) -> None:
"""Replace the list of cached devices for this user with the given list."""
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
) )

View File

@ -13,8 +13,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Iterable
from unittest import mock from unittest import mock
from parameterized import parameterized
from signedjson import key as key, sign as sign from signedjson import key as key, sign as sign
from twisted.internet import defer from twisted.internet import defer
@ -23,6 +25,7 @@ from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable
class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
@ -765,6 +768,8 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
remote_user_id = "@test:other" remote_user_id = "@test:other"
local_user_id = "@test:test" local_user_id = "@test:test"
# Pretend we're sharing a room with the user we're querying. If not,
# `_query_devices_for_destination` will return early.
self.store.get_rooms_for_user = mock.Mock( self.store.get_rooms_for_user = mock.Mock(
return_value=defer.succeed({"some_room_id"}) return_value=defer.succeed({"some_room_id"})
) )
@ -831,3 +836,94 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
} }
}, },
) )
@parameterized.expand(
[
# The remote homeserver's response indicates that this user has 0/1/2 devices.
([],),
(["device_1"],),
(["device_1", "device_2"],),
]
)
def test_query_all_devices_caches_result(self, device_ids: Iterable[str]):
"""Test that requests for all of a remote user's devices are cached.
We do this by asserting that only one call over federation was made, and that
the two queries to the local homeserver produce the same response.
"""
local_user_id = "@test:test"
remote_user_id = "@test:other"
request_body = {"device_keys": {remote_user_id: []}}
response_devices = [
{
"device_id": device_id,
"keys": {
"algorithms": ["dummy"],
"device_id": device_id,
"keys": {f"dummy:{device_id}": "dummy"},
"signatures": {device_id: {f"dummy:{device_id}": "dummy"}},
"unsigned": {},
"user_id": "@test:other",
},
}
for device_id in device_ids
]
response_body = {
"devices": response_devices,
"user_id": remote_user_id,
"stream_id": 12345, # an integer, according to the spec
}
e2e_handler = self.hs.get_e2e_keys_handler()
# Pretend we're sharing a room with the user we're querying. If not,
# `_query_devices_for_destination` will return early.
mock_get_rooms = mock.patch.object(
self.store,
"get_rooms_for_user",
new_callable=mock.MagicMock,
return_value=make_awaitable(["some_room_id"]),
)
mock_request = mock.patch.object(
self.hs.get_federation_client(),
"query_user_devices",
new_callable=mock.MagicMock,
return_value=make_awaitable(response_body),
)
with mock_get_rooms, mock_request as mocked_federation_request:
# Make the first query and sanity check it succeeds.
response_1 = self.get_success(
e2e_handler.query_devices(
request_body,
timeout=10,
from_user_id=local_user_id,
from_device_id="some_device_id",
)
)
self.assertEqual(response_1["failures"], {})
# We should have made a federation request to do so.
mocked_federation_request.assert_called_once()
# Reset the mock so we can prove we don't make a second federation request.
mocked_federation_request.reset_mock()
# Repeat the query.
response_2 = self.get_success(
e2e_handler.query_devices(
request_body,
timeout=10,
from_user_id=local_user_id,
from_device_id="some_device_id",
)
)
self.assertEqual(response_2["failures"], {})
# We should not have made a second federation request.
mocked_federation_request.assert_not_called()
# The two requests to the local homeserver should be identical.
self.assertEqual(response_1, response_2)

View File

@ -19,7 +19,7 @@ import sys
import warnings import warnings
from asyncio import Future from asyncio import Future
from binascii import unhexlify from binascii import unhexlify
from typing import Any, Awaitable, Callable, TypeVar from typing import Awaitable, Callable, TypeVar
from unittest.mock import Mock from unittest.mock import Mock
import attr import attr
@ -46,7 +46,7 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
raise Exception("awaitable has not yet completed") raise Exception("awaitable has not yet completed")
def make_awaitable(result: Any) -> Awaitable[Any]: def make_awaitable(result: TV) -> Awaitable[TV]:
""" """
Makes an awaitable, suitable for mocking an `async` function. Makes an awaitable, suitable for mocking an `async` function.
This uses Futures as they can be awaited multiple times so can be returned This uses Futures as they can be awaited multiple times so can be returned