Add some type annotations in synapse.storage (#6987)

I cracked, and added some type definitions in synapse.storage.
This commit is contained in:
Richard van der Hoff 2020-02-27 11:53:40 +00:00 committed by GitHub
parent 3e99528f2b
commit 132b673dbe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 270 additions and 84 deletions

1
changelog.d/6987.misc Normal file
View File

@ -0,0 +1 @@
Add some type annotations to the database storage classes.

View File

@ -15,9 +15,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import sys
import time import time
from typing import Iterable, Tuple from time import monotonic as monotonic_time
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
from six import iteritems, iterkeys, itervalues from six import iteritems, iterkeys, itervalues
from six.moves import intern, range from six.moves import intern, range
@ -32,22 +32,12 @@ from synapse.config.database import DatabaseConnectionConfig
from synapse.logging.context import LoggingContext, make_deferred_yieldable from synapse.logging.context import LoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
from synapse.util.stringutils import exception_to_unicode from synapse.util.stringutils import exception_to_unicode
# import a function which will return a monotonic time, in seconds
try:
# on python 3, use time.monotonic, since time.clock can go backwards
from time import monotonic as monotonic_time
except ImportError:
# ... but python 2 doesn't have it
from time import clock as monotonic_time
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try:
MAX_TXN_ID = sys.maxint - 1
except AttributeError:
# python 3 does not have a maximum int value # python 3 does not have a maximum int value
MAX_TXN_ID = 2 ** 63 - 1 MAX_TXN_ID = 2 ** 63 - 1
@ -77,7 +67,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
def make_pool( def make_pool(
reactor, db_config: DatabaseConnectionConfig, engine reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
) -> adbapi.ConnectionPool: ) -> adbapi.ConnectionPool:
"""Get the connection pool for the database. """Get the connection pool for the database.
""" """
@ -90,7 +80,9 @@ def make_pool(
) )
def make_conn(db_config: DatabaseConnectionConfig, engine): def make_conn(
db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
) -> Connection:
"""Make a new connection to the database and return it. """Make a new connection to the database and return it.
Returns: Returns:
@ -107,20 +99,27 @@ def make_conn(db_config: DatabaseConnectionConfig, engine):
return db_conn return db_conn
class LoggingTransaction(object): # The type of entry which goes on our after_callbacks and exception_callbacks lists.
#
# Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
# that mypy sees the type but the runtime python doesn't.
_CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]
class LoggingTransaction:
"""An object that almost-transparently proxies for the 'txn' object """An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute() passed to the constructor. Adds logging and metrics to the .execute()
method. method.
Args: Args:
txn: The database transcation object to wrap. txn: The database transcation object to wrap.
name (str): The name of this transactions for logging. name: The name of this transactions for logging.
database_engine (Sqlite3Engine|PostgresEngine) database_engine
after_callbacks(list|None): A list that callbacks will be appended to after_callbacks: A list that callbacks will be appended to
that have been added by `call_after` which should be run on that have been added by `call_after` which should be run on
successful completion of the transaction. None indicates that no successful completion of the transaction. None indicates that no
callbacks should be allowed to be scheduled to run. callbacks should be allowed to be scheduled to run.
exception_callbacks(list|None): A list that callbacks will be appended exception_callbacks: A list that callbacks will be appended
to that have been added by `call_on_exception` which should be run to that have been added by `call_on_exception` which should be run
if transaction ends with an error. None indicates that no callbacks if transaction ends with an error. None indicates that no callbacks
should be allowed to be scheduled to run. should be allowed to be scheduled to run.
@ -135,46 +134,67 @@ class LoggingTransaction(object):
] ]
def __init__( def __init__(
self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None self,
txn: Cursor,
name: str,
database_engine: BaseDatabaseEngine,
after_callbacks: Optional[List[_CallbackListEntry]] = None,
exception_callbacks: Optional[List[_CallbackListEntry]] = None,
): ):
object.__setattr__(self, "txn", txn) self.txn = txn
object.__setattr__(self, "name", name) self.name = name
object.__setattr__(self, "database_engine", database_engine) self.database_engine = database_engine
object.__setattr__(self, "after_callbacks", after_callbacks) self.after_callbacks = after_callbacks
object.__setattr__(self, "exception_callbacks", exception_callbacks) self.exception_callbacks = exception_callbacks
def call_after(self, callback, *args, **kwargs): def call_after(self, callback: "Callable[..., None]", *args, **kwargs):
"""Call the given callback on the main twisted thread after the """Call the given callback on the main twisted thread after the
transaction has finished. Used to invalidate the caches on the transaction has finished. Used to invalidate the caches on the
correct thread. correct thread.
""" """
# if self.after_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that
# is not the case.
assert self.after_callbacks is not None
self.after_callbacks.append((callback, args, kwargs)) self.after_callbacks.append((callback, args, kwargs))
def call_on_exception(self, callback, *args, **kwargs): def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs):
# if self.exception_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that
# is not the case.
assert self.exception_callbacks is not None
self.exception_callbacks.append((callback, args, kwargs)) self.exception_callbacks.append((callback, args, kwargs))
def __getattr__(self, name): def fetchall(self) -> List[Tuple]:
return getattr(self.txn, name) return self.txn.fetchall()
def __setattr__(self, name, value): def fetchone(self) -> Tuple:
setattr(self.txn, name, value) return self.txn.fetchone()
def __iter__(self): def __iter__(self) -> Iterator[Tuple]:
return self.txn.__iter__() return self.txn.__iter__()
@property
def rowcount(self) -> int:
return self.txn.rowcount
@property
def description(self) -> Any:
return self.txn.description
def execute_batch(self, sql, args): def execute_batch(self, sql, args):
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import execute_batch from psycopg2.extras import execute_batch # type: ignore
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args) self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
else: else:
for val in args: for val in args:
self.execute(sql, val) self.execute(sql, val)
def execute(self, sql, *args): def execute(self, sql: str, *args: Any):
self._do_execute(self.txn.execute, sql, *args) self._do_execute(self.txn.execute, sql, *args)
def executemany(self, sql, *args): def executemany(self, sql: str, *args: Any):
self._do_execute(self.txn.executemany, sql, *args) self._do_execute(self.txn.executemany, sql, *args)
def _make_sql_one_line(self, sql): def _make_sql_one_line(self, sql):
@ -207,6 +227,9 @@ class LoggingTransaction(object):
sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs) sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
sql_query_timer.labels(sql.split()[0]).observe(secs) sql_query_timer.labels(sql.split()[0]).observe(secs)
def close(self):
self.txn.close()
class PerformanceCounters(object): class PerformanceCounters(object):
def __init__(self): def __init__(self):
@ -251,7 +274,9 @@ class Database(object):
_TXN_ID = 0 _TXN_ID = 0
def __init__(self, hs, database_config: DatabaseConnectionConfig, engine): def __init__(
self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
):
self.hs = hs self.hs = hs
self._clock = hs.get_clock() self._clock = hs.get_clock()
self._database_config = database_config self._database_config = database_config
@ -259,9 +284,9 @@ class Database(object):
self.updates = BackgroundUpdater(hs, self) self.updates = BackgroundUpdater(hs, self)
self._previous_txn_total_time = 0 self._previous_txn_total_time = 0.0
self._current_txn_total_time = 0 self._current_txn_total_time = 0.0
self._previous_loop_ts = 0 self._previous_loop_ts = 0.0
# TODO(paul): These can eventually be removed once the metrics code # TODO(paul): These can eventually be removed once the metrics code
# is running in mainline, and we have some nice monitoring frontends # is running in mainline, and we have some nice monitoring frontends
@ -463,23 +488,23 @@ class Database(object):
sql_txn_timer.labels(desc).observe(duration) sql_txn_timer.labels(desc).observe(duration)
@defer.inlineCallbacks @defer.inlineCallbacks
def runInteraction(self, desc, func, *args, **kwargs): def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any):
"""Starts a transaction on the database and runs a given function """Starts a transaction on the database and runs a given function
Arguments: Arguments:
desc (str): description of the transaction, for logging and metrics desc: description of the transaction, for logging and metrics
func (func): callback function, which will be called with a func: callback function, which will be called with a
database transaction (twisted.enterprise.adbapi.Transaction) as database transaction (twisted.enterprise.adbapi.Transaction) as
its first argument, followed by `args` and `kwargs`. its first argument, followed by `args` and `kwargs`.
args (list): positional args to pass to `func` args: positional args to pass to `func`
kwargs (dict): named args to pass to `func` kwargs: named args to pass to `func`
Returns: Returns:
Deferred: The result of func Deferred: The result of func
""" """
after_callbacks = [] after_callbacks = [] # type: List[_CallbackListEntry]
exception_callbacks = [] exception_callbacks = [] # type: List[_CallbackListEntry]
if LoggingContext.current_context() == LoggingContext.sentinel: if LoggingContext.current_context() == LoggingContext.sentinel:
logger.warning("Starting db txn '%s' from sentinel context", desc) logger.warning("Starting db txn '%s' from sentinel context", desc)
@ -505,15 +530,15 @@ class Database(object):
return result return result
@defer.inlineCallbacks @defer.inlineCallbacks
def runWithConnection(self, func, *args, **kwargs): def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any):
"""Wraps the .runWithConnection() method on the underlying db_pool. """Wraps the .runWithConnection() method on the underlying db_pool.
Arguments: Arguments:
func (func): callback function, which will be called with a func: callback function, which will be called with a
database connection (twisted.enterprise.adbapi.Connection) as database connection (twisted.enterprise.adbapi.Connection) as
its first argument, followed by `args` and `kwargs`. its first argument, followed by `args` and `kwargs`.
args (list): positional args to pass to `func` args: positional args to pass to `func`
kwargs (dict): named args to pass to `func` kwargs: named args to pass to `func`
Returns: Returns:
Deferred: The result of func Deferred: The result of func
@ -800,7 +825,7 @@ class Database(object):
return False return False
# We didn't find any existing rows, so insert a new one # We didn't find any existing rows, so insert a new one
allvalues = {} allvalues = {} # type: Dict[str, Any]
allvalues.update(keyvalues) allvalues.update(keyvalues)
allvalues.update(values) allvalues.update(values)
allvalues.update(insertion_values) allvalues.update(insertion_values)
@ -829,7 +854,7 @@ class Database(object):
Returns: Returns:
None None
""" """
allvalues = {} allvalues = {} # type: Dict[str, Any]
allvalues.update(keyvalues) allvalues.update(keyvalues)
allvalues.update(insertion_values) allvalues.update(insertion_values)
@ -916,7 +941,7 @@ class Database(object):
Returns: Returns:
None None
""" """
allnames = [] allnames = [] # type: List[str]
allnames.extend(key_names) allnames.extend(key_names)
allnames.extend(value_names) allnames.extend(value_names)
@ -1100,7 +1125,7 @@ class Database(object):
keyvalues : dict of column names and values to select the rows with keyvalues : dict of column names and values to select the rows with
retcols : list of strings giving the names of the columns to return retcols : list of strings giving the names of the columns to return
""" """
results = [] results = [] # type: List[Dict[str, Any]]
if not iterable: if not iterable:
return results return results
@ -1439,7 +1464,7 @@ class Database(object):
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
where_clause = "WHERE " if filters or keyvalues else "" where_clause = "WHERE " if filters or keyvalues else ""
arg_list = [] arg_list = [] # type: List[Any]
if filters: if filters:
where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters) where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
arg_list += list(filters.values()) arg_list += list(filters.values())

