Return read-only collections from @cached methods (#13755)

It's important that collections returned from `@cached` methods are not
modified, otherwise future retrievals from the cache will return the
modified collection.

This applies to the return values from `@cached` methods and the values
inside the dictionaries returned by `@cachedList` methods. It's not
necessary for the dictionaries returned by `@cachedList` methods
themselves to be read-only.

Signed-off-by: Sean Quah <seanq@matrix.org>
Co-authored-by: David Robertson <davidr@element.io>
This commit is contained in:
Sean Quah 2023-02-10 23:29:00 +00:00 committed by GitHub
parent 14be78d492
commit d0c713cc85
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 98 additions and 77 deletions

1
changelog.d/13755.misc Normal file
View File

@ -0,0 +1 @@
Re-type hint some collections as read-only.

View File

@ -15,7 +15,7 @@ import logging
import math import math
import resource import resource
import sys import sys
from typing import TYPE_CHECKING, List, Sized, Tuple from typing import TYPE_CHECKING, List, Mapping, Sized, Tuple
from prometheus_client import Gauge from prometheus_client import Gauge
@ -194,7 +194,7 @@ def start_phone_stats_home(hs: "HomeServer") -> None:
@wrap_as_background_process("generate_monthly_active_users") @wrap_as_background_process("generate_monthly_active_users")
async def generate_monthly_active_users() -> None: async def generate_monthly_active_users() -> None:
current_mau_count = 0 current_mau_count = 0
current_mau_count_by_service = {} current_mau_count_by_service: Mapping[str, int] = {}
reserved_users: Sized = () reserved_users: Sized = ()
store = hs.get_datastores().main store = hs.get_datastores().main
if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only: if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only:

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, List from typing import Any, Collection
from matrix_common.regex import glob_to_regex from matrix_common.regex import glob_to_regex
@ -70,7 +70,7 @@ class RoomDirectoryConfig(Config):
return False return False
def is_publishing_room_allowed( def is_publishing_room_allowed(
self, user_id: str, room_id: str, aliases: List[str] self, user_id: str, room_id: str, aliases: Collection[str]
) -> bool: ) -> bool:
"""Checks if the given user is allowed to publish the room """Checks if the given user is allowed to publish the room
@ -122,7 +122,7 @@ class _RoomDirectoryRule:
except Exception as e: except Exception as e:
raise ConfigError("Failed to parse glob into regex") from e raise ConfigError("Failed to parse glob into regex") from e
def matches(self, user_id: str, room_id: str, aliases: List[str]) -> bool: def matches(self, user_id: str, room_id: str, aliases: Collection[str]) -> bool:
"""Tests if this rule matches the given user_id, room_id and aliases. """Tests if this rule matches the given user_id, room_id and aliases.
Args: Args:

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union
import attr import attr
from signedjson.types import SigningKey from signedjson.types import SigningKey
@ -103,7 +103,7 @@ class EventBuilder:
async def build( async def build(
self, self,
prev_event_ids: List[str], prev_event_ids: Collection[str],
auth_event_ids: Optional[List[str]], auth_event_ids: Optional[List[str]],
depth: Optional[int] = None, depth: Optional[int] = None,
) -> EventBase: ) -> EventBase:
@ -136,7 +136,7 @@ class EventBuilder:
format_version = self.room_version.event_format format_version = self.room_version.event_format
# The types of auth/prev events changes between event versions. # The types of auth/prev events changes between event versions.
prev_events: Union[List[str], List[Tuple[str, Dict[str, str]]]] prev_events: Union[Collection[str], List[Tuple[str, Dict[str, str]]]]
auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]] auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]]
if format_version == EventFormatVersions.ROOM_V1_V2: if format_version == EventFormatVersions.ROOM_V1_V2:
auth_events = await self._store.add_event_hashes(auth_event_ids) auth_events = await self._store.add_event_hashes(auth_event_ids)

View File

@ -23,6 +23,7 @@ from typing import (
Collection, Collection,
Dict, Dict,
List, List,
Mapping,
Optional, Optional,
Tuple, Tuple,
Union, Union,
@ -1512,7 +1513,7 @@ class FederationHandlerRegistry:
def _get_event_ids_for_partial_state_join( def _get_event_ids_for_partial_state_join(
join_event: EventBase, join_event: EventBase,
prev_state_ids: StateMap[str], prev_state_ids: StateMap[str],
summary: Dict[str, MemberSummary], summary: Mapping[str, MemberSummary],
) -> Collection[str]: ) -> Collection[str]:
"""Calculate state to be returned in a partial_state send_join """Calculate state to be returned in a partial_state send_join

View File

@ -14,7 +14,7 @@
import logging import logging
import string import string
from typing import TYPE_CHECKING, Iterable, List, Optional from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence
from typing_extensions import Literal from typing_extensions import Literal
@ -486,7 +486,7 @@ class DirectoryHandler:
) )
if canonical_alias: if canonical_alias:
# Ensure we do not mutate room_aliases. # Ensure we do not mutate room_aliases.
room_aliases = room_aliases + [canonical_alias] room_aliases = list(room_aliases) + [canonical_alias]
if not self.config.roomdirectory.is_publishing_room_allowed( if not self.config.roomdirectory.is_publishing_room_allowed(
user_id, room_id, room_aliases user_id, room_id, room_aliases
@ -529,7 +529,7 @@ class DirectoryHandler:
async def get_aliases_for_room( async def get_aliases_for_room(
self, requester: Requester, room_id: str self, requester: Requester, room_id: str
) -> List[str]: ) -> Sequence[str]:
""" """
Get a list of the aliases that currently point to this room on this server Get a list of the aliases that currently point to this room on this server
""" """

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Tuple
from synapse.api.constants import EduTypes, ReceiptTypes from synapse.api.constants import EduTypes, ReceiptTypes
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
@ -189,7 +189,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
@staticmethod @staticmethod
def filter_out_private_receipts( def filter_out_private_receipts(
rooms: List[JsonDict], user_id: str rooms: Sequence[JsonDict], user_id: str
) -> List[JsonDict]: ) -> List[JsonDict]:
""" """
Filters a list of serialized receipts (as returned by /sync and /initialSync) Filters a list of serialized receipts (as returned by /sync and /initialSync)

View File

@ -1928,6 +1928,6 @@ class RoomShutdownHandler:
return { return {
"kicked_users": kicked_users, "kicked_users": kicked_users,
"failed_to_kick_users": failed_to_kick_users, "failed_to_kick_users": failed_to_kick_users,
"local_aliases": aliases_for_room, "local_aliases": list(aliases_for_room),
"new_room_id": new_room_id, "new_room_id": new_room_id,
} }

View File

@ -1519,7 +1519,7 @@ class SyncHandler:
one_time_keys_count = await self.store.count_e2e_one_time_keys( one_time_keys_count = await self.store.count_e2e_one_time_keys(
user_id, device_id user_id, device_id
) )
unused_fallback_key_types = ( unused_fallback_key_types = list(
await self.store.get_e2e_unused_fallback_key_types(user_id, device_id) await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
) )
@ -2301,7 +2301,7 @@ class SyncHandler:
sync_result_builder: "SyncResultBuilder", sync_result_builder: "SyncResultBuilder",
room_builder: "RoomSyncResultBuilder", room_builder: "RoomSyncResultBuilder",
ephemeral: List[JsonDict], ephemeral: List[JsonDict],
tags: Optional[Dict[str, Dict[str, Any]]], tags: Optional[Mapping[str, Mapping[str, Any]]],
account_data: Mapping[str, JsonDict], account_data: Mapping[str, JsonDict],
always_include: bool = False, always_include: bool = False,
) -> None: ) -> None:

View File

