mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2024-12-25 08:59:31 -05:00
Add more types to synapse.storage.database. (#8127)
This commit is contained in:
parent
731dfff347
commit
5eac0b7e76
1
changelog.d/8127.misc
Normal file
1
changelog.d/8127.misc
Normal file
@ -0,0 +1 @@
|
||||
Add type hints to `synapse.storage.database`.
|
@ -28,6 +28,7 @@ from typing import (
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from prometheus_client import Histogram
|
||||
@ -125,7 +126,7 @@ class LoggingTransaction:
|
||||
method.
|
||||
|
||||
Args:
|
||||
txn: The database transcation object to wrap.
|
||||
txn: The database transaction object to wrap.
|
||||
name: The name of this transactions for logging.
|
||||
database_engine
|
||||
after_callbacks: A list that callbacks will be appended to
|
||||
@ -160,7 +161,7 @@ class LoggingTransaction:
|
||||
self.after_callbacks = after_callbacks
|
||||
self.exception_callbacks = exception_callbacks
|
||||
|
||||
def call_after(self, callback: "Callable[..., None]", *args, **kwargs):
|
||||
def call_after(self, callback: "Callable[..., None]", *args: Any, **kwargs: Any):
|
||||
"""Call the given callback on the main twisted thread after the
|
||||
transaction has finished. Used to invalidate the caches on the
|
||||
correct thread.
|
||||
@ -171,7 +172,9 @@ class LoggingTransaction:
|
||||
assert self.after_callbacks is not None
|
||||
self.after_callbacks.append((callback, args, kwargs))
|
||||
|
||||
def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs):
|
||||
def call_on_exception(
|
||||
self, callback: "Callable[..., None]", *args: Any, **kwargs: Any
|
||||
):
|
||||
# if self.exception_callbacks is None, that means that whatever constructed the
|
||||
# LoggingTransaction isn't expecting there to be any callbacks; assert that
|
||||
# is not the case.
|
||||
@ -195,7 +198,7 @@ class LoggingTransaction:
|
||||
def description(self) -> Any:
|
||||
return self.txn.description
|
||||
|
||||
def execute_batch(self, sql, args):
|
||||
def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
from psycopg2.extras import execute_batch # type: ignore
|
||||
|
||||
@ -204,17 +207,17 @@ class LoggingTransaction:
|
||||
for val in args:
|
||||
self.execute(sql, val)
|
||||
|
||||
def execute(self, sql: str, *args: Any):
|
||||
def execute(self, sql: str, *args: Any) -> None:
|
||||
self._do_execute(self.txn.execute, sql, *args)
|
||||
|
||||
def executemany(self, sql: str, *args: Any):
|
||||
def executemany(self, sql: str, *args: Any) -> None:
|
||||
self._do_execute(self.txn.executemany, sql, *args)
|
||||
|
||||
def _make_sql_one_line(self, sql: str) -> str:
|
||||
"Strip newlines out of SQL so that the loggers in the DB are on one line"
|
||||
return " ".join(line.strip() for line in sql.splitlines() if line.strip())
|
||||
|
||||
def _do_execute(self, func, sql, *args):
|
||||
def _do_execute(self, func, sql: str, *args: Any) -> None:
|
||||
sql = self._make_sql_one_line(sql)
|
||||
|
||||
# TODO(paul): Maybe use 'info' and 'debug' for values?
|
||||
@ -240,7 +243,7 @@ class LoggingTransaction:
|
||||
sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
|
||||
sql_query_timer.labels(sql.split()[0]).observe(secs)
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
self.txn.close()
|
||||
|
||||
|
||||
@ -249,13 +252,13 @@ class PerformanceCounters(object):
|
||||
self.current_counters = {}
|
||||
self.previous_counters = {}
|
||||
|
||||
def update(self, key, duration_secs):
|
||||
def update(self, key: str, duration_secs: float) -> None:
|
||||
count, cum_time = self.current_counters.get(key, (0, 0))
|
||||
count += 1
|
||||
cum_time += duration_secs
|
||||
self.current_counters[key] = (count, cum_time)
|
||||
|
||||
def interval(self, interval_duration_secs, limit=3):
|
||||
def interval(self, interval_duration_secs: float, limit: int = 3) -> str:
|
||||
counters = []
|
||||
for name, (count, cum_time) in self.current_counters.items():
|
||||
prev_count, prev_time = self.previous_counters.get(name, (0, 0))
|
||||
@ -279,6 +282,9 @@ class PerformanceCounters(object):
|
||||
return top_n_counters
|
||||
|
||||
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class DatabasePool(object):
|
||||
"""Wraps a single physical database and connection pool.
|
||||
|
||||
@ -327,12 +333,12 @@ class DatabasePool(object):
|
||||
self._check_safe_to_upsert,
|
||||
)
|
||||
|
||||
def is_running(self):
|
||||
def is_running(self) -> bool:
|
||||
"""Is the database pool currently running
|
||||
"""
|
||||
return self._db_pool.running
|
||||
|
||||
async def _check_safe_to_upsert(self):
|
||||
async def _check_safe_to_upsert(self) -> None:
|
||||
"""
|
||||
Is it safe to use native UPSERT?
|
||||
|
||||
@ -363,7 +369,7 @@ class DatabasePool(object):
|
||||
self._check_safe_to_upsert,
|
||||
)
|
||||
|
||||
def start_profiling(self):
|
||||
def start_profiling(self) -> None:
|
||||
self._previous_loop_ts = monotonic_time()
|
||||
|
||||
def loop():
|
||||
@ -387,8 +393,15 @@ class DatabasePool(object):
|
||||
self._clock.looping_call(loop, 10000)
|
||||
|
||||
def new_transaction(
|
||||
self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
|
||||
):
|
||||
self,
|
||||
conn: Connection,
|
||||
desc: str,
|
||||
after_callbacks: List[_CallbackListEntry],
|
||||
exception_callbacks: List[_CallbackListEntry],
|
||||
func: "Callable[..., R]",
|
||||
*args: Any,
|
||||
**kwargs: Any
|
||||
) -> R:
|
||||
start = monotonic_time()
|
||||
txn_id = self._TXN_ID
|
||||
|
||||
@ -537,7 +550,9 @@ class DatabasePool(object):
|
||||
|
||||
return result
|
||||
|
||||
async def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any) -> Any:
|
||||
async def runWithConnection(
|
||||
self, func: "Callable[..., R]", *args: Any, **kwargs: Any
|
||||
) -> R:
|
||||
"""Wraps the .runWithConnection() method on the underlying db_pool.
|
||||
|
||||
Arguments:
|
||||
@ -576,7 +591,7 @@ class DatabasePool(object):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def cursor_to_dict(cursor):
|
||||
def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]:
|
||||
"""Converts a SQL cursor into an list of dicts.
|
||||
|
||||
Args:
|
||||
@ -588,7 +603,7 @@ class DatabasePool(object):
|
||||
results = [dict(zip(col_headers, row)) for row in cursor]
|
||||
return results
|
||||
|
||||
def execute(self, desc, decoder, query, *args):
|
||||
def execute(self, desc: str, decoder: Callable, query: str, *args: Any):
|
||||
"""Runs a single query for a result set.
|
||||
|
||||
Args:
|
||||
@ -597,7 +612,7 @@ class DatabasePool(object):
|
||||
query - The query string to execute
|
||||
*args - Query args.
|
||||
Returns:
|
||||
The result of decoder(results)
|
||||
Deferred which results to the result of decoder(results)
|
||||
"""
|
||||
|
||||
def interaction(txn):
|
||||
@ -612,7 +627,13 @@ class DatabasePool(object):
|
||||
# "Simple" SQL API methods that operate on a single table with no JOINs,
|
||||
# no complex WHERE clauses, just a dict of values for columns.
|
||||
|
||||
async def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
|
||||
async def simple_insert(
|
||||
self,
|
||||
table: str,
|
||||
values: Dict[str, Any],
|
||||
or_ignore: bool = False,
|
||||
desc: str = "simple_insert",
|
||||
) -> bool:
|
||||
"""Executes an INSERT query on the named table.
|
||||
|
||||
Args:
|
||||
@ -624,8 +645,7 @@ class DatabasePool(object):
|
||||
desc: string giving a description of the transaction
|
||||
|
||||
Returns:
|
||||
bool: Whether the row was inserted or not. Only useful when
|
||||
`or_ignore` is True
|
||||
Whether the row was inserted or not. Only useful when `or_ignore` is True
|
||||
"""
|
||||
try:
|
||||
await self.runInteraction(desc, self.simple_insert_txn, table, values)
|
||||
@ -638,7 +658,9 @@ class DatabasePool(object):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def simple_insert_txn(txn, table, values):
|
||||
def simple_insert_txn(
|
||||
txn: LoggingTransaction, table: str, values: Dict[str, Any]
|
||||
) -> None:
|
||||
keys, vals = zip(*values.items())
|
||||
|
||||
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
|
||||
@ -649,11 +671,15 @@ class DatabasePool(object):
|
||||
|
||||
txn.execute(sql, vals)
|
||||
|
||||
def simple_insert_many(self, table, values, desc):
|
||||
def simple_insert_many(
|
||||
self, table: str, values: List[Dict[str, Any]], desc: str
|
||||
) -> defer.Deferred:
|
||||
return self.runInteraction(desc, self.simple_insert_many_txn, table, values)
|
||||
|
||||
@staticmethod
|
||||
def simple_insert_many_txn(txn, table, values):
|
||||
def simple_insert_many_txn(
|
||||
txn: LoggingTransaction, table: str, values: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
if not values:
|
||||
return
|
||||
|
||||
@ -683,13 +709,13 @@ class DatabasePool(object):
|
||||
|
||||
async def simple_upsert(
|
||||
self,
|
||||
table,
|
||||
keyvalues,
|
||||
values,
|
||||
insertion_values={},
|
||||
desc="simple_upsert",
|
||||
lock=True,
|
||||
):
|
||||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
values: Dict[str, Any],
|
||||
insertion_values: Dict[str, Any] = {},
|
||||
desc: str = "simple_upsert",
|
||||
lock: bool = True,
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
|
||||
`lock` should generally be set to True (the default), but can be set
|
||||
@ -703,16 +729,14 @@ class DatabasePool(object):
|
||||
this table.
|
||||
|
||||
Args:
|
||||
table (str): The table to upsert into
|
||||
keyvalues (dict): The unique key columns and their new values
|
||||
values (dict): The nonunique columns and their new values
|
||||
insertion_values (dict): additional key/values to use only when
|
||||
inserting
|
||||
lock (bool): True to lock the table when doing the upsert.
|
||||
table: The table to upsert into
|
||||
keyvalues: The unique key columns and their new values
|
||||
values: The nonunique columns and their new values
|
||||
insertion_values: additional key/values to use only when inserting
|
||||
lock: True to lock the table when doing the upsert.
|
||||
Returns:
|
||||
None or bool: Native upserts always return None. Emulated
|
||||
upserts return True if a new entry was created, False if an existing
|
||||
one was updated.
|
||||
Native upserts always return None. Emulated upserts return True if a
|
||||
new entry was created, False if an existing one was updated.
|
||||
"""
|
||||
attempts = 0
|
||||
while True:
|
||||
@ -739,29 +763,34 @@ class DatabasePool(object):
|
||||
)
|
||||
|
||||
def simple_upsert_txn(
|
||||
self, txn, table, keyvalues, values, insertion_values={}, lock=True
|
||||
):
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
values: Dict[str, Any],
|
||||
insertion_values: Dict[str, Any] = {},
|
||||
lock: bool = True,
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
Pick the UPSERT method which works best on the platform. Either the
|
||||
native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
|
||||
|
||||
Args:
|
||||
txn: The transaction to use.
|
||||
table (str): The table to upsert into
|
||||
keyvalues (dict): The unique key tables and their new values
|
||||
values (dict): The nonunique columns and their new values
|
||||
insertion_values (dict): additional key/values to use only when
|
||||
inserting
|
||||
lock (bool): True to lock the table when doing the upsert.
|
||||
table: The table to upsert into
|
||||
keyvalues: The unique key tables and their new values
|
||||
values: The nonunique columns and their new values
|
||||
insertion_values: additional key/values to use only when inserting
|
||||
lock: True to lock the table when doing the upsert.
|
||||
Returns:
|
||||
None or bool: Native upserts always return None. Emulated
|
||||
upserts return True if a new entry was created, False if an existing
|
||||
one was updated.
|
||||
Native upserts always return None. Emulated upserts return True if a
|
||||
new entry was created, False if an existing one was updated.
|
||||
"""
|
||||
if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
|
||||
return self.simple_upsert_txn_native_upsert(
|
||||
self.simple_upsert_txn_native_upsert(
|
||||
txn, table, keyvalues, values, insertion_values=insertion_values
|
||||
)
|
||||
return None
|
||||
else:
|
||||
return self.simple_upsert_txn_emulated(
|
||||
txn,
|
||||
@ -773,18 +802,23 @@ class DatabasePool(object):
|
||||
)
|
||||
|
||||
def simple_upsert_txn_emulated(
|
||||
self, txn, table, keyvalues, values, insertion_values={}, lock=True
|
||||
):
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
values: Dict[str, Any],
|
||||
insertion_values: Dict[str, Any] = {},
|
||||
lock: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
Args:
|
||||
table (str): The table to upsert into
|
||||
keyvalues (dict): The unique key tables and their new values
|
||||
values (dict): The nonunique columns and their new values
|
||||
insertion_values (dict): additional key/values to use only when
|
||||
inserting
|
||||
lock (bool): True to lock the table when doing the upsert.
|
||||
table: The table to upsert into
|
||||
keyvalues: The unique key tables and their new values
|
||||
values: The nonunique columns and their new values
|
||||
insertion_values: additional key/values to use only when inserting
|
||||
lock: True to lock the table when doing the upsert.
|
||||
Returns:
|
||||
bool: Return True if a new entry was created, False if an existing
|
||||
Returns True if a new entry was created, False if an existing
|
||||
one was updated.
|
||||
"""
|
||||
# We need to lock the table :(, unless we're *really* careful
|
||||
@ -842,19 +876,21 @@ class DatabasePool(object):
|
||||
return True
|
||||
|
||||
def simple_upsert_txn_native_upsert(
|
||||
self, txn, table, keyvalues, values, insertion_values={}
|
||||
):
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
values: Dict[str, Any],
|
||||
insertion_values: Dict[str, Any] = {},
|
||||
) -> None:
|
||||
"""
|
||||
Use the native UPSERT functionality in recent PostgreSQL versions.
|
||||
|
||||
Args:
|
||||
table (str): The table to upsert into
|
||||
keyvalues (dict): The unique key tables and their new values
|
||||
values (dict): The nonunique columns and their new values
|
||||
insertion_values (dict): additional key/values to use only when
|
||||
inserting
|
||||
Returns:
|
||||
None
|
||||
table: The table to upsert into
|
||||
keyvalues: The unique key tables and their new values
|
||||
values: The nonunique columns and their new values
|
||||
insertion_values: additional key/values to use only when inserting
|
||||
"""
|
||||
allvalues = {} # type: Dict[str, Any]
|
||||
allvalues.update(keyvalues)
|
||||
@ -985,8 +1021,13 @@ class DatabasePool(object):
|
||||
return txn.execute_batch(sql, args)
|
||||
|
||||
def simple_select_one(
|
||||
self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one"
|
||||
):
|
||||
self,
|
||||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
retcols: Iterable[str],
|
||||
allow_none: bool = False,
|
||||
desc: str = "simple_select_one",
|
||||
) -> defer.Deferred:
|
||||
"""Executes a SELECT query on the named table, which is expected to
|
||||
return a single row, returning multiple columns from it.
|
||||
|
||||
@ -994,7 +1035,6 @@ class DatabasePool(object):
|
||||
table: string giving the table name
|
||||
keyvalues: dict of column names and values to select the row with
|
||||
retcols: list of strings giving the names of the columns to return
|
||||
|
||||
allow_none: If true, return None instead of failing if the SELECT
|
||||
statement returns no rows
|
||||
"""
|
||||
@ -1004,12 +1044,12 @@ class DatabasePool(object):
|
||||
|
||||
def simple_select_one_onecol(
|
||||
self,
|
||||
table,
|
||||
keyvalues,
|
||||
retcol,
|
||||
allow_none=False,
|
||||
desc="simple_select_one_onecol",
|
||||
):
|
||||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
retcol: Iterable[str],
|
||||
allow_none: bool = False,
|
||||
desc: str = "simple_select_one_onecol",
|
||||
) -> defer.Deferred:
|
||||
"""Executes a SELECT query on the named table, which is expected to
|
||||
return a single row, returning a single column from it.
|
||||
|
||||
@ -1017,6 +1057,9 @@ class DatabasePool(object):
|
||||
table: string giving the table name
|
||||
keyvalues: dict of column names and values to select the row with
|
||||
retcol: string giving the name of the column to return
|
||||
allow_none: If true, return None instead of failing if the SELECT
|
||||
statement returns no rows
|
||||
desc: description of the transaction, for logging and metrics
|
||||
"""
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
@ -1029,8 +1072,13 @@ class DatabasePool(object):
|
||||
|
||||
@classmethod
|
||||
def simple_select_one_onecol_txn(
|
||||
cls, txn, table, keyvalues, retcol, allow_none=False
|
||||
):
|
||||
cls,
|
||||
txn: LoggingTransaction,
|
||||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
retcol: Iterable[str],
|
||||
allow_none: bool = False,
|
||||
) -> Optional[Any]:
|
||||
ret = cls.simple_select_onecol_txn(
|
||||
txn, table=table, keyvalues=keyvalues, retcol=retcol
|
||||
)
|
||||
@ -1044,7 +1092,12 @@ class DatabasePool(object):
|
||||
raise StoreError(404, "No row found")
|
||||
|
||||
@staticmethod
|
||||
def simple_select_onecol_txn(txn, table, keyvalues, retcol):
|
||||
def simple_select_onecol_txn(
|
||||
txn: LoggingTransaction,
|
||||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
retcol: Iterable[str],
|
||||
) -> List[Any]:
|
||||
sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
|
||||
|
||||
if keyvalues:
|
||||
@ -1056,15 +1109,19 @@ class DatabasePool(object):
|
||||
return [r[0] for r in txn]
|
||||
|
||||
def simple_select_onecol(
|
||||
self, table, keyvalues, retcol, desc="simple_select_onecol"
|
||||
):
|
||||
self,
|
||||
table: str,
|
||||
keyvalues: Optional[Dict[str, Any]],
|
||||
retcol: str,
|
||||
desc: str = "simple_select_onecol",
|
||||
) -> defer.Deferred:
|
||||
"""Executes a SELECT query on the named table, which returns a list
|
||||
comprising of the values of the named column from the selected rows.
|
||||
|
||||
Args:
|
||||
table (str): table name
|
||||
keyvalues (dict|None): column names and values to select the rows with
|
||||
retcol (str): column whos value we wish to retrieve.
|
||||
table: table name
|
||||
keyvalues: column names and values to select the rows with
|
||||
retcol: column whos value we wish to retrieve.
|
||||
|
||||
Returns:
|
||||
Deferred: Results in a list
|
||||
@ -1073,16 +1130,22 @@ class DatabasePool(object):
|
||||
desc, self.simple_select_onecol_txn, table, keyvalues, retcol
|
||||
)
|
||||
|
||||
def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"):
|
||||
def simple_select_list(
|
||||
self,
|
||||
table: str,
|
||||
keyvalues: Optional[Dict[str, Any]],
|
||||
retcols: Iterable[str],
|
||||
desc: str = "simple_select_list",
|
||||
) -> defer.Deferred:
|
||||
"""Executes a SELECT query on the named table, which may return zero or
|
||||
more rows, returning the result as a list of dicts.
|
||||
|
||||
Args:
|
||||
table (str): the table name
|
||||
keyvalues (dict[str, Any] | None):
|
||||
table: the table name
|
||||
keyvalues:
|
||||
column names and values to select the rows with, or None to not
|
||||
apply a WHERE clause.
|
||||
retcols (iterable[str]): the names of the columns to return
|
||||
retcols: the names of the columns to return
|
||||
Returns:
|
||||
defer.Deferred: resolves to list[dict[str, Any]]
|
||||
"""
|
||||
@ -1091,17 +1154,23 @@ class DatabasePool(object):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def simple_select_list_txn(cls, txn, table, keyvalues, retcols):
|
||||
def simple_select_list_txn(
|
||||
cls,
|
||||
txn: LoggingTransaction,
|
||||
table: str,
|
||||
keyvalues: Optional[Dict[str, Any]],
|
||||
retcols: Iterable[str],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Executes a SELECT query on the named table, which may return zero or
|
||||
more rows, returning the result as a list of dicts.
|
||||
|
||||
Args:
|
||||
txn: Transaction object
|
||||
table (str): the table name
|
||||
keyvalues (dict[str, T] | None):
|
||||
table: the table name
|
||||
keyvalues:
|
||||
column names and values to select the rows with, or None to not
|
||||
apply a WHERE clause.
|
||||
retcols (iterable[str]): the names of the columns to return
|
||||
retcols: the names of the columns to return
|
||||
"""
|
||||
if keyvalues:
|
||||
sql = "SELECT %s FROM %s WHERE %s" % (
|
||||
@ -1118,14 +1187,14 @@ class DatabasePool(object):
|
||||
|
||||
async def simple_select_many_batch(
|
||||
self,
|
||||
table,
|
||||
column,
|
||||
iterable,
|
||||
retcols,
|
||||
keyvalues={},
|
||||
desc="simple_select_many_batch",
|
||||
batch_size=100,
|
||||
):
|
||||
table: str,
|
||||
column: str,
|
||||
iterable: Iterable[Any],
|
||||
retcols: Iterable[str],
|
||||
keyvalues: Dict[str, Any] = {},
|
||||
desc: str = "simple_select_many_batch",
|
||||
batch_size: int = 100,
|
||||
) -> List[Any]:
|
||||
"""Executes a SELECT query on the named table, which may return zero or
|
||||
more rows, returning the result as a list of dicts.
|
||||
|
||||
@ -1165,7 +1234,15 @@ class DatabasePool(object):
|
||||
return results
|
||||
|
||||
@classmethod
|
||||
def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
|
||||
def simple_select_many_txn(
|
||||
cls,
|
||||
txn: LoggingTransaction,
|
||||
table: str,
|
||||
column: str,
|
||||
iterable: Iterable[Any],
|
||||
keyvalues: Dict[str, Any],
|
||||
retcols: Iterable[str],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Executes a SELECT query on the named table, which may return zero or
|
||||
more rows, returning the result as a list of dicts.
|
||||
|
||||
@ -1198,13 +1275,24 @@ class DatabasePool(object):
|
||||
txn.execute(sql, values)
|
||||
return cls.cursor_to_dict(txn)
|
||||
|
||||
def simple_update(self, table, keyvalues, updatevalues, desc):
|
||||
def simple_update(
|
||||
self,
|
||||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
updatevalues: Dict[str, Any],
|
||||
desc: str,
|
||||
) -> defer.Deferred:
|
||||
return self.runInteraction(
|
||||
desc, self.simple_update_txn, table, keyvalues, updatevalues
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def simple_update_txn(txn, table, keyvalues, updatevalues):
|
||||
def simple_update_txn(
|
||||
txn: LoggingTransaction,
|
||||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
updatevalues: Dict[str, Any],
|
||||
) -> int:
|
||||
if keyvalues:
|
||||
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
|
||||
else:
|
||||
@ -1221,8 +1309,12 @@ class DatabasePool(object):
|
||||
return txn.rowcount
|
||||
|
||||
def simple_update_one(
|
||||
self, table, keyvalues, updatevalues, desc="simple_update_one"
|
||||
):
|
||||
self,
|
||||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
updatevalues: Dict[str, Any],
|
||||
desc: str = "simple_update_one",
|
||||
) -> defer.Deferred:
|
||||
"""Executes an UPDATE query on the named table, setting new values for
|
||||
columns in a row matching the key values.
|
||||
|
||||
@ -1230,22 +1322,19 @@ class DatabasePool(object):
|
||||
table: string giving the table name
|
||||
keyvalues: dict of column names and values to select the row with
|
||||
updatevalues: dict giving column names and values to update
|
||||
retcols : optional list of column names to return
|
||||
|
||||
If present, retcols gives a list of column names on which to perform
|
||||
a SELECT statement *before* performing the UPDATE statement. The values
|
||||
of these will be returned in a dict.
|
||||
|
||||
These are performed within the same transaction, allowing an atomic
|
||||
get-and-set. This can be used to implement compare-and-set by putting
|
||||
the update column in the 'keyvalues' dict as well.
|
||||
"""
|
||||
return self.runInteraction(
|
||||
desc, self.simple_update_one_txn, table, keyvalues, updatevalues
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
|
||||
def simple_update_one_txn(
|
||||
cls,
|
||||
txn: LoggingTransaction,
|
||||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
updatevalues: Dict[str, Any],
|
||||
) -> None:
|
||||
rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues)
|
||||
|
||||
if rowcount == 0:
|
||||
@ -1253,8 +1342,18 @@ class DatabasePool(object):
|
||||
if rowcount > 1:
|
||||
raise StoreError(500, "More than one row matched (%s)" % (table,))
|
||||
|
||||
# Ideally we could use the overload decorator here to specify that the
|
||||
# return type is only optional if allow_none is True, but this does not work
|
||||
# when you call a static method from an instance.
|
||||
# See https://github.com/python/mypy/issues/7781
|
||||
@staticmethod
|
||||
def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
|
||||
def simple_select_one_txn(
|
||||
txn: LoggingTransaction,
|
||||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
retcols: Iterable[str],
|
||||
allow_none: bool = False,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
select_sql = "SELECT %s FROM %s WHERE %s" % (
|
||||
", ".join(retcols),
|
||||
table,
|
||||
@ -1273,7 +1372,9 @@ class DatabasePool(object):
|
||||
|
||||
return dict(zip(retcols, row))
|
||||
|
||||
def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"):
|
||||
def simple_delete_one(
|
||||
self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"
|
||||
) -> defer.Deferred:
|
||||
"""Executes a DELETE query on the named table, expecting to delete a
|
||||
single row.
|
||||
|
||||
@ -1284,7 +1385,9 @@ class DatabasePool(object):
|
||||
return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
|
||||
|
||||
@staticmethod
|
||||
def simple_delete_one_txn(txn, table, keyvalues):
|
||||
def simple_delete_one_txn(
|
||||
txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Executes a DELETE query on the named table, expecting to delete a
|
||||
single row.
|
||||
|
||||
@ -1303,11 +1406,13 @@ class DatabasePool(object):
|
||||
if txn.rowcount > 1:
|
||||
raise StoreError(500, "More than one row matched (%s)" % (table,))
|
||||
|
||||
def simple_delete(self, table, keyvalues, desc):
|
||||
def simple_delete(self, table: str, keyvalues: Dict[str, Any], desc: str):
|
||||
return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
|
||||
|
||||
@staticmethod
|
||||
def simple_delete_txn(txn, table, keyvalues):
|
||||
def simple_delete_txn(
|
||||
txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
|
||||
) -> int:
|
||||
sql = "DELETE FROM %s WHERE %s" % (
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k,) for k in keyvalues),
|
||||
@ -1316,13 +1421,26 @@ class DatabasePool(object):
|
||||
txn.execute(sql, list(keyvalues.values()))
|
||||
return txn.rowcount
|
||||
|
||||
def simple_delete_many(self, table, column, iterable, keyvalues, desc):
|
||||
def simple_delete_many(
|
||||
self,
|
||||
table: str,
|
||||
column: str,
|
||||
iterable: Iterable[Any],
|
||||
keyvalues: Dict[str, Any],
|
||||
desc: str,
|
||||
) -> defer.Deferred:
|
||||
return self.runInteraction(
|
||||
desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def simple_delete_many_txn(txn, table, column, iterable, keyvalues):
|
||||
def simple_delete_many_txn(
|
||||
txn: LoggingTransaction,
|
||||
table: str,
|
||||
column: str,
|
||||
iterable: Iterable[Any],
|
||||
keyvalues: Dict[str, Any],
|
||||
) -> int:
|
||||
"""Executes a DELETE query on the named table.
|
||||
|
||||
Filters rows by if value of `column` is in `iterable`.
|
||||
@ -1335,7 +1453,7 @@ class DatabasePool(object):
|
||||
keyvalues: dict of column names and values to select the rows with
|
||||
|
||||
Returns:
|
||||
int: Number rows deleted
|
||||
Number rows deleted
|
||||
"""
|
||||
if not iterable:
|
||||
return 0
|
||||
@ -1356,8 +1474,14 @@ class DatabasePool(object):
|
||||
return txn.rowcount
|
||||
|
||||
def get_cache_dict(
|
||||
self, db_conn, table, entity_column, stream_column, max_value, limit=100000
|
||||
):
|
||||
self,
|
||||
db_conn: Connection,
|
||||
table: str,
|
||||
entity_column: str,
|
||||
stream_column: str,
|
||||
max_value: int,
|
||||
limit: int = 100000,
|
||||
) -> Tuple[Dict[Any, int], int]:
|
||||
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
|
||||
# It doesn't really matter how many we get, the StreamChangeCache will
|
||||
# do the right thing to ensure it respects the max size of cache.
|
||||
@ -1390,34 +1514,34 @@ class DatabasePool(object):
|
||||
|
||||
def simple_select_list_paginate(
|
||||
self,
|
||||
table,
|
||||
orderby,
|
||||
start,
|
||||
limit,
|
||||
retcols,
|
||||
filters=None,
|
||||
keyvalues=None,
|
||||
order_direction="ASC",
|
||||
desc="simple_select_list_paginate",
|
||||
):
|
||||
table: str,
|
||||
orderby: str,
|
||||
start: int,
|
||||
limit: int,
|
||||
retcols: Iterable[str],
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
keyvalues: Optional[Dict[str, Any]] = None,
|
||||
order_direction: str = "ASC",
|
||||
desc: str = "simple_select_list_paginate",
|
||||
) -> defer.Deferred:
|
||||
"""
|
||||
Executes a SELECT query on the named table with start and limit,
|
||||
of row numbers, which may return zero or number of rows from start to limit,
|
||||
returning the result as a list of dicts.
|
||||
|
||||
Args:
|
||||
table (str): the table name
|
||||
filters (dict[str, T] | None):
|
||||
table: the table name
|
||||
orderby: Column to order the results by.
|
||||
start: Index to begin the query at.
|
||||
limit: Number of results to return.
|
||||
retcols: the names of the columns to return
|
||||
filters:
|
||||
column names and values to filter the rows with, or None to not
|
||||
apply a WHERE ? LIKE ? clause.
|
||||
keyvalues (dict[str, T] | None):
|
||||
keyvalues:
|
||||
column names and values to select the rows with, or None to not
|
||||
apply a WHERE clause.
|
||||
orderby (str): Column to order the results by.
|
||||
start (int): Index to begin the query at.
|
||||
limit (int): Number of results to return.
|
||||
retcols (iterable[str]): the names of the columns to return
|
||||
order_direction (str): Whether the results should be ordered "ASC" or "DESC".
|
||||
order_direction: Whether the results should be ordered "ASC" or "DESC".
|
||||
Returns:
|
||||
defer.Deferred: resolves to list[dict[str, Any]]
|
||||
"""
|
||||
@ -1437,16 +1561,16 @@ class DatabasePool(object):
|
||||
@classmethod
|
||||
def simple_select_list_paginate_txn(
|
||||
cls,
|
||||
txn,
|
||||
table,
|
||||
orderby,
|
||||
start,
|
||||
limit,
|
||||
retcols,
|
||||
filters=None,
|
||||
keyvalues=None,
|
||||
order_direction="ASC",
|
||||
):
|
||||
txn: LoggingTransaction,
|
||||
table: str,
|
||||
orderby: str,
|
||||
start: int,
|
||||
limit: int,
|
||||
retcols: Iterable[str],
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
keyvalues: Optional[Dict[str, Any]] = None,
|
||||
order_direction: str = "ASC",
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Executes a SELECT query on the named table with start and limit,
|
||||
of row numbers, which may return zero or number of rows from start to limit,
|
||||
@ -1458,20 +1582,21 @@ class DatabasePool(object):
|
||||
|
||||
Args:
|
||||
txn: Transaction object
|
||||
table (str): the table name
|
||||
orderby (str): Column to order the results by.
|
||||
start (int): Index to begin the query at.
|
||||
limit (int): Number of results to return.
|
||||
retcols (iterable[str]): the names of the columns to return
|
||||
filters (dict[str, T] | None):
|
||||
table: the table name
|
||||
orderby: Column to order the results by.
|
||||
start: Index to begin the query at.
|
||||
limit: Number of results to return.
|
||||
retcols: the names of the columns to return
|
||||
filters:
|
||||
column names and values to filter the rows with, or None to not
|
||||
apply a WHERE ? LIKE ? clause.
|
||||
keyvalues (dict[str, T] | None):
|
||||
keyvalues:
|
||||
column names and values to select the rows with, or None to not
|
||||
apply a WHERE clause.
|
||||
order_direction (str): Whether the results should be ordered "ASC" or "DESC".
|
||||
order_direction: Whether the results should be ordered "ASC" or "DESC".
|
||||
|
||||
Returns:
|
||||
defer.Deferred: resolves to list[dict[str, Any]]
|
||||
The result as a list of dictionaries.
|
||||
"""
|
||||
if order_direction not in ["ASC", "DESC"]:
|
||||
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
|
||||
@ -1497,16 +1622,23 @@ class DatabasePool(object):
|
||||
|
||||
return cls.cursor_to_dict(txn)
|
||||
|
||||
def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"):
|
||||
def simple_search_list(
|
||||
self,
|
||||
table: str,
|
||||
term: Optional[str],
|
||||
col: str,
|
||||
retcols: Iterable[str],
|
||||
desc="simple_search_list",
|
||||
):
|
||||
"""Executes a SELECT query on the named table, which may return zero or
|
||||
more rows, returning the result as a list of dicts.
|
||||
|
||||
Args:
|
||||
table (str): the table name
|
||||
term (str | None):
|
||||
term for searching the table matched to a column.
|
||||
col (str): column to query term should be matched to
|
||||
retcols (iterable[str]): the names of the columns to return
|
||||
table: the table name
|
||||
term: term for searching the table matched to a column.
|
||||
col: column to query term should be matched to
|
||||
retcols: the names of the columns to return
|
||||
|
||||
Returns:
|
||||
defer.Deferred: resolves to list[dict[str, Any]] or None
|
||||
"""
|
||||
@ -1516,19 +1648,26 @@ class DatabasePool(object):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def simple_search_list_txn(cls, txn, table, term, col, retcols):
|
||||
def simple_search_list_txn(
|
||||
cls,
|
||||
txn: LoggingTransaction,
|
||||
table: str,
|
||||
term: Optional[str],
|
||||
col: str,
|
||||
retcols: Iterable[str],
|
||||
) -> Union[List[Dict[str, Any]], int]:
|
||||
"""Executes a SELECT query on the named table, which may return zero or
|
||||
more rows, returning the result as a list of dicts.
|
||||
|
||||
Args:
|
||||
txn: Transaction object
|
||||
table (str): the table name
|
||||
term (str | None):
|
||||
term for searching the table matched to a column.
|
||||
col (str): column to query term should be matched to
|
||||
retcols (iterable[str]): the names of the columns to return
|
||||
table: the table name
|
||||
term: term for searching the table matched to a column.
|
||||
col: column to query term should be matched to
|
||||
retcols: the names of the columns to return
|
||||
|
||||
Returns:
|
||||
defer.Deferred: resolves to list[dict[str, Any]] or None
|
||||
0 if no term is given, otherwise a list of dictionaries.
|
||||
"""
|
||||
if term:
|
||||
sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
|
||||
@ -1541,7 +1680,7 @@ class DatabasePool(object):
|
||||
|
||||
|
||||
def make_in_list_sql_clause(
|
||||
database_engine, column: str, iterable: Iterable
|
||||
database_engine: BaseDatabaseEngine, column: str, iterable: Iterable
|
||||
) -> Tuple[str, list]:
|
||||
"""Returns an SQL clause that checks the given column is in the iterable.
|
||||
|
||||
|
@ -19,6 +19,7 @@ from canonicaljson import json
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import stringutils as stringutils
|
||||
|
||||
@ -214,14 +215,16 @@ class UIAuthWorkerStore(SQLBaseStore):
|
||||
value,
|
||||
)
|
||||
|
||||
def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any):
|
||||
def _set_ui_auth_session_data_txn(
|
||||
self, txn: LoggingTransaction, session_id: str, key: str, value: Any
|
||||
):
|
||||
# Get the current value.
|
||||
result = self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
table="ui_auth_sessions",
|
||||
keyvalues={"session_id": session_id},
|
||||
retcols=("serverdict",),
|
||||
)
|
||||
) # type: Dict[str, Any] # type: ignore
|
||||
|
||||
# Update it and add it back to the database.
|
||||
serverdict = db_to_json(result["serverdict"])
|
||||
@ -275,7 +278,9 @@ class UIAuthStore(UIAuthWorkerStore):
|
||||
expiration_time,
|
||||
)
|
||||
|
||||
def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int):
|
||||
def _delete_old_ui_auth_sessions_txn(
|
||||
self, txn: LoggingTransaction, expiration_time: int
|
||||
):
|
||||
# Get the expired sessions.
|
||||
sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
|
||||
txn.execute(sql, [expiration_time])
|
||||
|
Loading…
Reference in New Issue
Block a user