View File

@ -12,29 +12,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib
import platform import platform
from ._base import IncorrectDatabaseSetup from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
from .postgres import PostgresEngine from .postgres import PostgresEngine
from .sqlite import Sqlite3Engine from .sqlite import Sqlite3Engine
SUPPORTED_MODULE = {"sqlite3": Sqlite3Engine, "psycopg2": PostgresEngine}
def create_engine(database_config) -> BaseDatabaseEngine:
def create_engine(database_config):
name = database_config["name"] name = database_config["name"]
engine_class = SUPPORTED_MODULE.get(name, None)
if engine_class: if name == "sqlite3":
import sqlite3
return Sqlite3Engine(sqlite3, database_config)
if name == "psycopg2":
# pypy requires psycopg2cffi rather than psycopg2 # pypy requires psycopg2cffi rather than psycopg2
if name == "psycopg2" and platform.python_implementation() == "PyPy": if platform.python_implementation() == "PyPy":
name = "psycopg2cffi" import psycopg2cffi as psycopg2 # type: ignore
module = importlib.import_module(name) else:
return engine_class(module, database_config) import psycopg2 # type: ignore
return PostgresEngine(psycopg2, database_config)
raise RuntimeError("Unsupported database engine '%s'" % (name,)) raise RuntimeError("Unsupported database engine '%s'" % (name,))
__all__ = ["create_engine", "IncorrectDatabaseSetup"] __all__ = ["create_engine", "BaseDatabaseEngine", "IncorrectDatabaseSetup"]

View File

@ -12,7 +12,94 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc
from typing import Generic, TypeVar
from synapse.storage.types import Connection
class IncorrectDatabaseSetup(RuntimeError): class IncorrectDatabaseSetup(RuntimeError):
pass pass
ConnectionType = TypeVar("ConnectionType", bound=Connection)
class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
def __init__(self, module, database_config: dict):
self.module = module
@property
@abc.abstractmethod
def single_threaded(self) -> bool:
...
@property
@abc.abstractmethod
def can_native_upsert(self) -> bool:
"""
Do we support native UPSERTs?
"""
...
@property
@abc.abstractmethod
def supports_tuple_comparison(self) -> bool:
"""
Do we support comparing tuples, i.e. `(a, b) > (c, d)`?
"""
...
@property
@abc.abstractmethod
def supports_using_any_list(self) -> bool:
"""
Do we support using `a = ANY(?)` and passing a list
"""
...
@abc.abstractmethod
def check_database(
self, db_conn: ConnectionType, allow_outdated_version: bool = False
) -> None:
...
@abc.abstractmethod
def check_new_database(self, txn) -> None:
"""Gets called when setting up a brand new database. This allows us to
apply stricter checks on new databases versus existing database.
"""
...
@abc.abstractmethod
def convert_param_style(self, sql: str) -> str:
...
@abc.abstractmethod
def on_new_connection(self, db_conn: ConnectionType) -> None:
...
@abc.abstractmethod
def is_deadlock(self, error: Exception) -> bool:
...
@abc.abstractmethod
def is_connection_closed(self, conn: ConnectionType) -> bool:
...
@abc.abstractmethod
def lock_table(self, txn, table: str) -> None:
...
@abc.abstractmethod
def get_next_state_group_id(self, txn) -> int:
"""Returns an int that can be used as a new state_group ID
"""
...
@property
@abc.abstractmethod
def server_version(self) -> str:
"""Gets a string giving the server version. For example: '3.22.0'
"""
...