@ -22,6 +22,7 @@ from typing import (
List, List,
Mapping, Mapping,
Optional, Optional,
Sequence,
Set, Set,
Tuple, Tuple,
Union, Union,
@ -149,7 +150,7 @@ class BulkPushRuleEvaluator:
# little, we can skip fetching a huge number of push rules in large rooms. # little, we can skip fetching a huge number of push rules in large rooms.
# This helps make joins and leaves faster. # This helps make joins and leaves faster.
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
local_users = [] local_users: Sequence[str] = []
# We never notify a user about their own actions. This is enforced in # We never notify a user about their own actions. This is enforced in
# `_action_for_event_by_user` in the loop over `rules_by_user`, but we # `_action_for_event_by_user` in the loop over `rules_by_user`, but we
# do the same check here to avoid unnecessary DB queries. # do the same check here to avoid unnecessary DB queries.
@ -184,7 +185,6 @@ class BulkPushRuleEvaluator:
if event.type == EventTypes.Member and event.membership == Membership.INVITE: if event.type == EventTypes.Member and event.membership == Membership.INVITE:
invited = event.state_key invited = event.state_key
if invited and self.hs.is_mine_id(invited) and invited not in local_users: if invited and self.hs.is_mine_id(invited) and invited not in local_users:
local_users = list(local_users)
local_users.append(invited) local_users.append(invited)
if not local_users: if not local_users:

View File

@ -226,7 +226,7 @@ class StateHandler:
return await ret.get_state(self._state_storage_controller, state_filter) return await ret.get_state(self._state_storage_controller, state_filter)
async def get_current_user_ids_in_room( async def get_current_user_ids_in_room(
self, room_id: str, latest_event_ids: List[str] self, room_id: str, latest_event_ids: Collection[str]
) -> Set[str]: ) -> Set[str]:
""" """
Get the users IDs who are currently in a room. Get the users IDs who are currently in a room.

View File

@ -14,6 +14,7 @@
import logging import logging
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
AbstractSet,
Any, Any,
Awaitable, Awaitable,
Callable, Callable,
@ -23,7 +24,6 @@ from typing import (
List, List,
Mapping, Mapping,
Optional, Optional,
Set,
Tuple, Tuple,
) )
@ -527,7 +527,7 @@ class StateStorageController:
) )
return state_map.get(key) return state_map.get(key)
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]: async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]:
"""Get current hosts in room based on current state. """Get current hosts in room based on current state.
Blocks until we have full state for the given room. This only happens for rooms Blocks until we have full state for the given room. This only happens for rooms
@ -584,7 +584,7 @@ class StateStorageController:
async def get_users_in_room_with_profiles( async def get_users_in_room_with_profiles(
self, room_id: str self, room_id: str
) -> Dict[str, ProfileInfo]: ) -> Mapping[str, ProfileInfo]:
""" """
Get the current users in the room with their profiles. Get the current users in the room with their profiles.
If the room is currently partial-stated, this will block until the room has If the room is currently partial-stated, this will block until the room has

View File

@ -240,7 +240,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
@cached(num_args=2, tree=True) @cached(num_args=2, tree=True)
async def get_account_data_for_room( async def get_account_data_for_room(
self, user_id: str, room_id: str self, user_id: str, room_id: str
) -> Dict[str, JsonDict]: ) -> Mapping[str, JsonDict]:
"""Get all the client account_data for a user for a room. """Get all the client account_data for a user for a room.
Args: Args:

View File

@ -166,7 +166,7 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
room_id: str, room_id: str,
app_service: "ApplicationService", app_service: "ApplicationService",
cache_context: _CacheContext, cache_context: _CacheContext,
) -> List[str]: ) -> Sequence[str]:
""" """
Get all users in a room that the appservice controls. Get all users in a room that the appservice controls.

View File

