Add some type hints to datastore (#12717)

This commit is contained in:
Dirk Klimpel 2022-05-17 16:29:06 +02:00 committed by GitHub
parent 942c30b16b
commit 6edefef602
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 254 additions and 161 deletions

View file

@ -37,7 +37,12 @@ from synapse.metrics.background_process_metrics import (
wrap_as_background_process,
)
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import Sqlite3Engine
from synapse.storage.roommember import (
@ -46,7 +51,7 @@ from synapse.storage.roommember import (
ProfileInfo,
RoomsForUser,
)
from synapse.types import PersistedEventPosition, get_domain_from_id
from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
@ -115,7 +120,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
@wrap_as_background_process("_count_known_servers")
async def _count_known_servers(self):
async def _count_known_servers(self) -> int:
"""
Count the servers that this server knows about.
@ -123,7 +128,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
`synapse_federation_known_servers` LaterGauge to collect.
"""
def _transact(txn):
def _transact(txn: LoggingTransaction) -> int:
if isinstance(self.database_engine, Sqlite3Engine):
query = """
SELECT COUNT(DISTINCT substr(out.user_id, pos+1))
@ -150,7 +155,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
self._known_servers_count = max([count, 1])
return self._known_servers_count
def _check_safe_current_state_events_membership_updated_txn(self, txn):
def _check_safe_current_state_events_membership_updated_txn(
self, txn: LoggingTransaction
) -> None:
"""Checks if it is safe to assume the new current_state_events
membership column is up to date
"""
@ -182,7 +189,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"get_users_in_room", self.get_users_in_room_txn, room_id
)
def get_users_in_room_txn(self, txn, room_id: str) -> List[str]:
def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]:
# If we can assume current_state_events.membership is up to date
# then we can avoid a join, which is a Very Good Thing given how
# frequently this function gets called.
@ -222,7 +229,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
A mapping from user ID to ProfileInfo.
"""
def _get_users_in_room_with_profiles(txn) -> Dict[str, ProfileInfo]:
def _get_users_in_room_with_profiles(
txn: LoggingTransaction,
) -> Dict[str, ProfileInfo]:
sql = """
SELECT state_key, display_name, avatar_url FROM room_memberships as m
INNER JOIN current_state_events as c
@ -250,7 +259,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
dict of membership states, pointing to a MemberSummary named tuple.
"""
def _get_room_summary_txn(txn):
def _get_room_summary_txn(
txn: LoggingTransaction,
) -> Dict[str, MemberSummary]:
# first get counts.
# We do this all in one transaction to keep the cache small.
# FIXME: get rid of this when we have room_stats
@ -279,7 +290,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""
txn.execute(sql, (room_id,))
res = {}
res: Dict[str, MemberSummary] = {}
for count, membership in txn:
res.setdefault(membership, MemberSummary([], count))
@ -400,7 +411,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
def _get_rooms_for_local_user_where_membership_is_txn(
self,
txn,
txn: LoggingTransaction,
user_id: str,
membership_list: List[str],
) -> List[RoomsForUser]:
@ -488,7 +499,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
def _get_rooms_for_user_with_stream_ordering_txn(
self, txn, user_id: str
self, txn: LoggingTransaction, user_id: str
) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
# We use `current_state_events` here and not `local_current_membership`
# as a) this gets called with remote users and b) this only gets called
@ -542,7 +553,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
def _get_rooms_for_users_with_stream_ordering_txn(
self, txn, user_ids: Collection[str]
self, txn: LoggingTransaction, user_ids: Collection[str]
) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
clause, args = make_in_list_sql_clause(
@ -575,7 +586,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(sql, [Membership.JOIN] + args)
result = {user_id: set() for user_id in user_ids}
result: Dict[str, Set[GetRoomsForUserWithStreamOrdering]] = {
user_id: set() for user_id in user_ids
}
for user_id, room_id, instance, stream_id in txn:
result[user_id].add(
GetRoomsForUserWithStreamOrdering(
@ -595,7 +608,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
if not user_ids:
return set()
def _get_users_server_still_shares_room_with_txn(txn):
def _get_users_server_still_shares_room_with_txn(
txn: LoggingTransaction,
) -> Set[str]:
sql = """
SELECT state_key FROM current_state_events
WHERE
@ -657,7 +672,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
async def get_joined_users_from_context(
self, event: EventBase, context: EventContext
) -> Dict[str, ProfileInfo]:
state_group = context.state_group
state_group: Union[object, int] = context.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
@ -666,14 +681,16 @@ class RoomMemberWorkerStore(EventsWorkerStore):
state_group = object()
current_state_ids = await context.get_current_state_ids()
assert current_state_ids is not None
assert state_group is not None
return await self._get_joined_users_from_context(
event.room_id, state_group, current_state_ids, event=event, context=context
)
async def get_joined_users_from_state(
self, room_id, state_entry
self, room_id: str, state_entry: "_StateCacheEntry"
) -> Dict[str, ProfileInfo]:
state_group = state_entry.state_group
state_group: Union[object, int] = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
@ -681,6 +698,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# To do this we set the state_group to a new object as object() != object()
state_group = object()
assert state_group is not None
with Measure(self._clock, "get_joined_users_from_state"):
return await self._get_joined_users_from_context(
room_id, state_group, state_entry.state, context=state_entry
@ -689,12 +707,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
async def _get_joined_users_from_context(
self,
room_id,
state_group,
current_state_ids,
cache_context,
event=None,
context=None,
room_id: str,
state_group: Union[object, int],
current_state_ids: StateMap[str],
cache_context: _CacheContext,
event: Optional[EventBase] = None,
context: Optional[Union[EventContext, "_StateCacheEntry"]] = None,
) -> Dict[str, ProfileInfo]:
# We don't use `state_group`, it's there so that we can cache based
# on it. However, it's important that it's never None, since two current_states
@ -765,14 +783,18 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return users_in_room
@cached(max_entries=10000)
def _get_joined_profile_from_event_id(self, event_id):
def _get_joined_profile_from_event_id(
self, event_id: str
) -> Optional[Tuple[str, ProfileInfo]]:
raise NotImplementedError()
@cachedList(
cached_method_name="_get_joined_profile_from_event_id",
list_name="event_ids",
)
async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
async def _get_joined_profiles_from_event_ids(
self, event_ids: Iterable[str]
) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]:
"""For given set of member event_ids check if they point to a join
event and if so return the associated user and profile info.
@ -780,8 +802,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
event_ids: The member event IDs to lookup
Returns:
dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID
to `user_id` and ProfileInfo (or None if not join event).
Map from event ID to `user_id` and ProfileInfo (or None if not join event).
"""
rows = await self.db_pool.simple_select_many_batch(
@ -847,8 +868,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return True
async def get_joined_hosts(self, room_id: str, state_entry):
state_group = state_entry.state_group
async def get_joined_hosts(
self, room_id: str, state_entry: "_StateCacheEntry"
) -> FrozenSet[str]:
state_group: Union[object, int] = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
@ -856,6 +879,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# To do this we set the state_group to a new object as object() != object()
state_group = object()
assert state_group is not None
with Measure(self._clock, "get_joined_hosts"):
return await self._get_joined_hosts(
room_id, state_group, state_entry=state_entry
@ -863,7 +887,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cached(num_args=2, max_entries=10000, iterable=True)
async def _get_joined_hosts(
self, room_id: str, state_group: int, state_entry: "_StateCacheEntry"
self,
room_id: str,
state_group: Union[object, int],
state_entry: "_StateCacheEntry",
) -> FrozenSet[str]:
# We don't use `state_group`, it's there so that we can cache based on
# it. However, its important that its never None, since two
@ -881,7 +908,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# `get_joined_hosts` is called with the "current" state group for the
# room, and so consecutive calls will be for consecutive state groups
# which point to the previous state group.
cache = await self._get_joined_hosts_cache(room_id)
cache = await self._get_joined_hosts_cache(room_id) # type: ignore[misc]
# If the state group in the cache matches, we already have the data we need.
if state_entry.state_group == cache.state_group:
@ -897,6 +924,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
elif state_entry.prev_group == cache.state_group:
# The cached work is for the previous state group, so we work out
# the delta.
assert state_entry.delta_ids is not None
for (typ, state_key), event_id in state_entry.delta_ids.items():
if typ != EventTypes.Member:
continue
@ -942,7 +970,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
Returns False if they have since re-joined."""
def f(txn):
def f(txn: LoggingTransaction) -> int:
sql = (
"SELECT"
" COUNT(*)"
@ -973,7 +1001,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
The forgotten rooms.
"""
def _get_forgotten_rooms_for_user_txn(txn):
def _get_forgotten_rooms_for_user_txn(txn: LoggingTransaction) -> Set[str]:
# This is a slightly convoluted query that first looks up all rooms
# that the user has forgotten in the past, then rechecks that list
# to see if any have subsequently been updated. This is done so that
@ -1076,7 +1104,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
clause,
)
def _is_local_host_in_room_ignoring_users_txn(txn):
def _is_local_host_in_room_ignoring_users_txn(
txn: LoggingTransaction,
) -> bool:
txn.execute(sql, (room_id, Membership.JOIN, *args))
return bool(txn.fetchone())
@ -1110,15 +1140,17 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
where_clause="forgotten = 1",
)
async def _background_add_membership_profile(self, progress, batch_size):
async def _background_add_membership_profile(
self, progress: JsonDict, batch_size: int
) -> int:
target_min_stream_id = progress.get(
"target_min_stream_id_inclusive", self._min_stream_order_on_start
"target_min_stream_id_inclusive", self._min_stream_order_on_start # type: ignore[attr-defined]
)
max_stream_id = progress.get(
"max_stream_id_exclusive", self._stream_order_on_start + 1
"max_stream_id_exclusive", self._stream_order_on_start + 1 # type: ignore[attr-defined]
)
def add_membership_profile_txn(txn):
def add_membership_profile_txn(txn: LoggingTransaction) -> int:
sql = """
SELECT stream_ordering, event_id, events.room_id, event_json.json
FROM events
@ -1182,13 +1214,17 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
return result
async def _background_current_state_membership(self, progress, batch_size):
async def _background_current_state_membership(
self, progress: JsonDict, batch_size: int
) -> int:
"""Update the new membership column on current_state_events.
This works by iterating over all rooms in alphebetical order.
"""
def _background_current_state_membership_txn(txn, last_processed_room):
def _background_current_state_membership_txn(
txn: LoggingTransaction, last_processed_room: str
) -> Tuple[int, bool]:
processed = 0
while processed < batch_size:
txn.execute(
@ -1242,7 +1278,11 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
return row_count
class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
class RoomMemberStore(
RoomMemberWorkerStore,
RoomMemberBackgroundUpdateStore,
CacheInvalidationWorkerStore,
):
def __init__(
self,
database: DatabasePool,
@ -1254,7 +1294,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
async def forget(self, user_id: str, room_id: str) -> None:
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
def f(txn: LoggingTransaction) -> None:
sql = (
"UPDATE"
" room_memberships"
@ -1288,5 +1328,5 @@ class _JoinedHostsCache:
# equal to anything else).
state_group: Union[object, int] = attr.Factory(object)
def __len__(self):
def __len__(self) -> int:
return sum(len(v) for v in self.hosts_to_joined_users.values())