Add more types to synapse.storage.database. (#8127)

This commit is contained in:
Patrick Cloke 2020-08-20 09:00:59 -04:00 committed by GitHub
parent 731dfff347
commit 5eac0b7e76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 367 additions and 222 deletions

1
changelog.d/8127.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to `synapse.storage.database`.

View File

@ -28,6 +28,7 @@ from typing import (
Optional, Optional,
Tuple, Tuple,
TypeVar, TypeVar,
Union,
) )
from prometheus_client import Histogram from prometheus_client import Histogram
@ -125,7 +126,7 @@ class LoggingTransaction:
method. method.
Args: Args:
txn: The database transcation object to wrap. txn: The database transaction object to wrap.
name: The name of this transactions for logging. name: The name of this transactions for logging.
database_engine database_engine
after_callbacks: A list that callbacks will be appended to after_callbacks: A list that callbacks will be appended to
@ -160,7 +161,7 @@ class LoggingTransaction:
self.after_callbacks = after_callbacks self.after_callbacks = after_callbacks
self.exception_callbacks = exception_callbacks self.exception_callbacks = exception_callbacks
def call_after(self, callback: "Callable[..., None]", *args, **kwargs): def call_after(self, callback: "Callable[..., None]", *args: Any, **kwargs: Any):
"""Call the given callback on the main twisted thread after the """Call the given callback on the main twisted thread after the
transaction has finished. Used to invalidate the caches on the transaction has finished. Used to invalidate the caches on the
correct thread. correct thread.
@ -171,7 +172,9 @@ class LoggingTransaction:
assert self.after_callbacks is not None assert self.after_callbacks is not None
self.after_callbacks.append((callback, args, kwargs)) self.after_callbacks.append((callback, args, kwargs))
def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs): def call_on_exception(
self, callback: "Callable[..., None]", *args: Any, **kwargs: Any
):
# if self.exception_callbacks is None, that means that whatever constructed the # if self.exception_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that # LoggingTransaction isn't expecting there to be any callbacks; assert that
# is not the case. # is not the case.
@ -195,7 +198,7 @@ class LoggingTransaction:
def description(self) -> Any: def description(self) -> Any:
return self.txn.description return self.txn.description
def execute_batch(self, sql, args): def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import execute_batch # type: ignore from psycopg2.extras import execute_batch # type: ignore
@ -204,17 +207,17 @@ class LoggingTransaction:
for val in args: for val in args:
self.execute(sql, val) self.execute(sql, val)
def execute(self, sql: str, *args: Any): def execute(self, sql: str, *args: Any) -> None:
self._do_execute(self.txn.execute, sql, *args) self._do_execute(self.txn.execute, sql, *args)
def executemany(self, sql: str, *args: Any): def executemany(self, sql: str, *args: Any) -> None:
self._do_execute(self.txn.executemany, sql, *args) self._do_execute(self.txn.executemany, sql, *args)
def _make_sql_one_line(self, sql: str) -> str: def _make_sql_one_line(self, sql: str) -> str:
"Strip newlines out of SQL so that the loggers in the DB are on one line" "Strip newlines out of SQL so that the loggers in the DB are on one line"
return " ".join(line.strip() for line in sql.splitlines() if line.strip()) return " ".join(line.strip() for line in sql.splitlines() if line.strip())
def _do_execute(self, func, sql, *args): def _do_execute(self, func, sql: str, *args: Any) -> None:
sql = self._make_sql_one_line(sql) sql = self._make_sql_one_line(sql)
# TODO(paul): Maybe use 'info' and 'debug' for values? # TODO(paul): Maybe use 'info' and 'debug' for values?
@ -240,7 +243,7 @@ class LoggingTransaction:
sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs) sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
sql_query_timer.labels(sql.split()[0]).observe(secs) sql_query_timer.labels(sql.split()[0]).observe(secs)
def close(self): def close(self) -> None:
self.txn.close() self.txn.close()
@ -249,13 +252,13 @@ class PerformanceCounters(object):
self.current_counters = {} self.current_counters = {}
self.previous_counters = {} self.previous_counters = {}
def update(self, key, duration_secs): def update(self, key: str, duration_secs: float) -> None:
count, cum_time = self.current_counters.get(key, (0, 0)) count, cum_time = self.current_counters.get(key, (0, 0))
count += 1 count += 1
cum_time += duration_secs cum_time += duration_secs
self.current_counters[key] = (count, cum_time) self.current_counters[key] = (count, cum_time)
def interval(self, interval_duration_secs, limit=3): def interval(self, interval_duration_secs: float, limit: int = 3) -> str:
counters = [] counters = []
for name, (count, cum_time) in self.current_counters.items(): for name, (count, cum_time) in self.current_counters.items():
prev_count, prev_time = self.previous_counters.get(name, (0, 0)) prev_count, prev_time = self.previous_counters.get(name, (0, 0))
@ -279,6 +282,9 @@ class PerformanceCounters(object):
return top_n_counters return top_n_counters
R = TypeVar("R")
class DatabasePool(object): class DatabasePool(object):
"""Wraps a single physical database and connection pool. """Wraps a single physical database and connection pool.
@ -327,12 +333,12 @@ class DatabasePool(object):
self._check_safe_to_upsert, self._check_safe_to_upsert,
) )
def is_running(self): def is_running(self) -> bool:
"""Is the database pool currently running """Is the database pool currently running
""" """
return self._db_pool.running return self._db_pool.running
async def _check_safe_to_upsert(self): async def _check_safe_to_upsert(self) -> None:
""" """
Is it safe to use native UPSERT? Is it safe to use native UPSERT?
@ -363,7 +369,7 @@ class DatabasePool(object):
self._check_safe_to_upsert, self._check_safe_to_upsert,
) )
def start_profiling(self): def start_profiling(self) -> None:
self._previous_loop_ts = monotonic_time() self._previous_loop_ts = monotonic_time()
def loop(): def loop():
@ -387,8 +393,15 @@ class DatabasePool(object):
self._clock.looping_call(loop, 10000) self._clock.looping_call(loop, 10000)
def new_transaction( def new_transaction(
self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs self,
): conn: Connection,
desc: str,
after_callbacks: List[_CallbackListEntry],
exception_callbacks: List[_CallbackListEntry],
func: "Callable[..., R]",
*args: Any,
**kwargs: Any
) -> R:
start = monotonic_time() start = monotonic_time()
txn_id = self._TXN_ID txn_id = self._TXN_ID
@ -537,7 +550,9 @@ class DatabasePool(object):
return result return result
async def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any) -> Any: async def runWithConnection(
self, func: "Callable[..., R]", *args: Any, **kwargs: Any
) -> R:
"""Wraps the .runWithConnection() method on the underlying db_pool. """Wraps the .runWithConnection() method on the underlying db_pool.
Arguments: Arguments:
@ -576,7 +591,7 @@ class DatabasePool(object):
) )
@staticmethod @staticmethod
def cursor_to_dict(cursor): def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]:
"""Converts a SQL cursor into an list of dicts. """Converts a SQL cursor into an list of dicts.
Args: Args:
@ -588,7 +603,7 @@ class DatabasePool(object):
results = [dict(zip(col_headers, row)) for row in cursor] results = [dict(zip(col_headers, row)) for row in cursor]
return results return results
def execute(self, desc, decoder, query, *args): def execute(self, desc: str, decoder: Callable, query: str, *args: Any):
"""Runs a single query for a result set. """Runs a single query for a result set.
Args: Args:
@ -597,7 +612,7 @@ class DatabasePool(object):
query - The query string to execute query - The query string to execute
*args - Query args. *args - Query args.
Returns: Returns:
The result of decoder(results) Deferred which results to the result of decoder(results)
""" """
def interaction(txn): def interaction(txn):
@ -612,7 +627,13 @@ class DatabasePool(object):
# "Simple" SQL API methods that operate on a single table with no JOINs, # "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns. # no complex WHERE clauses, just a dict of values for columns.
async def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"): async def simple_insert(
self,
table: str,
values: Dict[str, Any],
or_ignore: bool = False,
desc: str = "simple_insert",
) -> bool:
"""Executes an INSERT query on the named table. """Executes an INSERT query on the named table.
Args: Args:
@ -624,8 +645,7 @@ class DatabasePool(object):
desc: string giving a description of the transaction desc: string giving a description of the transaction
Returns: Returns:
bool: Whether the row was inserted or not. Only useful when Whether the row was inserted or not. Only useful when `or_ignore` is True
`or_ignore` is True
""" """
try: try:
await self.runInteraction(desc, self.simple_insert_txn, table, values) await self.runInteraction(desc, self.simple_insert_txn, table, values)
@ -638,7 +658,9 @@ class DatabasePool(object):
return True return True
@staticmethod @staticmethod
def simple_insert_txn(txn, table, values): def simple_insert_txn(
txn: LoggingTransaction, table: str, values: Dict[str, Any]
) -> None:
keys, vals = zip(*values.items()) keys, vals = zip(*values.items())
sql = "INSERT INTO %s (%s) VALUES(%s)" % ( sql = "INSERT INTO %s (%s) VALUES(%s)" % (
@ -649,11 +671,15 @@ class DatabasePool(object):
txn.execute(sql, vals) txn.execute(sql, vals)
def simple_insert_many(self, table, values, desc): def simple_insert_many(
self, table: str, values: List[Dict[str, Any]], desc: str
) -> defer.Deferred:
return self.runInteraction(desc, self.simple_insert_many_txn, table, values) return self.runInteraction(desc, self.simple_insert_many_txn, table, values)
@staticmethod @staticmethod
def simple_insert_many_txn(txn, table, values): def simple_insert_many_txn(
txn: LoggingTransaction, table: str, values: List[Dict[str, Any]]
) -> None:
if not values: if not values:
return return
@ -683,13 +709,13 @@ class DatabasePool(object):
async def simple_upsert( async def simple_upsert(
self, self,
table, table: str,
keyvalues, keyvalues: Dict[str, Any],
values, values: Dict[str, Any],
insertion_values={}, insertion_values: Dict[str, Any] = {},
desc="simple_upsert", desc: str = "simple_upsert",
lock=True, lock: bool = True,
): ) -> Optional[bool]:
""" """
`lock` should generally be set to True (the default), but can be set `lock` should generally be set to True (the default), but can be set
@ -703,16 +729,14 @@ class DatabasePool(object):
this table. this table.
Args: Args:
table (str): The table to upsert into table: The table to upsert into
keyvalues (dict): The unique key columns and their new values keyvalues: The unique key columns and their new values
values (dict): The nonunique columns and their new values values: The nonunique columns and their new values
insertion_values (dict): additional key/values to use only when insertion_values: additional key/values to use only when inserting
inserting lock: True to lock the table when doing the upsert.
lock (bool): True to lock the table when doing the upsert.
Returns: Returns:
None or bool: Native upserts always return None. Emulated Native upserts always return None. Emulated upserts return True if a
upserts return True if a new entry was created, False if an existing new entry was created, False if an existing one was updated.
one was updated.
""" """
attempts = 0 attempts = 0
while True: while True:
@ -739,29 +763,34 @@ class DatabasePool(object):
) )
def simple_upsert_txn( def simple_upsert_txn(
self, txn, table, keyvalues, values, insertion_values={}, lock=True self,
): txn: LoggingTransaction,
table: str,
keyvalues: Dict[str, Any],
values: Dict[str, Any],
insertion_values: Dict[str, Any] = {},
lock: bool = True,
) -> Optional[bool]:
""" """
Pick the UPSERT method which works best on the platform. Either the Pick the UPSERT method which works best on the platform. Either the
native one (Pg9.5+, recent SQLites), or fall back to an emulated method. native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
Args: Args:
txn: The transaction to use. txn: The transaction to use.
table (str): The table to upsert into table: The table to upsert into
keyvalues (dict): The unique key tables and their new values keyvalues: The unique key tables and their new values
values (dict): The nonunique columns and their new values values: The nonunique columns and their new values
insertion_values (dict): additional key/values to use only when insertion_values: additional key/values to use only when inserting
inserting lock: True to lock the table when doing the upsert.
lock (bool): True to lock the table when doing the upsert.
Returns: Returns:
None or bool: Native upserts always return None. Emulated Native upserts always return None. Emulated upserts return True if a
upserts return True if a new entry was created, False if an existing new entry was created, False if an existing one was updated.
one was updated.
""" """
if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables: if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
return self.simple_upsert_txn_native_upsert( self.simple_upsert_txn_native_upsert(
txn, table, keyvalues, values, insertion_values=insertion_values txn, table, keyvalues, values, insertion_values=insertion_values
) )
return None
else: else:
return self.simple_upsert_txn_emulated( return self.simple_upsert_txn_emulated(
txn, txn,
@ -773,18 +802,23 @@ class DatabasePool(object):
) )
def simple_upsert_txn_emulated( def simple_upsert_txn_emulated(
self, txn, table, keyvalues, values, insertion_values={}, lock=True self,
): txn: LoggingTransaction,
table: str,
keyvalues: Dict[str, Any],
values: Dict[str, Any],
insertion_values: Dict[str, Any] = {},
lock: bool = True,
) -> bool:
""" """
Args: Args:
table (str): The table to upsert into table: The table to upsert into
keyvalues (dict): The unique key tables and their new values keyvalues: The unique key tables and their new values
values (dict): The nonunique columns and their new values values: The nonunique columns and their new values
insertion_values (dict): additional key/values to use only when insertion_values: additional key/values to use only when inserting
inserting lock: True to lock the table when doing the upsert.
lock (bool): True to lock the table when doing the upsert.
Returns: Returns:
bool: Return True if a new entry was created, False if an existing Returns True if a new entry was created, False if an existing
one was updated. one was updated.
""" """
# We need to lock the table :(, unless we're *really* careful # We need to lock the table :(, unless we're *really* careful
@ -842,19 +876,21 @@ class DatabasePool(object):
return True return True
def simple_upsert_txn_native_upsert( def simple_upsert_txn_native_upsert(
self, txn, table, keyvalues, values, insertion_values={} self,
): txn: LoggingTransaction,
table: str,
keyvalues: Dict[str, Any],
values: Dict[str, Any],
insertion_values: Dict[str, Any] = {},
) -> None:
""" """
Use the native UPSERT functionality in recent PostgreSQL versions. Use the native UPSERT functionality in recent PostgreSQL versions.
Args: Args:
table (str): The table to upsert into table: The table to upsert into
keyvalues (dict): The unique key tables and their new values keyvalues: The unique key tables and their new values
values (dict): The nonunique columns and their new values values: The nonunique columns and their new values
insertion_values (dict): additional key/values to use only when insertion_values: additional key/values to use only when inserting
inserting
Returns:
None
""" """
allvalues = {} # type: Dict[str, Any] allvalues = {} # type: Dict[str, Any]
allvalues.update(keyvalues) allvalues.update(keyvalues)
@ -985,8 +1021,13 @@ class DatabasePool(object):
return txn.execute_batch(sql, args) return txn.execute_batch(sql, args)
def simple_select_one( def simple_select_one(
self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one" self,
): table: str,
keyvalues: Dict[str, Any],
retcols: Iterable[str],
allow_none: bool = False,
desc: str = "simple_select_one",
) -> defer.Deferred:
"""Executes a SELECT query on the named table, which is expected to """Executes a SELECT query on the named table, which is expected to
return a single row, returning multiple columns from it. return a single row, returning multiple columns from it.
@ -994,7 +1035,6 @@ class DatabasePool(object):
table: string giving the table name table: string giving the table name
keyvalues: dict of column names and values to select the row with keyvalues: dict of column names and values to select the row with
retcols: list of strings giving the names of the columns to return retcols: list of strings giving the names of the columns to return
allow_none: If true, return None instead of failing if the SELECT allow_none: If true, return None instead of failing if the SELECT
statement returns no rows statement returns no rows
""" """
@ -1004,12 +1044,12 @@ class DatabasePool(object):
def simple_select_one_onecol( def simple_select_one_onecol(
self, self,
table, table: str,
keyvalues, keyvalues: Dict[str, Any],
retcol, retcol: Iterable[str],
allow_none=False, allow_none: bool = False,
desc="simple_select_one_onecol", desc: str = "simple_select_one_onecol",
): ) -> defer.Deferred:
"""Executes a SELECT query on the named table, which is expected to """Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it. return a single row, returning a single column from it.
@ -1017,6 +1057,9 @@ class DatabasePool(object):
table: string giving the table name table: string giving the table name
keyvalues: dict of column names and values to select the row with keyvalues: dict of column names and values to select the row with
retcol: string giving the name of the column to return retcol: string giving the name of the column to return
allow_none: If true, return None instead of failing if the SELECT
statement returns no rows
desc: description of the transaction, for logging and metrics
""" """
return self.runInteraction( return self.runInteraction(
desc, desc,
@ -1029,8 +1072,13 @@ class DatabasePool(object):
@classmethod @classmethod
def simple_select_one_onecol_txn( def simple_select_one_onecol_txn(
cls, txn, table, keyvalues, retcol, allow_none=False cls,
): txn: LoggingTransaction,
table: str,
keyvalues: Dict[str, Any],
retcol: Iterable[str],
allow_none: bool = False,
) -> Optional[Any]:
ret = cls.simple_select_onecol_txn( ret = cls.simple_select_onecol_txn(
txn, table=table, keyvalues=keyvalues, retcol=retcol txn, table=table, keyvalues=keyvalues, retcol=retcol
) )
@ -1044,7 +1092,12 @@ class DatabasePool(object):
raise StoreError(404, "No row found") raise StoreError(404, "No row found")
@staticmethod @staticmethod
def simple_select_onecol_txn(txn, table, keyvalues, retcol): def simple_select_onecol_txn(
txn: LoggingTransaction,
table: str,
keyvalues: Dict[str, Any],
retcol: Iterable[str],
) -> List[Any]:
sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table} sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
if keyvalues: if keyvalues:
@ -1056,15 +1109,19 @@ class DatabasePool(object):
return [r[0] for r in txn] return [r[0] for r in txn]
def simple_select_onecol( def simple_select_onecol(
self, table, keyvalues, retcol, desc="simple_select_onecol" self,
): table: str,
keyvalues: Optional[Dict[str, Any]],
retcol: str,
desc: str = "simple_select_onecol",
) -> defer.Deferred:
"""Executes a SELECT query on the named table, which returns a list """Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows. comprising of the values of the named column from the selected rows.
Args: Args:
table (str): table name table: table name
keyvalues (dict|None): column names and values to select the rows with keyvalues: column names and values to select the rows with
retcol (str): column whos value we wish to retrieve. retcol: column whos value we wish to retrieve.
Returns: Returns:
Deferred: Results in a list Deferred: Results in a list
@ -1073,16 +1130,22 @@ class DatabasePool(object):
desc, self.simple_select_onecol_txn, table, keyvalues, retcol desc, self.simple_select_onecol_txn, table, keyvalues, retcol
) )
def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"): def simple_select_list(
self,
table: str,
keyvalues: Optional[Dict[str, Any]],
retcols: Iterable[str],
desc: str = "simple_select_list",
) -> defer.Deferred:
"""Executes a SELECT query on the named table, which may return zero or """Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts. more rows, returning the result as a list of dicts.
Args: Args:
table (str): the table name table: the table name
keyvalues (dict[str, Any] | None): keyvalues:
column names and values to select the rows with, or None to not column names and values to select the rows with, or None to not
apply a WHERE clause. apply a WHERE clause.
retcols (iterable[str]): the names of the columns to return retcols: the names of the columns to return
Returns: Returns:
defer.Deferred: resolves to list[dict[str, Any]] defer.Deferred: resolves to list[dict[str, Any]]
""" """
@ -1091,17 +1154,23 @@ class DatabasePool(object):
) )
@classmethod @classmethod
def simple_select_list_txn(cls, txn, table, keyvalues, retcols): def simple_select_list_txn(
cls,
txn: LoggingTransaction,
table: str,
keyvalues: Optional[Dict[str, Any]],
retcols: Iterable[str],
) -> List[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which may return zero or """Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts. more rows, returning the result as a list of dicts.
Args: Args:
txn: Transaction object txn: Transaction object
table (str): the table name table: the table name
keyvalues (dict[str, T] | None): keyvalues:
column names and values to select the rows with, or None to not column names and values to select the rows with, or None to not
apply a WHERE clause. apply a WHERE clause.
retcols (iterable[str]): the names of the columns to return retcols: the names of the columns to return
""" """
if keyvalues: if keyvalues:
sql = "SELECT %s FROM %s WHERE %s" % ( sql = "SELECT %s FROM %s WHERE %s" % (
@ -1118,14 +1187,14 @@ class DatabasePool(object):
async def simple_select_many_batch( async def simple_select_many_batch(
self, self,
table, table: str,
column, column: str,
iterable, iterable: Iterable[Any],
retcols, retcols: Iterable[str],
keyvalues={}, keyvalues: Dict[str, Any] = {},
desc="simple_select_many_batch", desc: str = "simple_select_many_batch",
batch_size=100, batch_size: int = 100,
): ) -> List[Any]:
"""Executes a SELECT query on the named table, which may return zero or """Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts. more rows, returning the result as a list of dicts.
@ -1165,7 +1234,15 @@ class DatabasePool(object):
return results return results
@classmethod @classmethod
def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols): def simple_select_many_txn(
cls,
txn: LoggingTransaction,
table: str,
column: str,
iterable: Iterable[Any],
keyvalues: Dict[str, Any],
retcols: Iterable[str],
) -> List[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which may return zero or """Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts. more rows, returning the result as a list of dicts.
@ -1198,13 +1275,24 @@ class DatabasePool(object):
txn.execute(sql, values) txn.execute(sql, values)
return cls.cursor_to_dict(txn) return cls.cursor_to_dict(txn)
def simple_update(self, table, keyvalues, updatevalues, desc): def simple_update(
self,
table: str,
keyvalues: Dict[str, Any],
updatevalues: Dict[str, Any],
desc: str,
) -> defer.Deferred:
return self.runInteraction( return self.runInteraction(
desc, self.simple_update_txn, table, keyvalues, updatevalues desc, self.simple_update_txn, table, keyvalues, updatevalues
) )
@staticmethod @staticmethod
def simple_update_txn(txn, table, keyvalues, updatevalues): def simple_update_txn(
txn: LoggingTransaction,
table: str,
keyvalues: Dict[str, Any],
updatevalues: Dict[str, Any],
) -> int:
if keyvalues: if keyvalues:
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
else: else:
@ -1221,8 +1309,12 @@ class DatabasePool(object):
return txn.rowcount return txn.rowcount
def simple_update_one( def simple_update_one(
self, table, keyvalues, updatevalues, desc="simple_update_one" self,
): table: str,
keyvalues: Dict[str, Any],
updatevalues: Dict[str, Any],
desc: str = "simple_update_one",
) -> defer.Deferred:
"""Executes an UPDATE query on the named table, setting new values for """Executes an UPDATE query on the named table, setting new values for
columns in a row matching the key values. columns in a row matching the key values.
@ -1230,22 +1322,19 @@ class DatabasePool(object):
table: string giving the table name table: string giving the table name
keyvalues: dict of column names and values to select the row with keyvalues: dict of column names and values to select the row with
updatevalues: dict giving column names and values to update updatevalues: dict giving column names and values to update
retcols : optional list of column names to return
If present, retcols gives a list of column names on which to perform
a SELECT statement *before* performing the UPDATE statement. The values
of these will be returned in a dict.
These are performed within the same transaction, allowing an atomic
get-and-set. This can be used to implement compare-and-set by putting
the update column in the 'keyvalues' dict as well.
""" """
return self.runInteraction( return self.runInteraction(
desc, self.simple_update_one_txn, table, keyvalues, updatevalues desc, self.simple_update_one_txn, table, keyvalues, updatevalues
) )
@classmethod @classmethod
def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues): def simple_update_one_txn(
cls,
txn: LoggingTransaction,
table: str,
keyvalues: Dict[str, Any],
updatevalues: Dict[str, Any],
) -> None:
rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues) rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues)
if rowcount == 0: if rowcount == 0:
@ -1253,8 +1342,18 @@ class DatabasePool(object):
if rowcount > 1: if rowcount > 1:
raise StoreError(500, "More than one row matched (%s)" % (table,)) raise StoreError(500, "More than one row matched (%s)" % (table,))
# Ideally we could use the overload decorator here to specify that the
# return type is only optional if allow_none is True, but this does not work
# when you call a static method from an instance.
# See https://github.com/python/mypy/issues/7781
@staticmethod @staticmethod
def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False): def simple_select_one_txn(
txn: LoggingTransaction,
table: str,
keyvalues: Dict[str, Any],
retcols: Iterable[str],
allow_none: bool = False,
) -> Optional[Dict[str, Any]]:
select_sql = "SELECT %s FROM %s WHERE %s" % ( select_sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols), ", ".join(retcols),
table, table,
@ -1273,7 +1372,9 @@ class DatabasePool(object):
return dict(zip(retcols, row)) return dict(zip(retcols, row))
def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"): def simple_delete_one(
self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"
) -> defer.Deferred:
"""Executes a DELETE query on the named table, expecting to delete a """Executes a DELETE query on the named table, expecting to delete a
single row. single row.
@ -1284,7 +1385,9 @@ class DatabasePool(object):
return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues) return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
@staticmethod @staticmethod
def simple_delete_one_txn(txn, table, keyvalues): def simple_delete_one_txn(
txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
) -> None:
"""Executes a DELETE query on the named table, expecting to delete a """Executes a DELETE query on the named table, expecting to delete a
single row. single row.
@ -1303,11 +1406,13 @@ class DatabasePool(object):
if txn.rowcount > 1: if txn.rowcount > 1:
raise StoreError(500, "More than one row matched (%s)" % (table,)) raise StoreError(500, "More than one row matched (%s)" % (table,))
def simple_delete(self, table, keyvalues, desc): def simple_delete(self, table: str, keyvalues: Dict[str, Any], desc: str):
return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues) return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
@staticmethod @staticmethod
def simple_delete_txn(txn, table, keyvalues): def simple_delete_txn(
txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
) -> int:
sql = "DELETE FROM %s WHERE %s" % ( sql = "DELETE FROM %s WHERE %s" % (
table, table,
" AND ".join("%s = ?" % (k,) for k in keyvalues), " AND ".join("%s = ?" % (k,) for k in keyvalues),
@ -1316,13 +1421,26 @@ class DatabasePool(object):
txn.execute(sql, list(keyvalues.values())) txn.execute(sql, list(keyvalues.values()))
return txn.rowcount return txn.rowcount
def simple_delete_many(self, table, column, iterable, keyvalues, desc): def simple_delete_many(
self,
table: str,
column: str,
iterable: Iterable[Any],
keyvalues: Dict[str, Any],
desc: str,
) -> defer.Deferred:
return self.runInteraction( return self.runInteraction(
desc, self.simple_delete_many_txn, table, column, iterable, keyvalues desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
) )
@staticmethod @staticmethod
def simple_delete_many_txn(txn, table, column, iterable, keyvalues): def simple_delete_many_txn(
txn: LoggingTransaction,
table: str,
column: str,
iterable: Iterable[Any],
keyvalues: Dict[str, Any],
) -> int:
"""Executes a DELETE query on the named table. """Executes a DELETE query on the named table.
Filters rows by if value of `column` is in `iterable`. Filters rows by if value of `column` is in `iterable`.
@ -1335,7 +1453,7 @@ class DatabasePool(object):
keyvalues: dict of column names and values to select the rows with keyvalues: dict of column names and values to select the rows with
Returns: Returns:
int: Number rows deleted Number rows deleted
""" """
if not iterable: if not iterable:
return 0 return 0
@ -1356,8 +1474,14 @@ class DatabasePool(object):
return txn.rowcount return txn.rowcount
def get_cache_dict( def get_cache_dict(
self, db_conn, table, entity_column, stream_column, max_value, limit=100000 self,
): db_conn: Connection,
table: str,
entity_column: str,
stream_column: str,
max_value: int,
limit: int = 100000,
) -> Tuple[Dict[Any, int], int]:
# Fetch a mapping of room_id -> max stream position for "recent" rooms. # Fetch a mapping of room_id -> max stream position for "recent" rooms.
# It doesn't really matter how many we get, the StreamChangeCache will # It doesn't really matter how many we get, the StreamChangeCache will
# do the right thing to ensure it respects the max size of cache. # do the right thing to ensure it respects the max size of cache.
@ -1390,34 +1514,34 @@ class DatabasePool(object):
def simple_select_list_paginate( def simple_select_list_paginate(
self, self,
table, table: str,
orderby, orderby: str,
start, start: int,
limit, limit: int,
retcols, retcols: Iterable[str],
filters=None, filters: Optional[Dict[str, Any]] = None,
keyvalues=None, keyvalues: Optional[Dict[str, Any]] = None,
order_direction="ASC", order_direction: str = "ASC",
desc="simple_select_list_paginate", desc: str = "simple_select_list_paginate",
): ) -> defer.Deferred:
""" """
Executes a SELECT query on the named table with start and limit, Executes a SELECT query on the named table with start and limit,
of row numbers, which may return zero or number of rows from start to limit, of row numbers, which may return zero or number of rows from start to limit,
returning the result as a list of dicts. returning the result as a list of dicts.
Args: Args:
table (str): the table name table: the table name
filters (dict[str, T] | None): orderby: Column to order the results by.
start: Index to begin the query at.
limit: Number of results to return.
retcols: the names of the columns to return
filters:
column names and values to filter the rows with, or None to not column names and values to filter the rows with, or None to not
apply a WHERE ? LIKE ? clause. apply a WHERE ? LIKE ? clause.
keyvalues (dict[str, T] | None): keyvalues:
column names and values to select the rows with, or None to not column names and values to select the rows with, or None to not
apply a WHERE clause. apply a WHERE clause.
orderby (str): Column to order the results by. order_direction: Whether the results should be ordered "ASC" or "DESC".
start (int): Index to begin the query at.
limit (int): Number of results to return.
retcols (iterable[str]): the names of the columns to return
order_direction (str): Whether the results should be ordered "ASC" or "DESC".
Returns: Returns:
defer.Deferred: resolves to list[dict[str, Any]] defer.Deferred: resolves to list[dict[str, Any]]
""" """
@ -1437,16 +1561,16 @@ class DatabasePool(object):
@classmethod @classmethod
def simple_select_list_paginate_txn( def simple_select_list_paginate_txn(
cls, cls,
txn, txn: LoggingTransaction,
table, table: str,
orderby, orderby: str,
start, start: int,
limit, limit: int,
retcols, retcols: Iterable[str],
filters=None, filters: Optional[Dict[str, Any]] = None,
keyvalues=None, keyvalues: Optional[Dict[str, Any]] = None,
order_direction="ASC", order_direction: str = "ASC",
): ) -> List[Dict[str, Any]]:
""" """
Executes a SELECT query on the named table with start and limit, Executes a SELECT query on the named table with start and limit,
of row numbers, which may return zero or number of rows from start to limit, of row numbers, which may return zero or number of rows from start to limit,
@ -1458,20 +1582,21 @@ class DatabasePool(object):
Args: Args:
txn: Transaction object txn: Transaction object
table (str): the table name table: the table name
orderby (str): Column to order the results by. orderby: Column to order the results by.
start (int): Index to begin the query at. start: Index to begin the query at.
limit (int): Number of results to return. limit: Number of results to return.
retcols (iterable[str]): the names of the columns to return retcols: the names of the columns to return
filters (dict[str, T] | None): filters:
column names and values to filter the rows with, or None to not column names and values to filter the rows with, or None to not
apply a WHERE ? LIKE ? clause. apply a WHERE ? LIKE ? clause.
keyvalues (dict[str, T] | None): keyvalues:
column names and values to select the rows with, or None to not column names and values to select the rows with, or None to not
apply a WHERE clause. apply a WHERE clause.
order_direction (str): Whether the results should be ordered "ASC" or "DESC". order_direction: Whether the results should be ordered "ASC" or "DESC".
Returns: Returns:
defer.Deferred: resolves to list[dict[str, Any]] The result as a list of dictionaries.
""" """
if order_direction not in ["ASC", "DESC"]: if order_direction not in ["ASC", "DESC"]:
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
@ -1497,16 +1622,23 @@ class DatabasePool(object):
return cls.cursor_to_dict(txn) return cls.cursor_to_dict(txn)
def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"): def simple_search_list(
self,
table: str,
term: Optional[str],
col: str,
retcols: Iterable[str],
desc="simple_search_list",
):
"""Executes a SELECT query on the named table, which may return zero or """Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts. more rows, returning the result as a list of dicts.
Args: Args:
table (str): the table name table: the table name
term (str | None): term: term for searching the table matched to a column.
term for searching the table matched to a column. col: column to query term should be matched to
col (str): column to query term should be matched to retcols: the names of the columns to return
retcols (iterable[str]): the names of the columns to return
Returns: Returns:
defer.Deferred: resolves to list[dict[str, Any]] or None defer.Deferred: resolves to list[dict[str, Any]] or None
""" """
@ -1516,19 +1648,26 @@ class DatabasePool(object):
) )
@classmethod @classmethod
def simple_search_list_txn(cls, txn, table, term, col, retcols): def simple_search_list_txn(
cls,
txn: LoggingTransaction,
table: str,
term: Optional[str],
col: str,
retcols: Iterable[str],
) -> Union[List[Dict[str, Any]], int]:
"""Executes a SELECT query on the named table, which may return zero or """Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts. more rows, returning the result as a list of dicts.
Args: Args:
txn: Transaction object txn: Transaction object
table (str): the table name table: the table name
term (str | None): term: term for searching the table matched to a column.
term for searching the table matched to a column. col: column to query term should be matched to
col (str): column to query term should be matched to retcols: the names of the columns to return
retcols (iterable[str]): the names of the columns to return
Returns: Returns:
defer.Deferred: resolves to list[dict[str, Any]] or None 0 if no term is given, otherwise a list of dictionaries.
""" """
if term: if term:
sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col) sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
@ -1541,7 +1680,7 @@ class DatabasePool(object):
def make_in_list_sql_clause( def make_in_list_sql_clause(
database_engine, column: str, iterable: Iterable database_engine: BaseDatabaseEngine, column: str, iterable: Iterable
) -> Tuple[str, list]: ) -> Tuple[str, list]:
"""Returns an SQL clause that checks the given column is in the iterable. """Returns an SQL clause that checks the given column is in the iterable.

View File

@ -19,6 +19,7 @@ from canonicaljson import json
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import stringutils as stringutils from synapse.util import stringutils as stringutils
@ -214,14 +215,16 @@ class UIAuthWorkerStore(SQLBaseStore):
value, value,
) )
def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any): def _set_ui_auth_session_data_txn(
self, txn: LoggingTransaction, session_id: str, key: str, value: Any
):
# Get the current value. # Get the current value.
result = self.db_pool.simple_select_one_txn( result = self.db_pool.simple_select_one_txn(
txn, txn,
table="ui_auth_sessions", table="ui_auth_sessions",
keyvalues={"session_id": session_id}, keyvalues={"session_id": session_id},
retcols=("serverdict",), retcols=("serverdict",),
) ) # type: Dict[str, Any] # type: ignore
# Update it and add it back to the database. # Update it and add it back to the database.
serverdict = db_to_json(result["serverdict"]) serverdict = db_to_json(result["serverdict"])
@ -275,7 +278,9 @@ class UIAuthStore(UIAuthWorkerStore):
expiration_time, expiration_time,
) )
def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int): def _delete_old_ui_auth_sessions_txn(
self, txn: LoggingTransaction, expiration_time: int
):
# Get the expired sessions. # Get the expired sessions.
sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?" sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
txn.execute(sql, [expiration_time]) txn.execute(sql, [expiration_time])