Add some type hints to datastore (#12717)

This commit is contained in:
Dirk Klimpel 2022-05-17 16:29:06 +02:00 committed by GitHub
parent 942c30b16b
commit 6edefef602
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 254 additions and 161 deletions

View file

@ -14,16 +14,19 @@
import calendar
import logging
import time
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING, Dict, List, Tuple, cast
from synapse.metrics import GaugeBucketCollector
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)
from synapse.storage.types import Cursor
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -73,7 +76,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
@wrap_as_background_process("read_forward_extremities")
async def _read_forward_extremities(self) -> None:
def fetch(txn):
def fetch(txn: LoggingTransaction) -> List[Tuple[int, int]]:
txn.execute(
"""
SELECT t1.c, t2.c
@ -86,7 +89,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
) t2 ON t1.room_id = t2.room_id
"""
)
return txn.fetchall()
return cast(List[Tuple[int, int]], txn.fetchall())
res = await self.db_pool.runInteraction("read_forward_extremities", fetch)
@ -104,20 +107,20 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
call to this function, it will return None.
"""
def _count_messages(txn):
def _count_messages(txn: LoggingTransaction) -> int:
sql = """
SELECT COUNT(*) FROM events
WHERE type = 'm.room.encrypted'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
(count,) = txn.fetchone()
(count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
async def count_daily_sent_e2ee_messages(self) -> int:
def _count_messages(txn):
def _count_messages(txn: LoggingTransaction) -> int:
# This is good enough as if you have silly characters in your own
# hostname then that's your own fault.
like_clause = "%:" + self.hs.hostname
@ -130,7 +133,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"""
txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
(count,) = txn.fetchone()
(count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction(
@ -138,14 +141,14 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
)
async def count_daily_active_e2ee_rooms(self) -> int:
def _count(txn):
def _count(txn: LoggingTransaction) -> int:
sql = """
SELECT COUNT(DISTINCT room_id) FROM events
WHERE type = 'm.room.encrypted'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
(count,) = txn.fetchone()
(count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction(
@ -160,20 +163,20 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
call to this function, it will return None.
"""
def _count_messages(txn):
def _count_messages(txn: LoggingTransaction) -> int:
sql = """
SELECT COUNT(*) FROM events
WHERE type = 'm.room.message'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
(count,) = txn.fetchone()
(count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction("count_messages", _count_messages)
async def count_daily_sent_messages(self) -> int:
def _count_messages(txn):
def _count_messages(txn: LoggingTransaction) -> int:
# This is good enough as if you have silly characters in your own
# hostname then that's your own fault.
like_clause = "%:" + self.hs.hostname
@ -186,7 +189,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"""
txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
(count,) = txn.fetchone()
(count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction(
@ -194,14 +197,14 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
)
async def count_daily_active_rooms(self) -> int:
def _count(txn):
def _count(txn: LoggingTransaction) -> int:
sql = """
SELECT COUNT(DISTINCT room_id) FROM events
WHERE type = 'm.room.message'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
(count,) = txn.fetchone()
(count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
@ -227,7 +230,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_monthly_users", self._count_users, thirty_days_ago
)
def _count_users(self, txn: Cursor, time_from: int) -> int:
def _count_users(self, txn: LoggingTransaction, time_from: int) -> int:
"""
Returns number of users seen in the past time_from period
"""
@ -242,7 +245,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
# Mypy knows that fetchone() might return None if there are no rows.
# We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always
# returns exactly one row.
(count,) = txn.fetchone() # type: ignore[misc]
(count,) = cast(Tuple[int], txn.fetchone())
return count
async def count_r30_users(self) -> Dict[str, int]:
@ -256,7 +259,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
A mapping of counts globally as well as broken out by platform.
"""
def _count_r30_users(txn):
def _count_r30_users(txn: LoggingTransaction) -> Dict[str, int]:
thirty_days_in_secs = 86400 * 30
now = int(self._clock.time())
thirty_days_ago_in_secs = now - thirty_days_in_secs
@ -321,7 +324,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
(count,) = txn.fetchone()
(count,) = cast(Tuple[int], txn.fetchone())
results["all"] = count
return results
@ -348,7 +351,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
- "web" (any web application -- it's not possible to distinguish Element Web here)
"""
def _count_r30v2_users(txn):
def _count_r30v2_users(txn: LoggingTransaction) -> Dict[str, int]:
thirty_days_in_secs = 86400 * 30
now = int(self._clock.time())
sixty_days_ago_in_secs = now - 2 * thirty_days_in_secs
@ -445,11 +448,8 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
thirty_days_in_secs * 1000,
),
)
row = txn.fetchone()
if row is None:
results["all"] = 0
else:
results["all"] = row[0]
(count,) = cast(Tuple[int], txn.fetchone())
results["all"] = count
return results
@ -471,7 +471,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
Generates daily visit data for use in cohort/ retention analysis
"""
def _generate_user_daily_visits(txn):
def _generate_user_daily_visits(txn: LoggingTransaction) -> None:
logger.info("Calling _generate_user_daily_visits")
today_start = self._get_start_of_day()
a_day_in_milliseconds = 24 * 60 * 60 * 1000