Add more types to synapse.storage.database. (#8127)

This commit is contained in:
Patrick Cloke 2020-08-20 09:00:59 -04:00 committed by GitHub
parent 731dfff347
commit 5eac0b7e76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 367 additions and 222 deletions

1
changelog.d/8127.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to `synapse.storage.database`.

View File

@ -28,6 +28,7 @@ from typing import (
Optional,
Tuple,
TypeVar,
Union,
)
from prometheus_client import Histogram
@ -125,7 +126,7 @@ class LoggingTransaction:
method.
Args:
txn: The database transcation object to wrap.
txn: The database transaction object to wrap.
name: The name of this transactions for logging.
database_engine
after_callbacks: A list that callbacks will be appended to
@ -160,7 +161,7 @@ class LoggingTransaction:
self.after_callbacks = after_callbacks
self.exception_callbacks = exception_callbacks
def call_after(self, callback: "Callable[..., None]", *args, **kwargs):
def call_after(self, callback: "Callable[..., None]", *args: Any, **kwargs: Any):
"""Call the given callback on the main twisted thread after the
transaction has finished. Used to invalidate the caches on the
correct thread.
@ -171,7 +172,9 @@ class LoggingTransaction:
assert self.after_callbacks is not None
self.after_callbacks.append((callback, args, kwargs))
def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs):
def call_on_exception(
self, callback: "Callable[..., None]", *args: Any, **kwargs: Any
):
# if self.exception_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that
# is not the case.
@ -195,7 +198,7 @@ class LoggingTransaction:
def description(self) -> Any:
return self.txn.description
def execute_batch(self, sql, args):
def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import execute_batch # type: ignore
@ -204,17 +207,17 @@ class LoggingTransaction:
for val in args:
self.execute(sql, val)
def execute(self, sql: str, *args: Any):
def execute(self, sql: str, *args: Any) -> None:
self._do_execute(self.txn.execute, sql, *args)
def executemany(self, sql: str, *args: Any):
def executemany(self, sql: str, *args: Any) -> None:
self._do_execute(self.txn.executemany, sql, *args)
def _make_sql_one_line(self, sql: str) -> str:
"Strip newlines out of SQL so that the loggers in the DB are on one line"
return " ".join(line.strip() for line in sql.splitlines() if line.strip())
def _do_execute(self, func, sql, *args):
def _do_execute(self, func, sql: str, *args: Any) -> None:
sql = self._make_sql_one_line(sql)
# TODO(paul): Maybe use 'info' and 'debug' for values?
@ -240,7 +243,7 @@ class LoggingTransaction:
sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
sql_query_timer.labels(sql.split()[0]).observe(secs)
def close(self):
def close(self) -> None:
self.txn.close()
@ -249,13 +252,13 @@ class PerformanceCounters(object):
self.current_counters = {}
self.previous_counters = {}
def update(self, key, duration_secs):
def update(self, key: str, duration_secs: float) -> None:
count, cum_time = self.current_counters.get(key, (0, 0))
count += 1
cum_time += duration_secs
self.current_counters[key] = (count, cum_time)
def interval(self, interval_duration_secs, limit=3):
def interval(self, interval_duration_secs: float, limit: int = 3) -> str:
counters = []
for name, (count, cum_time) in self.current_counters.items():
prev_count, prev_time = self.previous_counters.get(name, (0, 0))
@ -279,6 +282,9 @@ class PerformanceCounters(object):
return top_n_counters
R = TypeVar("R")
class DatabasePool(object):
"""Wraps a single physical database and connection pool.
@ -327,12 +333,12 @@ class DatabasePool(object):
self._check_safe_to_upsert,
)
def is_running(self):
def is_running(self) -> bool:
"""Is the database pool currently running
"""
return self._db_pool.running
async def _check_safe_to_upsert(self):
async def _check_safe_to_upsert(self) -> None:
"""
Is it safe to use native UPSERT?
@ -363,7 +369,7 @@ class DatabasePool(object):
self._check_safe_to_upsert,
)
def start_profiling(self):
def start_profiling(self) -> None:
self._previous_loop_ts = monotonic_time()
def loop():
@ -387,8 +393,15 @@ class DatabasePool(object):
self._clock.looping_call(loop, 10000)
def new_transaction(
self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
):
self,
conn: Connection,
desc: str,
after_callbacks: List[_CallbackListEntry],
exception_callbacks: List[_CallbackListEntry],
func: "Callable[..., R]",
*args: Any,
**kwargs: Any
) -> R:
start = monotonic_time()
txn_id = self._TXN_ID
@ -537,7 +550,9 @@ class DatabasePool(object):
return result
async def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any) -> Any:
async def runWithConnection(
self, func: "Callable[..., R]", *args: Any, **kwargs: Any
) -> R:
"""Wraps the .runWithConnection() method on the underlying db_pool.
Arguments:
@ -576,7 +591,7 @@ class DatabasePool(object):
)
@staticmethod
def cursor_to_dict(cursor):
def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]:
"""Converts a SQL cursor into an list of dicts.
Args:
@ -588,7 +603,7 @@ class DatabasePool(object):
results = [dict(zip(col_headers, row)) for row in cursor]
return results
def execute(self, desc, decoder, query, *args):
def execute(self, desc: str, decoder: Callable, query: str, *args: Any):
"""Runs a single query for a result set.
Args:
@ -597,7 +612,7 @@ class DatabasePool(object):
query - The query string to execute
*args - Query args.
Returns:
The result of decoder(results)
Deferred which results to the result of decoder(results)
"""
def interaction(txn):
@ -612,7 +627,13 @@ class DatabasePool(object):
# "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns.
async def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
async def simple_insert(
self,
table: str,
values: Dict[str, Any],
or_ignore: bool = False,
desc: str = "simple_insert",
) -> bool:
"""Executes an INSERT query on the named table.
Args:
@ -624,8 +645,7 @@ class DatabasePool(object):
desc: string giving a description of the transaction
Returns:
bool: Whether the row was inserted or not. Only useful when
`or_ignore` is True
Whether the row was inserted or not. Only useful when `or_ignore` is True
"""
try:
await self.runInteraction(desc, self.simple_insert_txn, table, values)
@ -638,7 +658,9 @@ class DatabasePool(object):
return True
@staticmethod
def simple_insert_txn(txn, table, values):
def simple_insert_txn(
txn: LoggingTransaction, table: str, values: Dict[str, Any]
) -> None:
keys, vals = zip(*values.items())
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
@ -649,11 +671,15 @@ class DatabasePool(object):
txn.execute(sql, vals)
def simple_insert_many(self, table, values, desc):
def simple_insert_many(
self, table: str, values: List[Dict[str, Any]], desc: str
) -> defer.Deferred:
return self.runInteraction(desc, self.simple_insert_many_txn, table, values)
@staticmethod
def simple_insert_many_txn(txn, table, values):
def simple_insert_many_txn(
txn: LoggingTransaction, table: str, values: List[Dict[str, Any]]
) -> None:
if not values:
return
@ -683,13 +709,13 @@ class DatabasePool(object):
async def simple_upsert(
self,
table,
keyvalues,
values,
insertion_values={},
desc="simple_upsert",
lock=True,
):
table: str,
keyvalues: Dict[str, Any],
values: Dict[str, Any],
insertion_values: Dict[str, Any] = {},
desc: str = "simple_upsert",
lock: bool = True,
) -> Optional[bool]:
"""
`lock` should generally be set to True (the default), but can be set
@ -703,16 +729,14 @@ class DatabasePool(object):
this table.
Args:
table (str): The table to upsert into
keyvalues (dict): The unique key columns and their new values
values (dict): The nonunique columns and their new values
insertion_values (dict): additional key/values to use only when
inserting
lock (bool): True to lock the table when doing the upsert.
table: The table to upsert into
keyvalues: The unique key columns and their new values
values: The nonunique columns and their new values
insertion_values: additional key/values to use only when inserting
lock: True to lock the table when doing the upsert.
Returns:
None or bool: Native upserts always return None. Emulated
upserts return True if a new entry was created, False if an existing
one was updated.
Native upserts always return None. Emulated upserts return True if a
new entry was created, False if an existing one was updated.
"""
attempts = 0
while True:
@ -739,29 +763,34 @@ class DatabasePool(object):
)
def simple_upsert_txn(
self, txn, table, keyvalues, values, insertion_values={}, lock=True
):
self,
txn: LoggingTransaction,
table: str,
keyvalues: Dict[str, Any],
values: Dict[str, Any],
insertion_values: Dict[str, Any] = {},
lock: bool = True,
) -> Optional[bool]:
"""
Pick the UPSERT method which works best on the platform. Either the
native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
Args:
txn: The transaction to use.
table (str): The table to upsert into
keyvalues (dict): The unique key tables and their new values
values (dict): The nonunique columns and their new values
insertion_values (dict): additional key/values to use only when
inserting
lock (bool): True to lock the table when doing the upsert.
table: The table to upsert into
keyvalues: The unique key tables and their new values
values: The nonunique columns and their new values
insertion_values: additional key/values to use only when inserting
lock: True to lock the table when doing the upsert.
Returns:
None or bool: Native upserts always return None. Emulated
upserts return True if a new entry was created, False if an existing
one was updated.
Native upserts always return None. Emulated upserts return True if a
new entry was created, False if an existing one was updated.
"""
if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
return self.simple_upsert_txn_native_upsert(
self.simple_upsert_txn_native_upsert(
txn, table, keyvalues, values, insertion_values=insertion_values
)
return None
else:
return self.simple_upsert_txn_emulated(
txn,
@ -773,18 +802,23 @@ class DatabasePool(object):
)
def simple_upsert_txn_emulated(
self, txn, table, keyvalues, values, insertion_values={}, lock=True
):
self,
txn: LoggingTransaction,
table: str,
keyvalues: Dict[str, Any],
values: Dict[str, Any],
insertion_values: Dict[str, Any] = {},
lock: bool = True,
) -> bool:
"""
Args:
table (str): The table to upsert into
keyvalues (dict): The unique key tables and their new values
values (dict): The nonunique columns and their new values
insertion_values (dict): additional key/values to use only when
inserting
lock (bool): True to lock the table when doing the upsert.
table: The table to upsert into
keyvalues: The unique key tables and their new values
values: The nonunique columns and their new values
insertion_values: additional key/values to use only when inserting
lock: True to lock the table when doing the upsert.
Returns:
bool: Return True if a new entry was created, False if an existing
Returns True if a new entry was created, False if an existing
one was updated.
"""
# We need to lock the table :(, unless we're *really* careful
@ -842,19 +876,21 @@ class DatabasePool(object):
return True
def simple_upsert_txn_native_upsert(
self, txn, table, keyvalues, values, insertion_values={}
):
self,
txn: LoggingTransaction,
table: str,
keyvalues: Dict[str, Any],
values: Dict[str, Any],
insertion_values: Dict[str, Any] = {},
) -> None:
"""
Use the native UPSERT functionality in recent PostgreSQL versions.
Args:
table (str): The table to upsert into
keyvalues (dict): The unique key tables and their new values
values (dict): The nonunique columns and their new values
insertion_values (dict): additional key/values to use only when
inserting
Returns:
None
table: The table to upsert into
keyvalues: The unique key tables and their new values
values: The nonunique columns and their new values
insertion_values: additional key/values to use only when inserting
"""
allvalues = {} # type: Dict[str, Any]
allvalues.update(keyvalues)
@ -985,8 +1021,13 @@ class DatabasePool(object):
return txn.execute_batch(sql, args)
def simple_select_one(
self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one"
):
self,
table: str,
keyvalues: Dict[str, Any],
retcols: Iterable[str],
allow_none: bool = False,
desc: str = "simple_select_one",
) -> defer.Deferred:
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning multiple columns from it.
@ -994,7 +1035,6 @@ class DatabasePool(object):
table: string giving the table name
keyvalues: dict of column names and values to select the row with
retcols: list of strings giving the names of the columns to return
allow_none: If true, return None instead of failing if the SELECT
statement returns no rows
"""
@ -1004,12 +1044,12 @@ class DatabasePool(object):
def simple_select_one_onecol(
self,
table,
keyvalues,
retcol,
allow_none=False,
desc="simple_select_one_onecol",
):
table: str,
keyvalues: Dict[str, Any],
retcol: Iterable[str],
allow_none: bool = False,
desc: str = "simple_select_one_onecol",
) -> defer.Deferred:
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it.
@ -1017,6 +1057,9 @@ class DatabasePool(object):
table: string giving the table name
keyvalues: dict of column names and values to select the row with
retcol: string giving the name of the column to return
allow_none: If true, return None instead of failing if the SELECT
statement returns no rows
desc: description of the transaction, for logging and metrics
"""
return self.runInteraction(
desc,
@ -1029,8 +1072,13 @@ class DatabasePool(object):
@classmethod
def simple_select_one_onecol_txn(
cls, txn, table, keyvalues, retcol, allow_none=False
):
cls,
txn: LoggingTransaction,
table: str,
keyvalues: Dict[str, Any],
retcol: Iterable[str],
allow_none: bool = False,
) -> Optional[Any]:
ret = cls.simple_select_onecol_txn(
txn, table=table, keyvalues=keyvalues, retcol=retcol
)
@ -1044,7 +1092,12 @@ class DatabasePool(object):
raise StoreError(404, "No row found")
@staticmethod
def simple_select_onecol_txn(txn, table, keyvalues, retcol):
def simple_select_onecol_txn(
txn: LoggingTransaction,
table: str,
keyvalues: Dict[str, Any],
retcol: Iterable[str],
) -> List[Any]:
sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
if keyvalues:
@ -1056,15 +1109,19 @@ class DatabasePool(object):
return [r[0] for r in txn]
def simple_select_onecol(
self, table, keyvalues, retcol, desc="simple_select_onecol"
):
self,
table: str,
keyvalues: Optional[Dict[str, Any]],
retcol: str,
desc: str = "simple_select_onecol",
) -> defer.Deferred:
"""Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows.
Args:
table (str): table name
keyvalues (dict|None): column names and values to select the rows with
retcol (str): column whos value we wish to retrieve.
table: table name
keyvalues: column names and values to select the rows with
retcol: column whos value we wish to retrieve.
Returns:
Deferred: Results in a list
@ -1073,16 +1130,22 @@ class DatabasePool(object):
desc, self.simple_select_onecol_txn, table, keyvalues, retcol
)
def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"):
def simple_select_list(
self,
table: str,
keyvalues: Optional[Dict[str, Any]],
retcols: Iterable[str],
desc: str = "simple_select_list",
) -> defer.Deferred:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
table (str): the table name
keyvalues (dict[str, Any] | None):
table: the table name
keyvalues:
column names and values to select the rows with, or None to not
apply a WHERE clause.
retcols (iterable[str]): the names of the columns to return
retcols: the names of the columns to return
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
@ -1091,17 +1154,23 @@ class DatabasePool(object):
)
@classmethod
def simple_select_list_txn(cls, txn, table, keyvalues, retcols):
def simple_select_list_txn(
cls,
txn: LoggingTransaction,
table: str,
keyvalues: Optional[Dict[str, Any]],
retcols: Iterable[str],
) -> List[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
txn: Transaction object
table (str): the table name
keyvalues (dict[str, T] | None):
table: the table name
keyvalues:
column names and values to select the rows with, or None to not
apply a WHERE clause.
retcols (iterable[str]): the names of the columns to return
retcols: the names of the columns to return
"""
if keyvalues:
sql = "SELECT %s FROM %s WHERE %s" % (
@ -1118,14 +1187,14 @@ class DatabasePool(object):
async def simple_select_many_batch(
self,
table,
column,
iterable,
retcols,
keyvalues={},
desc="simple_select_many_batch",
batch_size=100,
):
table: str,
column: str,
iterable: Iterable[Any],
retcols: Iterable[str],
keyvalues: Dict[str, Any] = {},
desc: str = "simple_select_many_batch",
batch_size: int = 100,
) -> List[Any]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@ -1165,7 +1234,15 @@ class DatabasePool(object):
return results
@classmethod
def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
def simple_select_many_txn(
cls,
txn: LoggingTransaction,
table: str,
column: str,
iterable: Iterable[Any],
keyvalues: Dict[str, Any],
retcols: Iterable[str],
) -> List[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@ -1198,13 +1275,24 @@ class DatabasePool(object):
txn.execute(sql, values)
return cls.cursor_to_dict(txn)
def simple_update(self, table, keyvalues, updatevalues, desc):
def simple_update(
self,
table: str,
keyvalues: Dict[str, Any],
updatevalues: Dict[str, Any],
desc: str,
) -> defer.Deferred:
return self.runInteraction(
desc, self.simple_update_txn, table, keyvalues, updatevalues
)
@staticmethod
def simple_update_txn(txn, table, keyvalues, updatevalues):
def simple_update_txn(
txn: LoggingTransaction,
table: str,
keyvalues: Dict[str, Any],
updatevalues: Dict[str, Any],
) -> int:
if keyvalues:
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
else:
@ -1221,8 +1309,12 @@ class DatabasePool(object):
return txn.rowcount
def simple_update_one(
self, table, keyvalues, updatevalues, desc="simple_update_one"
):
self,
table: str,
keyvalues: Dict[str, Any],
updatevalues: Dict[str, Any],
desc: str = "simple_update_one",
) -> defer.Deferred:
"""Executes an UPDATE query on the named table, setting new values for
columns in a row matching the key values.
@ -1230,22 +1322,19 @@ class DatabasePool(object):
table: string giving the table name
keyvalues: dict of column names and values to select the row with
updatevalues: dict giving column names and values to update
retcols : optional list of column names to return
If present, retcols gives a list of column names on which to perform
a SELECT statement *before* performing the UPDATE statement. The values
of these will be returned in a dict.
These are performed within the same transaction, allowing an atomic
get-and-set. This can be used to implement compare-and-set by putting
the update column in the 'keyvalues' dict as well.
"""
return self.runInteraction(
desc, self.simple_update_one_txn, table, keyvalues, updatevalues
)
@classmethod
def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
def simple_update_one_txn(
cls,
txn: LoggingTransaction,
table: str,
keyvalues: Dict[str, Any],
updatevalues: Dict[str, Any],
) -> None:
rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues)
if rowcount == 0:
@ -1253,8 +1342,18 @@ class DatabasePool(object):
if rowcount > 1:
raise StoreError(500, "More than one row matched (%s)" % (table,))
# Ideally we could use the overload decorator here to specify that the
# return type is only optional if allow_none is True, but this does not work
# when you call a static method from an instance.
# See https://github.com/python/mypy/issues/7781
@staticmethod
def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
def simple_select_one_txn(
txn: LoggingTransaction,
table: str,
keyvalues: Dict[str, Any],
retcols: Iterable[str],
allow_none: bool = False,
) -> Optional[Dict[str, Any]]:
select_sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
@ -1273,7 +1372,9 @@ class DatabasePool(object):
return dict(zip(retcols, row))
def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"):
def simple_delete_one(
self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"
) -> defer.Deferred:
"""Executes a DELETE query on the named table, expecting to delete a
single row.
@ -1284,7 +1385,9 @@ class DatabasePool(object):
return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
@staticmethod
def simple_delete_one_txn(txn, table, keyvalues):
def simple_delete_one_txn(
txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
) -> None:
"""Executes a DELETE query on the named table, expecting to delete a
single row.
@ -1303,11 +1406,13 @@ class DatabasePool(object):
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched (%s)" % (table,))
def simple_delete(self, table, keyvalues, desc):
def simple_delete(self, table: str, keyvalues: Dict[str, Any], desc: str):
return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
@staticmethod
def simple_delete_txn(txn, table, keyvalues):
def simple_delete_txn(
txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
) -> int:
sql = "DELETE FROM %s WHERE %s" % (
table,
" AND ".join("%s = ?" % (k,) for k in keyvalues),
@ -1316,13 +1421,26 @@ class DatabasePool(object):
txn.execute(sql, list(keyvalues.values()))
return txn.rowcount
def simple_delete_many(self, table, column, iterable, keyvalues, desc):
def simple_delete_many(
self,
table: str,
column: str,
iterable: Iterable[Any],
keyvalues: Dict[str, Any],
desc: str,
) -> defer.Deferred:
return self.runInteraction(
desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
)
@staticmethod
def simple_delete_many_txn(txn, table, column, iterable, keyvalues):
def simple_delete_many_txn(
txn: LoggingTransaction,
table: str,
column: str,
iterable: Iterable[Any],
keyvalues: Dict[str, Any],
) -> int:
"""Executes a DELETE query on the named table.
Filters rows by if value of `column` is in `iterable`.
@ -1335,7 +1453,7 @@ class DatabasePool(object):
keyvalues: dict of column names and values to select the rows with
Returns:
int: Number rows deleted
Number rows deleted
"""
if not iterable:
return 0
@ -1356,8 +1474,14 @@ class DatabasePool(object):
return txn.rowcount
def get_cache_dict(
self, db_conn, table, entity_column, stream_column, max_value, limit=100000
):
self,
db_conn: Connection,
table: str,
entity_column: str,
stream_column: str,
max_value: int,
limit: int = 100000,
) -> Tuple[Dict[Any, int], int]:
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
# It doesn't really matter how many we get, the StreamChangeCache will
# do the right thing to ensure it respects the max size of cache.
@ -1390,34 +1514,34 @@ class DatabasePool(object):
def simple_select_list_paginate(
self,
table,
orderby,
start,
limit,
retcols,
filters=None,
keyvalues=None,
order_direction="ASC",
desc="simple_select_list_paginate",
):
table: str,
orderby: str,
start: int,
limit: int,
retcols: Iterable[str],
filters: Optional[Dict[str, Any]] = None,
keyvalues: Optional[Dict[str, Any]] = None,
order_direction: str = "ASC",
desc: str = "simple_select_list_paginate",
) -> defer.Deferred:
"""
Executes a SELECT query on the named table with start and limit,
of row numbers, which may return zero or number of rows from start to limit,
returning the result as a list of dicts.
Args:
table (str): the table name
filters (dict[str, T] | None):
table: the table name
orderby: Column to order the results by.
start: Index to begin the query at.
limit: Number of results to return.
retcols: the names of the columns to return
filters:
column names and values to filter the rows with, or None to not
apply a WHERE ? LIKE ? clause.
keyvalues (dict[str, T] | None):
keyvalues:
column names and values to select the rows with, or None to not
apply a WHERE clause.
orderby (str): Column to order the results by.
start (int): Index to begin the query at.
limit (int): Number of results to return.
retcols (iterable[str]): the names of the columns to return
order_direction (str): Whether the results should be ordered "ASC" or "DESC".
order_direction: Whether the results should be ordered "ASC" or "DESC".
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
@ -1437,16 +1561,16 @@ class DatabasePool(object):
@classmethod
def simple_select_list_paginate_txn(
cls,
txn,
table,
orderby,
start,
limit,
retcols,
filters=None,
keyvalues=None,
order_direction="ASC",
):
txn: LoggingTransaction,
table: str,
orderby: str,
start: int,
limit: int,
retcols: Iterable[str],
filters: Optional[Dict[str, Any]] = None,
keyvalues: Optional[Dict[str, Any]] = None,
order_direction: str = "ASC",
) -> List[Dict[str, Any]]:
"""
Executes a SELECT query on the named table with start and limit,
of row numbers, which may return zero or number of rows from start to limit,
@ -1458,20 +1582,21 @@ class DatabasePool(object):
Args:
txn: Transaction object
table (str): the table name
orderby (str): Column to order the results by.
start (int): Index to begin the query at.
limit (int): Number of results to return.
retcols (iterable[str]): the names of the columns to return
filters (dict[str, T] | None):
table: the table name
orderby: Column to order the results by.
start: Index to begin the query at.
limit: Number of results to return.
retcols: the names of the columns to return
filters:
column names and values to filter the rows with, or None to not
apply a WHERE ? LIKE ? clause.
keyvalues (dict[str, T] | None):
keyvalues:
column names and values to select the rows with, or None to not
apply a WHERE clause.
order_direction (str): Whether the results should be ordered "ASC" or "DESC".
order_direction: Whether the results should be ordered "ASC" or "DESC".
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
The result as a list of dictionaries.
"""
if order_direction not in ["ASC", "DESC"]:
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
@ -1497,16 +1622,23 @@ class DatabasePool(object):
return cls.cursor_to_dict(txn)
def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"):
def simple_search_list(
self,
table: str,
term: Optional[str],
col: str,
retcols: Iterable[str],
desc="simple_search_list",
):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
table (str): the table name
term (str | None):
term for searching the table matched to a column.
col (str): column to query term should be matched to
retcols (iterable[str]): the names of the columns to return
table: the table name
term: term for searching the table matched to a column.
col: column to query term should be matched to
retcols: the names of the columns to return
Returns:
defer.Deferred: resolves to list[dict[str, Any]] or None
"""
@ -1516,19 +1648,26 @@ class DatabasePool(object):
)
@classmethod
def simple_search_list_txn(cls, txn, table, term, col, retcols):
def simple_search_list_txn(
cls,
txn: LoggingTransaction,
table: str,
term: Optional[str],
col: str,
retcols: Iterable[str],
) -> Union[List[Dict[str, Any]], int]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
txn: Transaction object
table (str): the table name
term (str | None):
term for searching the table matched to a column.
col (str): column to query term should be matched to
retcols (iterable[str]): the names of the columns to return
table: the table name
term: term for searching the table matched to a column.
col: column to query term should be matched to
retcols: the names of the columns to return
Returns:
defer.Deferred: resolves to list[dict[str, Any]] or None
0 if no term is given, otherwise a list of dictionaries.
"""
if term:
sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
@ -1541,7 +1680,7 @@ class DatabasePool(object):
def make_in_list_sql_clause(
database_engine, column: str, iterable: Iterable
database_engine: BaseDatabaseEngine, column: str, iterable: Iterable
) -> Tuple[str, list]:
"""Returns an SQL clause that checks the given column is in the iterable.

View File

@ -19,6 +19,7 @@ from canonicaljson import json
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction
from synapse.types import JsonDict
from synapse.util import stringutils as stringutils
@ -214,14 +215,16 @@ class UIAuthWorkerStore(SQLBaseStore):
value,
)
def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any):
def _set_ui_auth_session_data_txn(
self, txn: LoggingTransaction, session_id: str, key: str, value: Any
):
# Get the current value.
result = self.db_pool.simple_select_one_txn(
txn,
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
retcols=("serverdict",),
)
) # type: Dict[str, Any] # type: ignore
# Update it and add it back to the database.
serverdict = db_to_json(result["serverdict"])
@ -275,7 +278,9 @@ class UIAuthStore(UIAuthWorkerStore):
expiration_time,
)
def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int):
def _delete_old_ui_auth_sessions_txn(
self, txn: LoggingTransaction, expiration_time: int
):
# Get the expired sessions.
sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
txn.execute(sql, [expiration_time])