@ -21,6 +21,7 @@ from typing import (
Dict, Dict,
Iterable, Iterable,
List, List,
Mapping,
Optional, Optional,
Set, Set,
Tuple, Tuple,
@ -202,7 +203,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
def get_device_stream_token(self) -> int: def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token() return self._device_list_id_gen.get_current_token()
async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int: async def count_devices_by_users(
self, user_ids: Optional[Collection[str]] = None
) -> int:
"""Retrieve number of all devices of given users. """Retrieve number of all devices of given users.
Only returns number of devices that are not marked as hidden. Only returns number of devices that are not marked as hidden.
@ -213,7 +216,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
""" """
def count_devices_by_users_txn( def count_devices_by_users_txn(
txn: LoggingTransaction, user_ids: List[str] txn: LoggingTransaction, user_ids: Collection[str]
) -> int: ) -> int:
sql = """ sql = """
SELECT count(*) SELECT count(*)
@ -747,7 +750,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
@cancellable @cancellable
async def get_user_devices_from_cache( async def get_user_devices_from_cache(
self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]] self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]]
) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]: ) -> Tuple[Set[str], Dict[str, Mapping[str, JsonDict]]]:
"""Get the devices (and keys if any) for remote users from the cache. """Get the devices (and keys if any) for remote users from the cache.
Args: Args:
@ -775,16 +778,18 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
user_ids_not_in_cache = unique_user_ids - user_ids_in_cache user_ids_not_in_cache = unique_user_ids - user_ids_in_cache
# First fetch all the users which all devices are to be returned. # First fetch all the users which all devices are to be returned.
results: Dict[str, Dict[str, JsonDict]] = {} results: Dict[str, Mapping[str, JsonDict]] = {}
for user_id in user_ids: for user_id in user_ids:
if user_id in user_ids_in_cache: if user_id in user_ids_in_cache:
results[user_id] = await self.get_cached_devices_for_user(user_id) results[user_id] = await self.get_cached_devices_for_user(user_id)
# Then fetch all device-specific requests, but skip users we've already # Then fetch all device-specific requests, but skip users we've already
# fetched all devices for. # fetched all devices for.
device_specific_results: Dict[str, Dict[str, JsonDict]] = {}
for user_id, device_id in user_and_device_ids: for user_id, device_id in user_and_device_ids:
if user_id in user_ids_in_cache and user_id not in user_ids: if user_id in user_ids_in_cache and user_id not in user_ids:
device = await self._get_cached_user_device(user_id, device_id) device = await self._get_cached_user_device(user_id, device_id)
results.setdefault(user_id, {})[device_id] = device device_specific_results.setdefault(user_id, {})[device_id] = device
results.update(device_specific_results)
set_tag("in_cache", str(results)) set_tag("in_cache", str(results))
set_tag("not_in_cache", str(user_ids_not_in_cache)) set_tag("not_in_cache", str(user_ids_not_in_cache))
@ -802,7 +807,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return db_to_json(content) return db_to_json(content)
@cached() @cached()
async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]: async def get_cached_devices_for_user(self, user_id: str) -> Mapping[str, JsonDict]:
devices = await self.db_pool.simple_select_list( devices = await self.db_pool.simple_select_list(
table="device_lists_remote_cache", table="device_lists_remote_cache",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Sequence, Tuple
import attr import attr
@ -74,7 +74,7 @@ class DirectoryWorkerStore(CacheInvalidationWorkerStore):
) )
@cached(max_entries=5000) @cached(max_entries=5000)
async def get_aliases_for_room(self, room_id: str) -> List[str]: async def get_aliases_for_room(self, room_id: str) -> Sequence[str]:
return await self.db_pool.simple_select_onecol( return await self.db_pool.simple_select_onecol(
"room_aliases", "room_aliases",
{"room_id": room_id}, {"room_id": room_id},

View File

@ -20,7 +20,9 @@ from typing import (
Dict, Dict,
Iterable, Iterable,
List, List,
Mapping,
Optional, Optional,
Sequence,
Tuple, Tuple,
Union, Union,
cast, cast,
@ -691,7 +693,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
@cached(max_entries=10000) @cached(max_entries=10000)
async def get_e2e_unused_fallback_key_types( async def get_e2e_unused_fallback_key_types(
self, user_id: str, device_id: str self, user_id: str, device_id: str
) -> List[str]: ) -> Sequence[str]:
"""Returns the fallback key types that have an unused key. """Returns the fallback key types that have an unused key.
Args: Args:
@ -731,7 +733,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
return user_keys.get(key_type) return user_keys.get(key_type)
@cached(num_args=1) @cached(num_args=1)
def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Dict[str, JsonDict]: def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Mapping[str, JsonDict]:
"""Dummy function. Only used to make a cache for """Dummy function. Only used to make a cache for
_get_bare_e2e_cross_signing_keys_bulk. _get_bare_e2e_cross_signing_keys_bulk.
""" """
@ -744,7 +746,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
) )
async def _get_bare_e2e_cross_signing_keys_bulk( async def _get_bare_e2e_cross_signing_keys_bulk(
self, user_ids: Iterable[str] self, user_ids: Iterable[str]
) -> Dict[str, Optional[Dict[str, JsonDict]]]: ) -> Dict[str, Optional[Mapping[str, JsonDict]]]:
"""Returns the cross-signing keys for a set of users. The output of this """Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if function should be passed to _get_e2e_cross_signing_signatures_txn if
the signatures for the calling user need to be fetched. the signatures for the calling user need to be fetched.
@ -765,7 +767,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
) )
# The `Optional` comes from the `@cachedList` decorator. # The `Optional` comes from the `@cachedList` decorator.
return cast(Dict[str, Optional[Dict[str, JsonDict]]], result) return cast(Dict[str, Optional[Mapping[str, JsonDict]]], result)
def _get_bare_e2e_cross_signing_keys_bulk_txn( def _get_bare_e2e_cross_signing_keys_bulk_txn(
self, self,
@ -924,7 +926,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
@cancellable @cancellable
async def get_e2e_cross_signing_keys_bulk( async def get_e2e_cross_signing_keys_bulk(
self, user_ids: List[str], from_user_id: Optional[str] = None self, user_ids: List[str], from_user_id: Optional[str] = None
) -> Dict[str, Optional[Dict[str, JsonDict]]]: ) -> Dict[str, Optional[Mapping[str, JsonDict]]]:
"""Returns the cross-signing keys for a set of users. """Returns the cross-signing keys for a set of users.
Args: Args:
@ -940,11 +942,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids) result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
if from_user_id: if from_user_id:
result = await self.db_pool.runInteraction( result = cast(
"get_e2e_cross_signing_signatures", Dict[str, Optional[Mapping[str, JsonDict]]],
self._get_e2e_cross_signing_signatures_txn, await self.db_pool.runInteraction(
result, "get_e2e_cross_signing_signatures",
from_user_id, self._get_e2e_cross_signing_signatures_txn,
result,
from_user_id,
),
) )
return result return result

View File

@ -22,6 +22,7 @@ from typing import (
Iterable, Iterable,
List, List,
Optional, Optional,
Sequence,
Set, Set,
Tuple, Tuple,
cast, cast,
@ -1004,7 +1005,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
room_id, room_id,
) )
async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]: async def get_max_depth_of(
self, event_ids: Collection[str]
) -> Tuple[Optional[str], int]:
"""Returns the event ID and depth for the event that has the max depth from a set of event IDs """Returns the event ID and depth for the event that has the max depth from a set of event IDs
Args: Args:
@ -1141,7 +1144,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
) )
@cached(max_entries=5000, iterable=True) @cached(max_entries=5000, iterable=True)
async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]: async def get_latest_event_ids_in_room(self, room_id: str) -> Sequence[str]:
return await self.db_pool.simple_select_onecol( return await self.db_pool.simple_select_onecol(
table="event_forward_extremities", table="event_forward_extremities",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
@ -1171,7 +1174,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
@cancellable @cancellable
async def get_forward_extremities_for_room_at_stream_ordering( async def get_forward_extremities_for_room_at_stream_ordering(
self, room_id: str, stream_ordering: int self, room_id: str, stream_ordering: int
) -> List[str]: ) -> Sequence[str]:
"""For a given room_id and stream_ordering, return the forward """For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time". extremeties of the room at that point in "time".
@ -1204,7 +1207,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
@cached(max_entries=5000, num_args=2) @cached(max_entries=5000, num_args=2)
async def _get_forward_extremeties_for_room( async def _get_forward_extremeties_for_room(
self, room_id: str, stream_ordering: int self, room_id: str, stream_ordering: int
) -> List[str]: ) -> Sequence[str]:
"""For a given room_id and stream_ordering, return the forward """For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time". extremeties of the room at that point in "time".

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, cast
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage.database import ( from synapse.storage.database import (
@ -95,7 +95,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
return await self.db_pool.runInteraction("count_users", _count_users) return await self.db_pool.runInteraction("count_users", _count_users)
@cached(num_args=0) @cached(num_args=0)
async def get_monthly_active_count_by_service(self) -> Dict[str, int]: async def get_monthly_active_count_by_service(self) -> Mapping[str, int]:
"""Generates current count of monthly active users broken down by service. """Generates current count of monthly active users broken down by service.
A service is typically an appservice but also includes native matrix users. A service is typically an appservice but also includes native matrix users.
Since the `monthly_active_users` table is populated from the `user_ips` table Since the `monthly_active_users` table is populated from the `user_ips` table

View File

@ -21,7 +21,9 @@ from typing import (
Dict, Dict,
Iterable, Iterable,
List, List,
Mapping,
Optional, Optional,
Sequence,
Tuple, Tuple,
cast, cast,
) )
@ -288,7 +290,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
async def get_linearized_receipts_for_room( async def get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None self, room_id: str, to_key: int, from_key: Optional[int] = None
) -> List[dict]: ) -> Sequence[JsonDict]:
"""Get receipts for a single room for sending to clients. """Get receipts for a single room for sending to clients.
Args: Args:
@ -311,7 +313,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cached(tree=True) @cached(tree=True)
async def _get_linearized_receipts_for_room( async def _get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None self, room_id: str, to_key: int, from_key: Optional[int] = None
) -> List[JsonDict]: ) -> Sequence[JsonDict]:
"""See get_linearized_receipts_for_room""" """See get_linearized_receipts_for_room"""
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
@ -354,7 +356,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
) )
async def _get_linearized_receipts_for_rooms( async def _get_linearized_receipts_for_rooms(
self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
) -> Dict[str, List[JsonDict]]: ) -> Dict[str, Sequence[JsonDict]]:
if not room_ids: if not room_ids:
return {} return {}
@ -416,7 +418,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
) )
async def get_linearized_receipts_for_all_rooms( async def get_linearized_receipts_for_all_rooms(
self, to_key: int, from_key: Optional[int] = None self, to_key: int, from_key: Optional[int] = None
) -> Dict[str, JsonDict]: ) -> Mapping[str, JsonDict]:
"""Get receipts for all rooms between two stream_ids, up """Get receipts for all rooms between two stream_ids, up
to a limit of the latest 100 read receipts. to a limit of the latest 100 read receipts.

View File

@ -16,7 +16,7 @@
import logging import logging
import random import random
import re import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
import attr import attr
@ -192,7 +192,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
) )
@cached() @cached()
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]:
"""Deprecated: use get_userinfo_by_id instead""" """Deprecated: use get_userinfo_by_id instead"""
def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]: def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:

View File

@ -22,6 +22,7 @@ from typing import (
List, List,
Mapping, Mapping,
Optional, Optional,
Sequence,
Set, Set,
Tuple, Tuple,
Union, Union,
@ -171,7 +172,7 @@ class RelationsWorkerStore(SQLBaseStore):
direction: Direction = Direction.BACKWARDS, direction: Direction = Direction.BACKWARDS,
from_token: Optional[StreamToken] = None, from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None, to_token: Optional[StreamToken] = None,
) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: ) -> Tuple[Sequence[_RelatedEvent], Optional[StreamToken]]:
"""Get a list of relations for an event, ordered by topological ordering. """Get a list of relations for an event, ordered by topological ordering.
Args: Args:
@ -397,7 +398,9 @@ class RelationsWorkerStore(SQLBaseStore):
return result is not None return result is not None
@cached() @cached()
async def get_aggregation_groups_for_event(self, event_id: str) -> List[JsonDict]: async def get_aggregation_groups_for_event(
self, event_id: str
) -> Sequence[JsonDict]:
raise NotImplementedError() raise NotImplementedError()
@cachedList( @cachedList(

View File

@ -24,6 +24,7 @@ from typing import (
List, List,
Mapping, Mapping,
Optional, Optional,
Sequence,
Set, Set,
Tuple, Tuple,
Union, Union,
@ -153,7 +154,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return self._known_servers_count return self._known_servers_count
@cached(max_entries=100000, iterable=True) @cached(max_entries=100000, iterable=True)
async def get_users_in_room(self, room_id: str) -> List[str]: async def get_users_in_room(self, room_id: str) -> Sequence[str]:
"""Returns a list of users in the room. """Returns a list of users in the room.
Will return inaccurate results for rooms with partial state, since the state for Will return inaccurate results for rooms with partial state, since the state for
@ -190,9 +191,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
) )
@cached() @cached()
def get_user_in_room_with_profile( def get_user_in_room_with_profile(self, room_id: str, user_id: str) -> ProfileInfo:
self, room_id: str, user_id: str
) -> Dict[str, ProfileInfo]:
raise NotImplementedError() raise NotImplementedError()
@cachedList( @cachedList(
@ -246,7 +245,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cached(max_entries=100000, iterable=True) @cached(max_entries=100000, iterable=True)
async def get_users_in_room_with_profiles( async def get_users_in_room_with_profiles(
self, room_id: str self, room_id: str
) -> Dict[str, ProfileInfo]: ) -> Mapping[str, ProfileInfo]:
"""Get a mapping from user ID to profile information for all users in a given room. """Get a mapping from user ID to profile information for all users in a given room.
The profile information comes directly from this room's `m.room.member` The profile information comes directly from this room's `m.room.member`
@ -285,7 +284,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
) )
@cached(max_entries=100000) @cached(max_entries=100000)
async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]: async def get_room_summary(self, room_id: str) -> Mapping[str, MemberSummary]:
"""Get the details of a room roughly suitable for use by the room """Get the details of a room roughly suitable for use by the room
summary extension to /sync. Useful when lazy loading room members. summary extension to /sync. Useful when lazy loading room members.
Args: Args:
@ -357,7 +356,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cached() @cached()
async def get_invited_rooms_for_local_user( async def get_invited_rooms_for_local_user(
self, user_id: str self, user_id: str
) -> List[RoomsForUser]: ) -> Sequence[RoomsForUser]:
"""Get all the rooms the *local* user is invited to. """Get all the rooms the *local* user is invited to.
Args: Args:
@ -475,7 +474,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results return results
@cached(iterable=True) @cached(iterable=True)
async def get_local_users_in_room(self, room_id: str) -> List[str]: async def get_local_users_in_room(self, room_id: str) -> Sequence[str]:
""" """
Retrieves a list of the current roommembers who are local to the server. Retrieves a list of the current roommembers who are local to the server.
""" """
@ -791,7 +790,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""Returns the set of users who share a room with `user_id`""" """Returns the set of users who share a room with `user_id`"""
room_ids = await self.get_rooms_for_user(user_id) room_ids = await self.get_rooms_for_user(user_id)
user_who_share_room = set() user_who_share_room: Set[str] = set()
for room_id in room_ids: for room_id in room_ids:
user_ids = await self.get_users_in_room(room_id) user_ids = await self.get_users_in_room(room_id)
user_who_share_room.update(user_ids) user_who_share_room.update(user_ids)
@ -953,7 +952,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return True return True
@cached(iterable=True, max_entries=10000) @cached(iterable=True, max_entries=10000)
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]: async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]:
"""Get current hosts in room based on current state.""" """Get current hosts in room based on current state."""
# First we check if we already have `get_users_in_room` in the cache, as # First we check if we already have `get_users_in_room` in the cache, as

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Collection, Dict, List, Tuple from typing import Collection, Dict, List, Mapping, Tuple
from unpaddedbase64 import encode_base64 from unpaddedbase64 import encode_base64
@ -26,7 +26,7 @@ from synapse.util.caches.descriptors import cached, cachedList
class SignatureWorkerStore(EventsWorkerStore): class SignatureWorkerStore(EventsWorkerStore):
@cached() @cached()
def get_event_reference_hash(self, event_id: str) -> Dict[str, Dict[str, bytes]]: def get_event_reference_hash(self, event_id: str) -> Mapping[str, bytes]:
# This is a dummy function to allow get_event_reference_hashes # This is a dummy function to allow get_event_reference_hashes
# to use its cache # to use its cache
raise NotImplementedError() raise NotImplementedError()
@ -36,7 +36,7 @@ class SignatureWorkerStore(EventsWorkerStore):
) )
async def get_event_reference_hashes( async def get_event_reference_hashes(
self, event_ids: Collection[str] self, event_ids: Collection[str]
) -> Dict[str, Dict[str, bytes]]: ) -> Mapping[str, Mapping[str, bytes]]:
"""Get all hashes for given events. """Get all hashes for given events.
Args: Args:

View File

@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Dict, Iterable, List, Tuple, cast from typing import Any, Dict, Iterable, List, Mapping, Tuple, cast
from synapse.api.constants import AccountDataTypes from synapse.api.constants import AccountDataTypes
from synapse.replication.tcp.streams import AccountDataStream from synapse.replication.tcp.streams import AccountDataStream
@ -32,7 +32,9 @@ logger = logging.getLogger(__name__)
class TagsWorkerStore(AccountDataWorkerStore): class TagsWorkerStore(AccountDataWorkerStore):
@cached() @cached()
async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict]]: async def get_tags_for_user(
self, user_id: str
) -> Mapping[str, Mapping[str, JsonDict]]:
"""Get all the tags for a user. """Get all the tags for a user.
@ -107,7 +109,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
async def get_updated_tags( async def get_updated_tags(
self, user_id: str, stream_id: int self, user_id: str, stream_id: int
) -> Dict[str, Dict[str, JsonDict]]: ) -> Mapping[str, Mapping[str, JsonDict]]:
"""Get all the tags for the rooms where the tags have changed since the """Get all the tags for the rooms where the tags have changed since the
given version given version

View File

@ -16,9 +16,9 @@ import logging
import re import re
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Dict,
Iterable, Iterable,
List, List,
Mapping,
Optional, Optional,
Sequence, Sequence,
Set, Set,
@ -586,7 +586,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
) )
@cached() @cached()
async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, str]]: async def get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]:
return await self.db_pool.simple_select_one( return await self.db_pool.simple_select_one(
table="user_directory", table="user_directory",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List from typing import List, Sequence
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -558,7 +558,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
def _check_invite_and_join_status( def _check_invite_and_join_status(
self, user_id: str, expected_invites: int, expected_memberships: int self, user_id: str, expected_invites: int, expected_memberships: int
) -> List[RoomsForUser]: ) -> Sequence[RoomsForUser]:
"""Check invite and room membership status of a user. """Check invite and room membership status of a user.
Args Args