mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-05-18 00:20:24 -04:00
Add some type hints to datastore (#12717)
This commit is contained in:
parent
942c30b16b
commit
6edefef602
10 changed files with 254 additions and 161 deletions
|
@ -37,7 +37,12 @@ from synapse.metrics.background_process_metrics import (
|
|||
wrap_as_background_process,
|
||||
)
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
||||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.engines import Sqlite3Engine
|
||||
from synapse.storage.roommember import (
|
||||
|
@ -46,7 +51,7 @@ from synapse.storage.roommember import (
|
|||
ProfileInfo,
|
||||
RoomsForUser,
|
||||
)
|
||||
from synapse.types import PersistedEventPosition, get_domain_from_id
|
||||
from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain_from_id
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches import intern_string
|
||||
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
|
||||
|
@ -115,7 +120,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
)
|
||||
|
||||
@wrap_as_background_process("_count_known_servers")
|
||||
async def _count_known_servers(self):
|
||||
async def _count_known_servers(self) -> int:
|
||||
"""
|
||||
Count the servers that this server knows about.
|
||||
|
||||
|
@ -123,7 +128,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
`synapse_federation_known_servers` LaterGauge to collect.
|
||||
"""
|
||||
|
||||
def _transact(txn):
|
||||
def _transact(txn: LoggingTransaction) -> int:
|
||||
if isinstance(self.database_engine, Sqlite3Engine):
|
||||
query = """
|
||||
SELECT COUNT(DISTINCT substr(out.user_id, pos+1))
|
||||
|
@ -150,7 +155,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
self._known_servers_count = max([count, 1])
|
||||
return self._known_servers_count
|
||||
|
||||
def _check_safe_current_state_events_membership_updated_txn(self, txn):
|
||||
def _check_safe_current_state_events_membership_updated_txn(
|
||||
self, txn: LoggingTransaction
|
||||
) -> None:
|
||||
"""Checks if it is safe to assume the new current_state_events
|
||||
membership column is up to date
|
||||
"""
|
||||
|
@ -182,7 +189,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
"get_users_in_room", self.get_users_in_room_txn, room_id
|
||||
)
|
||||
|
||||
def get_users_in_room_txn(self, txn, room_id: str) -> List[str]:
|
||||
def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]:
|
||||
# If we can assume current_state_events.membership is up to date
|
||||
# then we can avoid a join, which is a Very Good Thing given how
|
||||
# frequently this function gets called.
|
||||
|
@ -222,7 +229,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
A mapping from user ID to ProfileInfo.
|
||||
"""
|
||||
|
||||
def _get_users_in_room_with_profiles(txn) -> Dict[str, ProfileInfo]:
|
||||
def _get_users_in_room_with_profiles(
|
||||
txn: LoggingTransaction,
|
||||
) -> Dict[str, ProfileInfo]:
|
||||
sql = """
|
||||
SELECT state_key, display_name, avatar_url FROM room_memberships as m
|
||||
INNER JOIN current_state_events as c
|
||||
|
@ -250,7 +259,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
dict of membership states, pointing to a MemberSummary named tuple.
|
||||
"""
|
||||
|
||||
def _get_room_summary_txn(txn):
|
||||
def _get_room_summary_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Dict[str, MemberSummary]:
|
||||
# first get counts.
|
||||
# We do this all in one transaction to keep the cache small.
|
||||
# FIXME: get rid of this when we have room_stats
|
||||
|
@ -279,7 +290,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
"""
|
||||
|
||||
txn.execute(sql, (room_id,))
|
||||
res = {}
|
||||
res: Dict[str, MemberSummary] = {}
|
||||
for count, membership in txn:
|
||||
res.setdefault(membership, MemberSummary([], count))
|
||||
|
||||
|
@ -400,7 +411,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
|
||||
def _get_rooms_for_local_user_where_membership_is_txn(
|
||||
self,
|
||||
txn,
|
||||
txn: LoggingTransaction,
|
||||
user_id: str,
|
||||
membership_list: List[str],
|
||||
) -> List[RoomsForUser]:
|
||||
|
@ -488,7 +499,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
)
|
||||
|
||||
def _get_rooms_for_user_with_stream_ordering_txn(
|
||||
self, txn, user_id: str
|
||||
self, txn: LoggingTransaction, user_id: str
|
||||
) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
|
||||
# We use `current_state_events` here and not `local_current_membership`
|
||||
# as a) this gets called with remote users and b) this only gets called
|
||||
|
@ -542,7 +553,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
)
|
||||
|
||||
def _get_rooms_for_users_with_stream_ordering_txn(
|
||||
self, txn, user_ids: Collection[str]
|
||||
self, txn: LoggingTransaction, user_ids: Collection[str]
|
||||
) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
|
||||
|
||||
clause, args = make_in_list_sql_clause(
|
||||
|
@ -575,7 +586,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
|
||||
txn.execute(sql, [Membership.JOIN] + args)
|
||||
|
||||
result = {user_id: set() for user_id in user_ids}
|
||||
result: Dict[str, Set[GetRoomsForUserWithStreamOrdering]] = {
|
||||
user_id: set() for user_id in user_ids
|
||||
}
|
||||
for user_id, room_id, instance, stream_id in txn:
|
||||
result[user_id].add(
|
||||
GetRoomsForUserWithStreamOrdering(
|
||||
|
@ -595,7 +608,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
if not user_ids:
|
||||
return set()
|
||||
|
||||
def _get_users_server_still_shares_room_with_txn(txn):
|
||||
def _get_users_server_still_shares_room_with_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Set[str]:
|
||||
sql = """
|
||||
SELECT state_key FROM current_state_events
|
||||
WHERE
|
||||
|
@ -657,7 +672,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
async def get_joined_users_from_context(
|
||||
self, event: EventBase, context: EventContext
|
||||
) -> Dict[str, ProfileInfo]:
|
||||
state_group = context.state_group
|
||||
state_group: Union[object, int] = context.state_group
|
||||
if not state_group:
|
||||
# If state_group is None it means it has yet to be assigned a
|
||||
# state group, i.e. we need to make sure that calls with a state_group
|
||||
|
@ -666,14 +681,16 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
state_group = object()
|
||||
|
||||
current_state_ids = await context.get_current_state_ids()
|
||||
assert current_state_ids is not None
|
||||
assert state_group is not None
|
||||
return await self._get_joined_users_from_context(
|
||||
event.room_id, state_group, current_state_ids, event=event, context=context
|
||||
)
|
||||
|
||||
async def get_joined_users_from_state(
|
||||
self, room_id, state_entry
|
||||
self, room_id: str, state_entry: "_StateCacheEntry"
|
||||
) -> Dict[str, ProfileInfo]:
|
||||
state_group = state_entry.state_group
|
||||
state_group: Union[object, int] = state_entry.state_group
|
||||
if not state_group:
|
||||
# If state_group is None it means it has yet to be assigned a
|
||||
# state group, i.e. we need to make sure that calls with a state_group
|
||||
|
@ -681,6 +698,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
# To do this we set the state_group to a new object as object() != object()
|
||||
state_group = object()
|
||||
|
||||
assert state_group is not None
|
||||
with Measure(self._clock, "get_joined_users_from_state"):
|
||||
return await self._get_joined_users_from_context(
|
||||
room_id, state_group, state_entry.state, context=state_entry
|
||||
|
@ -689,12 +707,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
@cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
|
||||
async def _get_joined_users_from_context(
|
||||
self,
|
||||
room_id,
|
||||
state_group,
|
||||
current_state_ids,
|
||||
cache_context,
|
||||
event=None,
|
||||
context=None,
|
||||
room_id: str,
|
||||
state_group: Union[object, int],
|
||||
current_state_ids: StateMap[str],
|
||||
cache_context: _CacheContext,
|
||||
event: Optional[EventBase] = None,
|
||||
context: Optional[Union[EventContext, "_StateCacheEntry"]] = None,
|
||||
) -> Dict[str, ProfileInfo]:
|
||||
# We don't use `state_group`, it's there so that we can cache based
|
||||
# on it. However, it's important that it's never None, since two current_states
|
||||
|
@ -765,14 +783,18 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
return users_in_room
|
||||
|
||||
@cached(max_entries=10000)
|
||||
def _get_joined_profile_from_event_id(self, event_id):
|
||||
def _get_joined_profile_from_event_id(
|
||||
self, event_id: str
|
||||
) -> Optional[Tuple[str, ProfileInfo]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@cachedList(
|
||||
cached_method_name="_get_joined_profile_from_event_id",
|
||||
list_name="event_ids",
|
||||
)
|
||||
async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
|
||||
async def _get_joined_profiles_from_event_ids(
|
||||
self, event_ids: Iterable[str]
|
||||
) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]:
|
||||
"""For given set of member event_ids check if they point to a join
|
||||
event and if so return the associated user and profile info.
|
||||
|
||||
|
@ -780,8 +802,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
event_ids: The member event IDs to lookup
|
||||
|
||||
Returns:
|
||||
dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID
|
||||
to `user_id` and ProfileInfo (or None if not join event).
|
||||
Map from event ID to `user_id` and ProfileInfo (or None if not join event).
|
||||
"""
|
||||
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
|
@ -847,8 +868,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
|
||||
return True
|
||||
|
||||
async def get_joined_hosts(self, room_id: str, state_entry):
|
||||
state_group = state_entry.state_group
|
||||
async def get_joined_hosts(
|
||||
self, room_id: str, state_entry: "_StateCacheEntry"
|
||||
) -> FrozenSet[str]:
|
||||
state_group: Union[object, int] = state_entry.state_group
|
||||
if not state_group:
|
||||
# If state_group is None it means it has yet to be assigned a
|
||||
# state group, i.e. we need to make sure that calls with a state_group
|
||||
|
@ -856,6 +879,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
# To do this we set the state_group to a new object as object() != object()
|
||||
state_group = object()
|
||||
|
||||
assert state_group is not None
|
||||
with Measure(self._clock, "get_joined_hosts"):
|
||||
return await self._get_joined_hosts(
|
||||
room_id, state_group, state_entry=state_entry
|
||||
|
@ -863,7 +887,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
|
||||
@cached(num_args=2, max_entries=10000, iterable=True)
|
||||
async def _get_joined_hosts(
|
||||
self, room_id: str, state_group: int, state_entry: "_StateCacheEntry"
|
||||
self,
|
||||
room_id: str,
|
||||
state_group: Union[object, int],
|
||||
state_entry: "_StateCacheEntry",
|
||||
) -> FrozenSet[str]:
|
||||
# We don't use `state_group`, it's there so that we can cache based on
|
||||
# it. However, its important that its never None, since two
|
||||
|
@ -881,7 +908,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
# `get_joined_hosts` is called with the "current" state group for the
|
||||
# room, and so consecutive calls will be for consecutive state groups
|
||||
# which point to the previous state group.
|
||||
cache = await self._get_joined_hosts_cache(room_id)
|
||||
cache = await self._get_joined_hosts_cache(room_id) # type: ignore[misc]
|
||||
|
||||
# If the state group in the cache matches, we already have the data we need.
|
||||
if state_entry.state_group == cache.state_group:
|
||||
|
@ -897,6 +924,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
elif state_entry.prev_group == cache.state_group:
|
||||
# The cached work is for the previous state group, so we work out
|
||||
# the delta.
|
||||
assert state_entry.delta_ids is not None
|
||||
for (typ, state_key), event_id in state_entry.delta_ids.items():
|
||||
if typ != EventTypes.Member:
|
||||
continue
|
||||
|
@ -942,7 +970,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
|
||||
Returns False if they have since re-joined."""
|
||||
|
||||
def f(txn):
|
||||
def f(txn: LoggingTransaction) -> int:
|
||||
sql = (
|
||||
"SELECT"
|
||||
" COUNT(*)"
|
||||
|
@ -973,7 +1001,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
The forgotten rooms.
|
||||
"""
|
||||
|
||||
def _get_forgotten_rooms_for_user_txn(txn):
|
||||
def _get_forgotten_rooms_for_user_txn(txn: LoggingTransaction) -> Set[str]:
|
||||
# This is a slightly convoluted query that first looks up all rooms
|
||||
# that the user has forgotten in the past, then rechecks that list
|
||||
# to see if any have subsequently been updated. This is done so that
|
||||
|
@ -1076,7 +1104,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
clause,
|
||||
)
|
||||
|
||||
def _is_local_host_in_room_ignoring_users_txn(txn):
|
||||
def _is_local_host_in_room_ignoring_users_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> bool:
|
||||
txn.execute(sql, (room_id, Membership.JOIN, *args))
|
||||
|
||||
return bool(txn.fetchone())
|
||||
|
@ -1110,15 +1140,17 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
|
|||
where_clause="forgotten = 1",
|
||||
)
|
||||
|
||||
async def _background_add_membership_profile(self, progress, batch_size):
|
||||
async def _background_add_membership_profile(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
target_min_stream_id = progress.get(
|
||||
"target_min_stream_id_inclusive", self._min_stream_order_on_start
|
||||
"target_min_stream_id_inclusive", self._min_stream_order_on_start # type: ignore[attr-defined]
|
||||
)
|
||||
max_stream_id = progress.get(
|
||||
"max_stream_id_exclusive", self._stream_order_on_start + 1
|
||||
"max_stream_id_exclusive", self._stream_order_on_start + 1 # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
def add_membership_profile_txn(txn):
|
||||
def add_membership_profile_txn(txn: LoggingTransaction) -> int:
|
||||
sql = """
|
||||
SELECT stream_ordering, event_id, events.room_id, event_json.json
|
||||
FROM events
|
||||
|
@ -1182,13 +1214,17 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
|
|||
|
||||
return result
|
||||
|
||||
async def _background_current_state_membership(self, progress, batch_size):
|
||||
async def _background_current_state_membership(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
"""Update the new membership column on current_state_events.
|
||||
|
||||
This works by iterating over all rooms in alphebetical order.
|
||||
"""
|
||||
|
||||
def _background_current_state_membership_txn(txn, last_processed_room):
|
||||
def _background_current_state_membership_txn(
|
||||
txn: LoggingTransaction, last_processed_room: str
|
||||
) -> Tuple[int, bool]:
|
||||
processed = 0
|
||||
while processed < batch_size:
|
||||
txn.execute(
|
||||
|
@ -1242,7 +1278,11 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
|
|||
return row_count
|
||||
|
||||
|
||||
class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
|
||||
class RoomMemberStore(
|
||||
RoomMemberWorkerStore,
|
||||
RoomMemberBackgroundUpdateStore,
|
||||
CacheInvalidationWorkerStore,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
|
@ -1254,7 +1294,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
|
|||
async def forget(self, user_id: str, room_id: str) -> None:
|
||||
"""Indicate that user_id wishes to discard history for room_id."""
|
||||
|
||||
def f(txn):
|
||||
def f(txn: LoggingTransaction) -> None:
|
||||
sql = (
|
||||
"UPDATE"
|
||||
" room_memberships"
|
||||
|
@ -1288,5 +1328,5 @@ class _JoinedHostsCache:
|
|||
# equal to anything else).
|
||||
state_group: Union[object, int] = attr.Factory(object)
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return sum(len(v) for v in self.hosts_to_joined_users.values())
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue