Add type hints to synapse/storage/databases/main (#11984)

This commit is contained in:
Dirk Klimpel 2022-02-21 17:03:06 +01:00 committed by GitHub
parent 99f6d79fe1
commit 7c82da27aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 79 additions and 53 deletions

1
changelog.d/11984.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints to storage classes.

View File

@ -31,14 +31,11 @@ exclude = (?x)
|synapse/storage/databases/main/group_server.py |synapse/storage/databases/main/group_server.py
|synapse/storage/databases/main/metrics.py |synapse/storage/databases/main/metrics.py
|synapse/storage/databases/main/monthly_active_users.py |synapse/storage/databases/main/monthly_active_users.py
|synapse/storage/databases/main/presence.py
|synapse/storage/databases/main/purge_events.py
|synapse/storage/databases/main/push_rule.py |synapse/storage/databases/main/push_rule.py
|synapse/storage/databases/main/receipts.py |synapse/storage/databases/main/receipts.py
|synapse/storage/databases/main/roommember.py |synapse/storage/databases/main/roommember.py
|synapse/storage/databases/main/search.py |synapse/storage/databases/main/search.py
|synapse/storage/databases/main/state.py |synapse/storage/databases/main/state.py
|synapse/storage/databases/main/user_directory.py
|synapse/storage/schema/ |synapse/storage/schema/
|tests/api/test_auth.py |tests/api/test_auth.py

View File

@ -204,25 +204,27 @@ class BasePresenceHandler(abc.ABC):
Returns: Returns:
dict: `user_id` -> `UserPresenceState` dict: `user_id` -> `UserPresenceState`
""" """
states = { states = {}
user_id: self.user_to_current_state.get(user_id, None) missing = []
for user_id in user_ids for user_id in user_ids:
} state = self.user_to_current_state.get(user_id, None)
if state:
states[user_id] = state
else:
missing.append(user_id)
missing = [user_id for user_id, state in states.items() if not state]
if missing: if missing:
# There are things not in our in memory cache. Lets pull them out of # There are things not in our in memory cache. Lets pull them out of
# the database. # the database.
res = await self.store.get_presence_for_users(missing) res = await self.store.get_presence_for_users(missing)
states.update(res) states.update(res)
missing = [user_id for user_id, state in states.items() if not state] for user_id in missing:
if missing: # if user has no state in database, create the state
new = { if not res.get(user_id, None):
user_id: UserPresenceState.default(user_id) for user_id in missing new_state = UserPresenceState.default(user_id)
} states[user_id] = new_state
states.update(new) self.user_to_current_state[user_id] = new_state
self.user_to_current_state.update(new)
return states return states

View File

@ -12,15 +12,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, cast
from synapse.api.presence import PresenceState, UserPresenceState from synapse.api.presence import PresenceState, UserPresenceState
from synapse.replication.tcp.streams import PresenceStream from synapse.replication.tcp.streams import PresenceStream
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Connection from synapse.storage.types import Connection
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
@ -35,7 +43,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore):
database: DatabasePool, database: DatabasePool,
db_conn: LoggingDatabaseConnection, db_conn: LoggingDatabaseConnection,
hs: "HomeServer", hs: "HomeServer",
): ) -> None:
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
# Used by `PresenceStore._get_active_presence()` # Used by `PresenceStore._get_active_presence()`
@ -54,11 +62,14 @@ class PresenceStore(PresenceBackgroundUpdateStore):
database: DatabasePool, database: DatabasePool,
db_conn: LoggingDatabaseConnection, db_conn: LoggingDatabaseConnection,
hs: "HomeServer", hs: "HomeServer",
): ) -> None:
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
self._presence_id_gen: AbstractStreamIdGenerator
self._can_persist_presence = ( self._can_persist_presence = (
hs.get_instance_name() in hs.config.worker.writers.presence self._instance_name in hs.config.worker.writers.presence
) )
if isinstance(database.engine, PostgresEngine): if isinstance(database.engine, PostgresEngine):
@ -109,7 +120,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
return stream_orderings[-1], self._presence_id_gen.get_current_token() return stream_orderings[-1], self._presence_id_gen.get_current_token()
def _update_presence_txn(self, txn, stream_orderings, presence_states): def _update_presence_txn(
self, txn: LoggingTransaction, stream_orderings, presence_states
) -> None:
for stream_id, state in zip(stream_orderings, presence_states): for stream_id, state in zip(stream_orderings, presence_states):
txn.call_after( txn.call_after(
self.presence_stream_cache.entity_has_changed, state.user_id, stream_id self.presence_stream_cache.entity_has_changed, state.user_id, stream_id
@ -183,19 +196,23 @@ class PresenceStore(PresenceBackgroundUpdateStore):
if last_id == current_id: if last_id == current_id:
return [], current_id, False return [], current_id, False
def get_all_presence_updates_txn(txn): def get_all_presence_updates_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, list]], int, bool]:
sql = """ sql = """
SELECT stream_id, user_id, state, last_active_ts, SELECT stream_id, user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts, last_federation_update_ts, last_user_sync_ts,
status_msg, status_msg, currently_active
currently_active
FROM presence_stream FROM presence_stream
WHERE ? < stream_id AND stream_id <= ? WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC ORDER BY stream_id ASC
LIMIT ? LIMIT ?
""" """
txn.execute(sql, (last_id, current_id, limit)) txn.execute(sql, (last_id, current_id, limit))
updates = [(row[0], row[1:]) for row in txn] updates = cast(
List[Tuple[int, list]],
[(row[0], row[1:]) for row in txn],
)
upper_bound = current_id upper_bound = current_id
limited = False limited = False
@ -210,7 +227,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
) )
@cached() @cached()
def _get_presence_for_user(self, user_id): def _get_presence_for_user(self, user_id: str) -> None:
raise NotImplementedError() raise NotImplementedError()
@cachedList( @cachedList(
@ -218,7 +235,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
list_name="user_ids", list_name="user_ids",
num_args=1, num_args=1,
) )
async def get_presence_for_users(self, user_ids): async def get_presence_for_users(
self, user_ids: Iterable[str]
) -> Dict[str, UserPresenceState]:
rows = await self.db_pool.simple_select_many_batch( rows = await self.db_pool.simple_select_many_batch(
table="presence_stream", table="presence_stream",
column="user_id", column="user_id",
@ -257,7 +276,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
True if the user should have full presence sent to them, False otherwise. True if the user should have full presence sent to them, False otherwise.
""" """
def _should_user_receive_full_presence_with_token_txn(txn): def _should_user_receive_full_presence_with_token_txn(
txn: LoggingTransaction,
) -> bool:
sql = """ sql = """
SELECT 1 FROM users_to_send_full_presence_to SELECT 1 FROM users_to_send_full_presence_to
WHERE user_id = ? WHERE user_id = ?
@ -271,7 +292,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
_should_user_receive_full_presence_with_token_txn, _should_user_receive_full_presence_with_token_txn,
) )
async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]): async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]) -> None:
"""Adds to the list of users who should receive a full snapshot of presence """Adds to the list of users who should receive a full snapshot of presence
upon their next sync. upon their next sync.
@ -353,10 +374,10 @@ class PresenceStore(PresenceBackgroundUpdateStore):
return users_to_state return users_to_state
def get_current_presence_token(self): def get_current_presence_token(self) -> int:
return self._presence_id_gen.get_current_token() return self._presence_id_gen.get_current_token()
def _get_active_presence(self, db_conn: Connection): def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]:
"""Fetch non-offline presence from the database so that we can register """Fetch non-offline presence from the database so that we can register
the appropriate time outs. the appropriate time outs.
""" """
@ -379,12 +400,12 @@ class PresenceStore(PresenceBackgroundUpdateStore):
return [UserPresenceState(**row) for row in rows] return [UserPresenceState(**row) for row in rows]
def take_presence_startup_info(self): def take_presence_startup_info(self) -> List[UserPresenceState]:
active_on_startup = self._presence_on_startup active_on_startup = self._presence_on_startup
self._presence_on_startup = None self._presence_on_startup = []
return active_on_startup return active_on_startup
def process_replication_rows(self, stream_name, instance_name, token, rows): def process_replication_rows(self, stream_name, instance_name, token, rows) -> None:
if stream_name == PresenceStream.NAME: if stream_name == PresenceStream.NAME:
self._presence_id_gen.advance(instance_name, token) self._presence_id_gen.advance(instance_name, token)
for row in rows: for row in rows:

