Add type hints to synapse/storage/databases/main (#11984)

This commit is contained in:
Dirk Klimpel 2022-02-21 17:03:06 +01:00 committed by GitHub
parent 99f6d79fe1
commit 7c82da27aa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 79 additions and 53 deletions

View file

@ -12,15 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, cast
from synapse.api.presence import PresenceState, UserPresenceState
from synapse.replication.tcp.streams import PresenceStream
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.iterutils import batch_iter
@ -35,7 +43,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore):
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
) -> None:
super().__init__(database, db_conn, hs)
# Used by `PresenceStore._get_active_presence()`
@ -54,11 +62,14 @@ class PresenceStore(PresenceBackgroundUpdateStore):
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
) -> None:
super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
self._presence_id_gen: AbstractStreamIdGenerator
self._can_persist_presence = (
hs.get_instance_name() in hs.config.worker.writers.presence
self._instance_name in hs.config.worker.writers.presence
)
if isinstance(database.engine, PostgresEngine):
@ -109,7 +120,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
return stream_orderings[-1], self._presence_id_gen.get_current_token()
def _update_presence_txn(self, txn, stream_orderings, presence_states):
def _update_presence_txn(
self, txn: LoggingTransaction, stream_orderings, presence_states
) -> None:
for stream_id, state in zip(stream_orderings, presence_states):
txn.call_after(
self.presence_stream_cache.entity_has_changed, state.user_id, stream_id
@ -183,19 +196,23 @@ class PresenceStore(PresenceBackgroundUpdateStore):
if last_id == current_id:
return [], current_id, False
def get_all_presence_updates_txn(txn):
def get_all_presence_updates_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, list]], int, bool]:
sql = """
SELECT stream_id, user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts,
status_msg,
currently_active
status_msg, currently_active
FROM presence_stream
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
updates = [(row[0], row[1:]) for row in txn]
updates = cast(
List[Tuple[int, list]],
[(row[0], row[1:]) for row in txn],
)
upper_bound = current_id
limited = False
@ -210,7 +227,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
)
@cached()
def _get_presence_for_user(self, user_id):
def _get_presence_for_user(self, user_id: str) -> None:
raise NotImplementedError()
@cachedList(
@ -218,7 +235,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
list_name="user_ids",
num_args=1,
)
async def get_presence_for_users(self, user_ids):
async def get_presence_for_users(
self, user_ids: Iterable[str]
) -> Dict[str, UserPresenceState]:
rows = await self.db_pool.simple_select_many_batch(
table="presence_stream",
column="user_id",
@ -257,7 +276,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
True if the user should have full presence sent to them, False otherwise.
"""
def _should_user_receive_full_presence_with_token_txn(txn):
def _should_user_receive_full_presence_with_token_txn(
txn: LoggingTransaction,
) -> bool:
sql = """
SELECT 1 FROM users_to_send_full_presence_to
WHERE user_id = ?
@ -271,7 +292,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
_should_user_receive_full_presence_with_token_txn,
)
async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]):
async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]) -> None:
"""Adds to the list of users who should receive a full snapshot of presence
upon their next sync.
@ -353,10 +374,10 @@ class PresenceStore(PresenceBackgroundUpdateStore):
return users_to_state
def get_current_presence_token(self):
def get_current_presence_token(self) -> int:
return self._presence_id_gen.get_current_token()
def _get_active_presence(self, db_conn: Connection):
def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]:
"""Fetch non-offline presence from the database so that we can register
the appropriate time outs.
"""
@ -379,12 +400,12 @@ class PresenceStore(PresenceBackgroundUpdateStore):
return [UserPresenceState(**row) for row in rows]
def take_presence_startup_info(self):
def take_presence_startup_info(self) -> List[UserPresenceState]:
active_on_startup = self._presence_on_startup
self._presence_on_startup = None
self._presence_on_startup = []
return active_on_startup
def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(self, stream_name, instance_name, token, rows) -> None:
if stream_name == PresenceStream.NAME:
self._presence_id_gen.advance(instance_name, token)
for row in rows: