mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Cache empty responses from /user/devices
(#11587)
If we've never made a request to a remote homeserver, we should cache the response---even if the response is "this user has no devices".
This commit is contained in:
parent
0fb3dd0830
commit
88a78c6577
1
changelog.d/11587.bugfix
Normal file
1
changelog.d/11587.bugfix
Normal file
@ -0,0 +1 @@
|
|||||||
|
Fix a long-standing bug where Synapse wouldn't cache a response indicating that a remote user has no devices.
|
@ -948,8 +948,16 @@ class DeviceListUpdater:
|
|||||||
devices = []
|
devices = []
|
||||||
ignore_devices = True
|
ignore_devices = True
|
||||||
else:
|
else:
|
||||||
|
prev_stream_id = await self.store.get_device_list_last_stream_id_for_remote(
|
||||||
|
user_id
|
||||||
|
)
|
||||||
cached_devices = await self.store.get_cached_devices_for_user(user_id)
|
cached_devices = await self.store.get_cached_devices_for_user(user_id)
|
||||||
if cached_devices == {d["device_id"]: d for d in devices}:
|
|
||||||
|
# To ensure that a user with no devices is cached, we skip the resync only
|
||||||
|
# if we have a stream_id from previously writing a cache entry.
|
||||||
|
if prev_stream_id is not None and cached_devices == {
|
||||||
|
d["device_id"]: d for d in devices
|
||||||
|
}:
|
||||||
logging.info(
|
logging.info(
|
||||||
"Skipping device list resync for %s, as our cache matches already",
|
"Skipping device list resync for %s, as our cache matches already",
|
||||||
user_id,
|
user_id,
|
||||||
|
@ -713,7 +713,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
@cached(max_entries=10000)
|
@cached(max_entries=10000)
|
||||||
async def get_device_list_last_stream_id_for_remote(
|
async def get_device_list_last_stream_id_for_remote(
|
||||||
self, user_id: str
|
self, user_id: str
|
||||||
) -> Optional[Any]:
|
) -> Optional[str]:
|
||||||
"""Get the last stream_id we got for a user. May be None if we haven't
|
"""Get the last stream_id we got for a user. May be None if we haven't
|
||||||
got any information for them.
|
got any information for them.
|
||||||
"""
|
"""
|
||||||
@ -729,7 +729,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
cached_method_name="get_device_list_last_stream_id_for_remote",
|
cached_method_name="get_device_list_last_stream_id_for_remote",
|
||||||
list_name="user_ids",
|
list_name="user_ids",
|
||||||
)
|
)
|
||||||
async def get_device_list_last_stream_id_for_remotes(self, user_ids: Iterable[str]):
|
async def get_device_list_last_stream_id_for_remotes(
|
||||||
|
self, user_ids: Iterable[str]
|
||||||
|
) -> Dict[str, Optional[str]]:
|
||||||
rows = await self.db_pool.simple_select_many_batch(
|
rows = await self.db_pool.simple_select_many_batch(
|
||||||
table="device_lists_remote_extremeties",
|
table="device_lists_remote_extremeties",
|
||||||
column="user_id",
|
column="user_id",
|
||||||
@ -1316,6 +1318,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||||||
content: JsonDict,
|
content: JsonDict,
|
||||||
stream_id: str,
|
stream_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Delete, update or insert a cache entry for this (user, device) pair."""
|
||||||
if content.get("deleted"):
|
if content.get("deleted"):
|
||||||
self.db_pool.simple_delete_txn(
|
self.db_pool.simple_delete_txn(
|
||||||
txn,
|
txn,
|
||||||
@ -1375,6 +1378,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||||||
def _update_remote_device_list_cache_txn(
|
def _update_remote_device_list_cache_txn(
|
||||||
self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
|
self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Replace the list of cached devices for this user with the given list."""
|
||||||
self.db_pool.simple_delete_txn(
|
self.db_pool.simple_delete_txn(
|
||||||
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
|
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
|
||||||
)
|
)
|
||||||
|
@ -13,8 +13,10 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from typing import Iterable
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
|
from parameterized import parameterized
|
||||||
from signedjson import key as key, sign as sign
|
from signedjson import key as key, sign as sign
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
@ -23,6 +25,7 @@ from synapse.api.constants import RoomEncryptionAlgorithms
|
|||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
from tests.test_utils import make_awaitable
|
||||||
|
|
||||||
|
|
||||||
class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
@ -765,6 +768,8 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||||||
remote_user_id = "@test:other"
|
remote_user_id = "@test:other"
|
||||||
local_user_id = "@test:test"
|
local_user_id = "@test:test"
|
||||||
|
|
||||||
|
# Pretend we're sharing a room with the user we're querying. If not,
|
||||||
|
# `_query_devices_for_destination` will return early.
|
||||||
self.store.get_rooms_for_user = mock.Mock(
|
self.store.get_rooms_for_user = mock.Mock(
|
||||||
return_value=defer.succeed({"some_room_id"})
|
return_value=defer.succeed({"some_room_id"})
|
||||||
)
|
)
|
||||||
@ -831,3 +836,94 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@parameterized.expand(
|
||||||
|
[
|
||||||
|
# The remote homeserver's response indicates that this user has 0/1/2 devices.
|
||||||
|
([],),
|
||||||
|
(["device_1"],),
|
||||||
|
(["device_1", "device_2"],),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_query_all_devices_caches_result(self, device_ids: Iterable[str]):
|
||||||
|
"""Test that requests for all of a remote user's devices are cached.
|
||||||
|
|
||||||
|
We do this by asserting that only one call over federation was made, and that
|
||||||
|
the two queries to the local homeserver produce the same response.
|
||||||
|
"""
|
||||||
|
local_user_id = "@test:test"
|
||||||
|
remote_user_id = "@test:other"
|
||||||
|
request_body = {"device_keys": {remote_user_id: []}}
|
||||||
|
|
||||||
|
response_devices = [
|
||||||
|
{
|
||||||
|
"device_id": device_id,
|
||||||
|
"keys": {
|
||||||
|
"algorithms": ["dummy"],
|
||||||
|
"device_id": device_id,
|
||||||
|
"keys": {f"dummy:{device_id}": "dummy"},
|
||||||
|
"signatures": {device_id: {f"dummy:{device_id}": "dummy"}},
|
||||||
|
"unsigned": {},
|
||||||
|
"user_id": "@test:other",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for device_id in device_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
response_body = {
|
||||||
|
"devices": response_devices,
|
||||||
|
"user_id": remote_user_id,
|
||||||
|
"stream_id": 12345, # an integer, according to the spec
|
||||||
|
}
|
||||||
|
|
||||||
|
e2e_handler = self.hs.get_e2e_keys_handler()
|
||||||
|
|
||||||
|
# Pretend we're sharing a room with the user we're querying. If not,
|
||||||
|
# `_query_devices_for_destination` will return early.
|
||||||
|
mock_get_rooms = mock.patch.object(
|
||||||
|
self.store,
|
||||||
|
"get_rooms_for_user",
|
||||||
|
new_callable=mock.MagicMock,
|
||||||
|
return_value=make_awaitable(["some_room_id"]),
|
||||||
|
)
|
||||||
|
mock_request = mock.patch.object(
|
||||||
|
self.hs.get_federation_client(),
|
||||||
|
"query_user_devices",
|
||||||
|
new_callable=mock.MagicMock,
|
||||||
|
return_value=make_awaitable(response_body),
|
||||||
|
)
|
||||||
|
|
||||||
|
with mock_get_rooms, mock_request as mocked_federation_request:
|
||||||
|
# Make the first query and sanity check it succeeds.
|
||||||
|
response_1 = self.get_success(
|
||||||
|
e2e_handler.query_devices(
|
||||||
|
request_body,
|
||||||
|
timeout=10,
|
||||||
|
from_user_id=local_user_id,
|
||||||
|
from_device_id="some_device_id",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(response_1["failures"], {})
|
||||||
|
|
||||||
|
# We should have made a federation request to do so.
|
||||||
|
mocked_federation_request.assert_called_once()
|
||||||
|
|
||||||
|
# Reset the mock so we can prove we don't make a second federation request.
|
||||||
|
mocked_federation_request.reset_mock()
|
||||||
|
|
||||||
|
# Repeat the query.
|
||||||
|
response_2 = self.get_success(
|
||||||
|
e2e_handler.query_devices(
|
||||||
|
request_body,
|
||||||
|
timeout=10,
|
||||||
|
from_user_id=local_user_id,
|
||||||
|
from_device_id="some_device_id",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(response_2["failures"], {})
|
||||||
|
|
||||||
|
# We should not have made a second federation request.
|
||||||
|
mocked_federation_request.assert_not_called()
|
||||||
|
|
||||||
|
# The two requests to the local homeserver should be identical.
|
||||||
|
self.assertEqual(response_1, response_2)
|
||||||
|
@ -19,7 +19,7 @@ import sys
|
|||||||
import warnings
|
import warnings
|
||||||
from asyncio import Future
|
from asyncio import Future
|
||||||
from binascii import unhexlify
|
from binascii import unhexlify
|
||||||
from typing import Any, Awaitable, Callable, TypeVar
|
from typing import Awaitable, Callable, TypeVar
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
@ -46,7 +46,7 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
|
|||||||
raise Exception("awaitable has not yet completed")
|
raise Exception("awaitable has not yet completed")
|
||||||
|
|
||||||
|
|
||||||
def make_awaitable(result: Any) -> Awaitable[Any]:
|
def make_awaitable(result: TV) -> Awaitable[TV]:
|
||||||
"""
|
"""
|
||||||
Makes an awaitable, suitable for mocking an `async` function.
|
Makes an awaitable, suitable for mocking an `async` function.
|
||||||
This uses Futures as they can be awaited multiple times so can be returned
|
This uses Futures as they can be awaited multiple times so can be returned
|
||||||
|
Loading…
Reference in New Issue
Block a user