View File

@ -13,9 +13,10 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, List, Set, Tuple from typing import Any, List, Set, Tuple, cast
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main import CacheInvalidationWorkerStore from synapse.storage.databases.main import CacheInvalidationWorkerStore
from synapse.storage.databases.main.state import StateGroupWorkerStore from synapse.storage.databases.main.state import StateGroupWorkerStore
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
@ -55,7 +56,11 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
) )
def _purge_history_txn( def _purge_history_txn(
self, txn, room_id: str, token: RoomStreamToken, delete_local_events: bool self,
txn: LoggingTransaction,
room_id: str,
token: RoomStreamToken,
delete_local_events: bool,
) -> Set[int]: ) -> Set[int]:
# Tables that should be pruned: # Tables that should be pruned:
# event_auth # event_auth
@ -273,7 +278,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
""", """,
(room_id,), (room_id,),
) )
(min_depth,) = txn.fetchone() (min_depth,) = cast(Tuple[int], txn.fetchone())
logger.info("[purge] updating room_depth to %d", min_depth) logger.info("[purge] updating room_depth to %d", min_depth)
@ -318,7 +323,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
"purge_room", self._purge_room_txn, room_id "purge_room", self._purge_room_txn, room_id
) )
def _purge_room_txn(self, txn, room_id: str) -> List[int]: def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]:
# First we fetch all the state groups that should be deleted, before # First we fetch all the state groups that should be deleted, before
# we delete that information. # we delete that information.
txn.execute( txn.execute(

View File

@ -58,7 +58,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
database: DatabasePool, database: DatabasePool,
db_conn: LoggingDatabaseConnection, db_conn: LoggingDatabaseConnection,
hs: "HomeServer", hs: "HomeServer",
): ) -> None:
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.server_name = hs.hostname self.server_name = hs.hostname
@ -234,10 +234,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
processed_event_count = 0 processed_event_count = 0
for room_id, event_count in rooms_to_work_on: for room_id, event_count in rooms_to_work_on:
is_in_room = await self.is_host_joined(room_id, self.server_name) is_in_room = await self.is_host_joined(room_id, self.server_name) # type: ignore[attr-defined]
if is_in_room: if is_in_room:
users_with_profile = await self.get_users_in_room_with_profiles(room_id) users_with_profile = await self.get_users_in_room_with_profiles(room_id) # type: ignore[attr-defined]
# Throw away users excluded from the directory. # Throw away users excluded from the directory.
users_with_profile = { users_with_profile = {
user_id: profile user_id: profile
@ -368,7 +368,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
for user_id in users_to_work_on: for user_id in users_to_work_on:
if await self.should_include_local_user_in_dir(user_id): if await self.should_include_local_user_in_dir(user_id):
profile = await self.get_profileinfo(get_localpart_from_id(user_id)) profile = await self.get_profileinfo(get_localpart_from_id(user_id)) # type: ignore[attr-defined]
await self.update_profile_in_user_dir( await self.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url user_id, profile.display_name, profile.avatar_url
) )
@ -397,7 +397,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# technically it could be DM-able. In the future, this could potentially # technically it could be DM-able. In the future, this could potentially
# be configurable per-appservice whether the appservice sender can be # be configurable per-appservice whether the appservice sender can be
# contacted. # contacted.
if self.get_app_service_by_user_id(user) is not None: if self.get_app_service_by_user_id(user) is not None: # type: ignore[attr-defined]
return False return False
# We're opting to exclude appservice users (anyone matching the user # We're opting to exclude appservice users (anyone matching the user
@ -405,17 +405,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# they could be DM-able. In the future, this could potentially # they could be DM-able. In the future, this could potentially
# be configurable per-appservice whether the appservice users can be # be configurable per-appservice whether the appservice users can be
# contacted. # contacted.
if self.get_if_app_services_interested_in_user(user): if self.get_if_app_services_interested_in_user(user): # type: ignore[attr-defined]
# TODO we might want to make this configurable for each app service # TODO we might want to make this configurable for each app service
return False return False
# Support users are for diagnostics and should not appear in the user directory. # Support users are for diagnostics and should not appear in the user directory.
if await self.is_support_user(user): if await self.is_support_user(user): # type: ignore[attr-defined]
return False return False
# Deactivated users aren't contactable, so should not appear in the user directory. # Deactivated users aren't contactable, so should not appear in the user directory.
try: try:
if await self.get_user_deactivated_status(user): if await self.get_user_deactivated_status(user): # type: ignore[attr-defined]
return False return False
except StoreError: except StoreError:
# No such user in the users table. No need to do this when calling # No such user in the users table. No need to do this when calling
@ -433,20 +433,20 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
(EventTypes.RoomHistoryVisibility, ""), (EventTypes.RoomHistoryVisibility, ""),
) )
current_state_ids = await self.get_filtered_current_state_ids( current_state_ids = await self.get_filtered_current_state_ids( # type: ignore[attr-defined]
room_id, StateFilter.from_types(types_to_filter) room_id, StateFilter.from_types(types_to_filter)
) )
join_rules_id = current_state_ids.get((EventTypes.JoinRules, "")) join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
if join_rules_id: if join_rules_id:
join_rule_ev = await self.get_event(join_rules_id, allow_none=True) join_rule_ev = await self.get_event(join_rules_id, allow_none=True) # type: ignore[attr-defined]
if join_rule_ev: if join_rule_ev:
if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC: if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
return True return True
hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, "")) hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
if hist_vis_id: if hist_vis_id:
hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True) hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True) # type: ignore[attr-defined]
if hist_vis_ev: if hist_vis_ev:
if ( if (
hist_vis_ev.content.get("history_visibility") hist_vis_ev.content.get("history_visibility")

View File

@ -51,7 +51,7 @@ from synapse.util.stringutils import parse_and_validate_server_name
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.appservice.api import ApplicationService from synapse.appservice.api import ApplicationService
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore, PurgeEventsStore
# Define a state map type from type/state_key to T (usually an event ID or # Define a state map type from type/state_key to T (usually an event ID or
# event) # event)
@ -485,7 +485,7 @@ class RoomStreamToken:
) )
@classmethod @classmethod
async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken": async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken":
try: try:
if string[0] == "s": if string[0] == "s":
return cls(topological=None, stream=int(string[1:])) return cls(topological=None, stream=int(string[1:]))
@ -502,7 +502,7 @@ class RoomStreamToken:
instance_id = int(key) instance_id = int(key)
pos = int(value) pos = int(value)
instance_name = await store.get_name_from_instance_id(instance_id) instance_name = await store.get_name_from_instance_id(instance_id) # type: ignore[attr-defined]
instance_map[instance_name] = pos instance_map[instance_name] = pos
return cls( return cls(