View File

@ -15,16 +15,14 @@
import logging import logging
from ._base import IncorrectDatabaseSetup from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PostgresEngine(object): class PostgresEngine(BaseDatabaseEngine):
single_threaded = False
def __init__(self, database_module, database_config): def __init__(self, database_module, database_config):
self.module = database_module super().__init__(database_module, database_config)
self.module.extensions.register_type(self.module.extensions.UNICODE) self.module.extensions.register_type(self.module.extensions.UNICODE)
# Disables passing `bytes` to txn.execute, c.f. #6186. If you do # Disables passing `bytes` to txn.execute, c.f. #6186. If you do
@ -36,6 +34,10 @@ class PostgresEngine(object):
self.synchronous_commit = database_config.get("synchronous_commit", True) self.synchronous_commit = database_config.get("synchronous_commit", True)
self._version = None # unknown as yet self._version = None # unknown as yet
@property
def single_threaded(self) -> bool:
return False
def check_database(self, db_conn, allow_outdated_version: bool = False): def check_database(self, db_conn, allow_outdated_version: bool = False):
# Get the version of PostgreSQL that we're using. As per the psycopg2 # Get the version of PostgreSQL that we're using. As per the psycopg2
# docs: The number is formed by converting the major, minor, and # docs: The number is formed by converting the major, minor, and

View File

@ -12,16 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sqlite3
import struct import struct
import threading import threading
from synapse.storage.engines import BaseDatabaseEngine
class Sqlite3Engine(object):
single_threaded = True
class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]):
def __init__(self, database_module, database_config): def __init__(self, database_module, database_config):
self.module = database_module super().__init__(database_module, database_config)
database = database_config.get("args", {}).get("database") database = database_config.get("args", {}).get("database")
self._is_in_memory = database in (None, ":memory:",) self._is_in_memory = database in (None, ":memory:",)
@ -31,6 +31,10 @@ class Sqlite3Engine(object):
self._current_state_group_id = None self._current_state_group_id = None
self._current_state_group_id_lock = threading.Lock() self._current_state_group_id_lock = threading.Lock()
@property
def single_threaded(self) -> bool:
return True
@property @property
def can_native_upsert(self): def can_native_upsert(self):
""" """
@ -68,7 +72,6 @@ class Sqlite3Engine(object):
return sql return sql
def on_new_connection(self, db_conn): def on_new_connection(self, db_conn):
# We need to import here to avoid an import loop. # We need to import here to avoid an import loop.
from synapse.storage.prepare_database import prepare_database from synapse.storage.prepare_database import prepare_database

65
synapse/storage/types.py Normal file
View File

@ -0,0 +1,65 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Iterable, Iterator, List, Tuple
from typing_extensions import Protocol
"""
Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
"""
class Cursor(Protocol):
def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any:
...
def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any:
...
def fetchall(self) -> List[Tuple]:
...
def fetchone(self) -> Tuple:
...
@property
def description(self) -> Any:
return None
@property
def rowcount(self) -> int:
return 0
def __iter__(self) -> Iterator[Tuple]:
...
def close(self) -> None:
...
class Connection(Protocol):
def cursor(self) -> Cursor:
...
def close(self) -> None:
...
def commit(self) -> None:
...
def rollback(self, *args, **kwargs) -> None:
...

View File

@ -168,7 +168,6 @@ commands=
coverage html coverage html
[testenv:mypy] [testenv:mypy]
basepython = python3.7
skip_install = True skip_install = True
deps = deps =
{[base]deps} {[base]deps}
@ -179,7 +178,8 @@ env =
extras = all extras = all
commands = mypy \ commands = mypy \
synapse/api \ synapse/api \
synapse/config/ \ synapse/appservice \
synapse/config \
synapse/events/spamcheck.py \ synapse/events/spamcheck.py \
synapse/federation/sender \ synapse/federation/sender \
synapse/federation/transport \ synapse/federation/transport \
@ -192,6 +192,7 @@ commands = mypy \
synapse/rest \ synapse/rest \
synapse/spam_checker_api \ synapse/spam_checker_api \
synapse/storage/engines \ synapse/storage/engines \
synapse/storage/database.py \
synapse/streams synapse/streams
# To find all folders that pass mypy you run: # To find all folders that pass mypy you run: