mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-05-02 11:06:07 -04:00
Add some type annotations in synapse.storage
(#6987)
I cracked, and added some type definitions in synapse.storage.
This commit is contained in:
parent
3e99528f2b
commit
132b673dbe
8 changed files with 270 additions and 84 deletions
|
@ -15,9 +15,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import Iterable, Tuple
|
||||
from time import monotonic as monotonic_time
|
||||
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
|
||||
|
||||
from six import iteritems, iterkeys, itervalues
|
||||
from six.moves import intern, range
|
||||
|
@ -32,24 +32,14 @@ from synapse.config.database import DatabaseConnectionConfig
|
|||
from synapse.logging.context import LoggingContext, make_deferred_yieldable
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.background_updates import BackgroundUpdater
|
||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
||||
from synapse.storage.types import Connection, Cursor
|
||||
from synapse.util.stringutils import exception_to_unicode
|
||||
|
||||
# import a function which will return a monotonic time, in seconds
|
||||
try:
|
||||
# on python 3, use time.monotonic, since time.clock can go backwards
|
||||
from time import monotonic as monotonic_time
|
||||
except ImportError:
|
||||
# ... but python 2 doesn't have it
|
||||
from time import clock as monotonic_time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
MAX_TXN_ID = sys.maxint - 1
|
||||
except AttributeError:
|
||||
# python 3 does not have a maximum int value
|
||||
MAX_TXN_ID = 2 ** 63 - 1
|
||||
# python 3 does not have a maximum int value
|
||||
MAX_TXN_ID = 2 ** 63 - 1
|
||||
|
||||
sql_logger = logging.getLogger("synapse.storage.SQL")
|
||||
transaction_logger = logging.getLogger("synapse.storage.txn")
|
||||
|
@ -77,7 +67,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
|
|||
|
||||
|
||||
def make_pool(
|
||||
reactor, db_config: DatabaseConnectionConfig, engine
|
||||
reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
|
||||
) -> adbapi.ConnectionPool:
|
||||
"""Get the connection pool for the database.
|
||||
"""
|
||||
|
@ -90,7 +80,9 @@ def make_pool(
|
|||
)
|
||||
|
||||
|
||||
def make_conn(db_config: DatabaseConnectionConfig, engine):
|
||||
def make_conn(
|
||||
db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
|
||||
) -> Connection:
|
||||
"""Make a new connection to the database and return it.
|
||||
|
||||
Returns:
|
||||
|
@ -107,20 +99,27 @@ def make_conn(db_config: DatabaseConnectionConfig, engine):
|
|||
return db_conn
|
||||
|
||||
|
||||
class LoggingTransaction(object):
|
||||
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
|
||||
#
|
||||
# Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
|
||||
# that mypy sees the type but the runtime python doesn't.
|
||||
_CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]
|
||||
|
||||
|
||||
class LoggingTransaction:
|
||||
"""An object that almost-transparently proxies for the 'txn' object
|
||||
passed to the constructor. Adds logging and metrics to the .execute()
|
||||
method.
|
||||
|
||||
Args:
|
||||
txn: The database transcation object to wrap.
|
||||
name (str): The name of this transactions for logging.
|
||||
database_engine (Sqlite3Engine|PostgresEngine)
|
||||
after_callbacks(list|None): A list that callbacks will be appended to
|
||||
name: The name of this transactions for logging.
|
||||
database_engine
|
||||
after_callbacks: A list that callbacks will be appended to
|
||||
that have been added by `call_after` which should be run on
|
||||
successful completion of the transaction. None indicates that no
|
||||
callbacks should be allowed to be scheduled to run.
|
||||
exception_callbacks(list|None): A list that callbacks will be appended
|
||||
exception_callbacks: A list that callbacks will be appended
|
||||
to that have been added by `call_on_exception` which should be run
|
||||
if transaction ends with an error. None indicates that no callbacks
|
||||
should be allowed to be scheduled to run.
|
||||
|
@ -135,46 +134,67 @@ class LoggingTransaction(object):
|
|||
]
|
||||
|
||||
def __init__(
|
||||
self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None
|
||||
self,
|
||||
txn: Cursor,
|
||||
name: str,
|
||||
database_engine: BaseDatabaseEngine,
|
||||
after_callbacks: Optional[List[_CallbackListEntry]] = None,
|
||||
exception_callbacks: Optional[List[_CallbackListEntry]] = None,
|
||||
):
|
||||
object.__setattr__(self, "txn", txn)
|
||||
object.__setattr__(self, "name", name)
|
||||
object.__setattr__(self, "database_engine", database_engine)
|
||||
object.__setattr__(self, "after_callbacks", after_callbacks)
|
||||
object.__setattr__(self, "exception_callbacks", exception_callbacks)
|
||||
self.txn = txn
|
||||
self.name = name
|
||||
self.database_engine = database_engine
|
||||
self.after_callbacks = after_callbacks
|
||||
self.exception_callbacks = exception_callbacks
|
||||
|
||||
def call_after(self, callback, *args, **kwargs):
|
||||
def call_after(self, callback: "Callable[..., None]", *args, **kwargs):
|
||||
"""Call the given callback on the main twisted thread after the
|
||||
transaction has finished. Used to invalidate the caches on the
|
||||
correct thread.
|
||||
"""
|
||||
# if self.after_callbacks is None, that means that whatever constructed the
|
||||
# LoggingTransaction isn't expecting there to be any callbacks; assert that
|
||||
# is not the case.
|
||||
assert self.after_callbacks is not None
|
||||
self.after_callbacks.append((callback, args, kwargs))
|
||||
|
||||
def call_on_exception(self, callback, *args, **kwargs):
|
||||
def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs):
|
||||
# if self.exception_callbacks is None, that means that whatever constructed the
|
||||
# LoggingTransaction isn't expecting there to be any callbacks; assert that
|
||||
# is not the case.
|
||||
assert self.exception_callbacks is not None
|
||||
self.exception_callbacks.append((callback, args, kwargs))
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.txn, name)
|
||||
def fetchall(self) -> List[Tuple]:
|
||||
return self.txn.fetchall()
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
setattr(self.txn, name, value)
|
||||
def fetchone(self) -> Tuple:
|
||||
return self.txn.fetchone()
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterator[Tuple]:
|
||||
return self.txn.__iter__()
|
||||
|
||||
@property
|
||||
def rowcount(self) -> int:
|
||||
return self.txn.rowcount
|
||||
|
||||
@property
|
||||
def description(self) -> Any:
|
||||
return self.txn.description
|
||||
|
||||
def execute_batch(self, sql, args):
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
from psycopg2.extras import execute_batch
|
||||
from psycopg2.extras import execute_batch # type: ignore
|
||||
|
||||
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
|
||||
else:
|
||||
for val in args:
|
||||
self.execute(sql, val)
|
||||
|
||||
def execute(self, sql, *args):
|
||||
def execute(self, sql: str, *args: Any):
|
||||
self._do_execute(self.txn.execute, sql, *args)
|
||||
|
||||
def executemany(self, sql, *args):
|
||||
def executemany(self, sql: str, *args: Any):
|
||||
self._do_execute(self.txn.executemany, sql, *args)
|
||||
|
||||
def _make_sql_one_line(self, sql):
|
||||
|
@ -207,6 +227,9 @@ class LoggingTransaction(object):
|
|||
sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
|
||||
sql_query_timer.labels(sql.split()[0]).observe(secs)
|
||||
|
||||
def close(self):
|
||||
self.txn.close()
|
||||
|
||||
|
||||
class PerformanceCounters(object):
|
||||
def __init__(self):
|
||||
|
@ -251,7 +274,9 @@ class Database(object):
|
|||
|
||||
_TXN_ID = 0
|
||||
|
||||
def __init__(self, hs, database_config: DatabaseConnectionConfig, engine):
|
||||
def __init__(
|
||||
self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
|
||||
):
|
||||
self.hs = hs
|
||||
self._clock = hs.get_clock()
|
||||
self._database_config = database_config
|
||||
|
@ -259,9 +284,9 @@ class Database(object):
|
|||
|
||||
self.updates = BackgroundUpdater(hs, self)
|
||||
|
||||
self._previous_txn_total_time = 0
|
||||
self._current_txn_total_time = 0
|
||||
self._previous_loop_ts = 0
|
||||
self._previous_txn_total_time = 0.0
|
||||
self._current_txn_total_time = 0.0
|
||||
self._previous_loop_ts = 0.0
|
||||
|
||||
# TODO(paul): These can eventually be removed once the metrics code
|
||||
# is running in mainline, and we have some nice monitoring frontends
|
||||
|
@ -463,23 +488,23 @@ class Database(object):
|
|||
sql_txn_timer.labels(desc).observe(duration)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def runInteraction(self, desc, func, *args, **kwargs):
|
||||
def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any):
|
||||
"""Starts a transaction on the database and runs a given function
|
||||
|
||||
Arguments:
|
||||
desc (str): description of the transaction, for logging and metrics
|
||||
func (func): callback function, which will be called with a
|
||||
desc: description of the transaction, for logging and metrics
|
||||
func: callback function, which will be called with a
|
||||
database transaction (twisted.enterprise.adbapi.Transaction) as
|
||||
its first argument, followed by `args` and `kwargs`.
|
||||
|
||||
args (list): positional args to pass to `func`
|
||||
kwargs (dict): named args to pass to `func`
|
||||
args: positional args to pass to `func`
|
||||
kwargs: named args to pass to `func`
|
||||
|
||||
Returns:
|
||||
Deferred: The result of func
|
||||
"""
|
||||
after_callbacks = []
|
||||
exception_callbacks = []
|
||||
after_callbacks = [] # type: List[_CallbackListEntry]
|
||||
exception_callbacks = [] # type: List[_CallbackListEntry]
|
||||
|
||||
if LoggingContext.current_context() == LoggingContext.sentinel:
|
||||
logger.warning("Starting db txn '%s' from sentinel context", desc)
|
||||
|
@ -505,15 +530,15 @@ class Database(object):
|
|||
return result
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def runWithConnection(self, func, *args, **kwargs):
|
||||
def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any):
|
||||
"""Wraps the .runWithConnection() method on the underlying db_pool.
|
||||
|
||||
Arguments:
|
||||
func (func): callback function, which will be called with a
|
||||
func: callback function, which will be called with a
|
||||
database connection (twisted.enterprise.adbapi.Connection) as
|
||||
its first argument, followed by `args` and `kwargs`.
|
||||
args (list): positional args to pass to `func`
|
||||
kwargs (dict): named args to pass to `func`
|
||||
args: positional args to pass to `func`
|
||||
kwargs: named args to pass to `func`
|
||||
|
||||
Returns:
|
||||
Deferred: The result of func
|
||||
|
@ -800,7 +825,7 @@ class Database(object):
|
|||
return False
|
||||
|
||||
# We didn't find any existing rows, so insert a new one
|
||||
allvalues = {}
|
||||
allvalues = {} # type: Dict[str, Any]
|
||||
allvalues.update(keyvalues)
|
||||
allvalues.update(values)
|
||||
allvalues.update(insertion_values)
|
||||
|
@ -829,7 +854,7 @@ class Database(object):
|
|||
Returns:
|
||||
None
|
||||
"""
|
||||
allvalues = {}
|
||||
allvalues = {} # type: Dict[str, Any]
|
||||
allvalues.update(keyvalues)
|
||||
allvalues.update(insertion_values)
|
||||
|
||||
|
@ -916,7 +941,7 @@ class Database(object):
|
|||
Returns:
|
||||
None
|
||||
"""
|
||||
allnames = []
|
||||
allnames = [] # type: List[str]
|
||||
allnames.extend(key_names)
|
||||
allnames.extend(value_names)
|
||||
|
||||
|
@ -1100,7 +1125,7 @@ class Database(object):
|
|||
keyvalues : dict of column names and values to select the rows with
|
||||
retcols : list of strings giving the names of the columns to return
|
||||
"""
|
||||
results = []
|
||||
results = [] # type: List[Dict[str, Any]]
|
||||
|
||||
if not iterable:
|
||||
return results
|
||||
|
@ -1439,7 +1464,7 @@ class Database(object):
|
|||
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
|
||||
|
||||
where_clause = "WHERE " if filters or keyvalues else ""
|
||||
arg_list = []
|
||||
arg_list = [] # type: List[Any]
|
||||
if filters:
|
||||
where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
|
||||
arg_list += list(filters.values())
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue