Add additional type hints to the storage module. (#8980)

This commit is contained in:
Patrick Cloke 2020-12-30 08:09:53 -05:00 committed by GitHub
parent b8591899ab
commit 637282bb50
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 224 additions and 148 deletions

View file

@ -12,29 +12,34 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Iterable, Optional
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.types import Connection
from synapse.types import JsonDict
from synapse.util import json_encoder
from . import engines
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
from synapse.storage.database import DatabasePool, LoggingTransaction
logger = logging.getLogger(__name__)
class BackgroundUpdatePerformance:
"""Tracks the how long a background update is taking to update its items"""
def __init__(self, name):
def __init__(self, name: str):
self.name = name
self.total_item_count = 0
self.total_duration_ms = 0
self.avg_item_count = 0
self.avg_duration_ms = 0
self.total_duration_ms = 0.0
self.avg_item_count = 0.0
self.avg_duration_ms = 0.0
def update(self, item_count, duration_ms):
def update(self, item_count: int, duration_ms: float) -> None:
"""Update the stats after doing an update"""
self.total_item_count += item_count
self.total_duration_ms += duration_ms
@ -44,7 +49,7 @@ class BackgroundUpdatePerformance:
self.avg_item_count += 0.1 * (item_count - self.avg_item_count)
self.avg_duration_ms += 0.1 * (duration_ms - self.avg_duration_ms)
def average_items_per_ms(self):
def average_items_per_ms(self) -> Optional[float]:
"""An estimate of how long it takes to do a single update.
Returns:
A duration in ms as a float
@ -58,7 +63,7 @@ class BackgroundUpdatePerformance:
# changes in how long the update process takes.
return float(self.avg_item_count) / float(self.avg_duration_ms)
def total_items_per_ms(self):
def total_items_per_ms(self) -> Optional[float]:
"""An estimate of how long it takes to do a single update.
Returns:
A duration in ms as a float
@ -83,21 +88,25 @@ class BackgroundUpdater:
BACKGROUND_UPDATE_INTERVAL_MS = 1000
BACKGROUND_UPDATE_DURATION_MS = 100
def __init__(self, hs, database):
def __init__(self, hs: "HomeServer", database: "DatabasePool"):
self._clock = hs.get_clock()
self.db_pool = database
# if a background update is currently running, its name.
self._current_background_update = None # type: Optional[str]
self._background_update_performance = {}
self._background_update_handlers = {}
self._background_update_performance = (
{}
) # type: Dict[str, BackgroundUpdatePerformance]
self._background_update_handlers = (
{}
) # type: Dict[str, Callable[[JsonDict, int], Awaitable[int]]]
self._all_done = False
def start_doing_background_updates(self):
def start_doing_background_updates(self) -> None:
run_as_background_process("background_updates", self.run_background_updates)
async def run_background_updates(self, sleep=True):
async def run_background_updates(self, sleep: bool = True) -> None:
logger.info("Starting background schema updates")
while True:
if sleep:
@ -148,7 +157,7 @@ class BackgroundUpdater:
return False
async def has_completed_background_update(self, update_name) -> bool:
async def has_completed_background_update(self, update_name: str) -> bool:
"""Check if the given background update has finished running.
"""
if self._all_done:
@ -173,8 +182,7 @@ class BackgroundUpdater:
Returns once some amount of work is done.
Args:
desired_duration_ms(float): How long we want to spend
updating.
desired_duration_ms: How long we want to spend updating.
Returns:
True if we have finished running all the background updates, otherwise False
"""
@ -220,6 +228,7 @@ class BackgroundUpdater:
return False
async def _do_background_update(self, desired_duration_ms: float) -> int:
assert self._current_background_update is not None
update_name = self._current_background_update
logger.info("Starting update batch on background update '%s'", update_name)
@ -273,7 +282,11 @@ class BackgroundUpdater:
return len(self._background_update_performance)
def register_background_update_handler(self, update_name, update_handler):
def register_background_update_handler(
self,
update_name: str,
update_handler: Callable[[JsonDict, int], Awaitable[int]],
):
"""Register a handler for doing a background update.
The handler should take two arguments:
@ -287,12 +300,12 @@ class BackgroundUpdater:
The handler is responsible for updating the progress of the update.
Args:
update_name(str): The name of the update that this code handles.
update_handler(function): The function that does the update.
update_name: The name of the update that this code handles.
update_handler: The function that does the update.
"""
self._background_update_handlers[update_name] = update_handler
def register_noop_background_update(self, update_name):
def register_noop_background_update(self, update_name: str) -> None:
"""Register a noop handler for a background update.
This is useful when we previously did a background update, but no
@ -302,10 +315,10 @@ class BackgroundUpdater:
also be called to clear the update.
Args:
update_name (str): Name of update
update_name: Name of update
"""
async def noop_update(progress, batch_size):
async def noop_update(progress: JsonDict, batch_size: int) -> int:
await self._end_background_update(update_name)
return 1
@ -313,14 +326,14 @@ class BackgroundUpdater:
def register_background_index_update(
self,
update_name,
index_name,
table,
columns,
where_clause=None,
unique=False,
psql_only=False,
):
update_name: str,
index_name: str,
table: str,
columns: Iterable[str],
where_clause: Optional[str] = None,
unique: bool = False,
psql_only: bool = False,
) -> None:
"""Helper for store classes to do a background index addition
To use:
@ -332,19 +345,19 @@ class BackgroundUpdater:
2. In the Store constructor, call this method
Args:
update_name (str): update_name to register for
index_name (str): name of index to add
table (str): table to add index to
columns (list[str]): columns/expressions to include in index
unique (bool): true to make a UNIQUE index
update_name: update_name to register for
index_name: name of index to add
table: table to add index to
columns: columns/expressions to include in index
unique: true to make a UNIQUE index
psql_only: true to only create this index on psql databases (useful
for virtual sqlite tables)
"""
def create_index_psql(conn):
def create_index_psql(conn: Connection) -> None:
conn.rollback()
# postgres insists on autocommit for the index
conn.set_session(autocommit=True)
conn.set_session(autocommit=True) # type: ignore
try:
c = conn.cursor()
@ -371,9 +384,9 @@ class BackgroundUpdater:
logger.debug("[SQL] %s", sql)
c.execute(sql)
finally:
conn.set_session(autocommit=False)
conn.set_session(autocommit=False) # type: ignore
def create_index_sqlite(conn):
def create_index_sqlite(conn: Connection) -> None:
# Sqlite doesn't support concurrent creation of indexes.
#
# We don't use partial indices on SQLite as it wasn't introduced
@ -399,7 +412,7 @@ class BackgroundUpdater:
c.execute(sql)
if isinstance(self.db_pool.engine, engines.PostgresEngine):
runner = create_index_psql
runner = create_index_psql # type: Optional[Callable[[Connection], None]]
elif psql_only:
runner = None
else:
@ -433,7 +446,9 @@ class BackgroundUpdater:
"background_updates", keyvalues={"update_name": update_name}
)
async def _background_update_progress(self, update_name: str, progress: dict):
async def _background_update_progress(
self, update_name: str, progress: dict
) -> None:
"""Update the progress of a background update
Args:
@ -441,20 +456,22 @@ class BackgroundUpdater:
progress: The progress of the update.
"""
return await self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"background_update_progress",
self._background_update_progress_txn,
update_name,
progress,
)
def _background_update_progress_txn(self, txn, update_name, progress):
def _background_update_progress_txn(
self, txn: "LoggingTransaction", update_name: str, progress: JsonDict
) -> None:
"""Update the progress of a background update
Args:
txn(cursor): The transaction.
update_name(str): The name of the background update task
progress(dict): The progress of the update.
txn: The transaction.
update_name: The name of the background update task
progress: The progress of the update.
"""
progress_json = json_encoder.encode(progress)